]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Suggest `extend-ignore` over `ignore` for flake8 (#1165)
[etc/vim.git] / black.py
1 import ast
2 import asyncio
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from contextlib import contextmanager
5 from datetime import datetime
6 from enum import Enum
7 from functools import lru_cache, partial, wraps
8 import io
9 import itertools
10 import logging
11 from multiprocessing import Manager, freeze_support
12 import os
13 from pathlib import Path
14 import pickle
15 import regex as re
16 import signal
17 import sys
18 import tempfile
19 import tokenize
20 import traceback
21 from typing import (
22     Any,
23     Callable,
24     Collection,
25     Dict,
26     Generator,
27     Generic,
28     Iterable,
29     Iterator,
30     List,
31     Optional,
32     Pattern,
33     Sequence,
34     Set,
35     Tuple,
36     TypeVar,
37     Union,
38     cast,
39 )
40 from typing_extensions import Final
41 from mypy_extensions import mypyc_attr
42
43 from appdirs import user_cache_dir
44 from dataclasses import dataclass, field, replace
45 import click
46 import toml
47 from typed_ast import ast3, ast27
48 from pathspec import PathSpec
49
50 # lib2to3 fork
51 from blib2to3.pytree import Node, Leaf, type_repr
52 from blib2to3 import pygram, pytree
53 from blib2to3.pgen2 import driver, token
54 from blib2to3.pgen2.grammar import Grammar
55 from blib2to3.pgen2.parse import ParseError
56
57 from _black_version import version as __version__
58
59 DEFAULT_LINE_LENGTH = 88
60 DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
61 DEFAULT_INCLUDES = r"\.pyi?$"
62 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
63
64
65 # types
66 FileContent = str
67 Encoding = str
68 NewLine = str
69 Depth = int
70 NodeType = int
71 LeafID = int
72 Priority = int
73 Index = int
74 LN = Union[Leaf, Node]
75 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
76 Timestamp = float
77 FileSize = int
78 CacheInfo = Tuple[Timestamp, FileSize]
79 Cache = Dict[Path, CacheInfo]
80 out = partial(click.secho, bold=True, err=True)
81 err = partial(click.secho, fg="red", err=True)
82
83 pygram.initialize(CACHE_DIR)
84 syms = pygram.python_symbols
85
86
87 class NothingChanged(UserWarning):
88     """Raised when reformatted code is the same as source."""
89
90
91 class CannotSplit(Exception):
92     """A readable split that fits the allotted line length is impossible."""
93
94
95 class InvalidInput(ValueError):
96     """Raised when input source code fails all parse attempts."""
97
98
99 class WriteBack(Enum):
100     NO = 0
101     YES = 1
102     DIFF = 2
103     CHECK = 3
104
105     @classmethod
106     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
107         if check and not diff:
108             return cls.CHECK
109
110         return cls.DIFF if diff else cls.YES
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 class TargetVersion(Enum):
120     PY27 = 2
121     PY33 = 3
122     PY34 = 4
123     PY35 = 5
124     PY36 = 6
125     PY37 = 7
126     PY38 = 8
127
128     def is_python2(self) -> bool:
129         return self is TargetVersion.PY27
130
131
132 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
133
134
135 class Feature(Enum):
136     # All string literals are unicode
137     UNICODE_LITERALS = 1
138     F_STRINGS = 2
139     NUMERIC_UNDERSCORES = 3
140     TRAILING_COMMA_IN_CALL = 4
141     TRAILING_COMMA_IN_DEF = 5
142     # The following two feature-flags are mutually exclusive, and exactly one should be
143     # set for every version of python.
144     ASYNC_IDENTIFIERS = 6
145     ASYNC_KEYWORDS = 7
146     ASSIGNMENT_EXPRESSIONS = 8
147     POS_ONLY_ARGUMENTS = 9
148
149
150 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
151     TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
152     TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
153     TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
154     TargetVersion.PY35: {
155         Feature.UNICODE_LITERALS,
156         Feature.TRAILING_COMMA_IN_CALL,
157         Feature.ASYNC_IDENTIFIERS,
158     },
159     TargetVersion.PY36: {
160         Feature.UNICODE_LITERALS,
161         Feature.F_STRINGS,
162         Feature.NUMERIC_UNDERSCORES,
163         Feature.TRAILING_COMMA_IN_CALL,
164         Feature.TRAILING_COMMA_IN_DEF,
165         Feature.ASYNC_IDENTIFIERS,
166     },
167     TargetVersion.PY37: {
168         Feature.UNICODE_LITERALS,
169         Feature.F_STRINGS,
170         Feature.NUMERIC_UNDERSCORES,
171         Feature.TRAILING_COMMA_IN_CALL,
172         Feature.TRAILING_COMMA_IN_DEF,
173         Feature.ASYNC_KEYWORDS,
174     },
175     TargetVersion.PY38: {
176         Feature.UNICODE_LITERALS,
177         Feature.F_STRINGS,
178         Feature.NUMERIC_UNDERSCORES,
179         Feature.TRAILING_COMMA_IN_CALL,
180         Feature.TRAILING_COMMA_IN_DEF,
181         Feature.ASYNC_KEYWORDS,
182         Feature.ASSIGNMENT_EXPRESSIONS,
183         Feature.POS_ONLY_ARGUMENTS,
184     },
185 }
186
187
188 @dataclass
189 class FileMode:
190     target_versions: Set[TargetVersion] = field(default_factory=set)
191     line_length: int = DEFAULT_LINE_LENGTH
192     string_normalization: bool = True
193     is_pyi: bool = False
194
195     def get_cache_key(self) -> str:
196         if self.target_versions:
197             version_str = ",".join(
198                 str(version.value)
199                 for version in sorted(self.target_versions, key=lambda v: v.value)
200             )
201         else:
202             version_str = "-"
203         parts = [
204             version_str,
205             str(self.line_length),
206             str(int(self.string_normalization)),
207             str(int(self.is_pyi)),
208         ]
209         return ".".join(parts)
210
211
212 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
213     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
214
215
216 def read_pyproject_toml(
217     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
218 ) -> Optional[str]:
219     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
220
221     Returns the path to a successfully found and read configuration file, None
222     otherwise.
223     """
224     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
225     if not value:
226         root = find_project_root(ctx.params.get("src", ()))
227         path = root / "pyproject.toml"
228         if path.is_file():
229             value = str(path)
230         else:
231             return None
232
233     try:
234         pyproject_toml = toml.load(value)
235         config = pyproject_toml.get("tool", {}).get("black", {})
236     except (toml.TomlDecodeError, OSError) as e:
237         raise click.FileError(
238             filename=value, hint=f"Error reading configuration file: {e}"
239         )
240
241     if not config:
242         return None
243
244     if ctx.default_map is None:
245         ctx.default_map = {}
246     ctx.default_map.update(  # type: ignore  # bad types in .pyi
247         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
248     )
249     return value
250
251
252 def target_version_option_callback(
253     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
254 ) -> List[TargetVersion]:
255     """Compute the target versions from a --target-version flag.
256
257     This is its own function because mypy couldn't infer the type correctly
258     when it was a lambda, causing mypyc trouble.
259     """
260     return [TargetVersion[val.upper()] for val in v]
261
262
263 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
264 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
265 @click.option(
266     "-l",
267     "--line-length",
268     type=int,
269     default=DEFAULT_LINE_LENGTH,
270     help="How many characters per line to allow.",
271     show_default=True,
272 )
273 @click.option(
274     "-t",
275     "--target-version",
276     type=click.Choice([v.name.lower() for v in TargetVersion]),
277     callback=target_version_option_callback,
278     multiple=True,
279     help=(
280         "Python versions that should be supported by Black's output. [default: "
281         "per-file auto-detection]"
282     ),
283 )
284 @click.option(
285     "--py36",
286     is_flag=True,
287     help=(
288         "Allow using Python 3.6-only syntax on all input files.  This will put "
289         "trailing commas in function signatures and calls also after *args and "
290         "**kwargs. Deprecated; use --target-version instead. "
291         "[default: per-file auto-detection]"
292     ),
293 )
294 @click.option(
295     "--pyi",
296     is_flag=True,
297     help=(
298         "Format all input files like typing stubs regardless of file extension "
299         "(useful when piping source on standard input)."
300     ),
301 )
302 @click.option(
303     "-S",
304     "--skip-string-normalization",
305     is_flag=True,
306     help="Don't normalize string quotes or prefixes.",
307 )
308 @click.option(
309     "--check",
310     is_flag=True,
311     help=(
312         "Don't write the files back, just return the status.  Return code 0 "
313         "means nothing would change.  Return code 1 means some files would be "
314         "reformatted.  Return code 123 means there was an internal error."
315     ),
316 )
317 @click.option(
318     "--diff",
319     is_flag=True,
320     help="Don't write the files back, just output a diff for each file on stdout.",
321 )
322 @click.option(
323     "--fast/--safe",
324     is_flag=True,
325     help="If --fast given, skip temporary sanity checks. [default: --safe]",
326 )
327 @click.option(
328     "--include",
329     type=str,
330     default=DEFAULT_INCLUDES,
331     help=(
332         "A regular expression that matches files and directories that should be "
333         "included on recursive searches.  An empty value means all files are "
334         "included regardless of the name.  Use forward slashes for directories on "
335         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
336         "later."
337     ),
338     show_default=True,
339 )
340 @click.option(
341     "--exclude",
342     type=str,
343     default=DEFAULT_EXCLUDES,
344     help=(
345         "A regular expression that matches files and directories that should be "
346         "excluded on recursive searches.  An empty value means no paths are excluded. "
347         "Use forward slashes for directories on all platforms (Windows, too).  "
348         "Exclusions are calculated first, inclusions later."
349     ),
350     show_default=True,
351 )
352 @click.option(
353     "-q",
354     "--quiet",
355     is_flag=True,
356     help=(
357         "Don't emit non-error messages to stderr. Errors are still emitted; "
358         "silence those with 2>/dev/null."
359     ),
360 )
361 @click.option(
362     "-v",
363     "--verbose",
364     is_flag=True,
365     help=(
366         "Also emit messages to stderr about files that were not changed or were "
367         "ignored due to --exclude=."
368     ),
369 )
370 @click.version_option(version=__version__)
371 @click.argument(
372     "src",
373     nargs=-1,
374     type=click.Path(
375         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
376     ),
377     is_eager=True,
378 )
379 @click.option(
380     "--config",
381     type=click.Path(
382         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
383     ),
384     is_eager=True,
385     callback=read_pyproject_toml,
386     help="Read configuration from PATH.",
387 )
388 @click.pass_context
389 def main(
390     ctx: click.Context,
391     code: Optional[str],
392     line_length: int,
393     target_version: List[TargetVersion],
394     check: bool,
395     diff: bool,
396     fast: bool,
397     pyi: bool,
398     py36: bool,
399     skip_string_normalization: bool,
400     quiet: bool,
401     verbose: bool,
402     include: str,
403     exclude: str,
404     src: Tuple[str, ...],
405     config: Optional[str],
406 ) -> None:
407     """The uncompromising code formatter."""
408     write_back = WriteBack.from_configuration(check=check, diff=diff)
409     if target_version:
410         if py36:
411             err(f"Cannot use both --target-version and --py36")
412             ctx.exit(2)
413         else:
414             versions = set(target_version)
415     elif py36:
416         err(
417             "--py36 is deprecated and will be removed in a future version. "
418             "Use --target-version py36 instead."
419         )
420         versions = PY36_VERSIONS
421     else:
422         # We'll autodetect later.
423         versions = set()
424     mode = FileMode(
425         target_versions=versions,
426         line_length=line_length,
427         is_pyi=pyi,
428         string_normalization=not skip_string_normalization,
429     )
430     if config and verbose:
431         out(f"Using configuration from {config}.", bold=False, fg="blue")
432     if code is not None:
433         print(format_str(code, mode=mode))
434         ctx.exit(0)
435     try:
436         include_regex = re_compile_maybe_verbose(include)
437     except re.error:
438         err(f"Invalid regular expression for include given: {include!r}")
439         ctx.exit(2)
440     try:
441         exclude_regex = re_compile_maybe_verbose(exclude)
442     except re.error:
443         err(f"Invalid regular expression for exclude given: {exclude!r}")
444         ctx.exit(2)
445     report = Report(check=check, quiet=quiet, verbose=verbose)
446     root = find_project_root(src)
447     sources: Set[Path] = set()
448     path_empty(src, quiet, verbose, ctx)
449     for s in src:
450         p = Path(s)
451         if p.is_dir():
452             sources.update(
453                 gen_python_files_in_dir(
454                     p, root, include_regex, exclude_regex, report, get_gitignore(root)
455                 )
456             )
457         elif p.is_file() or s == "-":
458             # if a file was explicitly given, we don't care about its extension
459             sources.add(p)
460         else:
461             err(f"invalid path: {s}")
462     if len(sources) == 0:
463         if verbose or not quiet:
464             out("No Python files are present to be formatted. Nothing to do 😴")
465         ctx.exit(0)
466
467     if len(sources) == 1:
468         reformat_one(
469             src=sources.pop(),
470             fast=fast,
471             write_back=write_back,
472             mode=mode,
473             report=report,
474         )
475     else:
476         reformat_many(
477             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
478         )
479
480     if verbose or not quiet:
481         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
482         click.secho(str(report), err=True)
483     ctx.exit(report.return_code)
484
485
486 def path_empty(
487     src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context
488 ) -> None:
489     """
490     Exit if there is no `src` provided for formatting
491     """
492     if not src:
493         if verbose or not quiet:
494             out("No Path provided. Nothing to do 😴")
495             ctx.exit(0)
496
497
498 def reformat_one(
499     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
500 ) -> None:
501     """Reformat a single file under `src` without spawning child processes.
502
503     `fast`, `write_back`, and `mode` options are passed to
504     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
505     """
506     try:
507         changed = Changed.NO
508         if not src.is_file() and str(src) == "-":
509             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
510                 changed = Changed.YES
511         else:
512             cache: Cache = {}
513             if write_back != WriteBack.DIFF:
514                 cache = read_cache(mode)
515                 res_src = src.resolve()
516                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
517                     changed = Changed.CACHED
518             if changed is not Changed.CACHED and format_file_in_place(
519                 src, fast=fast, write_back=write_back, mode=mode
520             ):
521                 changed = Changed.YES
522             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
523                 write_back is WriteBack.CHECK and changed is Changed.NO
524             ):
525                 write_cache(cache, [src], mode)
526         report.done(src, changed)
527     except Exception as exc:
528         report.failed(src, str(exc))
529
530
531 def reformat_many(
532     sources: Set[Path],
533     fast: bool,
534     write_back: WriteBack,
535     mode: FileMode,
536     report: "Report",
537 ) -> None:
538     """Reformat multiple files using a ProcessPoolExecutor."""
539     loop = asyncio.get_event_loop()
540     worker_count = os.cpu_count()
541     if sys.platform == "win32":
542         # Work around https://bugs.python.org/issue26903
543         worker_count = min(worker_count, 61)
544     executor = ProcessPoolExecutor(max_workers=worker_count)
545     try:
546         loop.run_until_complete(
547             schedule_formatting(
548                 sources=sources,
549                 fast=fast,
550                 write_back=write_back,
551                 mode=mode,
552                 report=report,
553                 loop=loop,
554                 executor=executor,
555             )
556         )
557     finally:
558         shutdown(loop)
559         executor.shutdown()
560
561
562 async def schedule_formatting(
563     sources: Set[Path],
564     fast: bool,
565     write_back: WriteBack,
566     mode: FileMode,
567     report: "Report",
568     loop: asyncio.AbstractEventLoop,
569     executor: Executor,
570 ) -> None:
571     """Run formatting of `sources` in parallel using the provided `executor`.
572
573     (Use ProcessPoolExecutors for actual parallelism.)
574
575     `write_back`, `fast`, and `mode` options are passed to
576     :func:`format_file_in_place`.
577     """
578     cache: Cache = {}
579     if write_back != WriteBack.DIFF:
580         cache = read_cache(mode)
581         sources, cached = filter_cached(cache, sources)
582         for src in sorted(cached):
583             report.done(src, Changed.CACHED)
584     if not sources:
585         return
586
587     cancelled = []
588     sources_to_cache = []
589     lock = None
590     if write_back == WriteBack.DIFF:
591         # For diff output, we need locks to ensure we don't interleave output
592         # from different processes.
593         manager = Manager()
594         lock = manager.Lock()
595     tasks = {
596         asyncio.ensure_future(
597             loop.run_in_executor(
598                 executor, format_file_in_place, src, fast, mode, write_back, lock
599             )
600         ): src
601         for src in sorted(sources)
602     }
603     pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
604     try:
605         loop.add_signal_handler(signal.SIGINT, cancel, pending)
606         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
607     except NotImplementedError:
608         # There are no good alternatives for these on Windows.
609         pass
610     while pending:
611         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
612         for task in done:
613             src = tasks.pop(task)
614             if task.cancelled():
615                 cancelled.append(task)
616             elif task.exception():
617                 report.failed(src, str(task.exception()))
618             else:
619                 changed = Changed.YES if task.result() else Changed.NO
620                 # If the file was written back or was successfully checked as
621                 # well-formatted, store this information in the cache.
622                 if write_back is WriteBack.YES or (
623                     write_back is WriteBack.CHECK and changed is Changed.NO
624                 ):
625                     sources_to_cache.append(src)
626                 report.done(src, changed)
627     if cancelled:
628         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
629     if sources_to_cache:
630         write_cache(cache, sources_to_cache, mode)
631
632
633 def format_file_in_place(
634     src: Path,
635     fast: bool,
636     mode: FileMode,
637     write_back: WriteBack = WriteBack.NO,
638     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
639 ) -> bool:
640     """Format file under `src` path. Return True if changed.
641
642     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
643     code to the file.
644     `mode` and `fast` options are passed to :func:`format_file_contents`.
645     """
646     if src.suffix == ".pyi":
647         mode = replace(mode, is_pyi=True)
648
649     then = datetime.utcfromtimestamp(src.stat().st_mtime)
650     with open(src, "rb") as buf:
651         src_contents, encoding, newline = decode_bytes(buf.read())
652     try:
653         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
654     except NothingChanged:
655         return False
656
657     if write_back == WriteBack.YES:
658         with open(src, "w", encoding=encoding, newline=newline) as f:
659             f.write(dst_contents)
660     elif write_back == WriteBack.DIFF:
661         now = datetime.utcnow()
662         src_name = f"{src}\t{then} +0000"
663         dst_name = f"{src}\t{now} +0000"
664         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
665
666         with lock or nullcontext():
667             f = io.TextIOWrapper(
668                 sys.stdout.buffer,
669                 encoding=encoding,
670                 newline=newline,
671                 write_through=True,
672             )
673             f.write(diff_contents)
674             f.detach()
675
676     return True
677
678
679 def format_stdin_to_stdout(
680     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
681 ) -> bool:
682     """Format file on stdin. Return True if changed.
683
684     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
685     write a diff to stdout. The `mode` argument is passed to
686     :func:`format_file_contents`.
687     """
688     then = datetime.utcnow()
689     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
690     dst = src
691     try:
692         dst = format_file_contents(src, fast=fast, mode=mode)
693         return True
694
695     except NothingChanged:
696         return False
697
698     finally:
699         f = io.TextIOWrapper(
700             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
701         )
702         if write_back == WriteBack.YES:
703             f.write(dst)
704         elif write_back == WriteBack.DIFF:
705             now = datetime.utcnow()
706             src_name = f"STDIN\t{then} +0000"
707             dst_name = f"STDOUT\t{now} +0000"
708             f.write(diff(src, dst, src_name, dst_name))
709         f.detach()
710
711
712 def format_file_contents(
713     src_contents: str, *, fast: bool, mode: FileMode
714 ) -> FileContent:
715     """Reformat contents a file and return new contents.
716
717     If `fast` is False, additionally confirm that the reformatted code is
718     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
719     `mode` is passed to :func:`format_str`.
720     """
721     if src_contents.strip() == "":
722         raise NothingChanged
723
724     dst_contents = format_str(src_contents, mode=mode)
725     if src_contents == dst_contents:
726         raise NothingChanged
727
728     if not fast:
729         assert_equivalent(src_contents, dst_contents)
730         assert_stable(src_contents, dst_contents, mode=mode)
731     return dst_contents
732
733
734 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
735     """Reformat a string and return new contents.
736
737     `mode` determines formatting options, such as how many characters per line are
738     allowed.
739     """
740     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
741     dst_contents = []
742     future_imports = get_future_imports(src_node)
743     if mode.target_versions:
744         versions = mode.target_versions
745     else:
746         versions = detect_target_versions(src_node)
747     normalize_fmt_off(src_node)
748     lines = LineGenerator(
749         remove_u_prefix="unicode_literals" in future_imports
750         or supports_feature(versions, Feature.UNICODE_LITERALS),
751         is_pyi=mode.is_pyi,
752         normalize_strings=mode.string_normalization,
753     )
754     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
755     empty_line = Line()
756     after = 0
757     split_line_features = {
758         feature
759         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
760         if supports_feature(versions, feature)
761     }
762     for current_line in lines.visit(src_node):
763         for _ in range(after):
764             dst_contents.append(str(empty_line))
765         before, after = elt.maybe_empty_lines(current_line)
766         for _ in range(before):
767             dst_contents.append(str(empty_line))
768         for line in split_line(
769             current_line, line_length=mode.line_length, features=split_line_features
770         ):
771             dst_contents.append(str(line))
772     return "".join(dst_contents)
773
774
775 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
776     """Return a tuple of (decoded_contents, encoding, newline).
777
778     `newline` is either CRLF or LF but `decoded_contents` is decoded with
779     universal newlines (i.e. only contains LF).
780     """
781     srcbuf = io.BytesIO(src)
782     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
783     if not lines:
784         return "", encoding, "\n"
785
786     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
787     srcbuf.seek(0)
788     with io.TextIOWrapper(srcbuf, encoding) as tiow:
789         return tiow.read(), encoding, newline
790
791
792 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
793     if not target_versions:
794         # No target_version specified, so try all grammars.
795         return [
796             # Python 3.7+
797             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
798             # Python 3.0-3.6
799             pygram.python_grammar_no_print_statement_no_exec_statement,
800             # Python 2.7 with future print_function import
801             pygram.python_grammar_no_print_statement,
802             # Python 2.7
803             pygram.python_grammar,
804         ]
805
806     if all(version.is_python2() for version in target_versions):
807         # Python 2-only code, so try Python 2 grammars.
808         return [
809             # Python 2.7 with future print_function import
810             pygram.python_grammar_no_print_statement,
811             # Python 2.7
812             pygram.python_grammar,
813         ]
814
815     # Python 3-compatible code, so only try Python 3 grammar.
816     grammars = []
817     # If we have to parse both, try to parse async as a keyword first
818     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
819         # Python 3.7+
820         grammars.append(
821             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
822         )
823     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
824         # Python 3.0-3.6
825         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
826     # At least one of the above branches must have been taken, because every Python
827     # version has exactly one of the two 'ASYNC_*' flags
828     return grammars
829
830
831 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
832     """Given a string with source, return the lib2to3 Node."""
833     if src_txt[-1:] != "\n":
834         src_txt += "\n"
835
836     for grammar in get_grammars(set(target_versions)):
837         drv = driver.Driver(grammar, pytree.convert)
838         try:
839             result = drv.parse_string(src_txt, True)
840             break
841
842         except ParseError as pe:
843             lineno, column = pe.context[1]
844             lines = src_txt.splitlines()
845             try:
846                 faulty_line = lines[lineno - 1]
847             except IndexError:
848                 faulty_line = "<line number missing in source>"
849             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
850     else:
851         raise exc from None
852
853     if isinstance(result, Leaf):
854         result = Node(syms.file_input, [result])
855     return result
856
857
858 def lib2to3_unparse(node: Node) -> str:
859     """Given a lib2to3 node, return its string representation."""
860     code = str(node)
861     return code
862
863
864 T = TypeVar("T")
865
866
867 class Visitor(Generic[T]):
868     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
869
870     def visit(self, node: LN) -> Iterator[T]:
871         """Main method to visit `node` and its children.
872
873         It tries to find a `visit_*()` method for the given `node.type`, like
874         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
875         If no dedicated `visit_*()` method is found, chooses `visit_default()`
876         instead.
877
878         Then yields objects of type `T` from the selected visitor.
879         """
880         if node.type < 256:
881             name = token.tok_name[node.type]
882         else:
883             name = str(type_repr(node.type))
884         # We explicitly branch on whether a visitor exists (instead of
885         # using self.visit_default as the default arg to getattr) in order
886         # to save needing to create a bound method object and so mypyc can
887         # generate a native call to visit_default.
888         visitf = getattr(self, f"visit_{name}", None)
889         if visitf:
890             yield from visitf(node)
891         else:
892             yield from self.visit_default(node)
893
894     def visit_default(self, node: LN) -> Iterator[T]:
895         """Default `visit_*()` implementation. Recurses to children of `node`."""
896         if isinstance(node, Node):
897             for child in node.children:
898                 yield from self.visit(child)
899
900
901 @dataclass
902 class DebugVisitor(Visitor[T]):
903     tree_depth: int = 0
904
905     def visit_default(self, node: LN) -> Iterator[T]:
906         indent = " " * (2 * self.tree_depth)
907         if isinstance(node, Node):
908             _type = type_repr(node.type)
909             out(f"{indent}{_type}", fg="yellow")
910             self.tree_depth += 1
911             for child in node.children:
912                 yield from self.visit(child)
913
914             self.tree_depth -= 1
915             out(f"{indent}/{_type}", fg="yellow", bold=False)
916         else:
917             _type = token.tok_name.get(node.type, str(node.type))
918             out(f"{indent}{_type}", fg="blue", nl=False)
919             if node.prefix:
920                 # We don't have to handle prefixes for `Node` objects since
921                 # that delegates to the first child anyway.
922                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
923             out(f" {node.value!r}", fg="blue", bold=False)
924
925     @classmethod
926     def show(cls, code: Union[str, Leaf, Node]) -> None:
927         """Pretty-print the lib2to3 AST of a given string of `code`.
928
929         Convenience method for debugging.
930         """
931         v: DebugVisitor[None] = DebugVisitor()
932         if isinstance(code, str):
933             code = lib2to3_parse(code)
934         list(v.visit(code))
935
936
937 WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
938 STATEMENT: Final = {
939     syms.if_stmt,
940     syms.while_stmt,
941     syms.for_stmt,
942     syms.try_stmt,
943     syms.except_clause,
944     syms.with_stmt,
945     syms.funcdef,
946     syms.classdef,
947 }
948 STANDALONE_COMMENT: Final = 153
949 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
950 LOGIC_OPERATORS: Final = {"and", "or"}
951 COMPARATORS: Final = {
952     token.LESS,
953     token.GREATER,
954     token.EQEQUAL,
955     token.NOTEQUAL,
956     token.LESSEQUAL,
957     token.GREATEREQUAL,
958 }
959 MATH_OPERATORS: Final = {
960     token.VBAR,
961     token.CIRCUMFLEX,
962     token.AMPER,
963     token.LEFTSHIFT,
964     token.RIGHTSHIFT,
965     token.PLUS,
966     token.MINUS,
967     token.STAR,
968     token.SLASH,
969     token.DOUBLESLASH,
970     token.PERCENT,
971     token.AT,
972     token.TILDE,
973     token.DOUBLESTAR,
974 }
975 STARS: Final = {token.STAR, token.DOUBLESTAR}
976 VARARGS_SPECIALS: Final = STARS | {token.SLASH}
977 VARARGS_PARENTS: Final = {
978     syms.arglist,
979     syms.argument,  # double star in arglist
980     syms.trailer,  # single argument to call
981     syms.typedargslist,
982     syms.varargslist,  # lambdas
983 }
984 UNPACKING_PARENTS: Final = {
985     syms.atom,  # single element of a list or set literal
986     syms.dictsetmaker,
987     syms.listmaker,
988     syms.testlist_gexp,
989     syms.testlist_star_expr,
990 }
991 TEST_DESCENDANTS: Final = {
992     syms.test,
993     syms.lambdef,
994     syms.or_test,
995     syms.and_test,
996     syms.not_test,
997     syms.comparison,
998     syms.star_expr,
999     syms.expr,
1000     syms.xor_expr,
1001     syms.and_expr,
1002     syms.shift_expr,
1003     syms.arith_expr,
1004     syms.trailer,
1005     syms.term,
1006     syms.power,
1007 }
1008 ASSIGNMENTS: Final = {
1009     "=",
1010     "+=",
1011     "-=",
1012     "*=",
1013     "@=",
1014     "/=",
1015     "%=",
1016     "&=",
1017     "|=",
1018     "^=",
1019     "<<=",
1020     ">>=",
1021     "**=",
1022     "//=",
1023 }
1024 COMPREHENSION_PRIORITY: Final = 20
1025 COMMA_PRIORITY: Final = 18
1026 TERNARY_PRIORITY: Final = 16
1027 LOGIC_PRIORITY: Final = 14
1028 STRING_PRIORITY: Final = 12
1029 COMPARATOR_PRIORITY: Final = 10
1030 MATH_PRIORITIES: Final = {
1031     token.VBAR: 9,
1032     token.CIRCUMFLEX: 8,
1033     token.AMPER: 7,
1034     token.LEFTSHIFT: 6,
1035     token.RIGHTSHIFT: 6,
1036     token.PLUS: 5,
1037     token.MINUS: 5,
1038     token.STAR: 4,
1039     token.SLASH: 4,
1040     token.DOUBLESLASH: 4,
1041     token.PERCENT: 4,
1042     token.AT: 4,
1043     token.TILDE: 3,
1044     token.DOUBLESTAR: 2,
1045 }
1046 DOT_PRIORITY: Final = 1
1047
1048
1049 @dataclass
1050 class BracketTracker:
1051     """Keeps track of brackets on a line."""
1052
1053     depth: int = 0
1054     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
1055     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
1056     previous: Optional[Leaf] = None
1057     _for_loop_depths: List[int] = field(default_factory=list)
1058     _lambda_argument_depths: List[int] = field(default_factory=list)
1059
1060     def mark(self, leaf: Leaf) -> None:
1061         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
1062
1063         All leaves receive an int `bracket_depth` field that stores how deep
1064         within brackets a given leaf is. 0 means there are no enclosing brackets
1065         that started on this line.
1066
1067         If a leaf is itself a closing bracket, it receives an `opening_bracket`
1068         field that it forms a pair with. This is a one-directional link to
1069         avoid reference cycles.
1070
1071         If a leaf is a delimiter (a token on which Black can split the line if
1072         needed) and it's on depth 0, its `id()` is stored in the tracker's
1073         `delimiters` field.
1074         """
1075         if leaf.type == token.COMMENT:
1076             return
1077
1078         self.maybe_decrement_after_for_loop_variable(leaf)
1079         self.maybe_decrement_after_lambda_arguments(leaf)
1080         if leaf.type in CLOSING_BRACKETS:
1081             self.depth -= 1
1082             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1083             leaf.opening_bracket = opening_bracket
1084         leaf.bracket_depth = self.depth
1085         if self.depth == 0:
1086             delim = is_split_before_delimiter(leaf, self.previous)
1087             if delim and self.previous is not None:
1088                 self.delimiters[id(self.previous)] = delim
1089             else:
1090                 delim = is_split_after_delimiter(leaf, self.previous)
1091                 if delim:
1092                     self.delimiters[id(leaf)] = delim
1093         if leaf.type in OPENING_BRACKETS:
1094             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1095             self.depth += 1
1096         self.previous = leaf
1097         self.maybe_increment_lambda_arguments(leaf)
1098         self.maybe_increment_for_loop_variable(leaf)
1099
1100     def any_open_brackets(self) -> bool:
1101         """Return True if there is an yet unmatched open bracket on the line."""
1102         return bool(self.bracket_match)
1103
1104     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
1105         """Return the highest priority of a delimiter found on the line.
1106
1107         Values are consistent with what `is_split_*_delimiter()` return.
1108         Raises ValueError on no delimiters.
1109         """
1110         return max(v for k, v in self.delimiters.items() if k not in exclude)
1111
1112     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
1113         """Return the number of delimiters with the given `priority`.
1114
1115         If no `priority` is passed, defaults to max priority on the line.
1116         """
1117         if not self.delimiters:
1118             return 0
1119
1120         priority = priority or self.max_delimiter_priority()
1121         return sum(1 for p in self.delimiters.values() if p == priority)
1122
1123     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1124         """In a for loop, or comprehension, the variables are often unpacks.
1125
1126         To avoid splitting on the comma in this situation, increase the depth of
1127         tokens between `for` and `in`.
1128         """
1129         if leaf.type == token.NAME and leaf.value == "for":
1130             self.depth += 1
1131             self._for_loop_depths.append(self.depth)
1132             return True
1133
1134         return False
1135
1136     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1137         """See `maybe_increment_for_loop_variable` above for explanation."""
1138         if (
1139             self._for_loop_depths
1140             and self._for_loop_depths[-1] == self.depth
1141             and leaf.type == token.NAME
1142             and leaf.value == "in"
1143         ):
1144             self.depth -= 1
1145             self._for_loop_depths.pop()
1146             return True
1147
1148         return False
1149
1150     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1151         """In a lambda expression, there might be more than one argument.
1152
1153         To avoid splitting on the comma in this situation, increase the depth of
1154         tokens between `lambda` and `:`.
1155         """
1156         if leaf.type == token.NAME and leaf.value == "lambda":
1157             self.depth += 1
1158             self._lambda_argument_depths.append(self.depth)
1159             return True
1160
1161         return False
1162
1163     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1164         """See `maybe_increment_lambda_arguments` above for explanation."""
1165         if (
1166             self._lambda_argument_depths
1167             and self._lambda_argument_depths[-1] == self.depth
1168             and leaf.type == token.COLON
1169         ):
1170             self.depth -= 1
1171             self._lambda_argument_depths.pop()
1172             return True
1173
1174         return False
1175
1176     def get_open_lsqb(self) -> Optional[Leaf]:
1177         """Return the most recent opening square bracket (if any)."""
1178         return self.bracket_match.get((self.depth - 1, token.RSQB))
1179
1180
1181 @dataclass
1182 class Line:
1183     """Holds leaves and comments. Can be printed with `str(line)`."""
1184
1185     depth: int = 0
1186     leaves: List[Leaf] = field(default_factory=list)
1187     # keys ordered like `leaves`
1188     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
1189     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
1190     inside_brackets: bool = False
1191     should_explode: bool = False
1192
1193     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1194         """Add a new `leaf` to the end of the line.
1195
1196         Unless `preformatted` is True, the `leaf` will receive a new consistent
1197         whitespace prefix and metadata applied by :class:`BracketTracker`.
1198         Trailing commas are maybe removed, unpacked for loop variables are
1199         demoted from being delimiters.
1200
1201         Inline comments are put aside.
1202         """
1203         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1204         if not has_value:
1205             return
1206
1207         if token.COLON == leaf.type and self.is_class_paren_empty:
1208             del self.leaves[-2:]
1209         if self.leaves and not preformatted:
1210             # Note: at this point leaf.prefix should be empty except for
1211             # imports, for which we only preserve newlines.
1212             leaf.prefix += whitespace(
1213                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1214             )
1215         if self.inside_brackets or not preformatted:
1216             self.bracket_tracker.mark(leaf)
1217             self.maybe_remove_trailing_comma(leaf)
1218         if not self.append_comment(leaf):
1219             self.leaves.append(leaf)
1220
1221     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1222         """Like :func:`append()` but disallow invalid standalone comment structure.
1223
1224         Raises ValueError when any `leaf` is appended after a standalone comment
1225         or when a standalone comment is not the first leaf on the line.
1226         """
1227         if self.bracket_tracker.depth == 0:
1228             if self.is_comment:
1229                 raise ValueError("cannot append to standalone comments")
1230
1231             if self.leaves and leaf.type == STANDALONE_COMMENT:
1232                 raise ValueError(
1233                     "cannot append standalone comments to a populated line"
1234                 )
1235
1236         self.append(leaf, preformatted=preformatted)
1237
1238     @property
1239     def is_comment(self) -> bool:
1240         """Is this line a standalone comment?"""
1241         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1242
1243     @property
1244     def is_decorator(self) -> bool:
1245         """Is this line a decorator?"""
1246         return bool(self) and self.leaves[0].type == token.AT
1247
1248     @property
1249     def is_import(self) -> bool:
1250         """Is this an import line?"""
1251         return bool(self) and is_import(self.leaves[0])
1252
1253     @property
1254     def is_class(self) -> bool:
1255         """Is this line a class definition?"""
1256         return (
1257             bool(self)
1258             and self.leaves[0].type == token.NAME
1259             and self.leaves[0].value == "class"
1260         )
1261
1262     @property
1263     def is_stub_class(self) -> bool:
1264         """Is this line a class definition with a body consisting only of "..."?"""
1265         return self.is_class and self.leaves[-3:] == [
1266             Leaf(token.DOT, ".") for _ in range(3)
1267         ]
1268
1269     @property
1270     def is_collection_with_optional_trailing_comma(self) -> bool:
1271         """Is this line a collection literal with a trailing comma that's optional?
1272
1273         Note that the trailing comma in a 1-tuple is not optional.
1274         """
1275         if not self.leaves or len(self.leaves) < 4:
1276             return False
1277
1278         # Look for and address a trailing colon.
1279         if self.leaves[-1].type == token.COLON:
1280             closer = self.leaves[-2]
1281             close_index = -2
1282         else:
1283             closer = self.leaves[-1]
1284             close_index = -1
1285         if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
1286             return False
1287
1288         if closer.type == token.RPAR:
1289             # Tuples require an extra check, because if there's only
1290             # one element in the tuple removing the comma unmakes the
1291             # tuple.
1292             #
1293             # We also check for parens before looking for the trailing
1294             # comma because in some cases (eg assigning a dict
1295             # literal) the literal gets wrapped in temporary parens
1296             # during parsing. This case is covered by the
1297             # collections.py test data.
1298             opener = closer.opening_bracket
1299             for _open_index, leaf in enumerate(self.leaves):
1300                 if leaf is opener:
1301                     break
1302
1303             else:
1304                 # Couldn't find the matching opening paren, play it safe.
1305                 return False
1306
1307             commas = 0
1308             comma_depth = self.leaves[close_index - 1].bracket_depth
1309             for leaf in self.leaves[_open_index + 1 : close_index]:
1310                 if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA:
1311                     commas += 1
1312             if commas > 1:
1313                 # We haven't looked yet for the trailing comma because
1314                 # we might also have caught noop parens.
1315                 return self.leaves[close_index - 1].type == token.COMMA
1316
1317             elif commas == 1:
1318                 return False  # it's either a one-tuple or didn't have a trailing comma
1319
1320             if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
1321                 close_index -= 1
1322                 closer = self.leaves[close_index]
1323                 if closer.type == token.RPAR:
1324                     # TODO: this is a gut feeling. Will we ever see this?
1325                     return False
1326
1327         if self.leaves[close_index - 1].type != token.COMMA:
1328             return False
1329
1330         return True
1331
1332     @property
1333     def is_def(self) -> bool:
1334         """Is this a function definition? (Also returns True for async defs.)"""
1335         try:
1336             first_leaf = self.leaves[0]
1337         except IndexError:
1338             return False
1339
1340         try:
1341             second_leaf: Optional[Leaf] = self.leaves[1]
1342         except IndexError:
1343             second_leaf = None
1344         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1345             first_leaf.type == token.ASYNC
1346             and second_leaf is not None
1347             and second_leaf.type == token.NAME
1348             and second_leaf.value == "def"
1349         )
1350
1351     @property
1352     def is_class_paren_empty(self) -> bool:
1353         """Is this a class with no base classes but using parentheses?
1354
1355         Those are unnecessary and should be removed.
1356         """
1357         return (
1358             bool(self)
1359             and len(self.leaves) == 4
1360             and self.is_class
1361             and self.leaves[2].type == token.LPAR
1362             and self.leaves[2].value == "("
1363             and self.leaves[3].type == token.RPAR
1364             and self.leaves[3].value == ")"
1365         )
1366
1367     @property
1368     def is_triple_quoted_string(self) -> bool:
1369         """Is the line a triple quoted string?"""
1370         return (
1371             bool(self)
1372             and self.leaves[0].type == token.STRING
1373             and self.leaves[0].value.startswith(('"""', "'''"))
1374         )
1375
1376     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1377         """If so, needs to be split before emitting."""
1378         for leaf in self.leaves:
1379             if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
1380                 return True
1381
1382         return False
1383
1384     def contains_uncollapsable_type_comments(self) -> bool:
1385         ignored_ids = set()
1386         try:
1387             last_leaf = self.leaves[-1]
1388             ignored_ids.add(id(last_leaf))
1389             if last_leaf.type == token.COMMA or (
1390                 last_leaf.type == token.RPAR and not last_leaf.value
1391             ):
1392                 # When trailing commas or optional parens are inserted by Black for
1393                 # consistency, comments after the previous last element are not moved
1394                 # (they don't have to, rendering will still be correct).  So we ignore
1395                 # trailing commas and invisible.
1396                 last_leaf = self.leaves[-2]
1397                 ignored_ids.add(id(last_leaf))
1398         except IndexError:
1399             return False
1400
1401         # A type comment is uncollapsable if it is attached to a leaf
1402         # that isn't at the end of the line (since that could cause it
1403         # to get associated to a different argument) or if there are
1404         # comments before it (since that could cause it to get hidden
1405         # behind a comment.
1406         comment_seen = False
1407         for leaf_id, comments in self.comments.items():
1408             for comment in comments:
1409                 if is_type_comment(comment):
1410                     if leaf_id not in ignored_ids or comment_seen:
1411                         return True
1412
1413                 comment_seen = True
1414
1415         return False
1416
1417     def contains_unsplittable_type_ignore(self) -> bool:
1418         if not self.leaves:
1419             return False
1420
1421         # If a 'type: ignore' is attached to the end of a line, we
1422         # can't split the line, because we can't know which of the
1423         # subexpressions the ignore was meant to apply to.
1424         #
1425         # We only want this to apply to actual physical lines from the
1426         # original source, though: we don't want the presence of a
1427         # 'type: ignore' at the end of a multiline expression to
1428         # justify pushing it all onto one line. Thus we
1429         # (unfortunately) need to check the actual source lines and
1430         # only report an unsplittable 'type: ignore' if this line was
1431         # one line in the original code.
1432
1433         # Grab the first and last line numbers, skipping generated leaves
1434         first_line = next((l.lineno for l in self.leaves if l.lineno != 0), 0)
1435         last_line = next((l.lineno for l in reversed(self.leaves) if l.lineno != 0), 0)
1436
1437         if first_line == last_line:
1438             # We look at the last two leaves since a comma or an
1439             # invisible paren could have been added at the end of the
1440             # line.
1441             for node in self.leaves[-2:]:
1442                 for comment in self.comments.get(id(node), []):
1443                     if is_type_comment(comment, " ignore"):
1444                         return True
1445
1446         return False
1447
1448     def contains_multiline_strings(self) -> bool:
1449         for leaf in self.leaves:
1450             if is_multiline_string(leaf):
1451                 return True
1452
1453         return False
1454
1455     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1456         """Remove trailing comma if there is one and it's safe."""
1457         if not (self.leaves and self.leaves[-1].type == token.COMMA):
1458             return False
1459
1460         # We remove trailing commas only in the case of importing a
1461         # single name from a module.
1462         if not (
1463             self.leaves
1464             and self.is_import
1465             and len(self.leaves) > 4
1466             and self.leaves[-1].type == token.COMMA
1467             and closing.type in CLOSING_BRACKETS
1468             and self.leaves[-4].type == token.NAME
1469             and (
1470                 # regular `from foo import bar,`
1471                 self.leaves[-4].value == "import"
1472                 # `from foo import (bar as baz,)
1473                 or (
1474                     len(self.leaves) > 6
1475                     and self.leaves[-6].value == "import"
1476                     and self.leaves[-3].value == "as"
1477                 )
1478                 # `from foo import bar as baz,`
1479                 or (
1480                     len(self.leaves) > 5
1481                     and self.leaves[-5].value == "import"
1482                     and self.leaves[-3].value == "as"
1483                 )
1484             )
1485             and closing.type == token.RPAR
1486         ):
1487             return False
1488
1489         self.remove_trailing_comma()
1490         return True
1491
1492     def append_comment(self, comment: Leaf) -> bool:
1493         """Add an inline or standalone comment to the line."""
1494         if (
1495             comment.type == STANDALONE_COMMENT
1496             and self.bracket_tracker.any_open_brackets()
1497         ):
1498             comment.prefix = ""
1499             return False
1500
1501         if comment.type != token.COMMENT:
1502             return False
1503
1504         if not self.leaves:
1505             comment.type = STANDALONE_COMMENT
1506             comment.prefix = ""
1507             return False
1508
1509         last_leaf = self.leaves[-1]
1510         if (
1511             last_leaf.type == token.RPAR
1512             and not last_leaf.value
1513             and last_leaf.parent
1514             and len(list(last_leaf.parent.leaves())) <= 3
1515             and not is_type_comment(comment)
1516         ):
1517             # Comments on an optional parens wrapping a single leaf should belong to
1518             # the wrapped node except if it's a type comment. Pinning the comment like
1519             # this avoids unstable formatting caused by comment migration.
1520             if len(self.leaves) < 2:
1521                 comment.type = STANDALONE_COMMENT
1522                 comment.prefix = ""
1523                 return False
1524
1525             last_leaf = self.leaves[-2]
1526         self.comments.setdefault(id(last_leaf), []).append(comment)
1527         return True
1528
1529     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1530         """Generate comments that should appear directly after `leaf`."""
1531         return self.comments.get(id(leaf), [])
1532
1533     def remove_trailing_comma(self) -> None:
1534         """Remove the trailing comma and moves the comments attached to it."""
1535         trailing_comma = self.leaves.pop()
1536         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1537         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1538             trailing_comma_comments
1539         )
1540
1541     def is_complex_subscript(self, leaf: Leaf) -> bool:
1542         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1543         open_lsqb = self.bracket_tracker.get_open_lsqb()
1544         if open_lsqb is None:
1545             return False
1546
1547         subscript_start = open_lsqb.next_sibling
1548
1549         if isinstance(subscript_start, Node):
1550             if subscript_start.type == syms.listmaker:
1551                 return False
1552
1553             if subscript_start.type == syms.subscriptlist:
1554                 subscript_start = child_towards(subscript_start, leaf)
1555         return subscript_start is not None and any(
1556             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1557         )
1558
1559     def __str__(self) -> str:
1560         """Render the line."""
1561         if not self:
1562             return "\n"
1563
1564         indent = "    " * self.depth
1565         leaves = iter(self.leaves)
1566         first = next(leaves)
1567         res = f"{first.prefix}{indent}{first.value}"
1568         for leaf in leaves:
1569             res += str(leaf)
1570         for comment in itertools.chain.from_iterable(self.comments.values()):
1571             res += str(comment)
1572         return res + "\n"
1573
1574     def __bool__(self) -> bool:
1575         """Return True if the line has leaves or comments."""
1576         return bool(self.leaves or self.comments)
1577
1578
1579 @dataclass
1580 class EmptyLineTracker:
1581     """Provides a stateful method that returns the number of potential extra
1582     empty lines needed before and after the currently processed line.
1583
1584     Note: this tracker works on lines that haven't been split yet.  It assumes
1585     the prefix of the first leaf consists of optional newlines.  Those newlines
1586     are consumed by `maybe_empty_lines()` and included in the computation.
1587     """
1588
1589     is_pyi: bool = False
1590     previous_line: Optional[Line] = None
1591     previous_after: int = 0
1592     previous_defs: List[int] = field(default_factory=list)
1593
1594     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1595         """Return the number of extra empty lines before and after the `current_line`.
1596
1597         This is for separating `def`, `async def` and `class` with extra empty
1598         lines (two on module-level).
1599         """
1600         before, after = self._maybe_empty_lines(current_line)
1601         before = (
1602             # Black should not insert empty lines at the beginning
1603             # of the file
1604             0
1605             if self.previous_line is None
1606             else before - self.previous_after
1607         )
1608         self.previous_after = after
1609         self.previous_line = current_line
1610         return before, after
1611
1612     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1613         max_allowed = 1
1614         if current_line.depth == 0:
1615             max_allowed = 1 if self.is_pyi else 2
1616         if current_line.leaves:
1617             # Consume the first leaf's extra newlines.
1618             first_leaf = current_line.leaves[0]
1619             before = first_leaf.prefix.count("\n")
1620             before = min(before, max_allowed)
1621             first_leaf.prefix = ""
1622         else:
1623             before = 0
1624         depth = current_line.depth
1625         while self.previous_defs and self.previous_defs[-1] >= depth:
1626             self.previous_defs.pop()
1627             if self.is_pyi:
1628                 before = 0 if depth else 1
1629             else:
1630                 before = 1 if depth else 2
1631         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1632             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1633
1634         if (
1635             self.previous_line
1636             and self.previous_line.is_import
1637             and not current_line.is_import
1638             and depth == self.previous_line.depth
1639         ):
1640             return (before or 1), 0
1641
1642         if (
1643             self.previous_line
1644             and self.previous_line.is_class
1645             and current_line.is_triple_quoted_string
1646         ):
1647             return before, 1
1648
1649         return before, 0
1650
1651     def _maybe_empty_lines_for_class_or_def(
1652         self, current_line: Line, before: int
1653     ) -> Tuple[int, int]:
1654         if not current_line.is_decorator:
1655             self.previous_defs.append(current_line.depth)
1656         if self.previous_line is None:
1657             # Don't insert empty lines before the first line in the file.
1658             return 0, 0
1659
1660         if self.previous_line.is_decorator:
1661             return 0, 0
1662
1663         if self.previous_line.depth < current_line.depth and (
1664             self.previous_line.is_class or self.previous_line.is_def
1665         ):
1666             return 0, 0
1667
1668         if (
1669             self.previous_line.is_comment
1670             and self.previous_line.depth == current_line.depth
1671             and before == 0
1672         ):
1673             return 0, 0
1674
1675         if self.is_pyi:
1676             if self.previous_line.depth > current_line.depth:
1677                 newlines = 1
1678             elif current_line.is_class or self.previous_line.is_class:
1679                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1680                     # No blank line between classes with an empty body
1681                     newlines = 0
1682                 else:
1683                     newlines = 1
1684             elif current_line.is_def and not self.previous_line.is_def:
1685                 # Blank line between a block of functions and a block of non-functions
1686                 newlines = 1
1687             else:
1688                 newlines = 0
1689         else:
1690             newlines = 2
1691         if current_line.depth and newlines:
1692             newlines -= 1
1693         return newlines, 0
1694
1695
1696 @dataclass
1697 class LineGenerator(Visitor[Line]):
1698     """Generates reformatted Line objects.  Empty lines are not emitted.
1699
1700     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1701     in ways that will no longer stringify to valid Python code on the tree.
1702     """
1703
1704     is_pyi: bool = False
1705     normalize_strings: bool = True
1706     current_line: Line = field(default_factory=Line)
1707     remove_u_prefix: bool = False
1708
1709     def line(self, indent: int = 0) -> Iterator[Line]:
1710         """Generate a line.
1711
1712         If the line is empty, only emit if it makes sense.
1713         If the line is too long, split it first and then generate.
1714
1715         If any lines were generated, set up a new current_line.
1716         """
1717         if not self.current_line:
1718             self.current_line.depth += indent
1719             return  # Line is empty, don't emit. Creating a new one unnecessary.
1720
1721         complete_line = self.current_line
1722         self.current_line = Line(depth=complete_line.depth + indent)
1723         yield complete_line
1724
1725     def visit_default(self, node: LN) -> Iterator[Line]:
1726         """Default `visit_*()` implementation. Recurses to children of `node`."""
1727         if isinstance(node, Leaf):
1728             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1729             for comment in generate_comments(node):
1730                 if any_open_brackets:
1731                     # any comment within brackets is subject to splitting
1732                     self.current_line.append(comment)
1733                 elif comment.type == token.COMMENT:
1734                     # regular trailing comment
1735                     self.current_line.append(comment)
1736                     yield from self.line()
1737
1738                 else:
1739                     # regular standalone comment
1740                     yield from self.line()
1741
1742                     self.current_line.append(comment)
1743                     yield from self.line()
1744
1745             normalize_prefix(node, inside_brackets=any_open_brackets)
1746             if self.normalize_strings and node.type == token.STRING:
1747                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1748                 normalize_string_quotes(node)
1749             if node.type == token.NUMBER:
1750                 normalize_numeric_literal(node)
1751             if node.type not in WHITESPACE:
1752                 self.current_line.append(node)
1753         yield from super().visit_default(node)
1754
1755     def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
1756         """Increase indentation level, maybe yield a line."""
1757         # In blib2to3 INDENT never holds comments.
1758         yield from self.line(+1)
1759         yield from self.visit_default(node)
1760
1761     def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
1762         """Decrease indentation level, maybe yield a line."""
1763         # The current line might still wait for trailing comments.  At DEDENT time
1764         # there won't be any (they would be prefixes on the preceding NEWLINE).
1765         # Emit the line then.
1766         yield from self.line()
1767
1768         # While DEDENT has no value, its prefix may contain standalone comments
1769         # that belong to the current indentation level.  Get 'em.
1770         yield from self.visit_default(node)
1771
1772         # Finally, emit the dedent.
1773         yield from self.line(-1)
1774
1775     def visit_stmt(
1776         self, node: Node, keywords: Set[str], parens: Set[str]
1777     ) -> Iterator[Line]:
1778         """Visit a statement.
1779
1780         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1781         `def`, `with`, `class`, `assert` and assignments.
1782
1783         The relevant Python language `keywords` for a given statement will be
1784         NAME leaves within it. This methods puts those on a separate line.
1785
1786         `parens` holds a set of string leaf values immediately after which
1787         invisible parens should be put.
1788         """
1789         normalize_invisible_parens(node, parens_after=parens)
1790         for child in node.children:
1791             if child.type == token.NAME and child.value in keywords:  # type: ignore
1792                 yield from self.line()
1793
1794             yield from self.visit(child)
1795
1796     def visit_suite(self, node: Node) -> Iterator[Line]:
1797         """Visit a suite."""
1798         if self.is_pyi and is_stub_suite(node):
1799             yield from self.visit(node.children[2])
1800         else:
1801             yield from self.visit_default(node)
1802
1803     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1804         """Visit a statement without nested statements."""
1805         is_suite_like = node.parent and node.parent.type in STATEMENT
1806         if is_suite_like:
1807             if self.is_pyi and is_stub_body(node):
1808                 yield from self.visit_default(node)
1809             else:
1810                 yield from self.line(+1)
1811                 yield from self.visit_default(node)
1812                 yield from self.line(-1)
1813
1814         else:
1815             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1816                 yield from self.line()
1817             yield from self.visit_default(node)
1818
1819     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1820         """Visit `async def`, `async for`, `async with`."""
1821         yield from self.line()
1822
1823         children = iter(node.children)
1824         for child in children:
1825             yield from self.visit(child)
1826
1827             if child.type == token.ASYNC:
1828                 break
1829
1830         internal_stmt = next(children)
1831         for child in internal_stmt.children:
1832             yield from self.visit(child)
1833
1834     def visit_decorators(self, node: Node) -> Iterator[Line]:
1835         """Visit decorators."""
1836         for child in node.children:
1837             yield from self.line()
1838             yield from self.visit(child)
1839
1840     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1841         """Remove a semicolon and put the other statement on a separate line."""
1842         yield from self.line()
1843
1844     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1845         """End of file. Process outstanding comments and end with a newline."""
1846         yield from self.visit_default(leaf)
1847         yield from self.line()
1848
1849     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1850         if not self.current_line.bracket_tracker.any_open_brackets():
1851             yield from self.line()
1852         yield from self.visit_default(leaf)
1853
1854     def visit_factor(self, node: Node) -> Iterator[Line]:
1855         """Force parentheses between a unary op and a binary power:
1856
1857         -2 ** 8 -> -(2 ** 8)
1858         """
1859         _operator, operand = node.children
1860         if (
1861             operand.type == syms.power
1862             and len(operand.children) == 3
1863             and operand.children[1].type == token.DOUBLESTAR
1864         ):
1865             lpar = Leaf(token.LPAR, "(")
1866             rpar = Leaf(token.RPAR, ")")
1867             index = operand.remove() or 0
1868             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
1869         yield from self.visit_default(node)
1870
1871     def __post_init__(self) -> None:
1872         """You are in a twisty little maze of passages."""
1873         v = self.visit_stmt
1874         Ø: Set[str] = set()
1875         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1876         self.visit_if_stmt = partial(
1877             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1878         )
1879         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1880         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1881         self.visit_try_stmt = partial(
1882             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1883         )
1884         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1885         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1886         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1887         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1888         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1889         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1890         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1891         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1892         self.visit_async_funcdef = self.visit_async_stmt
1893         self.visit_decorated = self.visit_decorators
1894
1895
1896 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1897 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1898 OPENING_BRACKETS = set(BRACKET.keys())
1899 CLOSING_BRACKETS = set(BRACKET.values())
1900 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1901 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1902
1903
1904 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1905     """Return whitespace prefix if needed for the given `leaf`.
1906
1907     `complex_subscript` signals whether the given leaf is part of a subscription
1908     which has non-trivial arguments, like arithmetic expressions or function calls.
1909     """
1910     NO = ""
1911     SPACE = " "
1912     DOUBLESPACE = "  "
1913     t = leaf.type
1914     p = leaf.parent
1915     v = leaf.value
1916     if t in ALWAYS_NO_SPACE:
1917         return NO
1918
1919     if t == token.COMMENT:
1920         return DOUBLESPACE
1921
1922     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1923     if t == token.COLON and p.type not in {
1924         syms.subscript,
1925         syms.subscriptlist,
1926         syms.sliceop,
1927     }:
1928         return NO
1929
1930     prev = leaf.prev_sibling
1931     if not prev:
1932         prevp = preceding_leaf(p)
1933         if not prevp or prevp.type in OPENING_BRACKETS:
1934             return NO
1935
1936         if t == token.COLON:
1937             if prevp.type == token.COLON:
1938                 return NO
1939
1940             elif prevp.type != token.COMMA and not complex_subscript:
1941                 return NO
1942
1943             return SPACE
1944
1945         if prevp.type == token.EQUAL:
1946             if prevp.parent:
1947                 if prevp.parent.type in {
1948                     syms.arglist,
1949                     syms.argument,
1950                     syms.parameters,
1951                     syms.varargslist,
1952                 }:
1953                     return NO
1954
1955                 elif prevp.parent.type == syms.typedargslist:
1956                     # A bit hacky: if the equal sign has whitespace, it means we
1957                     # previously found it's a typed argument.  So, we're using
1958                     # that, too.
1959                     return prevp.prefix
1960
1961         elif prevp.type in VARARGS_SPECIALS:
1962             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1963                 return NO
1964
1965         elif prevp.type == token.COLON:
1966             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1967                 return SPACE if complex_subscript else NO
1968
1969         elif (
1970             prevp.parent
1971             and prevp.parent.type == syms.factor
1972             and prevp.type in MATH_OPERATORS
1973         ):
1974             return NO
1975
1976         elif (
1977             prevp.type == token.RIGHTSHIFT
1978             and prevp.parent
1979             and prevp.parent.type == syms.shift_expr
1980             and prevp.prev_sibling
1981             and prevp.prev_sibling.type == token.NAME
1982             and prevp.prev_sibling.value == "print"  # type: ignore
1983         ):
1984             # Python 2 print chevron
1985             return NO
1986
1987     elif prev.type in OPENING_BRACKETS:
1988         return NO
1989
1990     if p.type in {syms.parameters, syms.arglist}:
1991         # untyped function signatures or calls
1992         if not prev or prev.type != token.COMMA:
1993             return NO
1994
1995     elif p.type == syms.varargslist:
1996         # lambdas
1997         if prev and prev.type != token.COMMA:
1998             return NO
1999
2000     elif p.type == syms.typedargslist:
2001         # typed function signatures
2002         if not prev:
2003             return NO
2004
2005         if t == token.EQUAL:
2006             if prev.type != syms.tname:
2007                 return NO
2008
2009         elif prev.type == token.EQUAL:
2010             # A bit hacky: if the equal sign has whitespace, it means we
2011             # previously found it's a typed argument.  So, we're using that, too.
2012             return prev.prefix
2013
2014         elif prev.type != token.COMMA:
2015             return NO
2016
2017     elif p.type == syms.tname:
2018         # type names
2019         if not prev:
2020             prevp = preceding_leaf(p)
2021             if not prevp or prevp.type != token.COMMA:
2022                 return NO
2023
2024     elif p.type == syms.trailer:
2025         # attributes and calls
2026         if t == token.LPAR or t == token.RPAR:
2027             return NO
2028
2029         if not prev:
2030             if t == token.DOT:
2031                 prevp = preceding_leaf(p)
2032                 if not prevp or prevp.type != token.NUMBER:
2033                     return NO
2034
2035             elif t == token.LSQB:
2036                 return NO
2037
2038         elif prev.type != token.COMMA:
2039             return NO
2040
2041     elif p.type == syms.argument:
2042         # single argument
2043         if t == token.EQUAL:
2044             return NO
2045
2046         if not prev:
2047             prevp = preceding_leaf(p)
2048             if not prevp or prevp.type == token.LPAR:
2049                 return NO
2050
2051         elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
2052             return NO
2053
2054     elif p.type == syms.decorator:
2055         # decorators
2056         return NO
2057
2058     elif p.type == syms.dotted_name:
2059         if prev:
2060             return NO
2061
2062         prevp = preceding_leaf(p)
2063         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
2064             return NO
2065
2066     elif p.type == syms.classdef:
2067         if t == token.LPAR:
2068             return NO
2069
2070         if prev and prev.type == token.LPAR:
2071             return NO
2072
2073     elif p.type in {syms.subscript, syms.sliceop}:
2074         # indexing
2075         if not prev:
2076             assert p.parent is not None, "subscripts are always parented"
2077             if p.parent.type == syms.subscriptlist:
2078                 return SPACE
2079
2080             return NO
2081
2082         elif not complex_subscript:
2083             return NO
2084
2085     elif p.type == syms.atom:
2086         if prev and t == token.DOT:
2087             # dots, but not the first one.
2088             return NO
2089
2090     elif p.type == syms.dictsetmaker:
2091         # dict unpacking
2092         if prev and prev.type == token.DOUBLESTAR:
2093             return NO
2094
2095     elif p.type in {syms.factor, syms.star_expr}:
2096         # unary ops
2097         if not prev:
2098             prevp = preceding_leaf(p)
2099             if not prevp or prevp.type in OPENING_BRACKETS:
2100                 return NO
2101
2102             prevp_parent = prevp.parent
2103             assert prevp_parent is not None
2104             if prevp.type == token.COLON and prevp_parent.type in {
2105                 syms.subscript,
2106                 syms.sliceop,
2107             }:
2108                 return NO
2109
2110             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
2111                 return NO
2112
2113         elif t in {token.NAME, token.NUMBER, token.STRING}:
2114             return NO
2115
2116     elif p.type == syms.import_from:
2117         if t == token.DOT:
2118             if prev and prev.type == token.DOT:
2119                 return NO
2120
2121         elif t == token.NAME:
2122             if v == "import":
2123                 return SPACE
2124
2125             if prev and prev.type == token.DOT:
2126                 return NO
2127
2128     elif p.type == syms.sliceop:
2129         return NO
2130
2131     return SPACE
2132
2133
2134 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
2135     """Return the first leaf that precedes `node`, if any."""
2136     while node:
2137         res = node.prev_sibling
2138         if res:
2139             if isinstance(res, Leaf):
2140                 return res
2141
2142             try:
2143                 return list(res.leaves())[-1]
2144
2145             except IndexError:
2146                 return None
2147
2148         node = node.parent
2149     return None
2150
2151
2152 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
2153     """Return the child of `ancestor` that contains `descendant`."""
2154     node: Optional[LN] = descendant
2155     while node and node.parent != ancestor:
2156         node = node.parent
2157     return node
2158
2159
2160 def container_of(leaf: Leaf) -> LN:
2161     """Return `leaf` or one of its ancestors that is the topmost container of it.
2162
2163     By "container" we mean a node where `leaf` is the very first child.
2164     """
2165     same_prefix = leaf.prefix
2166     container: LN = leaf
2167     while container:
2168         parent = container.parent
2169         if parent is None:
2170             break
2171
2172         if parent.children[0].prefix != same_prefix:
2173             break
2174
2175         if parent.type == syms.file_input:
2176             break
2177
2178         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
2179             break
2180
2181         container = parent
2182     return container
2183
2184
2185 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2186     """Return the priority of the `leaf` delimiter, given a line break after it.
2187
2188     The delimiter priorities returned here are from those delimiters that would
2189     cause a line break after themselves.
2190
2191     Higher numbers are higher priority.
2192     """
2193     if leaf.type == token.COMMA:
2194         return COMMA_PRIORITY
2195
2196     return 0
2197
2198
2199 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
2200     """Return the priority of the `leaf` delimiter, given a line break before it.
2201
2202     The delimiter priorities returned here are from those delimiters that would
2203     cause a line break before themselves.
2204
2205     Higher numbers are higher priority.
2206     """
2207     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2208         # * and ** might also be MATH_OPERATORS but in this case they are not.
2209         # Don't treat them as a delimiter.
2210         return 0
2211
2212     if (
2213         leaf.type == token.DOT
2214         and leaf.parent
2215         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2216         and (previous is None or previous.type in CLOSING_BRACKETS)
2217     ):
2218         return DOT_PRIORITY
2219
2220     if (
2221         leaf.type in MATH_OPERATORS
2222         and leaf.parent
2223         and leaf.parent.type not in {syms.factor, syms.star_expr}
2224     ):
2225         return MATH_PRIORITIES[leaf.type]
2226
2227     if leaf.type in COMPARATORS:
2228         return COMPARATOR_PRIORITY
2229
2230     if (
2231         leaf.type == token.STRING
2232         and previous is not None
2233         and previous.type == token.STRING
2234     ):
2235         return STRING_PRIORITY
2236
2237     if leaf.type not in {token.NAME, token.ASYNC}:
2238         return 0
2239
2240     if (
2241         leaf.value == "for"
2242         and leaf.parent
2243         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2244         or leaf.type == token.ASYNC
2245     ):
2246         if (
2247             not isinstance(leaf.prev_sibling, Leaf)
2248             or leaf.prev_sibling.value != "async"
2249         ):
2250             return COMPREHENSION_PRIORITY
2251
2252     if (
2253         leaf.value == "if"
2254         and leaf.parent
2255         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2256     ):
2257         return COMPREHENSION_PRIORITY
2258
2259     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2260         return TERNARY_PRIORITY
2261
2262     if leaf.value == "is":
2263         return COMPARATOR_PRIORITY
2264
2265     if (
2266         leaf.value == "in"
2267         and leaf.parent
2268         and leaf.parent.type in {syms.comp_op, syms.comparison}
2269         and not (
2270             previous is not None
2271             and previous.type == token.NAME
2272             and previous.value == "not"
2273         )
2274     ):
2275         return COMPARATOR_PRIORITY
2276
2277     if (
2278         leaf.value == "not"
2279         and leaf.parent
2280         and leaf.parent.type == syms.comp_op
2281         and not (
2282             previous is not None
2283             and previous.type == token.NAME
2284             and previous.value == "is"
2285         )
2286     ):
2287         return COMPARATOR_PRIORITY
2288
2289     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2290         return LOGIC_PRIORITY
2291
2292     return 0
2293
2294
2295 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2296 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2297
2298
2299 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2300     """Clean the prefix of the `leaf` and generate comments from it, if any.
2301
2302     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2303     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2304     move because it does away with modifying the grammar to include all the
2305     possible places in which comments can be placed.
2306
2307     The sad consequence for us though is that comments don't "belong" anywhere.
2308     This is why this function generates simple parentless Leaf objects for
2309     comments.  We simply don't know what the correct parent should be.
2310
2311     No matter though, we can live without this.  We really only need to
2312     differentiate between inline and standalone comments.  The latter don't
2313     share the line with any code.
2314
2315     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2316     are emitted with a fake STANDALONE_COMMENT token identifier.
2317     """
2318     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2319         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2320
2321
2322 @dataclass
2323 class ProtoComment:
2324     """Describes a piece of syntax that is a comment.
2325
2326     It's not a :class:`blib2to3.pytree.Leaf` so that:
2327
2328     * it can be cached (`Leaf` objects should not be reused more than once as
2329       they store their lineno, column, prefix, and parent information);
2330     * `newlines` and `consumed` fields are kept separate from the `value`. This
2331       simplifies handling of special marker comments like ``# fmt: off/on``.
2332     """
2333
2334     type: int  # token.COMMENT or STANDALONE_COMMENT
2335     value: str  # content of the comment
2336     newlines: int  # how many newlines before the comment
2337     consumed: int  # how many characters of the original leaf's prefix did we consume
2338
2339
2340 @lru_cache(maxsize=4096)
2341 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2342     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2343     result: List[ProtoComment] = []
2344     if not prefix or "#" not in prefix:
2345         return result
2346
2347     consumed = 0
2348     nlines = 0
2349     ignored_lines = 0
2350     for index, line in enumerate(prefix.split("\n")):
2351         consumed += len(line) + 1  # adding the length of the split '\n'
2352         line = line.lstrip()
2353         if not line:
2354             nlines += 1
2355         if not line.startswith("#"):
2356             # Escaped newlines outside of a comment are not really newlines at
2357             # all. We treat a single-line comment following an escaped newline
2358             # as a simple trailing comment.
2359             if line.endswith("\\"):
2360                 ignored_lines += 1
2361             continue
2362
2363         if index == ignored_lines and not is_endmarker:
2364             comment_type = token.COMMENT  # simple trailing comment
2365         else:
2366             comment_type = STANDALONE_COMMENT
2367         comment = make_comment(line)
2368         result.append(
2369             ProtoComment(
2370                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2371             )
2372         )
2373         nlines = 0
2374     return result
2375
2376
2377 def make_comment(content: str) -> str:
2378     """Return a consistently formatted comment from the given `content` string.
2379
2380     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2381     space between the hash sign and the content.
2382
2383     If `content` didn't start with a hash sign, one is provided.
2384     """
2385     content = content.rstrip()
2386     if not content:
2387         return "#"
2388
2389     if content[0] == "#":
2390         content = content[1:]
2391     if content and content[0] not in " !:#'%":
2392         content = " " + content
2393     return "#" + content
2394
2395
2396 def split_line(
2397     line: Line,
2398     line_length: int,
2399     inner: bool = False,
2400     features: Collection[Feature] = (),
2401 ) -> Iterator[Line]:
2402     """Split a `line` into potentially many lines.
2403
2404     They should fit in the allotted `line_length` but might not be able to.
2405     `inner` signifies that there were a pair of brackets somewhere around the
2406     current `line`, possibly transitively. This means we can fallback to splitting
2407     by delimiters if the LHS/RHS don't yield any results.
2408
2409     `features` are syntactical features that may be used in the output.
2410     """
2411     if line.is_comment:
2412         yield line
2413         return
2414
2415     line_str = str(line).strip("\n")
2416
2417     if (
2418         not line.contains_uncollapsable_type_comments()
2419         and not line.should_explode
2420         and not line.is_collection_with_optional_trailing_comma
2421         and (
2422             is_line_short_enough(line, line_length=line_length, line_str=line_str)
2423             or line.contains_unsplittable_type_ignore()
2424         )
2425     ):
2426         yield line
2427         return
2428
2429     split_funcs: List[SplitFunc]
2430     if line.is_def:
2431         split_funcs = [left_hand_split]
2432     else:
2433
2434         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2435             for omit in generate_trailers_to_omit(line, line_length):
2436                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2437                 if is_line_short_enough(lines[0], line_length=line_length):
2438                     yield from lines
2439                     return
2440
2441             # All splits failed, best effort split with no omits.
2442             # This mostly happens to multiline strings that are by definition
2443             # reported as not fitting a single line.
2444             # line_length=1 here was historically a bug that somehow became a feature.
2445             # See #762 and #781 for the full story.
2446             yield from right_hand_split(line, line_length=1, features=features)
2447
2448         if line.inside_brackets:
2449             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2450         else:
2451             split_funcs = [rhs]
2452     for split_func in split_funcs:
2453         # We are accumulating lines in `result` because we might want to abort
2454         # mission and return the original line in the end, or attempt a different
2455         # split altogether.
2456         result: List[Line] = []
2457         try:
2458             for l in split_func(line, features):
2459                 if str(l).strip("\n") == line_str:
2460                     raise CannotSplit("Split function returned an unchanged result")
2461
2462                 result.extend(
2463                     split_line(
2464                         l, line_length=line_length, inner=True, features=features
2465                     )
2466                 )
2467         except CannotSplit:
2468             continue
2469
2470         else:
2471             yield from result
2472             break
2473
2474     else:
2475         yield line
2476
2477
2478 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2479     """Split line into many lines, starting with the first matching bracket pair.
2480
2481     Note: this usually looks weird, only use this for function definitions.
2482     Prefer RHS otherwise.  This is why this function is not symmetrical with
2483     :func:`right_hand_split` which also handles optional parentheses.
2484     """
2485     tail_leaves: List[Leaf] = []
2486     body_leaves: List[Leaf] = []
2487     head_leaves: List[Leaf] = []
2488     current_leaves = head_leaves
2489     matching_bracket: Optional[Leaf] = None
2490     for leaf in line.leaves:
2491         if (
2492             current_leaves is body_leaves
2493             and leaf.type in CLOSING_BRACKETS
2494             and leaf.opening_bracket is matching_bracket
2495         ):
2496             current_leaves = tail_leaves if body_leaves else head_leaves
2497         current_leaves.append(leaf)
2498         if current_leaves is head_leaves:
2499             if leaf.type in OPENING_BRACKETS:
2500                 matching_bracket = leaf
2501                 current_leaves = body_leaves
2502     if not matching_bracket:
2503         raise CannotSplit("No brackets found")
2504
2505     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2506     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2507     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2508     bracket_split_succeeded_or_raise(head, body, tail)
2509     for result in (head, body, tail):
2510         if result:
2511             yield result
2512
2513
2514 def right_hand_split(
2515     line: Line,
2516     line_length: int,
2517     features: Collection[Feature] = (),
2518     omit: Collection[LeafID] = (),
2519 ) -> Iterator[Line]:
2520     """Split line into many lines, starting with the last matching bracket pair.
2521
2522     If the split was by optional parentheses, attempt splitting without them, too.
2523     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2524     this split.
2525
2526     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2527     """
2528     tail_leaves: List[Leaf] = []
2529     body_leaves: List[Leaf] = []
2530     head_leaves: List[Leaf] = []
2531     current_leaves = tail_leaves
2532     opening_bracket: Optional[Leaf] = None
2533     closing_bracket: Optional[Leaf] = None
2534     for leaf in reversed(line.leaves):
2535         if current_leaves is body_leaves:
2536             if leaf is opening_bracket:
2537                 current_leaves = head_leaves if body_leaves else tail_leaves
2538         current_leaves.append(leaf)
2539         if current_leaves is tail_leaves:
2540             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2541                 opening_bracket = leaf.opening_bracket
2542                 closing_bracket = leaf
2543                 current_leaves = body_leaves
2544     if not (opening_bracket and closing_bracket and head_leaves):
2545         # If there is no opening or closing_bracket that means the split failed and
2546         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2547         # the matching `opening_bracket` wasn't available on `line` anymore.
2548         raise CannotSplit("No brackets found")
2549
2550     tail_leaves.reverse()
2551     body_leaves.reverse()
2552     head_leaves.reverse()
2553     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2554     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2555     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2556     bracket_split_succeeded_or_raise(head, body, tail)
2557     if (
2558         # the body shouldn't be exploded
2559         not body.should_explode
2560         # the opening bracket is an optional paren
2561         and opening_bracket.type == token.LPAR
2562         and not opening_bracket.value
2563         # the closing bracket is an optional paren
2564         and closing_bracket.type == token.RPAR
2565         and not closing_bracket.value
2566         # it's not an import (optional parens are the only thing we can split on
2567         # in this case; attempting a split without them is a waste of time)
2568         and not line.is_import
2569         # there are no standalone comments in the body
2570         and not body.contains_standalone_comments(0)
2571         # and we can actually remove the parens
2572         and can_omit_invisible_parens(body, line_length)
2573     ):
2574         omit = {id(closing_bracket), *omit}
2575         try:
2576             yield from right_hand_split(line, line_length, features=features, omit=omit)
2577             return
2578
2579         except CannotSplit:
2580             if not (
2581                 can_be_split(body)
2582                 or is_line_short_enough(body, line_length=line_length)
2583             ):
2584                 raise CannotSplit(
2585                     "Splitting failed, body is still too long and can't be split."
2586                 )
2587
2588             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2589                 raise CannotSplit(
2590                     "The current optional pair of parentheses is bound to fail to "
2591                     "satisfy the splitting algorithm because the head or the tail "
2592                     "contains multiline strings which by definition never fit one "
2593                     "line."
2594                 )
2595
2596     ensure_visible(opening_bracket)
2597     ensure_visible(closing_bracket)
2598     for result in (head, body, tail):
2599         if result:
2600             yield result
2601
2602
2603 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2604     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2605
2606     Do nothing otherwise.
2607
2608     A left- or right-hand split is based on a pair of brackets. Content before
2609     (and including) the opening bracket is left on one line, content inside the
2610     brackets is put on a separate line, and finally content starting with and
2611     following the closing bracket is put on a separate line.
2612
2613     Those are called `head`, `body`, and `tail`, respectively. If the split
2614     produced the same line (all content in `head`) or ended up with an empty `body`
2615     and the `tail` is just the closing bracket, then it's considered failed.
2616     """
2617     tail_len = len(str(tail).strip())
2618     if not body:
2619         if tail_len == 0:
2620             raise CannotSplit("Splitting brackets produced the same line")
2621
2622         elif tail_len < 3:
2623             raise CannotSplit(
2624                 f"Splitting brackets on an empty body to save "
2625                 f"{tail_len} characters is not worth it"
2626             )
2627
2628
2629 def bracket_split_build_line(
2630     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2631 ) -> Line:
2632     """Return a new line with given `leaves` and respective comments from `original`.
2633
2634     If `is_body` is True, the result line is one-indented inside brackets and as such
2635     has its first leaf's prefix normalized and a trailing comma added when expected.
2636     """
2637     result = Line(depth=original.depth)
2638     if is_body:
2639         result.inside_brackets = True
2640         result.depth += 1
2641         if leaves:
2642             # Since body is a new indent level, remove spurious leading whitespace.
2643             normalize_prefix(leaves[0], inside_brackets=True)
2644             # Ensure a trailing comma for imports and standalone function arguments, but
2645             # be careful not to add one after any comments or within type annotations.
2646             no_commas = (
2647                 original.is_def
2648                 and opening_bracket.value == "("
2649                 and not any(l.type == token.COMMA for l in leaves)
2650             )
2651
2652             if original.is_import or no_commas:
2653                 for i in range(len(leaves) - 1, -1, -1):
2654                     if leaves[i].type == STANDALONE_COMMENT:
2655                         continue
2656
2657                     if leaves[i].type != token.COMMA:
2658                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2659                     break
2660
2661     # Populate the line
2662     for leaf in leaves:
2663         result.append(leaf, preformatted=True)
2664         for comment_after in original.comments_after(leaf):
2665             result.append(comment_after, preformatted=True)
2666     if is_body:
2667         result.should_explode = should_explode(result, opening_bracket)
2668     return result
2669
2670
2671 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2672     """Normalize prefix of the first leaf in every line returned by `split_func`.
2673
2674     This is a decorator over relevant split functions.
2675     """
2676
2677     @wraps(split_func)
2678     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2679         for l in split_func(line, features):
2680             normalize_prefix(l.leaves[0], inside_brackets=True)
2681             yield l
2682
2683     return split_wrapper
2684
2685
2686 @dont_increase_indentation
2687 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2688     """Split according to delimiters of the highest priority.
2689
2690     If the appropriate Features are given, the split will add trailing commas
2691     also in function signatures and calls that contain `*` and `**`.
2692     """
2693     try:
2694         last_leaf = line.leaves[-1]
2695     except IndexError:
2696         raise CannotSplit("Line empty")
2697
2698     bt = line.bracket_tracker
2699     try:
2700         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2701     except ValueError:
2702         raise CannotSplit("No delimiters found")
2703
2704     if delimiter_priority == DOT_PRIORITY:
2705         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2706             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2707
2708     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2709     lowest_depth = sys.maxsize
2710     trailing_comma_safe = True
2711
2712     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2713         """Append `leaf` to current line or to new line if appending impossible."""
2714         nonlocal current_line
2715         try:
2716             current_line.append_safe(leaf, preformatted=True)
2717         except ValueError:
2718             yield current_line
2719
2720             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2721             current_line.append(leaf)
2722
2723     for leaf in line.leaves:
2724         yield from append_to_line(leaf)
2725
2726         for comment_after in line.comments_after(leaf):
2727             yield from append_to_line(comment_after)
2728
2729         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2730         if leaf.bracket_depth == lowest_depth:
2731             if is_vararg(leaf, within={syms.typedargslist}):
2732                 trailing_comma_safe = (
2733                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2734                 )
2735             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2736                 trailing_comma_safe = (
2737                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2738                 )
2739
2740         leaf_priority = bt.delimiters.get(id(leaf))
2741         if leaf_priority == delimiter_priority:
2742             yield current_line
2743
2744             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2745     if current_line:
2746         if (
2747             trailing_comma_safe
2748             and delimiter_priority == COMMA_PRIORITY
2749             and current_line.leaves[-1].type != token.COMMA
2750             and current_line.leaves[-1].type != STANDALONE_COMMENT
2751         ):
2752             current_line.append(Leaf(token.COMMA, ","))
2753         yield current_line
2754
2755
2756 @dont_increase_indentation
2757 def standalone_comment_split(
2758     line: Line, features: Collection[Feature] = ()
2759 ) -> Iterator[Line]:
2760     """Split standalone comments from the rest of the line."""
2761     if not line.contains_standalone_comments(0):
2762         raise CannotSplit("Line does not have any standalone comments")
2763
2764     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2765
2766     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2767         """Append `leaf` to current line or to new line if appending impossible."""
2768         nonlocal current_line
2769         try:
2770             current_line.append_safe(leaf, preformatted=True)
2771         except ValueError:
2772             yield current_line
2773
2774             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2775             current_line.append(leaf)
2776
2777     for leaf in line.leaves:
2778         yield from append_to_line(leaf)
2779
2780         for comment_after in line.comments_after(leaf):
2781             yield from append_to_line(comment_after)
2782
2783     if current_line:
2784         yield current_line
2785
2786
2787 def is_import(leaf: Leaf) -> bool:
2788     """Return True if the given leaf starts an import statement."""
2789     p = leaf.parent
2790     t = leaf.type
2791     v = leaf.value
2792     return bool(
2793         t == token.NAME
2794         and (
2795             (v == "import" and p and p.type == syms.import_name)
2796             or (v == "from" and p and p.type == syms.import_from)
2797         )
2798     )
2799
2800
2801 def is_type_comment(leaf: Leaf, suffix: str = "") -> bool:
2802     """Return True if the given leaf is a special comment.
2803     Only returns true for type comments for now."""
2804     t = leaf.type
2805     v = leaf.value
2806     return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix)
2807
2808
2809 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2810     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2811     else.
2812
2813     Note: don't use backslashes for formatting or you'll lose your voting rights.
2814     """
2815     if not inside_brackets:
2816         spl = leaf.prefix.split("#")
2817         if "\\" not in spl[0]:
2818             nl_count = spl[-1].count("\n")
2819             if len(spl) > 1:
2820                 nl_count -= 1
2821             leaf.prefix = "\n" * nl_count
2822             return
2823
2824     leaf.prefix = ""
2825
2826
2827 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2828     """Make all string prefixes lowercase.
2829
2830     If remove_u_prefix is given, also removes any u prefix from the string.
2831
2832     Note: Mutates its argument.
2833     """
2834     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2835     assert match is not None, f"failed to match string {leaf.value!r}"
2836     orig_prefix = match.group(1)
2837     new_prefix = orig_prefix.lower()
2838     if remove_u_prefix:
2839         new_prefix = new_prefix.replace("u", "")
2840     leaf.value = f"{new_prefix}{match.group(2)}"
2841
2842
2843 def normalize_string_quotes(leaf: Leaf) -> None:
2844     """Prefer double quotes but only if it doesn't cause more escaping.
2845
2846     Adds or removes backslashes as appropriate. Doesn't parse and fix
2847     strings nested in f-strings (yet).
2848
2849     Note: Mutates its argument.
2850     """
2851     value = leaf.value.lstrip("furbFURB")
2852     if value[:3] == '"""':
2853         return
2854
2855     elif value[:3] == "'''":
2856         orig_quote = "'''"
2857         new_quote = '"""'
2858     elif value[0] == '"':
2859         orig_quote = '"'
2860         new_quote = "'"
2861     else:
2862         orig_quote = "'"
2863         new_quote = '"'
2864     first_quote_pos = leaf.value.find(orig_quote)
2865     if first_quote_pos == -1:
2866         return  # There's an internal error
2867
2868     prefix = leaf.value[:first_quote_pos]
2869     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2870     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2871     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2872     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2873     if "r" in prefix.casefold():
2874         if unescaped_new_quote.search(body):
2875             # There's at least one unescaped new_quote in this raw string
2876             # so converting is impossible
2877             return
2878
2879         # Do not introduce or remove backslashes in raw strings
2880         new_body = body
2881     else:
2882         # remove unnecessary escapes
2883         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2884         if body != new_body:
2885             # Consider the string without unnecessary escapes as the original
2886             body = new_body
2887             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2888         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2889         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2890     if "f" in prefix.casefold():
2891         matches = re.findall(
2892             r"""
2893             (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
2894                 ([^{].*?)  # contents of the brackets except if begins with {{
2895             \}(?:[^}]|$)  # A } followed by end of the string or a non-}
2896             """,
2897             new_body,
2898             re.VERBOSE,
2899         )
2900         for m in matches:
2901             if "\\" in str(m):
2902                 # Do not introduce backslashes in interpolated expressions
2903                 return
2904
2905     if new_quote == '"""' and new_body[-1:] == '"':
2906         # edge case:
2907         new_body = new_body[:-1] + '\\"'
2908     orig_escape_count = body.count("\\")
2909     new_escape_count = new_body.count("\\")
2910     if new_escape_count > orig_escape_count:
2911         return  # Do not introduce more escaping
2912
2913     if new_escape_count == orig_escape_count and orig_quote == '"':
2914         return  # Prefer double quotes
2915
2916     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2917
2918
2919 def normalize_numeric_literal(leaf: Leaf) -> None:
2920     """Normalizes numeric (float, int, and complex) literals.
2921
2922     All letters used in the representation are normalized to lowercase (except
2923     in Python 2 long literals).
2924     """
2925     text = leaf.value.lower()
2926     if text.startswith(("0o", "0b")):
2927         # Leave octal and binary literals alone.
2928         pass
2929     elif text.startswith("0x"):
2930         # Change hex literals to upper case.
2931         before, after = text[:2], text[2:]
2932         text = f"{before}{after.upper()}"
2933     elif "e" in text:
2934         before, after = text.split("e")
2935         sign = ""
2936         if after.startswith("-"):
2937             after = after[1:]
2938             sign = "-"
2939         elif after.startswith("+"):
2940             after = after[1:]
2941         before = format_float_or_int_string(before)
2942         text = f"{before}e{sign}{after}"
2943     elif text.endswith(("j", "l")):
2944         number = text[:-1]
2945         suffix = text[-1]
2946         # Capitalize in "2L" because "l" looks too similar to "1".
2947         if suffix == "l":
2948             suffix = "L"
2949         text = f"{format_float_or_int_string(number)}{suffix}"
2950     else:
2951         text = format_float_or_int_string(text)
2952     leaf.value = text
2953
2954
2955 def format_float_or_int_string(text: str) -> str:
2956     """Formats a float string like "1.0"."""
2957     if "." not in text:
2958         return text
2959
2960     before, after = text.split(".")
2961     return f"{before or 0}.{after or 0}"
2962
2963
2964 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2965     """Make existing optional parentheses invisible or create new ones.
2966
2967     `parens_after` is a set of string leaf values immediately after which parens
2968     should be put.
2969
2970     Standardizes on visible parentheses for single-element tuples, and keeps
2971     existing visible parentheses for other tuples and generator expressions.
2972     """
2973     for pc in list_comments(node.prefix, is_endmarker=False):
2974         if pc.value in FMT_OFF:
2975             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2976             return
2977     check_lpar = False
2978     for index, child in enumerate(list(node.children)):
2979         # Add parentheses around long tuple unpacking in assignments.
2980         if (
2981             index == 0
2982             and isinstance(child, Node)
2983             and child.type == syms.testlist_star_expr
2984         ):
2985             check_lpar = True
2986
2987         if check_lpar:
2988             if is_walrus_assignment(child):
2989                 continue
2990
2991             if child.type == syms.atom:
2992                 if maybe_make_parens_invisible_in_atom(child, parent=node):
2993                     wrap_in_parentheses(node, child, visible=False)
2994             elif is_one_tuple(child):
2995                 wrap_in_parentheses(node, child, visible=True)
2996             elif node.type == syms.import_from:
2997                 # "import from" nodes store parentheses directly as part of
2998                 # the statement
2999                 if child.type == token.LPAR:
3000                     # make parentheses invisible
3001                     child.value = ""  # type: ignore
3002                     node.children[-1].value = ""  # type: ignore
3003                 elif child.type != token.STAR:
3004                     # insert invisible parentheses
3005                     node.insert_child(index, Leaf(token.LPAR, ""))
3006                     node.append_child(Leaf(token.RPAR, ""))
3007                 break
3008
3009             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
3010                 wrap_in_parentheses(node, child, visible=False)
3011
3012         check_lpar = isinstance(child, Leaf) and child.value in parens_after
3013
3014
3015 def normalize_fmt_off(node: Node) -> None:
3016     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
3017     try_again = True
3018     while try_again:
3019         try_again = convert_one_fmt_off_pair(node)
3020
3021
3022 def convert_one_fmt_off_pair(node: Node) -> bool:
3023     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
3024
3025     Returns True if a pair was converted.
3026     """
3027     for leaf in node.leaves():
3028         previous_consumed = 0
3029         for comment in list_comments(leaf.prefix, is_endmarker=False):
3030             if comment.value in FMT_OFF:
3031                 # We only want standalone comments. If there's no previous leaf or
3032                 # the previous leaf is indentation, it's a standalone comment in
3033                 # disguise.
3034                 if comment.type != STANDALONE_COMMENT:
3035                     prev = preceding_leaf(leaf)
3036                     if prev and prev.type not in WHITESPACE:
3037                         continue
3038
3039                 ignored_nodes = list(generate_ignored_nodes(leaf))
3040                 if not ignored_nodes:
3041                     continue
3042
3043                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
3044                 parent = first.parent
3045                 prefix = first.prefix
3046                 first.prefix = prefix[comment.consumed :]
3047                 hidden_value = (
3048                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
3049                 )
3050                 if hidden_value.endswith("\n"):
3051                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
3052                     # leaf (possibly followed by a DEDENT).
3053                     hidden_value = hidden_value[:-1]
3054                 first_idx: Optional[int] = None
3055                 for ignored in ignored_nodes:
3056                     index = ignored.remove()
3057                     if first_idx is None:
3058                         first_idx = index
3059                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
3060                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
3061                 parent.insert_child(
3062                     first_idx,
3063                     Leaf(
3064                         STANDALONE_COMMENT,
3065                         hidden_value,
3066                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
3067                     ),
3068                 )
3069                 return True
3070
3071             previous_consumed = comment.consumed
3072
3073     return False
3074
3075
3076 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
3077     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
3078
3079     Stops at the end of the block.
3080     """
3081     container: Optional[LN] = container_of(leaf)
3082     while container is not None and container.type != token.ENDMARKER:
3083         is_fmt_on = False
3084         for comment in list_comments(container.prefix, is_endmarker=False):
3085             if comment.value in FMT_ON:
3086                 is_fmt_on = True
3087             elif comment.value in FMT_OFF:
3088                 is_fmt_on = False
3089         if is_fmt_on:
3090             return
3091
3092         yield container
3093
3094         container = container.next_sibling
3095
3096
3097 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
3098     """If it's safe, make the parens in the atom `node` invisible, recursively.
3099     Additionally, remove repeated, adjacent invisible parens from the atom `node`
3100     as they are redundant.
3101
3102     Returns whether the node should itself be wrapped in invisible parentheses.
3103
3104     """
3105     if (
3106         node.type != syms.atom
3107         or is_empty_tuple(node)
3108         or is_one_tuple(node)
3109         or (is_yield(node) and parent.type != syms.expr_stmt)
3110         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
3111     ):
3112         return False
3113
3114     first = node.children[0]
3115     last = node.children[-1]
3116     if first.type == token.LPAR and last.type == token.RPAR:
3117         middle = node.children[1]
3118         # make parentheses invisible
3119         first.value = ""  # type: ignore
3120         last.value = ""  # type: ignore
3121         maybe_make_parens_invisible_in_atom(middle, parent=parent)
3122
3123         if is_atom_with_invisible_parens(middle):
3124             # Strip the invisible parens from `middle` by replacing
3125             # it with the child in-between the invisible parens
3126             middle.replace(middle.children[1])
3127
3128         return False
3129
3130     return True
3131
3132
3133 def is_atom_with_invisible_parens(node: LN) -> bool:
3134     """Given a `LN`, determines whether it's an atom `node` with invisible
3135     parens. Useful in dedupe-ing and normalizing parens.
3136     """
3137     if isinstance(node, Leaf) or node.type != syms.atom:
3138         return False
3139
3140     first, last = node.children[0], node.children[-1]
3141     return (
3142         isinstance(first, Leaf)
3143         and first.type == token.LPAR
3144         and first.value == ""
3145         and isinstance(last, Leaf)
3146         and last.type == token.RPAR
3147         and last.value == ""
3148     )
3149
3150
3151 def is_empty_tuple(node: LN) -> bool:
3152     """Return True if `node` holds an empty tuple."""
3153     return (
3154         node.type == syms.atom
3155         and len(node.children) == 2
3156         and node.children[0].type == token.LPAR
3157         and node.children[1].type == token.RPAR
3158     )
3159
3160
3161 def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
3162     """Returns `wrapped` if `node` is of the shape ( wrapped ).
3163
3164     Parenthesis can be optional. Returns None otherwise"""
3165     if len(node.children) != 3:
3166         return None
3167
3168     lpar, wrapped, rpar = node.children
3169     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
3170         return None
3171
3172     return wrapped
3173
3174
3175 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
3176     """Wrap `child` in parentheses.
3177
3178     This replaces `child` with an atom holding the parentheses and the old
3179     child.  That requires moving the prefix.
3180
3181     If `visible` is False, the leaves will be valueless (and thus invisible).
3182     """
3183     lpar = Leaf(token.LPAR, "(" if visible else "")
3184     rpar = Leaf(token.RPAR, ")" if visible else "")
3185     prefix = child.prefix
3186     child.prefix = ""
3187     index = child.remove() or 0
3188     new_child = Node(syms.atom, [lpar, child, rpar])
3189     new_child.prefix = prefix
3190     parent.insert_child(index, new_child)
3191
3192
3193 def is_one_tuple(node: LN) -> bool:
3194     """Return True if `node` holds a tuple with one element, with or without parens."""
3195     if node.type == syms.atom:
3196         gexp = unwrap_singleton_parenthesis(node)
3197         if gexp is None or gexp.type != syms.testlist_gexp:
3198             return False
3199
3200         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
3201
3202     return (
3203         node.type in IMPLICIT_TUPLE
3204         and len(node.children) == 2
3205         and node.children[1].type == token.COMMA
3206     )
3207
3208
3209 def is_walrus_assignment(node: LN) -> bool:
3210     """Return True iff `node` is of the shape ( test := test )"""
3211     inner = unwrap_singleton_parenthesis(node)
3212     return inner is not None and inner.type == syms.namedexpr_test
3213
3214
3215 def is_yield(node: LN) -> bool:
3216     """Return True if `node` holds a `yield` or `yield from` expression."""
3217     if node.type == syms.yield_expr:
3218         return True
3219
3220     if node.type == token.NAME and node.value == "yield":  # type: ignore
3221         return True
3222
3223     if node.type != syms.atom:
3224         return False
3225
3226     if len(node.children) != 3:
3227         return False
3228
3229     lpar, expr, rpar = node.children
3230     if lpar.type == token.LPAR and rpar.type == token.RPAR:
3231         return is_yield(expr)
3232
3233     return False
3234
3235
3236 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
3237     """Return True if `leaf` is a star or double star in a vararg or kwarg.
3238
3239     If `within` includes VARARGS_PARENTS, this applies to function signatures.
3240     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
3241     extended iterable unpacking (PEP 3132) and additional unpacking
3242     generalizations (PEP 448).
3243     """
3244     if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
3245         return False
3246
3247     p = leaf.parent
3248     if p.type == syms.star_expr:
3249         # Star expressions are also used as assignment targets in extended
3250         # iterable unpacking (PEP 3132).  See what its parent is instead.
3251         if not p.parent:
3252             return False
3253
3254         p = p.parent
3255
3256     return p.type in within
3257
3258
3259 def is_multiline_string(leaf: Leaf) -> bool:
3260     """Return True if `leaf` is a multiline string that actually spans many lines."""
3261     value = leaf.value.lstrip("furbFURB")
3262     return value[:3] in {'"""', "'''"} and "\n" in value
3263
3264
3265 def is_stub_suite(node: Node) -> bool:
3266     """Return True if `node` is a suite with a stub body."""
3267     if (
3268         len(node.children) != 4
3269         or node.children[0].type != token.NEWLINE
3270         or node.children[1].type != token.INDENT
3271         or node.children[3].type != token.DEDENT
3272     ):
3273         return False
3274
3275     return is_stub_body(node.children[2])
3276
3277
3278 def is_stub_body(node: LN) -> bool:
3279     """Return True if `node` is a simple statement containing an ellipsis."""
3280     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3281         return False
3282
3283     if len(node.children) != 2:
3284         return False
3285
3286     child = node.children[0]
3287     return (
3288         child.type == syms.atom
3289         and len(child.children) == 3
3290         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3291     )
3292
3293
3294 def max_delimiter_priority_in_atom(node: LN) -> Priority:
3295     """Return maximum delimiter priority inside `node`.
3296
3297     This is specific to atoms with contents contained in a pair of parentheses.
3298     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3299     """
3300     if node.type != syms.atom:
3301         return 0
3302
3303     first = node.children[0]
3304     last = node.children[-1]
3305     if not (first.type == token.LPAR and last.type == token.RPAR):
3306         return 0
3307
3308     bt = BracketTracker()
3309     for c in node.children[1:-1]:
3310         if isinstance(c, Leaf):
3311             bt.mark(c)
3312         else:
3313             for leaf in c.leaves():
3314                 bt.mark(leaf)
3315     try:
3316         return bt.max_delimiter_priority()
3317
3318     except ValueError:
3319         return 0
3320
3321
3322 def ensure_visible(leaf: Leaf) -> None:
3323     """Make sure parentheses are visible.
3324
3325     They could be invisible as part of some statements (see
3326     :func:`normalize_invisible_parens` and :func:`visit_import_from`).
3327     """
3328     if leaf.type == token.LPAR:
3329         leaf.value = "("
3330     elif leaf.type == token.RPAR:
3331         leaf.value = ")"
3332
3333
3334 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3335     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3336
3337     if not (
3338         opening_bracket.parent
3339         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3340         and opening_bracket.value in "[{("
3341     ):
3342         return False
3343
3344     try:
3345         last_leaf = line.leaves[-1]
3346         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3347         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3348     except (IndexError, ValueError):
3349         return False
3350
3351     return max_priority == COMMA_PRIORITY
3352
3353
3354 def get_features_used(node: Node) -> Set[Feature]:
3355     """Return a set of (relatively) new Python features used in this file.
3356
3357     Currently looking for:
3358     - f-strings;
3359     - underscores in numeric literals;
3360     - trailing commas after * or ** in function signatures and calls;
3361     - positional only arguments in function signatures and lambdas;
3362     """
3363     features: Set[Feature] = set()
3364     for n in node.pre_order():
3365         if n.type == token.STRING:
3366             value_head = n.value[:2]  # type: ignore
3367             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3368                 features.add(Feature.F_STRINGS)
3369
3370         elif n.type == token.NUMBER:
3371             if "_" in n.value:  # type: ignore
3372                 features.add(Feature.NUMERIC_UNDERSCORES)
3373
3374         elif n.type == token.SLASH:
3375             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
3376                 features.add(Feature.POS_ONLY_ARGUMENTS)
3377
3378         elif n.type == token.COLONEQUAL:
3379             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
3380
3381         elif (
3382             n.type in {syms.typedargslist, syms.arglist}
3383             and n.children
3384             and n.children[-1].type == token.COMMA
3385         ):
3386             if n.type == syms.typedargslist:
3387                 feature = Feature.TRAILING_COMMA_IN_DEF
3388             else:
3389                 feature = Feature.TRAILING_COMMA_IN_CALL
3390
3391             for ch in n.children:
3392                 if ch.type in STARS:
3393                     features.add(feature)
3394
3395                 if ch.type == syms.argument:
3396                     for argch in ch.children:
3397                         if argch.type in STARS:
3398                             features.add(feature)
3399
3400     return features
3401
3402
3403 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3404     """Detect the version to target based on the nodes used."""
3405     features = get_features_used(node)
3406     return {
3407         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3408     }
3409
3410
3411 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3412     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3413
3414     Brackets can be omitted if the entire trailer up to and including
3415     a preceding closing bracket fits in one line.
3416
3417     Yielded sets are cumulative (contain results of previous yields, too).  First
3418     set is empty.
3419     """
3420
3421     omit: Set[LeafID] = set()
3422     yield omit
3423
3424     length = 4 * line.depth
3425     opening_bracket: Optional[Leaf] = None
3426     closing_bracket: Optional[Leaf] = None
3427     inner_brackets: Set[LeafID] = set()
3428     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3429         length += leaf_length
3430         if length > line_length:
3431             break
3432
3433         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3434         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3435             break
3436
3437         if opening_bracket:
3438             if leaf is opening_bracket:
3439                 opening_bracket = None
3440             elif leaf.type in CLOSING_BRACKETS:
3441                 inner_brackets.add(id(leaf))
3442         elif leaf.type in CLOSING_BRACKETS:
3443             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3444                 # Empty brackets would fail a split so treat them as "inner"
3445                 # brackets (e.g. only add them to the `omit` set if another
3446                 # pair of brackets was good enough.
3447                 inner_brackets.add(id(leaf))
3448                 continue
3449
3450             if closing_bracket:
3451                 omit.add(id(closing_bracket))
3452                 omit.update(inner_brackets)
3453                 inner_brackets.clear()
3454                 yield omit
3455
3456             if leaf.value:
3457                 opening_bracket = leaf.opening_bracket
3458                 closing_bracket = leaf
3459
3460
3461 def get_future_imports(node: Node) -> Set[str]:
3462     """Return a set of __future__ imports in the file."""
3463     imports: Set[str] = set()
3464
3465     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3466         for child in children:
3467             if isinstance(child, Leaf):
3468                 if child.type == token.NAME:
3469                     yield child.value
3470
3471             elif child.type == syms.import_as_name:
3472                 orig_name = child.children[0]
3473                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3474                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3475                 yield orig_name.value
3476
3477             elif child.type == syms.import_as_names:
3478                 yield from get_imports_from_children(child.children)
3479
3480             else:
3481                 raise AssertionError("Invalid syntax parsing imports")
3482
3483     for child in node.children:
3484         if child.type != syms.simple_stmt:
3485             break
3486
3487         first_child = child.children[0]
3488         if isinstance(first_child, Leaf):
3489             # Continue looking if we see a docstring; otherwise stop.
3490             if (
3491                 len(child.children) == 2
3492                 and first_child.type == token.STRING
3493                 and child.children[1].type == token.NEWLINE
3494             ):
3495                 continue
3496
3497             break
3498
3499         elif first_child.type == syms.import_from:
3500             module_name = first_child.children[1]
3501             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3502                 break
3503
3504             imports |= set(get_imports_from_children(first_child.children[3:]))
3505         else:
3506             break
3507
3508     return imports
3509
3510
3511 @lru_cache()
3512 def get_gitignore(root: Path) -> PathSpec:
3513     """ Return a PathSpec matching gitignore content if present."""
3514     gitignore = root / ".gitignore"
3515     lines: List[str] = []
3516     if gitignore.is_file():
3517         with gitignore.open() as gf:
3518             lines = gf.readlines()
3519     return PathSpec.from_lines("gitwildmatch", lines)
3520
3521
3522 def gen_python_files_in_dir(
3523     path: Path,
3524     root: Path,
3525     include: Pattern[str],
3526     exclude: Pattern[str],
3527     report: "Report",
3528     gitignore: PathSpec,
3529 ) -> Iterator[Path]:
3530     """Generate all files under `path` whose paths are not excluded by the
3531     `exclude` regex, but are included by the `include` regex.
3532
3533     Symbolic links pointing outside of the `root` directory are ignored.
3534
3535     `report` is where output about exclusions goes.
3536     """
3537     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3538     for child in path.iterdir():
3539         # First ignore files matching .gitignore
3540         if gitignore.match_file(child.as_posix()):
3541             report.path_ignored(child, f"matches the .gitignore file content")
3542             continue
3543
3544         # Then ignore with `exclude` option.
3545         try:
3546             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3547         except OSError as e:
3548             report.path_ignored(child, f"cannot be read because {e}")
3549             continue
3550
3551         except ValueError:
3552             if child.is_symlink():
3553                 report.path_ignored(
3554                     child, f"is a symbolic link that points outside {root}"
3555                 )
3556                 continue
3557
3558             raise
3559
3560         if child.is_dir():
3561             normalized_path += "/"
3562
3563         exclude_match = exclude.search(normalized_path)
3564         if exclude_match and exclude_match.group(0):
3565             report.path_ignored(child, f"matches the --exclude regular expression")
3566             continue
3567
3568         if child.is_dir():
3569             yield from gen_python_files_in_dir(
3570                 child, root, include, exclude, report, gitignore
3571             )
3572
3573         elif child.is_file():
3574             include_match = include.search(normalized_path)
3575             if include_match:
3576                 yield child
3577
3578
3579 @lru_cache()
3580 def find_project_root(srcs: Iterable[str]) -> Path:
3581     """Return a directory containing .git, .hg, or pyproject.toml.
3582
3583     That directory can be one of the directories passed in `srcs` or their
3584     common parent.
3585
3586     If no directory in the tree contains a marker that would specify it's the
3587     project root, the root of the file system is returned.
3588     """
3589     if not srcs:
3590         return Path("/").resolve()
3591
3592     common_base = min(Path(src).resolve() for src in srcs)
3593     if common_base.is_dir():
3594         # Append a fake file so `parents` below returns `common_base_dir`, too.
3595         common_base /= "fake-file"
3596     for directory in common_base.parents:
3597         if (directory / ".git").is_dir():
3598             return directory
3599
3600         if (directory / ".hg").is_dir():
3601             return directory
3602
3603         if (directory / "pyproject.toml").is_file():
3604             return directory
3605
3606     return directory
3607
3608
3609 @dataclass
3610 class Report:
3611     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3612
3613     check: bool = False
3614     quiet: bool = False
3615     verbose: bool = False
3616     change_count: int = 0
3617     same_count: int = 0
3618     failure_count: int = 0
3619
3620     def done(self, src: Path, changed: Changed) -> None:
3621         """Increment the counter for successful reformatting. Write out a message."""
3622         if changed is Changed.YES:
3623             reformatted = "would reformat" if self.check else "reformatted"
3624             if self.verbose or not self.quiet:
3625                 out(f"{reformatted} {src}")
3626             self.change_count += 1
3627         else:
3628             if self.verbose:
3629                 if changed is Changed.NO:
3630                     msg = f"{src} already well formatted, good job."
3631                 else:
3632                     msg = f"{src} wasn't modified on disk since last run."
3633                 out(msg, bold=False)
3634             self.same_count += 1
3635
3636     def failed(self, src: Path, message: str) -> None:
3637         """Increment the counter for failed reformatting. Write out a message."""
3638         err(f"error: cannot format {src}: {message}")
3639         self.failure_count += 1
3640
3641     def path_ignored(self, path: Path, message: str) -> None:
3642         if self.verbose:
3643             out(f"{path} ignored: {message}", bold=False)
3644
3645     @property
3646     def return_code(self) -> int:
3647         """Return the exit code that the app should use.
3648
3649         This considers the current state of changed files and failures:
3650         - if there were any failures, return 123;
3651         - if any files were changed and --check is being used, return 1;
3652         - otherwise return 0.
3653         """
3654         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3655         # 126 we have special return codes reserved by the shell.
3656         if self.failure_count:
3657             return 123
3658
3659         elif self.change_count and self.check:
3660             return 1
3661
3662         return 0
3663
3664     def __str__(self) -> str:
3665         """Render a color report of the current state.
3666
3667         Use `click.unstyle` to remove colors.
3668         """
3669         if self.check:
3670             reformatted = "would be reformatted"
3671             unchanged = "would be left unchanged"
3672             failed = "would fail to reformat"
3673         else:
3674             reformatted = "reformatted"
3675             unchanged = "left unchanged"
3676             failed = "failed to reformat"
3677         report = []
3678         if self.change_count:
3679             s = "s" if self.change_count > 1 else ""
3680             report.append(
3681                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3682             )
3683         if self.same_count:
3684             s = "s" if self.same_count > 1 else ""
3685             report.append(f"{self.same_count} file{s} {unchanged}")
3686         if self.failure_count:
3687             s = "s" if self.failure_count > 1 else ""
3688             report.append(
3689                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3690             )
3691         return ", ".join(report) + "."
3692
3693
3694 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
3695     filename = "<unknown>"
3696     if sys.version_info >= (3, 8):
3697         # TODO: support Python 4+ ;)
3698         for minor_version in range(sys.version_info[1], 4, -1):
3699             try:
3700                 return ast.parse(src, filename, feature_version=(3, minor_version))
3701             except SyntaxError:
3702                 continue
3703     else:
3704         for feature_version in (7, 6):
3705             try:
3706                 return ast3.parse(src, filename, feature_version=feature_version)
3707             except SyntaxError:
3708                 continue
3709
3710     return ast27.parse(src)
3711
3712
3713 def _fixup_ast_constants(
3714     node: Union[ast.AST, ast3.AST, ast27.AST]
3715 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
3716     """Map ast nodes deprecated in 3.8 to Constant."""
3717     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
3718         return ast.Constant(value=node.s)
3719
3720     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
3721         return ast.Constant(value=node.n)
3722
3723     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
3724         return ast.Constant(value=node.value)
3725
3726     return node
3727
3728
3729 def assert_equivalent(src: str, dst: str) -> None:
3730     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3731
3732     def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3733         """Simple visitor generating strings to compare ASTs by content."""
3734
3735         node = _fixup_ast_constants(node)
3736
3737         yield f"{'  ' * depth}{node.__class__.__name__}("
3738
3739         for field in sorted(node._fields):  # noqa: F402
3740             # TypeIgnore has only one field 'lineno' which breaks this comparison
3741             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
3742             if sys.version_info >= (3, 8):
3743                 type_ignore_classes += (ast.TypeIgnore,)
3744             if isinstance(node, type_ignore_classes):
3745                 break
3746
3747             try:
3748                 value = getattr(node, field)
3749             except AttributeError:
3750                 continue
3751
3752             yield f"{'  ' * (depth+1)}{field}="
3753
3754             if isinstance(value, list):
3755                 for item in value:
3756                     # Ignore nested tuples within del statements, because we may insert
3757                     # parentheses and they change the AST.
3758                     if (
3759                         field == "targets"
3760                         and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
3761                         and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
3762                     ):
3763                         for item in item.elts:
3764                             yield from _v(item, depth + 2)
3765
3766                     elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
3767                         yield from _v(item, depth + 2)
3768
3769             elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
3770                 yield from _v(value, depth + 2)
3771
3772             else:
3773                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3774
3775         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3776
3777     try:
3778         src_ast = parse_ast(src)
3779     except Exception as exc:
3780         raise AssertionError(
3781             f"cannot use --safe with this file; failed to parse source file.  "
3782             f"AST error message: {exc}"
3783         )
3784
3785     try:
3786         dst_ast = parse_ast(dst)
3787     except Exception as exc:
3788         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3789         raise AssertionError(
3790             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3791             f"Please report a bug on https://github.com/psf/black/issues.  "
3792             f"This invalid output might be helpful: {log}"
3793         ) from None
3794
3795     src_ast_str = "\n".join(_v(src_ast))
3796     dst_ast_str = "\n".join(_v(dst_ast))
3797     if src_ast_str != dst_ast_str:
3798         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3799         raise AssertionError(
3800             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3801             f"the source.  "
3802             f"Please report a bug on https://github.com/psf/black/issues.  "
3803             f"This diff might be helpful: {log}"
3804         ) from None
3805
3806
3807 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3808     """Raise AssertionError if `dst` reformats differently the second time."""
3809     newdst = format_str(dst, mode=mode)
3810     if dst != newdst:
3811         log = dump_to_file(
3812             diff(src, dst, "source", "first pass"),
3813             diff(dst, newdst, "first pass", "second pass"),
3814         )
3815         raise AssertionError(
3816             f"INTERNAL ERROR: Black produced different code on the second pass "
3817             f"of the formatter.  "
3818             f"Please report a bug on https://github.com/psf/black/issues.  "
3819             f"This diff might be helpful: {log}"
3820         ) from None
3821
3822
3823 @mypyc_attr(patchable=True)
3824 def dump_to_file(*output: str) -> str:
3825     """Dump `output` to a temporary file. Return path to the file."""
3826     with tempfile.NamedTemporaryFile(
3827         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3828     ) as f:
3829         for lines in output:
3830             f.write(lines)
3831             if lines and lines[-1] != "\n":
3832                 f.write("\n")
3833     return f.name
3834
3835
3836 @contextmanager
3837 def nullcontext() -> Iterator[None]:
3838     """Return an empty context manager.
3839
3840     To be used like `nullcontext` in Python 3.7.
3841     """
3842     yield
3843
3844
3845 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3846     """Return a unified diff string between strings `a` and `b`."""
3847     import difflib
3848
3849     a_lines = [line + "\n" for line in a.split("\n")]
3850     b_lines = [line + "\n" for line in b.split("\n")]
3851     return "".join(
3852         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3853     )
3854
3855
3856 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
3857     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3858     err("Aborted!")
3859     for task in tasks:
3860         task.cancel()
3861
3862
3863 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
3864     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3865     try:
3866         if sys.version_info[:2] >= (3, 7):
3867             all_tasks = asyncio.all_tasks
3868         else:
3869             all_tasks = asyncio.Task.all_tasks
3870         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3871         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3872         if not to_cancel:
3873             return
3874
3875         for task in to_cancel:
3876             task.cancel()
3877         loop.run_until_complete(
3878             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3879         )
3880     finally:
3881         # `concurrent.futures.Future` objects cannot be cancelled once they
3882         # are already running. There might be some when the `shutdown()` happened.
3883         # Silence their logger's spew about the event loop being closed.
3884         cf_logger = logging.getLogger("concurrent.futures")
3885         cf_logger.setLevel(logging.CRITICAL)
3886         loop.close()
3887
3888
3889 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3890     """Replace `regex` with `replacement` twice on `original`.
3891
3892     This is used by string normalization to perform replaces on
3893     overlapping matches.
3894     """
3895     return regex.sub(replacement, regex.sub(replacement, original))
3896
3897
3898 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3899     """Compile a regular expression string in `regex`.
3900
3901     If it contains newlines, use verbose mode.
3902     """
3903     if "\n" in regex:
3904         regex = "(?x)" + regex
3905     compiled: Pattern[str] = re.compile(regex)
3906     return compiled
3907
3908
3909 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3910     """Like `reversed(enumerate(sequence))` if that were possible."""
3911     index = len(sequence) - 1
3912     for element in reversed(sequence):
3913         yield (index, element)
3914         index -= 1
3915
3916
3917 def enumerate_with_length(
3918     line: Line, reversed: bool = False
3919 ) -> Iterator[Tuple[Index, Leaf, int]]:
3920     """Return an enumeration of leaves with their length.
3921
3922     Stops prematurely on multiline strings and standalone comments.
3923     """
3924     op = cast(
3925         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3926         enumerate_reversed if reversed else enumerate,
3927     )
3928     for index, leaf in op(line.leaves):
3929         length = len(leaf.prefix) + len(leaf.value)
3930         if "\n" in leaf.value:
3931             return  # Multiline strings, we can't continue.
3932
3933         for comment in line.comments_after(leaf):
3934             length += len(comment.value)
3935
3936         yield index, leaf, length
3937
3938
3939 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3940     """Return True if `line` is no longer than `line_length`.
3941
3942     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3943     """
3944     if not line_str:
3945         line_str = str(line).strip("\n")
3946     return (
3947         len(line_str) <= line_length
3948         and "\n" not in line_str  # multiline strings
3949         and not line.contains_standalone_comments()
3950     )
3951
3952
3953 def can_be_split(line: Line) -> bool:
3954     """Return False if the line cannot be split *for sure*.
3955
3956     This is not an exhaustive search but a cheap heuristic that we can use to
3957     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3958     in unnecessary parentheses).
3959     """
3960     leaves = line.leaves
3961     if len(leaves) < 2:
3962         return False
3963
3964     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3965         call_count = 0
3966         dot_count = 0
3967         next = leaves[-1]
3968         for leaf in leaves[-2::-1]:
3969             if leaf.type in OPENING_BRACKETS:
3970                 if next.type not in CLOSING_BRACKETS:
3971                     return False
3972
3973                 call_count += 1
3974             elif leaf.type == token.DOT:
3975                 dot_count += 1
3976             elif leaf.type == token.NAME:
3977                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3978                     return False
3979
3980             elif leaf.type not in CLOSING_BRACKETS:
3981                 return False
3982
3983             if dot_count > 1 and call_count > 1:
3984                 return False
3985
3986     return True
3987
3988
3989 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3990     """Does `line` have a shape safe to reformat without optional parens around it?
3991
3992     Returns True for only a subset of potentially nice looking formattings but
3993     the point is to not return false positives that end up producing lines that
3994     are too long.
3995     """
3996     bt = line.bracket_tracker
3997     if not bt.delimiters:
3998         # Without delimiters the optional parentheses are useless.
3999         return True
4000
4001     max_priority = bt.max_delimiter_priority()
4002     if bt.delimiter_count_with_priority(max_priority) > 1:
4003         # With more than one delimiter of a kind the optional parentheses read better.
4004         return False
4005
4006     if max_priority == DOT_PRIORITY:
4007         # A single stranded method call doesn't require optional parentheses.
4008         return True
4009
4010     assert len(line.leaves) >= 2, "Stranded delimiter"
4011
4012     first = line.leaves[0]
4013     second = line.leaves[1]
4014     penultimate = line.leaves[-2]
4015     last = line.leaves[-1]
4016
4017     # With a single delimiter, omit if the expression starts or ends with
4018     # a bracket.
4019     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
4020         remainder = False
4021         length = 4 * line.depth
4022         for _index, leaf, leaf_length in enumerate_with_length(line):
4023             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
4024                 remainder = True
4025             if remainder:
4026                 length += leaf_length
4027                 if length > line_length:
4028                     break
4029
4030                 if leaf.type in OPENING_BRACKETS:
4031                     # There are brackets we can further split on.
4032                     remainder = False
4033
4034         else:
4035             # checked the entire string and line length wasn't exceeded
4036             if len(line.leaves) == _index + 1:
4037                 return True
4038
4039         # Note: we are not returning False here because a line might have *both*
4040         # a leading opening bracket and a trailing closing bracket.  If the
4041         # opening bracket doesn't match our rule, maybe the closing will.
4042
4043     if (
4044         last.type == token.RPAR
4045         or last.type == token.RBRACE
4046         or (
4047             # don't use indexing for omitting optional parentheses;
4048             # it looks weird
4049             last.type == token.RSQB
4050             and last.parent
4051             and last.parent.type != syms.trailer
4052         )
4053     ):
4054         if penultimate.type in OPENING_BRACKETS:
4055             # Empty brackets don't help.
4056             return False
4057
4058         if is_multiline_string(first):
4059             # Additional wrapping of a multiline string in this situation is
4060             # unnecessary.
4061             return True
4062
4063         length = 4 * line.depth
4064         seen_other_brackets = False
4065         for _index, leaf, leaf_length in enumerate_with_length(line):
4066             length += leaf_length
4067             if leaf is last.opening_bracket:
4068                 if seen_other_brackets or length <= line_length:
4069                     return True
4070
4071             elif leaf.type in OPENING_BRACKETS:
4072                 # There are brackets we can further split on.
4073                 seen_other_brackets = True
4074
4075     return False
4076
4077
4078 def get_cache_file(mode: FileMode) -> Path:
4079     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
4080
4081
4082 def read_cache(mode: FileMode) -> Cache:
4083     """Read the cache if it exists and is well formed.
4084
4085     If it is not well formed, the call to write_cache later should resolve the issue.
4086     """
4087     cache_file = get_cache_file(mode)
4088     if not cache_file.exists():
4089         return {}
4090
4091     with cache_file.open("rb") as fobj:
4092         try:
4093             cache: Cache = pickle.load(fobj)
4094         except (pickle.UnpicklingError, ValueError):
4095             return {}
4096
4097     return cache
4098
4099
4100 def get_cache_info(path: Path) -> CacheInfo:
4101     """Return the information used to check if a file is already formatted or not."""
4102     stat = path.stat()
4103     return stat.st_mtime, stat.st_size
4104
4105
4106 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
4107     """Split an iterable of paths in `sources` into two sets.
4108
4109     The first contains paths of files that modified on disk or are not in the
4110     cache. The other contains paths to non-modified files.
4111     """
4112     todo, done = set(), set()
4113     for src in sources:
4114         src = src.resolve()
4115         if cache.get(src) != get_cache_info(src):
4116             todo.add(src)
4117         else:
4118             done.add(src)
4119     return todo, done
4120
4121
4122 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
4123     """Update the cache file."""
4124     cache_file = get_cache_file(mode)
4125     try:
4126         CACHE_DIR.mkdir(parents=True, exist_ok=True)
4127         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
4128         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
4129             pickle.dump(new_cache, f, protocol=4)
4130         os.replace(f.name, cache_file)
4131     except OSError:
4132         pass
4133
4134
4135 def patch_click() -> None:
4136     """Make Click not crash.
4137
4138     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
4139     default which restricts paths that it can access during the lifetime of the
4140     application.  Click refuses to work in this scenario by raising a RuntimeError.
4141
4142     In case of Black the likelihood that non-ASCII characters are going to be used in
4143     file paths is minimal since it's Python source code.  Moreover, this crash was
4144     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
4145     """
4146     try:
4147         from click import core
4148         from click import _unicodefun  # type: ignore
4149     except ModuleNotFoundError:
4150         return
4151
4152     for module in (core, _unicodefun):
4153         if hasattr(module, "_verify_python3_env"):
4154             module._verify_python3_env = lambda: None
4155
4156
4157 def patched_main() -> None:
4158     freeze_support()
4159     patch_click()
4160     main()
4161
4162
4163 if __name__ == "__main__":
4164     patched_main()