]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

acks += revfried
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43 from typed_ast import ast3, ast27
44
45 # lib2to3 fork
46 from blib2to3.pytree import Node, Leaf, type_repr
47 from blib2to3 import pygram, pytree
48 from blib2to3.pgen2 import driver, token
49 from blib2to3.pgen2.grammar import Grammar
50 from blib2to3.pgen2.parse import ParseError
51
52
53 __version__ = "19.3b0"
54 DEFAULT_LINE_LENGTH = 88
55 DEFAULT_EXCLUDES = (
56     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
57 )
58 DEFAULT_INCLUDES = r"\.pyi?$"
59 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
60
61
62 # types
63 FileContent = str
64 Encoding = str
65 NewLine = str
66 Depth = int
67 NodeType = int
68 LeafID = int
69 Priority = int
70 Index = int
71 LN = Union[Leaf, Node]
72 SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
73 Timestamp = float
74 FileSize = int
75 CacheInfo = Tuple[Timestamp, FileSize]
76 Cache = Dict[Path, CacheInfo]
77 out = partial(click.secho, bold=True, err=True)
78 err = partial(click.secho, fg="red", err=True)
79
80 pygram.initialize(CACHE_DIR)
81 syms = pygram.python_symbols
82
83
84 class NothingChanged(UserWarning):
85     """Raised when reformatted code is the same as source."""
86
87
88 class CannotSplit(Exception):
89     """A readable split that fits the allotted line length is impossible."""
90
91
92 class InvalidInput(ValueError):
93     """Raised when input source code fails all parse attempts."""
94
95
96 class WriteBack(Enum):
97     NO = 0
98     YES = 1
99     DIFF = 2
100     CHECK = 3
101
102     @classmethod
103     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
104         if check and not diff:
105             return cls.CHECK
106
107         return cls.DIFF if diff else cls.YES
108
109
110 class Changed(Enum):
111     NO = 0
112     CACHED = 1
113     YES = 2
114
115
116 class TargetVersion(Enum):
117     PY27 = 2
118     PY33 = 3
119     PY34 = 4
120     PY35 = 5
121     PY36 = 6
122     PY37 = 7
123     PY38 = 8
124
125     def is_python2(self) -> bool:
126         return self is TargetVersion.PY27
127
128
129 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
130
131
132 class Feature(Enum):
133     # All string literals are unicode
134     UNICODE_LITERALS = 1
135     F_STRINGS = 2
136     NUMERIC_UNDERSCORES = 3
137     TRAILING_COMMA_IN_CALL = 4
138     TRAILING_COMMA_IN_DEF = 5
139
140
141 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
142     TargetVersion.PY27: set(),
143     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
144     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
145     TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
146     TargetVersion.PY36: {
147         Feature.UNICODE_LITERALS,
148         Feature.F_STRINGS,
149         Feature.NUMERIC_UNDERSCORES,
150         Feature.TRAILING_COMMA_IN_CALL,
151         Feature.TRAILING_COMMA_IN_DEF,
152     },
153     TargetVersion.PY37: {
154         Feature.UNICODE_LITERALS,
155         Feature.F_STRINGS,
156         Feature.NUMERIC_UNDERSCORES,
157         Feature.TRAILING_COMMA_IN_CALL,
158         Feature.TRAILING_COMMA_IN_DEF,
159     },
160     TargetVersion.PY38: {
161         Feature.UNICODE_LITERALS,
162         Feature.F_STRINGS,
163         Feature.NUMERIC_UNDERSCORES,
164         Feature.TRAILING_COMMA_IN_CALL,
165         Feature.TRAILING_COMMA_IN_DEF,
166     },
167 }
168
169
170 @dataclass
171 class FileMode:
172     target_versions: Set[TargetVersion] = Factory(set)
173     line_length: int = DEFAULT_LINE_LENGTH
174     string_normalization: bool = True
175     is_pyi: bool = False
176
177     def get_cache_key(self) -> str:
178         if self.target_versions:
179             version_str = ",".join(
180                 str(version.value)
181                 for version in sorted(self.target_versions, key=lambda v: v.value)
182             )
183         else:
184             version_str = "-"
185         parts = [
186             version_str,
187             str(self.line_length),
188             str(int(self.string_normalization)),
189             str(int(self.is_pyi)),
190         ]
191         return ".".join(parts)
192
193
194 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
195     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
196
197
198 def read_pyproject_toml(
199     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
200 ) -> Optional[str]:
201     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
202
203     Returns the path to a successfully found and read configuration file, None
204     otherwise.
205     """
206     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
207     if not value:
208         root = find_project_root(ctx.params.get("src", ()))
209         path = root / "pyproject.toml"
210         if path.is_file():
211             value = str(path)
212         else:
213             return None
214
215     try:
216         pyproject_toml = toml.load(value)
217         config = pyproject_toml.get("tool", {}).get("black", {})
218     except (toml.TomlDecodeError, OSError) as e:
219         raise click.FileError(
220             filename=value, hint=f"Error reading configuration file: {e}"
221         )
222
223     if not config:
224         return None
225
226     if ctx.default_map is None:
227         ctx.default_map = {}
228     ctx.default_map.update(  # type: ignore  # bad types in .pyi
229         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
230     )
231     return value
232
233
234 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
235 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
236 @click.option(
237     "-l",
238     "--line-length",
239     type=int,
240     default=DEFAULT_LINE_LENGTH,
241     help="How many characters per line to allow.",
242     show_default=True,
243 )
244 @click.option(
245     "-t",
246     "--target-version",
247     type=click.Choice([v.name.lower() for v in TargetVersion]),
248     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
249     multiple=True,
250     help=(
251         "Python versions that should be supported by Black's output. [default: "
252         "per-file auto-detection]"
253     ),
254 )
255 @click.option(
256     "--py36",
257     is_flag=True,
258     help=(
259         "Allow using Python 3.6-only syntax on all input files.  This will put "
260         "trailing commas in function signatures and calls also after *args and "
261         "**kwargs. Deprecated; use --target-version instead. "
262         "[default: per-file auto-detection]"
263     ),
264 )
265 @click.option(
266     "--pyi",
267     is_flag=True,
268     help=(
269         "Format all input files like typing stubs regardless of file extension "
270         "(useful when piping source on standard input)."
271     ),
272 )
273 @click.option(
274     "-S",
275     "--skip-string-normalization",
276     is_flag=True,
277     help="Don't normalize string quotes or prefixes.",
278 )
279 @click.option(
280     "--check",
281     is_flag=True,
282     help=(
283         "Don't write the files back, just return the status.  Return code 0 "
284         "means nothing would change.  Return code 1 means some files would be "
285         "reformatted.  Return code 123 means there was an internal error."
286     ),
287 )
288 @click.option(
289     "--diff",
290     is_flag=True,
291     help="Don't write the files back, just output a diff for each file on stdout.",
292 )
293 @click.option(
294     "--fast/--safe",
295     is_flag=True,
296     help="If --fast given, skip temporary sanity checks. [default: --safe]",
297 )
298 @click.option(
299     "--include",
300     type=str,
301     default=DEFAULT_INCLUDES,
302     help=(
303         "A regular expression that matches files and directories that should be "
304         "included on recursive searches.  An empty value means all files are "
305         "included regardless of the name.  Use forward slashes for directories on "
306         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
307         "later."
308     ),
309     show_default=True,
310 )
311 @click.option(
312     "--exclude",
313     type=str,
314     default=DEFAULT_EXCLUDES,
315     help=(
316         "A regular expression that matches files and directories that should be "
317         "excluded on recursive searches.  An empty value means no paths are excluded. "
318         "Use forward slashes for directories on all platforms (Windows, too).  "
319         "Exclusions are calculated first, inclusions later."
320     ),
321     show_default=True,
322 )
323 @click.option(
324     "-q",
325     "--quiet",
326     is_flag=True,
327     help=(
328         "Don't emit non-error messages to stderr. Errors are still emitted, "
329         "silence those with 2>/dev/null."
330     ),
331 )
332 @click.option(
333     "-v",
334     "--verbose",
335     is_flag=True,
336     help=(
337         "Also emit messages to stderr about files that were not changed or were "
338         "ignored due to --exclude=."
339     ),
340 )
341 @click.version_option(version=__version__)
342 @click.argument(
343     "src",
344     nargs=-1,
345     type=click.Path(
346         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
347     ),
348     is_eager=True,
349 )
350 @click.option(
351     "--config",
352     type=click.Path(
353         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
354     ),
355     is_eager=True,
356     callback=read_pyproject_toml,
357     help="Read configuration from PATH.",
358 )
359 @click.pass_context
360 def main(
361     ctx: click.Context,
362     code: Optional[str],
363     line_length: int,
364     target_version: List[TargetVersion],
365     check: bool,
366     diff: bool,
367     fast: bool,
368     pyi: bool,
369     py36: bool,
370     skip_string_normalization: bool,
371     quiet: bool,
372     verbose: bool,
373     include: str,
374     exclude: str,
375     src: Tuple[str],
376     config: Optional[str],
377 ) -> None:
378     """The uncompromising code formatter."""
379     write_back = WriteBack.from_configuration(check=check, diff=diff)
380     if target_version:
381         if py36:
382             err(f"Cannot use both --target-version and --py36")
383             ctx.exit(2)
384         else:
385             versions = set(target_version)
386     elif py36:
387         err(
388             "--py36 is deprecated and will be removed in a future version. "
389             "Use --target-version py36 instead."
390         )
391         versions = PY36_VERSIONS
392     else:
393         # We'll autodetect later.
394         versions = set()
395     mode = FileMode(
396         target_versions=versions,
397         line_length=line_length,
398         is_pyi=pyi,
399         string_normalization=not skip_string_normalization,
400     )
401     if config and verbose:
402         out(f"Using configuration from {config}.", bold=False, fg="blue")
403     if code is not None:
404         print(format_str(code, mode=mode))
405         ctx.exit(0)
406     try:
407         include_regex = re_compile_maybe_verbose(include)
408     except re.error:
409         err(f"Invalid regular expression for include given: {include!r}")
410         ctx.exit(2)
411     try:
412         exclude_regex = re_compile_maybe_verbose(exclude)
413     except re.error:
414         err(f"Invalid regular expression for exclude given: {exclude!r}")
415         ctx.exit(2)
416     report = Report(check=check, quiet=quiet, verbose=verbose)
417     root = find_project_root(src)
418     sources: Set[Path] = set()
419     for s in src:
420         p = Path(s)
421         if p.is_dir():
422             sources.update(
423                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
424             )
425         elif p.is_file() or s == "-":
426             # if a file was explicitly given, we don't care about its extension
427             sources.add(p)
428         else:
429             err(f"invalid path: {s}")
430     if len(sources) == 0:
431         if verbose or not quiet:
432             out("No paths given. Nothing to do 😴")
433         ctx.exit(0)
434
435     if len(sources) == 1:
436         reformat_one(
437             src=sources.pop(),
438             fast=fast,
439             write_back=write_back,
440             mode=mode,
441             report=report,
442         )
443     else:
444         reformat_many(
445             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
446         )
447
448     if verbose or not quiet:
449         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
450         out(f"All done! {bang}")
451         click.secho(str(report), err=True)
452     ctx.exit(report.return_code)
453
454
455 def reformat_one(
456     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
457 ) -> None:
458     """Reformat a single file under `src` without spawning child processes.
459
460     If `quiet` is True, non-error messages are not output. `line_length`,
461     `write_back`, `fast` and `pyi` options are passed to
462     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
463     """
464     try:
465         changed = Changed.NO
466         if not src.is_file() and str(src) == "-":
467             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
468                 changed = Changed.YES
469         else:
470             cache: Cache = {}
471             if write_back != WriteBack.DIFF:
472                 cache = read_cache(mode)
473                 res_src = src.resolve()
474                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
475                     changed = Changed.CACHED
476             if changed is not Changed.CACHED and format_file_in_place(
477                 src, fast=fast, write_back=write_back, mode=mode
478             ):
479                 changed = Changed.YES
480             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
481                 write_back is WriteBack.CHECK and changed is Changed.NO
482             ):
483                 write_cache(cache, [src], mode)
484         report.done(src, changed)
485     except Exception as exc:
486         report.failed(src, str(exc))
487
488
489 def reformat_many(
490     sources: Set[Path],
491     fast: bool,
492     write_back: WriteBack,
493     mode: FileMode,
494     report: "Report",
495 ) -> None:
496     """Reformat multiple files using a ProcessPoolExecutor."""
497     loop = asyncio.get_event_loop()
498     worker_count = os.cpu_count()
499     if sys.platform == "win32":
500         # Work around https://bugs.python.org/issue26903
501         worker_count = min(worker_count, 61)
502     executor = ProcessPoolExecutor(max_workers=worker_count)
503     try:
504         loop.run_until_complete(
505             schedule_formatting(
506                 sources=sources,
507                 fast=fast,
508                 write_back=write_back,
509                 mode=mode,
510                 report=report,
511                 loop=loop,
512                 executor=executor,
513             )
514         )
515     finally:
516         shutdown(loop)
517
518
519 async def schedule_formatting(
520     sources: Set[Path],
521     fast: bool,
522     write_back: WriteBack,
523     mode: FileMode,
524     report: "Report",
525     loop: BaseEventLoop,
526     executor: Executor,
527 ) -> None:
528     """Run formatting of `sources` in parallel using the provided `executor`.
529
530     (Use ProcessPoolExecutors for actual parallelism.)
531
532     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
533     :func:`format_file_in_place`.
534     """
535     cache: Cache = {}
536     if write_back != WriteBack.DIFF:
537         cache = read_cache(mode)
538         sources, cached = filter_cached(cache, sources)
539         for src in sorted(cached):
540             report.done(src, Changed.CACHED)
541     if not sources:
542         return
543
544     cancelled = []
545     sources_to_cache = []
546     lock = None
547     if write_back == WriteBack.DIFF:
548         # For diff output, we need locks to ensure we don't interleave output
549         # from different processes.
550         manager = Manager()
551         lock = manager.Lock()
552     tasks = {
553         asyncio.ensure_future(
554             loop.run_in_executor(
555                 executor, format_file_in_place, src, fast, mode, write_back, lock
556             )
557         ): src
558         for src in sorted(sources)
559     }
560     pending: Iterable[asyncio.Future] = tasks.keys()
561     try:
562         loop.add_signal_handler(signal.SIGINT, cancel, pending)
563         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
564     except NotImplementedError:
565         # There are no good alternatives for these on Windows.
566         pass
567     while pending:
568         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
569         for task in done:
570             src = tasks.pop(task)
571             if task.cancelled():
572                 cancelled.append(task)
573             elif task.exception():
574                 report.failed(src, str(task.exception()))
575             else:
576                 changed = Changed.YES if task.result() else Changed.NO
577                 # If the file was written back or was successfully checked as
578                 # well-formatted, store this information in the cache.
579                 if write_back is WriteBack.YES or (
580                     write_back is WriteBack.CHECK and changed is Changed.NO
581                 ):
582                     sources_to_cache.append(src)
583                 report.done(src, changed)
584     if cancelled:
585         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
586     if sources_to_cache:
587         write_cache(cache, sources_to_cache, mode)
588
589
590 def format_file_in_place(
591     src: Path,
592     fast: bool,
593     mode: FileMode,
594     write_back: WriteBack = WriteBack.NO,
595     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
596 ) -> bool:
597     """Format file under `src` path. Return True if changed.
598
599     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
600     code to the file.
601     `line_length` and `fast` options are passed to :func:`format_file_contents`.
602     """
603     if src.suffix == ".pyi":
604         mode = evolve(mode, is_pyi=True)
605
606     then = datetime.utcfromtimestamp(src.stat().st_mtime)
607     with open(src, "rb") as buf:
608         src_contents, encoding, newline = decode_bytes(buf.read())
609     try:
610         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
611     except NothingChanged:
612         return False
613
614     if write_back == write_back.YES:
615         with open(src, "w", encoding=encoding, newline=newline) as f:
616             f.write(dst_contents)
617     elif write_back == write_back.DIFF:
618         now = datetime.utcnow()
619         src_name = f"{src}\t{then} +0000"
620         dst_name = f"{src}\t{now} +0000"
621         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
622         if lock:
623             lock.acquire()
624         try:
625             f = io.TextIOWrapper(
626                 sys.stdout.buffer,
627                 encoding=encoding,
628                 newline=newline,
629                 write_through=True,
630             )
631             f.write(diff_contents)
632             f.detach()
633         finally:
634             if lock:
635                 lock.release()
636     return True
637
638
639 def format_stdin_to_stdout(
640     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
641 ) -> bool:
642     """Format file on stdin. Return True if changed.
643
644     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
645     write a diff to stdout. The `mode` argument is passed to
646     :func:`format_file_contents`.
647     """
648     then = datetime.utcnow()
649     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
650     dst = src
651     try:
652         dst = format_file_contents(src, fast=fast, mode=mode)
653         return True
654
655     except NothingChanged:
656         return False
657
658     finally:
659         f = io.TextIOWrapper(
660             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
661         )
662         if write_back == WriteBack.YES:
663             f.write(dst)
664         elif write_back == WriteBack.DIFF:
665             now = datetime.utcnow()
666             src_name = f"STDIN\t{then} +0000"
667             dst_name = f"STDOUT\t{now} +0000"
668             f.write(diff(src, dst, src_name, dst_name))
669         f.detach()
670
671
672 def format_file_contents(
673     src_contents: str, *, fast: bool, mode: FileMode
674 ) -> FileContent:
675     """Reformat contents a file and return new contents.
676
677     If `fast` is False, additionally confirm that the reformatted code is
678     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
679     `line_length` is passed to :func:`format_str`.
680     """
681     if src_contents.strip() == "":
682         raise NothingChanged
683
684     dst_contents = format_str(src_contents, mode=mode)
685     if src_contents == dst_contents:
686         raise NothingChanged
687
688     if not fast:
689         assert_equivalent(src_contents, dst_contents)
690         assert_stable(src_contents, dst_contents, mode=mode)
691     return dst_contents
692
693
694 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
695     """Reformat a string and return new contents.
696
697     `line_length` determines how many characters per line are allowed.
698     """
699     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
700     dst_contents = ""
701     future_imports = get_future_imports(src_node)
702     if mode.target_versions:
703         versions = mode.target_versions
704     else:
705         versions = detect_target_versions(src_node)
706     normalize_fmt_off(src_node)
707     lines = LineGenerator(
708         remove_u_prefix="unicode_literals" in future_imports
709         or supports_feature(versions, Feature.UNICODE_LITERALS),
710         is_pyi=mode.is_pyi,
711         normalize_strings=mode.string_normalization,
712     )
713     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
714     empty_line = Line()
715     after = 0
716     split_line_features = {
717         feature
718         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
719         if supports_feature(versions, feature)
720     }
721     for current_line in lines.visit(src_node):
722         for _ in range(after):
723             dst_contents += str(empty_line)
724         before, after = elt.maybe_empty_lines(current_line)
725         for _ in range(before):
726             dst_contents += str(empty_line)
727         for line in split_line(
728             current_line, line_length=mode.line_length, features=split_line_features
729         ):
730             dst_contents += str(line)
731     return dst_contents
732
733
734 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
735     """Return a tuple of (decoded_contents, encoding, newline).
736
737     `newline` is either CRLF or LF but `decoded_contents` is decoded with
738     universal newlines (i.e. only contains LF).
739     """
740     srcbuf = io.BytesIO(src)
741     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
742     if not lines:
743         return "", encoding, "\n"
744
745     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
746     srcbuf.seek(0)
747     with io.TextIOWrapper(srcbuf, encoding) as tiow:
748         return tiow.read(), encoding, newline
749
750
751 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
752     if not target_versions:
753         # No target_version specified, so try all grammars.
754         return [
755             pygram.python_grammar_no_print_statement_no_exec_statement,
756             pygram.python_grammar_no_print_statement,
757             pygram.python_grammar,
758         ]
759     elif all(version.is_python2() for version in target_versions):
760         # Python 2-only code, so try Python 2 grammars.
761         return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
762     else:
763         # Python 3-compatible code, so only try Python 3 grammar.
764         return [pygram.python_grammar_no_print_statement_no_exec_statement]
765
766
767 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
768     """Given a string with source, return the lib2to3 Node."""
769     if src_txt[-1:] != "\n":
770         src_txt += "\n"
771
772     for grammar in get_grammars(set(target_versions)):
773         drv = driver.Driver(grammar, pytree.convert)
774         try:
775             result = drv.parse_string(src_txt, True)
776             break
777
778         except ParseError as pe:
779             lineno, column = pe.context[1]
780             lines = src_txt.splitlines()
781             try:
782                 faulty_line = lines[lineno - 1]
783             except IndexError:
784                 faulty_line = "<line number missing in source>"
785             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
786     else:
787         raise exc from None
788
789     if isinstance(result, Leaf):
790         result = Node(syms.file_input, [result])
791     return result
792
793
794 def lib2to3_unparse(node: Node) -> str:
795     """Given a lib2to3 node, return its string representation."""
796     code = str(node)
797     return code
798
799
800 T = TypeVar("T")
801
802
803 class Visitor(Generic[T]):
804     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
805
806     def visit(self, node: LN) -> Iterator[T]:
807         """Main method to visit `node` and its children.
808
809         It tries to find a `visit_*()` method for the given `node.type`, like
810         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
811         If no dedicated `visit_*()` method is found, chooses `visit_default()`
812         instead.
813
814         Then yields objects of type `T` from the selected visitor.
815         """
816         if node.type < 256:
817             name = token.tok_name[node.type]
818         else:
819             name = type_repr(node.type)
820         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
821
822     def visit_default(self, node: LN) -> Iterator[T]:
823         """Default `visit_*()` implementation. Recurses to children of `node`."""
824         if isinstance(node, Node):
825             for child in node.children:
826                 yield from self.visit(child)
827
828
829 @dataclass
830 class DebugVisitor(Visitor[T]):
831     tree_depth: int = 0
832
833     def visit_default(self, node: LN) -> Iterator[T]:
834         indent = " " * (2 * self.tree_depth)
835         if isinstance(node, Node):
836             _type = type_repr(node.type)
837             out(f"{indent}{_type}", fg="yellow")
838             self.tree_depth += 1
839             for child in node.children:
840                 yield from self.visit(child)
841
842             self.tree_depth -= 1
843             out(f"{indent}/{_type}", fg="yellow", bold=False)
844         else:
845             _type = token.tok_name.get(node.type, str(node.type))
846             out(f"{indent}{_type}", fg="blue", nl=False)
847             if node.prefix:
848                 # We don't have to handle prefixes for `Node` objects since
849                 # that delegates to the first child anyway.
850                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
851             out(f" {node.value!r}", fg="blue", bold=False)
852
853     @classmethod
854     def show(cls, code: Union[str, Leaf, Node]) -> None:
855         """Pretty-print the lib2to3 AST of a given string of `code`.
856
857         Convenience method for debugging.
858         """
859         v: DebugVisitor[None] = DebugVisitor()
860         if isinstance(code, str):
861             code = lib2to3_parse(code)
862         list(v.visit(code))
863
864
865 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
866 STATEMENT = {
867     syms.if_stmt,
868     syms.while_stmt,
869     syms.for_stmt,
870     syms.try_stmt,
871     syms.except_clause,
872     syms.with_stmt,
873     syms.funcdef,
874     syms.classdef,
875 }
876 STANDALONE_COMMENT = 153
877 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
878 LOGIC_OPERATORS = {"and", "or"}
879 COMPARATORS = {
880     token.LESS,
881     token.GREATER,
882     token.EQEQUAL,
883     token.NOTEQUAL,
884     token.LESSEQUAL,
885     token.GREATEREQUAL,
886 }
887 MATH_OPERATORS = {
888     token.VBAR,
889     token.CIRCUMFLEX,
890     token.AMPER,
891     token.LEFTSHIFT,
892     token.RIGHTSHIFT,
893     token.PLUS,
894     token.MINUS,
895     token.STAR,
896     token.SLASH,
897     token.DOUBLESLASH,
898     token.PERCENT,
899     token.AT,
900     token.TILDE,
901     token.DOUBLESTAR,
902 }
903 STARS = {token.STAR, token.DOUBLESTAR}
904 VARARGS_PARENTS = {
905     syms.arglist,
906     syms.argument,  # double star in arglist
907     syms.trailer,  # single argument to call
908     syms.typedargslist,
909     syms.varargslist,  # lambdas
910 }
911 UNPACKING_PARENTS = {
912     syms.atom,  # single element of a list or set literal
913     syms.dictsetmaker,
914     syms.listmaker,
915     syms.testlist_gexp,
916     syms.testlist_star_expr,
917 }
918 TEST_DESCENDANTS = {
919     syms.test,
920     syms.lambdef,
921     syms.or_test,
922     syms.and_test,
923     syms.not_test,
924     syms.comparison,
925     syms.star_expr,
926     syms.expr,
927     syms.xor_expr,
928     syms.and_expr,
929     syms.shift_expr,
930     syms.arith_expr,
931     syms.trailer,
932     syms.term,
933     syms.power,
934 }
935 ASSIGNMENTS = {
936     "=",
937     "+=",
938     "-=",
939     "*=",
940     "@=",
941     "/=",
942     "%=",
943     "&=",
944     "|=",
945     "^=",
946     "<<=",
947     ">>=",
948     "**=",
949     "//=",
950 }
951 COMPREHENSION_PRIORITY = 20
952 COMMA_PRIORITY = 18
953 TERNARY_PRIORITY = 16
954 LOGIC_PRIORITY = 14
955 STRING_PRIORITY = 12
956 COMPARATOR_PRIORITY = 10
957 MATH_PRIORITIES = {
958     token.VBAR: 9,
959     token.CIRCUMFLEX: 8,
960     token.AMPER: 7,
961     token.LEFTSHIFT: 6,
962     token.RIGHTSHIFT: 6,
963     token.PLUS: 5,
964     token.MINUS: 5,
965     token.STAR: 4,
966     token.SLASH: 4,
967     token.DOUBLESLASH: 4,
968     token.PERCENT: 4,
969     token.AT: 4,
970     token.TILDE: 3,
971     token.DOUBLESTAR: 2,
972 }
973 DOT_PRIORITY = 1
974
975
976 @dataclass
977 class BracketTracker:
978     """Keeps track of brackets on a line."""
979
980     depth: int = 0
981     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
982     delimiters: Dict[LeafID, Priority] = Factory(dict)
983     previous: Optional[Leaf] = None
984     _for_loop_depths: List[int] = Factory(list)
985     _lambda_argument_depths: List[int] = Factory(list)
986
987     def mark(self, leaf: Leaf) -> None:
988         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
989
990         All leaves receive an int `bracket_depth` field that stores how deep
991         within brackets a given leaf is. 0 means there are no enclosing brackets
992         that started on this line.
993
994         If a leaf is itself a closing bracket, it receives an `opening_bracket`
995         field that it forms a pair with. This is a one-directional link to
996         avoid reference cycles.
997
998         If a leaf is a delimiter (a token on which Black can split the line if
999         needed) and it's on depth 0, its `id()` is stored in the tracker's
1000         `delimiters` field.
1001         """
1002         if leaf.type == token.COMMENT:
1003             return
1004
1005         self.maybe_decrement_after_for_loop_variable(leaf)
1006         self.maybe_decrement_after_lambda_arguments(leaf)
1007         if leaf.type in CLOSING_BRACKETS:
1008             self.depth -= 1
1009             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
1010             leaf.opening_bracket = opening_bracket
1011         leaf.bracket_depth = self.depth
1012         if self.depth == 0:
1013             delim = is_split_before_delimiter(leaf, self.previous)
1014             if delim and self.previous is not None:
1015                 self.delimiters[id(self.previous)] = delim
1016             else:
1017                 delim = is_split_after_delimiter(leaf, self.previous)
1018                 if delim:
1019                     self.delimiters[id(leaf)] = delim
1020         if leaf.type in OPENING_BRACKETS:
1021             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
1022             self.depth += 1
1023         self.previous = leaf
1024         self.maybe_increment_lambda_arguments(leaf)
1025         self.maybe_increment_for_loop_variable(leaf)
1026
1027     def any_open_brackets(self) -> bool:
1028         """Return True if there is an yet unmatched open bracket on the line."""
1029         return bool(self.bracket_match)
1030
1031     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1032         """Return the highest priority of a delimiter found on the line.
1033
1034         Values are consistent with what `is_split_*_delimiter()` return.
1035         Raises ValueError on no delimiters.
1036         """
1037         return max(v for k, v in self.delimiters.items() if k not in exclude)
1038
1039     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1040         """Return the number of delimiters with the given `priority`.
1041
1042         If no `priority` is passed, defaults to max priority on the line.
1043         """
1044         if not self.delimiters:
1045             return 0
1046
1047         priority = priority or self.max_delimiter_priority()
1048         return sum(1 for p in self.delimiters.values() if p == priority)
1049
1050     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1051         """In a for loop, or comprehension, the variables are often unpacks.
1052
1053         To avoid splitting on the comma in this situation, increase the depth of
1054         tokens between `for` and `in`.
1055         """
1056         if leaf.type == token.NAME and leaf.value == "for":
1057             self.depth += 1
1058             self._for_loop_depths.append(self.depth)
1059             return True
1060
1061         return False
1062
1063     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1064         """See `maybe_increment_for_loop_variable` above for explanation."""
1065         if (
1066             self._for_loop_depths
1067             and self._for_loop_depths[-1] == self.depth
1068             and leaf.type == token.NAME
1069             and leaf.value == "in"
1070         ):
1071             self.depth -= 1
1072             self._for_loop_depths.pop()
1073             return True
1074
1075         return False
1076
1077     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1078         """In a lambda expression, there might be more than one argument.
1079
1080         To avoid splitting on the comma in this situation, increase the depth of
1081         tokens between `lambda` and `:`.
1082         """
1083         if leaf.type == token.NAME and leaf.value == "lambda":
1084             self.depth += 1
1085             self._lambda_argument_depths.append(self.depth)
1086             return True
1087
1088         return False
1089
1090     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1091         """See `maybe_increment_lambda_arguments` above for explanation."""
1092         if (
1093             self._lambda_argument_depths
1094             and self._lambda_argument_depths[-1] == self.depth
1095             and leaf.type == token.COLON
1096         ):
1097             self.depth -= 1
1098             self._lambda_argument_depths.pop()
1099             return True
1100
1101         return False
1102
1103     def get_open_lsqb(self) -> Optional[Leaf]:
1104         """Return the most recent opening square bracket (if any)."""
1105         return self.bracket_match.get((self.depth - 1, token.RSQB))
1106
1107
1108 @dataclass
1109 class Line:
1110     """Holds leaves and comments. Can be printed with `str(line)`."""
1111
1112     depth: int = 0
1113     leaves: List[Leaf] = Factory(list)
1114     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1115     bracket_tracker: BracketTracker = Factory(BracketTracker)
1116     inside_brackets: bool = False
1117     should_explode: bool = False
1118
1119     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1120         """Add a new `leaf` to the end of the line.
1121
1122         Unless `preformatted` is True, the `leaf` will receive a new consistent
1123         whitespace prefix and metadata applied by :class:`BracketTracker`.
1124         Trailing commas are maybe removed, unpacked for loop variables are
1125         demoted from being delimiters.
1126
1127         Inline comments are put aside.
1128         """
1129         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1130         if not has_value:
1131             return
1132
1133         if token.COLON == leaf.type and self.is_class_paren_empty:
1134             del self.leaves[-2:]
1135         if self.leaves and not preformatted:
1136             # Note: at this point leaf.prefix should be empty except for
1137             # imports, for which we only preserve newlines.
1138             leaf.prefix += whitespace(
1139                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1140             )
1141         if self.inside_brackets or not preformatted:
1142             self.bracket_tracker.mark(leaf)
1143             self.maybe_remove_trailing_comma(leaf)
1144         if not self.append_comment(leaf):
1145             self.leaves.append(leaf)
1146
1147     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1148         """Like :func:`append()` but disallow invalid standalone comment structure.
1149
1150         Raises ValueError when any `leaf` is appended after a standalone comment
1151         or when a standalone comment is not the first leaf on the line.
1152         """
1153         if self.bracket_tracker.depth == 0:
1154             if self.is_comment:
1155                 raise ValueError("cannot append to standalone comments")
1156
1157             if self.leaves and leaf.type == STANDALONE_COMMENT:
1158                 raise ValueError(
1159                     "cannot append standalone comments to a populated line"
1160                 )
1161
1162         self.append(leaf, preformatted=preformatted)
1163
1164     @property
1165     def is_comment(self) -> bool:
1166         """Is this line a standalone comment?"""
1167         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1168
1169     @property
1170     def is_decorator(self) -> bool:
1171         """Is this line a decorator?"""
1172         return bool(self) and self.leaves[0].type == token.AT
1173
1174     @property
1175     def is_import(self) -> bool:
1176         """Is this an import line?"""
1177         return bool(self) and is_import(self.leaves[0])
1178
1179     @property
1180     def is_class(self) -> bool:
1181         """Is this line a class definition?"""
1182         return (
1183             bool(self)
1184             and self.leaves[0].type == token.NAME
1185             and self.leaves[0].value == "class"
1186         )
1187
1188     @property
1189     def is_stub_class(self) -> bool:
1190         """Is this line a class definition with a body consisting only of "..."?"""
1191         return self.is_class and self.leaves[-3:] == [
1192             Leaf(token.DOT, ".") for _ in range(3)
1193         ]
1194
1195     @property
1196     def is_def(self) -> bool:
1197         """Is this a function definition? (Also returns True for async defs.)"""
1198         try:
1199             first_leaf = self.leaves[0]
1200         except IndexError:
1201             return False
1202
1203         try:
1204             second_leaf: Optional[Leaf] = self.leaves[1]
1205         except IndexError:
1206             second_leaf = None
1207         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1208             first_leaf.type == token.ASYNC
1209             and second_leaf is not None
1210             and second_leaf.type == token.NAME
1211             and second_leaf.value == "def"
1212         )
1213
1214     @property
1215     def is_class_paren_empty(self) -> bool:
1216         """Is this a class with no base classes but using parentheses?
1217
1218         Those are unnecessary and should be removed.
1219         """
1220         return (
1221             bool(self)
1222             and len(self.leaves) == 4
1223             and self.is_class
1224             and self.leaves[2].type == token.LPAR
1225             and self.leaves[2].value == "("
1226             and self.leaves[3].type == token.RPAR
1227             and self.leaves[3].value == ")"
1228         )
1229
1230     @property
1231     def is_triple_quoted_string(self) -> bool:
1232         """Is the line a triple quoted string?"""
1233         return (
1234             bool(self)
1235             and self.leaves[0].type == token.STRING
1236             and self.leaves[0].value.startswith(('"""', "'''"))
1237         )
1238
1239     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1240         """If so, needs to be split before emitting."""
1241         for leaf in self.leaves:
1242             if leaf.type == STANDALONE_COMMENT:
1243                 if leaf.bracket_depth <= depth_limit:
1244                     return True
1245         return False
1246
1247     def contains_inner_type_comments(self) -> bool:
1248         ignored_ids = set()
1249         try:
1250             last_leaf = self.leaves[-1]
1251             ignored_ids.add(id(last_leaf))
1252             if last_leaf.type == token.COMMA:
1253                 # When trailing commas are inserted by Black for consistency, comments
1254                 # after the previous last element are not moved (they don't have to,
1255                 # rendering will still be correct).  So we ignore trailing commas.
1256                 last_leaf = self.leaves[-2]
1257                 ignored_ids.add(id(last_leaf))
1258         except IndexError:
1259             return False
1260
1261         for leaf_id, comments in self.comments.items():
1262             if leaf_id in ignored_ids:
1263                 continue
1264
1265             for comment in comments:
1266                 if is_type_comment(comment):
1267                     return True
1268
1269         return False
1270
1271     def contains_multiline_strings(self) -> bool:
1272         for leaf in self.leaves:
1273             if is_multiline_string(leaf):
1274                 return True
1275
1276         return False
1277
1278     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1279         """Remove trailing comma if there is one and it's safe."""
1280         if not (
1281             self.leaves
1282             and self.leaves[-1].type == token.COMMA
1283             and closing.type in CLOSING_BRACKETS
1284         ):
1285             return False
1286
1287         if closing.type == token.RBRACE:
1288             self.remove_trailing_comma()
1289             return True
1290
1291         if closing.type == token.RSQB:
1292             comma = self.leaves[-1]
1293             if comma.parent and comma.parent.type == syms.listmaker:
1294                 self.remove_trailing_comma()
1295                 return True
1296
1297         # For parens let's check if it's safe to remove the comma.
1298         # Imports are always safe.
1299         if self.is_import:
1300             self.remove_trailing_comma()
1301             return True
1302
1303         # Otherwise, if the trailing one is the only one, we might mistakenly
1304         # change a tuple into a different type by removing the comma.
1305         depth = closing.bracket_depth + 1
1306         commas = 0
1307         opening = closing.opening_bracket
1308         for _opening_index, leaf in enumerate(self.leaves):
1309             if leaf is opening:
1310                 break
1311
1312         else:
1313             return False
1314
1315         for leaf in self.leaves[_opening_index + 1 :]:
1316             if leaf is closing:
1317                 break
1318
1319             bracket_depth = leaf.bracket_depth
1320             if bracket_depth == depth and leaf.type == token.COMMA:
1321                 commas += 1
1322                 if leaf.parent and leaf.parent.type == syms.arglist:
1323                     commas += 1
1324                     break
1325
1326         if commas > 1:
1327             self.remove_trailing_comma()
1328             return True
1329
1330         return False
1331
1332     def append_comment(self, comment: Leaf) -> bool:
1333         """Add an inline or standalone comment to the line."""
1334         if (
1335             comment.type == STANDALONE_COMMENT
1336             and self.bracket_tracker.any_open_brackets()
1337         ):
1338             comment.prefix = ""
1339             return False
1340
1341         if comment.type != token.COMMENT:
1342             return False
1343
1344         if not self.leaves:
1345             comment.type = STANDALONE_COMMENT
1346             comment.prefix = ""
1347             return False
1348
1349         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1350         return True
1351
1352     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1353         """Generate comments that should appear directly after `leaf`."""
1354         return self.comments.get(id(leaf), [])
1355
1356     def remove_trailing_comma(self) -> None:
1357         """Remove the trailing comma and moves the comments attached to it."""
1358         trailing_comma = self.leaves.pop()
1359         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1360         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1361             trailing_comma_comments
1362         )
1363
1364     def is_complex_subscript(self, leaf: Leaf) -> bool:
1365         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1366         open_lsqb = self.bracket_tracker.get_open_lsqb()
1367         if open_lsqb is None:
1368             return False
1369
1370         subscript_start = open_lsqb.next_sibling
1371
1372         if isinstance(subscript_start, Node):
1373             if subscript_start.type == syms.listmaker:
1374                 return False
1375
1376             if subscript_start.type == syms.subscriptlist:
1377                 subscript_start = child_towards(subscript_start, leaf)
1378         return subscript_start is not None and any(
1379             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1380         )
1381
1382     def __str__(self) -> str:
1383         """Render the line."""
1384         if not self:
1385             return "\n"
1386
1387         indent = "    " * self.depth
1388         leaves = iter(self.leaves)
1389         first = next(leaves)
1390         res = f"{first.prefix}{indent}{first.value}"
1391         for leaf in leaves:
1392             res += str(leaf)
1393         for comment in itertools.chain.from_iterable(self.comments.values()):
1394             res += str(comment)
1395         return res + "\n"
1396
1397     def __bool__(self) -> bool:
1398         """Return True if the line has leaves or comments."""
1399         return bool(self.leaves or self.comments)
1400
1401
1402 @dataclass
1403 class EmptyLineTracker:
1404     """Provides a stateful method that returns the number of potential extra
1405     empty lines needed before and after the currently processed line.
1406
1407     Note: this tracker works on lines that haven't been split yet.  It assumes
1408     the prefix of the first leaf consists of optional newlines.  Those newlines
1409     are consumed by `maybe_empty_lines()` and included in the computation.
1410     """
1411
1412     is_pyi: bool = False
1413     previous_line: Optional[Line] = None
1414     previous_after: int = 0
1415     previous_defs: List[int] = Factory(list)
1416
1417     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1418         """Return the number of extra empty lines before and after the `current_line`.
1419
1420         This is for separating `def`, `async def` and `class` with extra empty
1421         lines (two on module-level).
1422         """
1423         before, after = self._maybe_empty_lines(current_line)
1424         before -= self.previous_after
1425         self.previous_after = after
1426         self.previous_line = current_line
1427         return before, after
1428
1429     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1430         max_allowed = 1
1431         if current_line.depth == 0:
1432             max_allowed = 1 if self.is_pyi else 2
1433         if current_line.leaves:
1434             # Consume the first leaf's extra newlines.
1435             first_leaf = current_line.leaves[0]
1436             before = first_leaf.prefix.count("\n")
1437             before = min(before, max_allowed)
1438             first_leaf.prefix = ""
1439         else:
1440             before = 0
1441         depth = current_line.depth
1442         while self.previous_defs and self.previous_defs[-1] >= depth:
1443             self.previous_defs.pop()
1444             if self.is_pyi:
1445                 before = 0 if depth else 1
1446             else:
1447                 before = 1 if depth else 2
1448         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1449             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1450
1451         if (
1452             self.previous_line
1453             and self.previous_line.is_import
1454             and not current_line.is_import
1455             and depth == self.previous_line.depth
1456         ):
1457             return (before or 1), 0
1458
1459         if (
1460             self.previous_line
1461             and self.previous_line.is_class
1462             and current_line.is_triple_quoted_string
1463         ):
1464             return before, 1
1465
1466         return before, 0
1467
1468     def _maybe_empty_lines_for_class_or_def(
1469         self, current_line: Line, before: int
1470     ) -> Tuple[int, int]:
1471         if not current_line.is_decorator:
1472             self.previous_defs.append(current_line.depth)
1473         if self.previous_line is None:
1474             # Don't insert empty lines before the first line in the file.
1475             return 0, 0
1476
1477         if self.previous_line.is_decorator:
1478             return 0, 0
1479
1480         if self.previous_line.depth < current_line.depth and (
1481             self.previous_line.is_class or self.previous_line.is_def
1482         ):
1483             return 0, 0
1484
1485         if (
1486             self.previous_line.is_comment
1487             and self.previous_line.depth == current_line.depth
1488             and before == 0
1489         ):
1490             return 0, 0
1491
1492         if self.is_pyi:
1493             if self.previous_line.depth > current_line.depth:
1494                 newlines = 1
1495             elif current_line.is_class or self.previous_line.is_class:
1496                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1497                     # No blank line between classes with an empty body
1498                     newlines = 0
1499                 else:
1500                     newlines = 1
1501             elif current_line.is_def and not self.previous_line.is_def:
1502                 # Blank line between a block of functions and a block of non-functions
1503                 newlines = 1
1504             else:
1505                 newlines = 0
1506         else:
1507             newlines = 2
1508         if current_line.depth and newlines:
1509             newlines -= 1
1510         return newlines, 0
1511
1512
1513 @dataclass
1514 class LineGenerator(Visitor[Line]):
1515     """Generates reformatted Line objects.  Empty lines are not emitted.
1516
1517     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1518     in ways that will no longer stringify to valid Python code on the tree.
1519     """
1520
1521     is_pyi: bool = False
1522     normalize_strings: bool = True
1523     current_line: Line = Factory(Line)
1524     remove_u_prefix: bool = False
1525
1526     def line(self, indent: int = 0) -> Iterator[Line]:
1527         """Generate a line.
1528
1529         If the line is empty, only emit if it makes sense.
1530         If the line is too long, split it first and then generate.
1531
1532         If any lines were generated, set up a new current_line.
1533         """
1534         if not self.current_line:
1535             self.current_line.depth += indent
1536             return  # Line is empty, don't emit. Creating a new one unnecessary.
1537
1538         complete_line = self.current_line
1539         self.current_line = Line(depth=complete_line.depth + indent)
1540         yield complete_line
1541
1542     def visit_default(self, node: LN) -> Iterator[Line]:
1543         """Default `visit_*()` implementation. Recurses to children of `node`."""
1544         if isinstance(node, Leaf):
1545             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1546             for comment in generate_comments(node):
1547                 if any_open_brackets:
1548                     # any comment within brackets is subject to splitting
1549                     self.current_line.append(comment)
1550                 elif comment.type == token.COMMENT:
1551                     # regular trailing comment
1552                     self.current_line.append(comment)
1553                     yield from self.line()
1554
1555                 else:
1556                     # regular standalone comment
1557                     yield from self.line()
1558
1559                     self.current_line.append(comment)
1560                     yield from self.line()
1561
1562             normalize_prefix(node, inside_brackets=any_open_brackets)
1563             if self.normalize_strings and node.type == token.STRING:
1564                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1565                 normalize_string_quotes(node)
1566             if node.type == token.NUMBER:
1567                 normalize_numeric_literal(node)
1568             if node.type not in WHITESPACE:
1569                 self.current_line.append(node)
1570         yield from super().visit_default(node)
1571
1572     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1573         """Increase indentation level, maybe yield a line."""
1574         # In blib2to3 INDENT never holds comments.
1575         yield from self.line(+1)
1576         yield from self.visit_default(node)
1577
1578     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1579         """Decrease indentation level, maybe yield a line."""
1580         # The current line might still wait for trailing comments.  At DEDENT time
1581         # there won't be any (they would be prefixes on the preceding NEWLINE).
1582         # Emit the line then.
1583         yield from self.line()
1584
1585         # While DEDENT has no value, its prefix may contain standalone comments
1586         # that belong to the current indentation level.  Get 'em.
1587         yield from self.visit_default(node)
1588
1589         # Finally, emit the dedent.
1590         yield from self.line(-1)
1591
1592     def visit_stmt(
1593         self, node: Node, keywords: Set[str], parens: Set[str]
1594     ) -> Iterator[Line]:
1595         """Visit a statement.
1596
1597         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1598         `def`, `with`, `class`, `assert` and assignments.
1599
1600         The relevant Python language `keywords` for a given statement will be
1601         NAME leaves within it. This methods puts those on a separate line.
1602
1603         `parens` holds a set of string leaf values immediately after which
1604         invisible parens should be put.
1605         """
1606         normalize_invisible_parens(node, parens_after=parens)
1607         for child in node.children:
1608             if child.type == token.NAME and child.value in keywords:  # type: ignore
1609                 yield from self.line()
1610
1611             yield from self.visit(child)
1612
1613     def visit_suite(self, node: Node) -> Iterator[Line]:
1614         """Visit a suite."""
1615         if self.is_pyi and is_stub_suite(node):
1616             yield from self.visit(node.children[2])
1617         else:
1618             yield from self.visit_default(node)
1619
1620     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1621         """Visit a statement without nested statements."""
1622         is_suite_like = node.parent and node.parent.type in STATEMENT
1623         if is_suite_like:
1624             if self.is_pyi and is_stub_body(node):
1625                 yield from self.visit_default(node)
1626             else:
1627                 yield from self.line(+1)
1628                 yield from self.visit_default(node)
1629                 yield from self.line(-1)
1630
1631         else:
1632             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1633                 yield from self.line()
1634             yield from self.visit_default(node)
1635
1636     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1637         """Visit `async def`, `async for`, `async with`."""
1638         yield from self.line()
1639
1640         children = iter(node.children)
1641         for child in children:
1642             yield from self.visit(child)
1643
1644             if child.type == token.ASYNC:
1645                 break
1646
1647         internal_stmt = next(children)
1648         for child in internal_stmt.children:
1649             yield from self.visit(child)
1650
1651     def visit_decorators(self, node: Node) -> Iterator[Line]:
1652         """Visit decorators."""
1653         for child in node.children:
1654             yield from self.line()
1655             yield from self.visit(child)
1656
1657     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1658         """Remove a semicolon and put the other statement on a separate line."""
1659         yield from self.line()
1660
1661     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1662         """End of file. Process outstanding comments and end with a newline."""
1663         yield from self.visit_default(leaf)
1664         yield from self.line()
1665
1666     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1667         if not self.current_line.bracket_tracker.any_open_brackets():
1668             yield from self.line()
1669         yield from self.visit_default(leaf)
1670
1671     def __attrs_post_init__(self) -> None:
1672         """You are in a twisty little maze of passages."""
1673         v = self.visit_stmt
1674         Ø: Set[str] = set()
1675         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1676         self.visit_if_stmt = partial(
1677             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1678         )
1679         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1680         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1681         self.visit_try_stmt = partial(
1682             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1683         )
1684         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1685         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1686         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1687         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1688         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1689         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1690         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1691         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1692         self.visit_async_funcdef = self.visit_async_stmt
1693         self.visit_decorated = self.visit_decorators
1694
1695
1696 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1697 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1698 OPENING_BRACKETS = set(BRACKET.keys())
1699 CLOSING_BRACKETS = set(BRACKET.values())
1700 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1701 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1702
1703
1704 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1705     """Return whitespace prefix if needed for the given `leaf`.
1706
1707     `complex_subscript` signals whether the given leaf is part of a subscription
1708     which has non-trivial arguments, like arithmetic expressions or function calls.
1709     """
1710     NO = ""
1711     SPACE = " "
1712     DOUBLESPACE = "  "
1713     t = leaf.type
1714     p = leaf.parent
1715     v = leaf.value
1716     if t in ALWAYS_NO_SPACE:
1717         return NO
1718
1719     if t == token.COMMENT:
1720         return DOUBLESPACE
1721
1722     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1723     if t == token.COLON and p.type not in {
1724         syms.subscript,
1725         syms.subscriptlist,
1726         syms.sliceop,
1727     }:
1728         return NO
1729
1730     prev = leaf.prev_sibling
1731     if not prev:
1732         prevp = preceding_leaf(p)
1733         if not prevp or prevp.type in OPENING_BRACKETS:
1734             return NO
1735
1736         if t == token.COLON:
1737             if prevp.type == token.COLON:
1738                 return NO
1739
1740             elif prevp.type != token.COMMA and not complex_subscript:
1741                 return NO
1742
1743             return SPACE
1744
1745         if prevp.type == token.EQUAL:
1746             if prevp.parent:
1747                 if prevp.parent.type in {
1748                     syms.arglist,
1749                     syms.argument,
1750                     syms.parameters,
1751                     syms.varargslist,
1752                 }:
1753                     return NO
1754
1755                 elif prevp.parent.type == syms.typedargslist:
1756                     # A bit hacky: if the equal sign has whitespace, it means we
1757                     # previously found it's a typed argument.  So, we're using
1758                     # that, too.
1759                     return prevp.prefix
1760
1761         elif prevp.type in STARS:
1762             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1763                 return NO
1764
1765         elif prevp.type == token.COLON:
1766             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1767                 return SPACE if complex_subscript else NO
1768
1769         elif (
1770             prevp.parent
1771             and prevp.parent.type == syms.factor
1772             and prevp.type in MATH_OPERATORS
1773         ):
1774             return NO
1775
1776         elif (
1777             prevp.type == token.RIGHTSHIFT
1778             and prevp.parent
1779             and prevp.parent.type == syms.shift_expr
1780             and prevp.prev_sibling
1781             and prevp.prev_sibling.type == token.NAME
1782             and prevp.prev_sibling.value == "print"  # type: ignore
1783         ):
1784             # Python 2 print chevron
1785             return NO
1786
1787     elif prev.type in OPENING_BRACKETS:
1788         return NO
1789
1790     if p.type in {syms.parameters, syms.arglist}:
1791         # untyped function signatures or calls
1792         if not prev or prev.type != token.COMMA:
1793             return NO
1794
1795     elif p.type == syms.varargslist:
1796         # lambdas
1797         if prev and prev.type != token.COMMA:
1798             return NO
1799
1800     elif p.type == syms.typedargslist:
1801         # typed function signatures
1802         if not prev:
1803             return NO
1804
1805         if t == token.EQUAL:
1806             if prev.type != syms.tname:
1807                 return NO
1808
1809         elif prev.type == token.EQUAL:
1810             # A bit hacky: if the equal sign has whitespace, it means we
1811             # previously found it's a typed argument.  So, we're using that, too.
1812             return prev.prefix
1813
1814         elif prev.type != token.COMMA:
1815             return NO
1816
1817     elif p.type == syms.tname:
1818         # type names
1819         if not prev:
1820             prevp = preceding_leaf(p)
1821             if not prevp or prevp.type != token.COMMA:
1822                 return NO
1823
1824     elif p.type == syms.trailer:
1825         # attributes and calls
1826         if t == token.LPAR or t == token.RPAR:
1827             return NO
1828
1829         if not prev:
1830             if t == token.DOT:
1831                 prevp = preceding_leaf(p)
1832                 if not prevp or prevp.type != token.NUMBER:
1833                     return NO
1834
1835             elif t == token.LSQB:
1836                 return NO
1837
1838         elif prev.type != token.COMMA:
1839             return NO
1840
1841     elif p.type == syms.argument:
1842         # single argument
1843         if t == token.EQUAL:
1844             return NO
1845
1846         if not prev:
1847             prevp = preceding_leaf(p)
1848             if not prevp or prevp.type == token.LPAR:
1849                 return NO
1850
1851         elif prev.type in {token.EQUAL} | STARS:
1852             return NO
1853
1854     elif p.type == syms.decorator:
1855         # decorators
1856         return NO
1857
1858     elif p.type == syms.dotted_name:
1859         if prev:
1860             return NO
1861
1862         prevp = preceding_leaf(p)
1863         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1864             return NO
1865
1866     elif p.type == syms.classdef:
1867         if t == token.LPAR:
1868             return NO
1869
1870         if prev and prev.type == token.LPAR:
1871             return NO
1872
1873     elif p.type in {syms.subscript, syms.sliceop}:
1874         # indexing
1875         if not prev:
1876             assert p.parent is not None, "subscripts are always parented"
1877             if p.parent.type == syms.subscriptlist:
1878                 return SPACE
1879
1880             return NO
1881
1882         elif not complex_subscript:
1883             return NO
1884
1885     elif p.type == syms.atom:
1886         if prev and t == token.DOT:
1887             # dots, but not the first one.
1888             return NO
1889
1890     elif p.type == syms.dictsetmaker:
1891         # dict unpacking
1892         if prev and prev.type == token.DOUBLESTAR:
1893             return NO
1894
1895     elif p.type in {syms.factor, syms.star_expr}:
1896         # unary ops
1897         if not prev:
1898             prevp = preceding_leaf(p)
1899             if not prevp or prevp.type in OPENING_BRACKETS:
1900                 return NO
1901
1902             prevp_parent = prevp.parent
1903             assert prevp_parent is not None
1904             if prevp.type == token.COLON and prevp_parent.type in {
1905                 syms.subscript,
1906                 syms.sliceop,
1907             }:
1908                 return NO
1909
1910             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1911                 return NO
1912
1913         elif t in {token.NAME, token.NUMBER, token.STRING}:
1914             return NO
1915
1916     elif p.type == syms.import_from:
1917         if t == token.DOT:
1918             if prev and prev.type == token.DOT:
1919                 return NO
1920
1921         elif t == token.NAME:
1922             if v == "import":
1923                 return SPACE
1924
1925             if prev and prev.type == token.DOT:
1926                 return NO
1927
1928     elif p.type == syms.sliceop:
1929         return NO
1930
1931     return SPACE
1932
1933
1934 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1935     """Return the first leaf that precedes `node`, if any."""
1936     while node:
1937         res = node.prev_sibling
1938         if res:
1939             if isinstance(res, Leaf):
1940                 return res
1941
1942             try:
1943                 return list(res.leaves())[-1]
1944
1945             except IndexError:
1946                 return None
1947
1948         node = node.parent
1949     return None
1950
1951
1952 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1953     """Return the child of `ancestor` that contains `descendant`."""
1954     node: Optional[LN] = descendant
1955     while node and node.parent != ancestor:
1956         node = node.parent
1957     return node
1958
1959
1960 def container_of(leaf: Leaf) -> LN:
1961     """Return `leaf` or one of its ancestors that is the topmost container of it.
1962
1963     By "container" we mean a node where `leaf` is the very first child.
1964     """
1965     same_prefix = leaf.prefix
1966     container: LN = leaf
1967     while container:
1968         parent = container.parent
1969         if parent is None:
1970             break
1971
1972         if parent.children[0].prefix != same_prefix:
1973             break
1974
1975         if parent.type == syms.file_input:
1976             break
1977
1978         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1979             break
1980
1981         container = parent
1982     return container
1983
1984
1985 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1986     """Return the priority of the `leaf` delimiter, given a line break after it.
1987
1988     The delimiter priorities returned here are from those delimiters that would
1989     cause a line break after themselves.
1990
1991     Higher numbers are higher priority.
1992     """
1993     if leaf.type == token.COMMA:
1994         return COMMA_PRIORITY
1995
1996     return 0
1997
1998
1999 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
2000     """Return the priority of the `leaf` delimiter, given a line break before it.
2001
2002     The delimiter priorities returned here are from those delimiters that would
2003     cause a line break before themselves.
2004
2005     Higher numbers are higher priority.
2006     """
2007     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
2008         # * and ** might also be MATH_OPERATORS but in this case they are not.
2009         # Don't treat them as a delimiter.
2010         return 0
2011
2012     if (
2013         leaf.type == token.DOT
2014         and leaf.parent
2015         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
2016         and (previous is None or previous.type in CLOSING_BRACKETS)
2017     ):
2018         return DOT_PRIORITY
2019
2020     if (
2021         leaf.type in MATH_OPERATORS
2022         and leaf.parent
2023         and leaf.parent.type not in {syms.factor, syms.star_expr}
2024     ):
2025         return MATH_PRIORITIES[leaf.type]
2026
2027     if leaf.type in COMPARATORS:
2028         return COMPARATOR_PRIORITY
2029
2030     if (
2031         leaf.type == token.STRING
2032         and previous is not None
2033         and previous.type == token.STRING
2034     ):
2035         return STRING_PRIORITY
2036
2037     if leaf.type not in {token.NAME, token.ASYNC}:
2038         return 0
2039
2040     if (
2041         leaf.value == "for"
2042         and leaf.parent
2043         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2044         or leaf.type == token.ASYNC
2045     ):
2046         if (
2047             not isinstance(leaf.prev_sibling, Leaf)
2048             or leaf.prev_sibling.value != "async"
2049         ):
2050             return COMPREHENSION_PRIORITY
2051
2052     if (
2053         leaf.value == "if"
2054         and leaf.parent
2055         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2056     ):
2057         return COMPREHENSION_PRIORITY
2058
2059     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2060         return TERNARY_PRIORITY
2061
2062     if leaf.value == "is":
2063         return COMPARATOR_PRIORITY
2064
2065     if (
2066         leaf.value == "in"
2067         and leaf.parent
2068         and leaf.parent.type in {syms.comp_op, syms.comparison}
2069         and not (
2070             previous is not None
2071             and previous.type == token.NAME
2072             and previous.value == "not"
2073         )
2074     ):
2075         return COMPARATOR_PRIORITY
2076
2077     if (
2078         leaf.value == "not"
2079         and leaf.parent
2080         and leaf.parent.type == syms.comp_op
2081         and not (
2082             previous is not None
2083             and previous.type == token.NAME
2084             and previous.value == "is"
2085         )
2086     ):
2087         return COMPARATOR_PRIORITY
2088
2089     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2090         return LOGIC_PRIORITY
2091
2092     return 0
2093
2094
2095 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2096 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2097
2098
2099 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2100     """Clean the prefix of the `leaf` and generate comments from it, if any.
2101
2102     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2103     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2104     move because it does away with modifying the grammar to include all the
2105     possible places in which comments can be placed.
2106
2107     The sad consequence for us though is that comments don't "belong" anywhere.
2108     This is why this function generates simple parentless Leaf objects for
2109     comments.  We simply don't know what the correct parent should be.
2110
2111     No matter though, we can live without this.  We really only need to
2112     differentiate between inline and standalone comments.  The latter don't
2113     share the line with any code.
2114
2115     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2116     are emitted with a fake STANDALONE_COMMENT token identifier.
2117     """
2118     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2119         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2120
2121
2122 @dataclass
2123 class ProtoComment:
2124     """Describes a piece of syntax that is a comment.
2125
2126     It's not a :class:`blib2to3.pytree.Leaf` so that:
2127
2128     * it can be cached (`Leaf` objects should not be reused more than once as
2129       they store their lineno, column, prefix, and parent information);
2130     * `newlines` and `consumed` fields are kept separate from the `value`. This
2131       simplifies handling of special marker comments like ``# fmt: off/on``.
2132     """
2133
2134     type: int  # token.COMMENT or STANDALONE_COMMENT
2135     value: str  # content of the comment
2136     newlines: int  # how many newlines before the comment
2137     consumed: int  # how many characters of the original leaf's prefix did we consume
2138
2139
2140 @lru_cache(maxsize=4096)
2141 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2142     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2143     result: List[ProtoComment] = []
2144     if not prefix or "#" not in prefix:
2145         return result
2146
2147     consumed = 0
2148     nlines = 0
2149     ignored_lines = 0
2150     for index, line in enumerate(prefix.split("\n")):
2151         consumed += len(line) + 1  # adding the length of the split '\n'
2152         line = line.lstrip()
2153         if not line:
2154             nlines += 1
2155         if not line.startswith("#"):
2156             # Escaped newlines outside of a comment are not really newlines at
2157             # all. We treat a single-line comment following an escaped newline
2158             # as a simple trailing comment.
2159             if line.endswith("\\"):
2160                 ignored_lines += 1
2161             continue
2162
2163         if index == ignored_lines and not is_endmarker:
2164             comment_type = token.COMMENT  # simple trailing comment
2165         else:
2166             comment_type = STANDALONE_COMMENT
2167         comment = make_comment(line)
2168         result.append(
2169             ProtoComment(
2170                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2171             )
2172         )
2173         nlines = 0
2174     return result
2175
2176
2177 def make_comment(content: str) -> str:
2178     """Return a consistently formatted comment from the given `content` string.
2179
2180     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2181     space between the hash sign and the content.
2182
2183     If `content` didn't start with a hash sign, one is provided.
2184     """
2185     content = content.rstrip()
2186     if not content:
2187         return "#"
2188
2189     if content[0] == "#":
2190         content = content[1:]
2191     if content and content[0] not in " !:#'%":
2192         content = " " + content
2193     return "#" + content
2194
2195
2196 def split_line(
2197     line: Line,
2198     line_length: int,
2199     inner: bool = False,
2200     features: Collection[Feature] = (),
2201 ) -> Iterator[Line]:
2202     """Split a `line` into potentially many lines.
2203
2204     They should fit in the allotted `line_length` but might not be able to.
2205     `inner` signifies that there were a pair of brackets somewhere around the
2206     current `line`, possibly transitively. This means we can fallback to splitting
2207     by delimiters if the LHS/RHS don't yield any results.
2208
2209     `features` are syntactical features that may be used in the output.
2210     """
2211     if line.is_comment:
2212         yield line
2213         return
2214
2215     line_str = str(line).strip("\n")
2216
2217     if (
2218         not line.contains_inner_type_comments()
2219         and not line.should_explode
2220         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2221     ):
2222         yield line
2223         return
2224
2225     split_funcs: List[SplitFunc]
2226     if line.is_def:
2227         split_funcs = [left_hand_split]
2228     else:
2229
2230         def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
2231             for omit in generate_trailers_to_omit(line, line_length):
2232                 lines = list(right_hand_split(line, line_length, features, omit=omit))
2233                 if is_line_short_enough(lines[0], line_length=line_length):
2234                     yield from lines
2235                     return
2236
2237             # All splits failed, best effort split with no omits.
2238             # This mostly happens to multiline strings that are by definition
2239             # reported as not fitting a single line.
2240             yield from right_hand_split(line, line_length, features=features)
2241
2242         if line.inside_brackets:
2243             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2244         else:
2245             split_funcs = [rhs]
2246     for split_func in split_funcs:
2247         # We are accumulating lines in `result` because we might want to abort
2248         # mission and return the original line in the end, or attempt a different
2249         # split altogether.
2250         result: List[Line] = []
2251         try:
2252             for l in split_func(line, features):
2253                 if str(l).strip("\n") == line_str:
2254                     raise CannotSplit("Split function returned an unchanged result")
2255
2256                 result.extend(
2257                     split_line(
2258                         l, line_length=line_length, inner=True, features=features
2259                     )
2260                 )
2261         except CannotSplit:
2262             continue
2263
2264         else:
2265             yield from result
2266             break
2267
2268     else:
2269         yield line
2270
2271
2272 def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2273     """Split line into many lines, starting with the first matching bracket pair.
2274
2275     Note: this usually looks weird, only use this for function definitions.
2276     Prefer RHS otherwise.  This is why this function is not symmetrical with
2277     :func:`right_hand_split` which also handles optional parentheses.
2278     """
2279     tail_leaves: List[Leaf] = []
2280     body_leaves: List[Leaf] = []
2281     head_leaves: List[Leaf] = []
2282     current_leaves = head_leaves
2283     matching_bracket = None
2284     for leaf in line.leaves:
2285         if (
2286             current_leaves is body_leaves
2287             and leaf.type in CLOSING_BRACKETS
2288             and leaf.opening_bracket is matching_bracket
2289         ):
2290             current_leaves = tail_leaves if body_leaves else head_leaves
2291         current_leaves.append(leaf)
2292         if current_leaves is head_leaves:
2293             if leaf.type in OPENING_BRACKETS:
2294                 matching_bracket = leaf
2295                 current_leaves = body_leaves
2296     if not matching_bracket:
2297         raise CannotSplit("No brackets found")
2298
2299     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2300     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2301     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2302     bracket_split_succeeded_or_raise(head, body, tail)
2303     for result in (head, body, tail):
2304         if result:
2305             yield result
2306
2307
2308 def right_hand_split(
2309     line: Line,
2310     line_length: int,
2311     features: Collection[Feature] = (),
2312     omit: Collection[LeafID] = (),
2313 ) -> Iterator[Line]:
2314     """Split line into many lines, starting with the last matching bracket pair.
2315
2316     If the split was by optional parentheses, attempt splitting without them, too.
2317     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2318     this split.
2319
2320     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2321     """
2322     tail_leaves: List[Leaf] = []
2323     body_leaves: List[Leaf] = []
2324     head_leaves: List[Leaf] = []
2325     current_leaves = tail_leaves
2326     opening_bracket = None
2327     closing_bracket = None
2328     for leaf in reversed(line.leaves):
2329         if current_leaves is body_leaves:
2330             if leaf is opening_bracket:
2331                 current_leaves = head_leaves if body_leaves else tail_leaves
2332         current_leaves.append(leaf)
2333         if current_leaves is tail_leaves:
2334             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2335                 opening_bracket = leaf.opening_bracket
2336                 closing_bracket = leaf
2337                 current_leaves = body_leaves
2338     if not (opening_bracket and closing_bracket and head_leaves):
2339         # If there is no opening or closing_bracket that means the split failed and
2340         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2341         # the matching `opening_bracket` wasn't available on `line` anymore.
2342         raise CannotSplit("No brackets found")
2343
2344     tail_leaves.reverse()
2345     body_leaves.reverse()
2346     head_leaves.reverse()
2347     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2348     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2349     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2350     bracket_split_succeeded_or_raise(head, body, tail)
2351     if (
2352         # the body shouldn't be exploded
2353         not body.should_explode
2354         # the opening bracket is an optional paren
2355         and opening_bracket.type == token.LPAR
2356         and not opening_bracket.value
2357         # the closing bracket is an optional paren
2358         and closing_bracket.type == token.RPAR
2359         and not closing_bracket.value
2360         # it's not an import (optional parens are the only thing we can split on
2361         # in this case; attempting a split without them is a waste of time)
2362         and not line.is_import
2363         # there are no standalone comments in the body
2364         and not body.contains_standalone_comments(0)
2365         # and we can actually remove the parens
2366         and can_omit_invisible_parens(body, line_length)
2367     ):
2368         omit = {id(closing_bracket), *omit}
2369         try:
2370             yield from right_hand_split(line, line_length, features=features, omit=omit)
2371             return
2372
2373         except CannotSplit:
2374             if not (
2375                 can_be_split(body)
2376                 or is_line_short_enough(body, line_length=line_length)
2377             ):
2378                 raise CannotSplit(
2379                     "Splitting failed, body is still too long and can't be split."
2380                 )
2381
2382             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2383                 raise CannotSplit(
2384                     "The current optional pair of parentheses is bound to fail to "
2385                     "satisfy the splitting algorithm because the head or the tail "
2386                     "contains multiline strings which by definition never fit one "
2387                     "line."
2388                 )
2389
2390     ensure_visible(opening_bracket)
2391     ensure_visible(closing_bracket)
2392     for result in (head, body, tail):
2393         if result:
2394             yield result
2395
2396
2397 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2398     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2399
2400     Do nothing otherwise.
2401
2402     A left- or right-hand split is based on a pair of brackets. Content before
2403     (and including) the opening bracket is left on one line, content inside the
2404     brackets is put on a separate line, and finally content starting with and
2405     following the closing bracket is put on a separate line.
2406
2407     Those are called `head`, `body`, and `tail`, respectively. If the split
2408     produced the same line (all content in `head`) or ended up with an empty `body`
2409     and the `tail` is just the closing bracket, then it's considered failed.
2410     """
2411     tail_len = len(str(tail).strip())
2412     if not body:
2413         if tail_len == 0:
2414             raise CannotSplit("Splitting brackets produced the same line")
2415
2416         elif tail_len < 3:
2417             raise CannotSplit(
2418                 f"Splitting brackets on an empty body to save "
2419                 f"{tail_len} characters is not worth it"
2420             )
2421
2422
2423 def bracket_split_build_line(
2424     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2425 ) -> Line:
2426     """Return a new line with given `leaves` and respective comments from `original`.
2427
2428     If `is_body` is True, the result line is one-indented inside brackets and as such
2429     has its first leaf's prefix normalized and a trailing comma added when expected.
2430     """
2431     result = Line(depth=original.depth)
2432     if is_body:
2433         result.inside_brackets = True
2434         result.depth += 1
2435         if leaves:
2436             # Since body is a new indent level, remove spurious leading whitespace.
2437             normalize_prefix(leaves[0], inside_brackets=True)
2438             # Ensure a trailing comma for imports, but be careful not to add one after
2439             # any comments.
2440             if original.is_import:
2441                 for i in range(len(leaves) - 1, -1, -1):
2442                     if leaves[i].type == STANDALONE_COMMENT:
2443                         continue
2444                     elif leaves[i].type == token.COMMA:
2445                         break
2446                     else:
2447                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
2448                         break
2449     # Populate the line
2450     for leaf in leaves:
2451         result.append(leaf, preformatted=True)
2452         for comment_after in original.comments_after(leaf):
2453             result.append(comment_after, preformatted=True)
2454     if is_body:
2455         result.should_explode = should_explode(result, opening_bracket)
2456     return result
2457
2458
2459 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2460     """Normalize prefix of the first leaf in every line returned by `split_func`.
2461
2462     This is a decorator over relevant split functions.
2463     """
2464
2465     @wraps(split_func)
2466     def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2467         for l in split_func(line, features):
2468             normalize_prefix(l.leaves[0], inside_brackets=True)
2469             yield l
2470
2471     return split_wrapper
2472
2473
2474 @dont_increase_indentation
2475 def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
2476     """Split according to delimiters of the highest priority.
2477
2478     If the appropriate Features are given, the split will add trailing commas
2479     also in function signatures and calls that contain `*` and `**`.
2480     """
2481     try:
2482         last_leaf = line.leaves[-1]
2483     except IndexError:
2484         raise CannotSplit("Line empty")
2485
2486     bt = line.bracket_tracker
2487     try:
2488         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2489     except ValueError:
2490         raise CannotSplit("No delimiters found")
2491
2492     if delimiter_priority == DOT_PRIORITY:
2493         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2494             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2495
2496     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2497     lowest_depth = sys.maxsize
2498     trailing_comma_safe = True
2499
2500     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2501         """Append `leaf` to current line or to new line if appending impossible."""
2502         nonlocal current_line
2503         try:
2504             current_line.append_safe(leaf, preformatted=True)
2505         except ValueError:
2506             yield current_line
2507
2508             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2509             current_line.append(leaf)
2510
2511     for leaf in line.leaves:
2512         yield from append_to_line(leaf)
2513
2514         for comment_after in line.comments_after(leaf):
2515             yield from append_to_line(comment_after)
2516
2517         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2518         if leaf.bracket_depth == lowest_depth:
2519             if is_vararg(leaf, within={syms.typedargslist}):
2520                 trailing_comma_safe = (
2521                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
2522                 )
2523             elif is_vararg(leaf, within={syms.arglist, syms.argument}):
2524                 trailing_comma_safe = (
2525                     trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
2526                 )
2527
2528         leaf_priority = bt.delimiters.get(id(leaf))
2529         if leaf_priority == delimiter_priority:
2530             yield current_line
2531
2532             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2533     if current_line:
2534         if (
2535             trailing_comma_safe
2536             and delimiter_priority == COMMA_PRIORITY
2537             and current_line.leaves[-1].type != token.COMMA
2538             and current_line.leaves[-1].type != STANDALONE_COMMENT
2539         ):
2540             current_line.append(Leaf(token.COMMA, ","))
2541         yield current_line
2542
2543
2544 @dont_increase_indentation
2545 def standalone_comment_split(
2546     line: Line, features: Collection[Feature] = ()
2547 ) -> Iterator[Line]:
2548     """Split standalone comments from the rest of the line."""
2549     if not line.contains_standalone_comments(0):
2550         raise CannotSplit("Line does not have any standalone comments")
2551
2552     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2553
2554     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2555         """Append `leaf` to current line or to new line if appending impossible."""
2556         nonlocal current_line
2557         try:
2558             current_line.append_safe(leaf, preformatted=True)
2559         except ValueError:
2560             yield current_line
2561
2562             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2563             current_line.append(leaf)
2564
2565     for leaf in line.leaves:
2566         yield from append_to_line(leaf)
2567
2568         for comment_after in line.comments_after(leaf):
2569             yield from append_to_line(comment_after)
2570
2571     if current_line:
2572         yield current_line
2573
2574
2575 def is_import(leaf: Leaf) -> bool:
2576     """Return True if the given leaf starts an import statement."""
2577     p = leaf.parent
2578     t = leaf.type
2579     v = leaf.value
2580     return bool(
2581         t == token.NAME
2582         and (
2583             (v == "import" and p and p.type == syms.import_name)
2584             or (v == "from" and p and p.type == syms.import_from)
2585         )
2586     )
2587
2588
2589 def is_type_comment(leaf: Leaf) -> bool:
2590     """Return True if the given leaf is a special comment.
2591     Only returns true for type comments for now."""
2592     t = leaf.type
2593     v = leaf.value
2594     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2595
2596
2597 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2598     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2599     else.
2600
2601     Note: don't use backslashes for formatting or you'll lose your voting rights.
2602     """
2603     if not inside_brackets:
2604         spl = leaf.prefix.split("#")
2605         if "\\" not in spl[0]:
2606             nl_count = spl[-1].count("\n")
2607             if len(spl) > 1:
2608                 nl_count -= 1
2609             leaf.prefix = "\n" * nl_count
2610             return
2611
2612     leaf.prefix = ""
2613
2614
2615 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2616     """Make all string prefixes lowercase.
2617
2618     If remove_u_prefix is given, also removes any u prefix from the string.
2619
2620     Note: Mutates its argument.
2621     """
2622     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2623     assert match is not None, f"failed to match string {leaf.value!r}"
2624     orig_prefix = match.group(1)
2625     new_prefix = orig_prefix.lower()
2626     if remove_u_prefix:
2627         new_prefix = new_prefix.replace("u", "")
2628     leaf.value = f"{new_prefix}{match.group(2)}"
2629
2630
2631 def normalize_string_quotes(leaf: Leaf) -> None:
2632     """Prefer double quotes but only if it doesn't cause more escaping.
2633
2634     Adds or removes backslashes as appropriate. Doesn't parse and fix
2635     strings nested in f-strings (yet).
2636
2637     Note: Mutates its argument.
2638     """
2639     value = leaf.value.lstrip("furbFURB")
2640     if value[:3] == '"""':
2641         return
2642
2643     elif value[:3] == "'''":
2644         orig_quote = "'''"
2645         new_quote = '"""'
2646     elif value[0] == '"':
2647         orig_quote = '"'
2648         new_quote = "'"
2649     else:
2650         orig_quote = "'"
2651         new_quote = '"'
2652     first_quote_pos = leaf.value.find(orig_quote)
2653     if first_quote_pos == -1:
2654         return  # There's an internal error
2655
2656     prefix = leaf.value[:first_quote_pos]
2657     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2658     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2659     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2660     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2661     if "r" in prefix.casefold():
2662         if unescaped_new_quote.search(body):
2663             # There's at least one unescaped new_quote in this raw string
2664             # so converting is impossible
2665             return
2666
2667         # Do not introduce or remove backslashes in raw strings
2668         new_body = body
2669     else:
2670         # remove unnecessary escapes
2671         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2672         if body != new_body:
2673             # Consider the string without unnecessary escapes as the original
2674             body = new_body
2675             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2676         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2677         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2678     if "f" in prefix.casefold():
2679         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2680         for m in matches:
2681             if "\\" in str(m):
2682                 # Do not introduce backslashes in interpolated expressions
2683                 return
2684     if new_quote == '"""' and new_body[-1:] == '"':
2685         # edge case:
2686         new_body = new_body[:-1] + '\\"'
2687     orig_escape_count = body.count("\\")
2688     new_escape_count = new_body.count("\\")
2689     if new_escape_count > orig_escape_count:
2690         return  # Do not introduce more escaping
2691
2692     if new_escape_count == orig_escape_count and orig_quote == '"':
2693         return  # Prefer double quotes
2694
2695     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2696
2697
2698 def normalize_numeric_literal(leaf: Leaf) -> None:
2699     """Normalizes numeric (float, int, and complex) literals.
2700
2701     All letters used in the representation are normalized to lowercase (except
2702     in Python 2 long literals).
2703     """
2704     text = leaf.value.lower()
2705     if text.startswith(("0o", "0b")):
2706         # Leave octal and binary literals alone.
2707         pass
2708     elif text.startswith("0x"):
2709         # Change hex literals to upper case.
2710         before, after = text[:2], text[2:]
2711         text = f"{before}{after.upper()}"
2712     elif "e" in text:
2713         before, after = text.split("e")
2714         sign = ""
2715         if after.startswith("-"):
2716             after = after[1:]
2717             sign = "-"
2718         elif after.startswith("+"):
2719             after = after[1:]
2720         before = format_float_or_int_string(before)
2721         text = f"{before}e{sign}{after}"
2722     elif text.endswith(("j", "l")):
2723         number = text[:-1]
2724         suffix = text[-1]
2725         # Capitalize in "2L" because "l" looks too similar to "1".
2726         if suffix == "l":
2727             suffix = "L"
2728         text = f"{format_float_or_int_string(number)}{suffix}"
2729     else:
2730         text = format_float_or_int_string(text)
2731     leaf.value = text
2732
2733
2734 def format_float_or_int_string(text: str) -> str:
2735     """Formats a float string like "1.0"."""
2736     if "." not in text:
2737         return text
2738
2739     before, after = text.split(".")
2740     return f"{before or 0}.{after or 0}"
2741
2742
2743 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2744     """Make existing optional parentheses invisible or create new ones.
2745
2746     `parens_after` is a set of string leaf values immeditely after which parens
2747     should be put.
2748
2749     Standardizes on visible parentheses for single-element tuples, and keeps
2750     existing visible parentheses for other tuples and generator expressions.
2751     """
2752     for pc in list_comments(node.prefix, is_endmarker=False):
2753         if pc.value in FMT_OFF:
2754             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2755             return
2756
2757     check_lpar = False
2758     for index, child in enumerate(list(node.children)):
2759         # Add parentheses around long tuple unpacking in assignments.
2760         if (
2761             index == 0
2762             and isinstance(child, Node)
2763             and child.type == syms.testlist_star_expr
2764         ):
2765             check_lpar = True
2766
2767         if check_lpar:
2768             if child.type == syms.atom:
2769                 if maybe_make_parens_invisible_in_atom(child, parent=node):
2770                     lpar = Leaf(token.LPAR, "")
2771                     rpar = Leaf(token.RPAR, "")
2772                     index = child.remove() or 0
2773                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2774             elif is_one_tuple(child):
2775                 # wrap child in visible parentheses
2776                 lpar = Leaf(token.LPAR, "(")
2777                 rpar = Leaf(token.RPAR, ")")
2778                 child.remove()
2779                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2780             elif node.type == syms.import_from:
2781                 # "import from" nodes store parentheses directly as part of
2782                 # the statement
2783                 if child.type == token.LPAR:
2784                     # make parentheses invisible
2785                     child.value = ""  # type: ignore
2786                     node.children[-1].value = ""  # type: ignore
2787                 elif child.type != token.STAR:
2788                     # insert invisible parentheses
2789                     node.insert_child(index, Leaf(token.LPAR, ""))
2790                     node.append_child(Leaf(token.RPAR, ""))
2791                 break
2792
2793             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2794                 # wrap child in invisible parentheses
2795                 lpar = Leaf(token.LPAR, "")
2796                 rpar = Leaf(token.RPAR, "")
2797                 index = child.remove() or 0
2798                 prefix = child.prefix
2799                 child.prefix = ""
2800                 new_child = Node(syms.atom, [lpar, child, rpar])
2801                 new_child.prefix = prefix
2802                 node.insert_child(index, new_child)
2803
2804         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2805
2806
2807 def normalize_fmt_off(node: Node) -> None:
2808     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2809     try_again = True
2810     while try_again:
2811         try_again = convert_one_fmt_off_pair(node)
2812
2813
2814 def convert_one_fmt_off_pair(node: Node) -> bool:
2815     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2816
2817     Returns True if a pair was converted.
2818     """
2819     for leaf in node.leaves():
2820         previous_consumed = 0
2821         for comment in list_comments(leaf.prefix, is_endmarker=False):
2822             if comment.value in FMT_OFF:
2823                 # We only want standalone comments. If there's no previous leaf or
2824                 # the previous leaf is indentation, it's a standalone comment in
2825                 # disguise.
2826                 if comment.type != STANDALONE_COMMENT:
2827                     prev = preceding_leaf(leaf)
2828                     if prev and prev.type not in WHITESPACE:
2829                         continue
2830
2831                 ignored_nodes = list(generate_ignored_nodes(leaf))
2832                 if not ignored_nodes:
2833                     continue
2834
2835                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2836                 parent = first.parent
2837                 prefix = first.prefix
2838                 first.prefix = prefix[comment.consumed :]
2839                 hidden_value = (
2840                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2841                 )
2842                 if hidden_value.endswith("\n"):
2843                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2844                     # leaf (possibly followed by a DEDENT).
2845                     hidden_value = hidden_value[:-1]
2846                 first_idx = None
2847                 for ignored in ignored_nodes:
2848                     index = ignored.remove()
2849                     if first_idx is None:
2850                         first_idx = index
2851                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2852                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2853                 parent.insert_child(
2854                     first_idx,
2855                     Leaf(
2856                         STANDALONE_COMMENT,
2857                         hidden_value,
2858                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2859                     ),
2860                 )
2861                 return True
2862
2863             previous_consumed = comment.consumed
2864
2865     return False
2866
2867
2868 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2869     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2870
2871     Stops at the end of the block.
2872     """
2873     container: Optional[LN] = container_of(leaf)
2874     while container is not None and container.type != token.ENDMARKER:
2875         for comment in list_comments(container.prefix, is_endmarker=False):
2876             if comment.value in FMT_ON:
2877                 return
2878
2879         yield container
2880
2881         container = container.next_sibling
2882
2883
2884 def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
2885     """If it's safe, make the parens in the atom `node` invisible, recursively.
2886
2887     Returns whether the node should itself be wrapped in invisible parentheses.
2888
2889     """
2890     if (
2891         node.type != syms.atom
2892         or is_empty_tuple(node)
2893         or is_one_tuple(node)
2894         or (is_yield(node) and parent.type != syms.expr_stmt)
2895         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2896     ):
2897         return False
2898
2899     first = node.children[0]
2900     last = node.children[-1]
2901     if first.type == token.LPAR and last.type == token.RPAR:
2902         # make parentheses invisible
2903         first.value = ""  # type: ignore
2904         last.value = ""  # type: ignore
2905         if len(node.children) > 1:
2906             maybe_make_parens_invisible_in_atom(node.children[1], parent=parent)
2907         return False
2908
2909     return True
2910
2911
2912 def is_empty_tuple(node: LN) -> bool:
2913     """Return True if `node` holds an empty tuple."""
2914     return (
2915         node.type == syms.atom
2916         and len(node.children) == 2
2917         and node.children[0].type == token.LPAR
2918         and node.children[1].type == token.RPAR
2919     )
2920
2921
2922 def is_one_tuple(node: LN) -> bool:
2923     """Return True if `node` holds a tuple with one element, with or without parens."""
2924     if node.type == syms.atom:
2925         if len(node.children) != 3:
2926             return False
2927
2928         lpar, gexp, rpar = node.children
2929         if not (
2930             lpar.type == token.LPAR
2931             and gexp.type == syms.testlist_gexp
2932             and rpar.type == token.RPAR
2933         ):
2934             return False
2935
2936         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2937
2938     return (
2939         node.type in IMPLICIT_TUPLE
2940         and len(node.children) == 2
2941         and node.children[1].type == token.COMMA
2942     )
2943
2944
2945 def is_yield(node: LN) -> bool:
2946     """Return True if `node` holds a `yield` or `yield from` expression."""
2947     if node.type == syms.yield_expr:
2948         return True
2949
2950     if node.type == token.NAME and node.value == "yield":  # type: ignore
2951         return True
2952
2953     if node.type != syms.atom:
2954         return False
2955
2956     if len(node.children) != 3:
2957         return False
2958
2959     lpar, expr, rpar = node.children
2960     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2961         return is_yield(expr)
2962
2963     return False
2964
2965
2966 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2967     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2968
2969     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2970     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2971     extended iterable unpacking (PEP 3132) and additional unpacking
2972     generalizations (PEP 448).
2973     """
2974     if leaf.type not in STARS or not leaf.parent:
2975         return False
2976
2977     p = leaf.parent
2978     if p.type == syms.star_expr:
2979         # Star expressions are also used as assignment targets in extended
2980         # iterable unpacking (PEP 3132).  See what its parent is instead.
2981         if not p.parent:
2982             return False
2983
2984         p = p.parent
2985
2986     return p.type in within
2987
2988
2989 def is_multiline_string(leaf: Leaf) -> bool:
2990     """Return True if `leaf` is a multiline string that actually spans many lines."""
2991     value = leaf.value.lstrip("furbFURB")
2992     return value[:3] in {'"""', "'''"} and "\n" in value
2993
2994
2995 def is_stub_suite(node: Node) -> bool:
2996     """Return True if `node` is a suite with a stub body."""
2997     if (
2998         len(node.children) != 4
2999         or node.children[0].type != token.NEWLINE
3000         or node.children[1].type != token.INDENT
3001         or node.children[3].type != token.DEDENT
3002     ):
3003         return False
3004
3005     return is_stub_body(node.children[2])
3006
3007
3008 def is_stub_body(node: LN) -> bool:
3009     """Return True if `node` is a simple statement containing an ellipsis."""
3010     if not isinstance(node, Node) or node.type != syms.simple_stmt:
3011         return False
3012
3013     if len(node.children) != 2:
3014         return False
3015
3016     child = node.children[0]
3017     return (
3018         child.type == syms.atom
3019         and len(child.children) == 3
3020         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
3021     )
3022
3023
3024 def max_delimiter_priority_in_atom(node: LN) -> int:
3025     """Return maximum delimiter priority inside `node`.
3026
3027     This is specific to atoms with contents contained in a pair of parentheses.
3028     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
3029     """
3030     if node.type != syms.atom:
3031         return 0
3032
3033     first = node.children[0]
3034     last = node.children[-1]
3035     if not (first.type == token.LPAR and last.type == token.RPAR):
3036         return 0
3037
3038     bt = BracketTracker()
3039     for c in node.children[1:-1]:
3040         if isinstance(c, Leaf):
3041             bt.mark(c)
3042         else:
3043             for leaf in c.leaves():
3044                 bt.mark(leaf)
3045     try:
3046         return bt.max_delimiter_priority()
3047
3048     except ValueError:
3049         return 0
3050
3051
3052 def ensure_visible(leaf: Leaf) -> None:
3053     """Make sure parentheses are visible.
3054
3055     They could be invisible as part of some statements (see
3056     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3057     """
3058     if leaf.type == token.LPAR:
3059         leaf.value = "("
3060     elif leaf.type == token.RPAR:
3061         leaf.value = ")"
3062
3063
3064 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3065     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3066
3067     if not (
3068         opening_bracket.parent
3069         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3070         and opening_bracket.value in "[{("
3071     ):
3072         return False
3073
3074     try:
3075         last_leaf = line.leaves[-1]
3076         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3077         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3078     except (IndexError, ValueError):
3079         return False
3080
3081     return max_priority == COMMA_PRIORITY
3082
3083
3084 def get_features_used(node: Node) -> Set[Feature]:
3085     """Return a set of (relatively) new Python features used in this file.
3086
3087     Currently looking for:
3088     - f-strings;
3089     - underscores in numeric literals; and
3090     - trailing commas after * or ** in function signatures and calls.
3091     """
3092     features: Set[Feature] = set()
3093     for n in node.pre_order():
3094         if n.type == token.STRING:
3095             value_head = n.value[:2]  # type: ignore
3096             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3097                 features.add(Feature.F_STRINGS)
3098
3099         elif n.type == token.NUMBER:
3100             if "_" in n.value:  # type: ignore
3101                 features.add(Feature.NUMERIC_UNDERSCORES)
3102
3103         elif (
3104             n.type in {syms.typedargslist, syms.arglist}
3105             and n.children
3106             and n.children[-1].type == token.COMMA
3107         ):
3108             if n.type == syms.typedargslist:
3109                 feature = Feature.TRAILING_COMMA_IN_DEF
3110             else:
3111                 feature = Feature.TRAILING_COMMA_IN_CALL
3112
3113             for ch in n.children:
3114                 if ch.type in STARS:
3115                     features.add(feature)
3116
3117                 if ch.type == syms.argument:
3118                     for argch in ch.children:
3119                         if argch.type in STARS:
3120                             features.add(feature)
3121
3122     return features
3123
3124
3125 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3126     """Detect the version to target based on the nodes used."""
3127     features = get_features_used(node)
3128     return {
3129         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3130     }
3131
3132
3133 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3134     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3135
3136     Brackets can be omitted if the entire trailer up to and including
3137     a preceding closing bracket fits in one line.
3138
3139     Yielded sets are cumulative (contain results of previous yields, too).  First
3140     set is empty.
3141     """
3142
3143     omit: Set[LeafID] = set()
3144     yield omit
3145
3146     length = 4 * line.depth
3147     opening_bracket = None
3148     closing_bracket = None
3149     inner_brackets: Set[LeafID] = set()
3150     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3151         length += leaf_length
3152         if length > line_length:
3153             break
3154
3155         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3156         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3157             break
3158
3159         if opening_bracket:
3160             if leaf is opening_bracket:
3161                 opening_bracket = None
3162             elif leaf.type in CLOSING_BRACKETS:
3163                 inner_brackets.add(id(leaf))
3164         elif leaf.type in CLOSING_BRACKETS:
3165             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3166                 # Empty brackets would fail a split so treat them as "inner"
3167                 # brackets (e.g. only add them to the `omit` set if another
3168                 # pair of brackets was good enough.
3169                 inner_brackets.add(id(leaf))
3170                 continue
3171
3172             if closing_bracket:
3173                 omit.add(id(closing_bracket))
3174                 omit.update(inner_brackets)
3175                 inner_brackets.clear()
3176                 yield omit
3177
3178             if leaf.value:
3179                 opening_bracket = leaf.opening_bracket
3180                 closing_bracket = leaf
3181
3182
3183 def get_future_imports(node: Node) -> Set[str]:
3184     """Return a set of __future__ imports in the file."""
3185     imports: Set[str] = set()
3186
3187     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3188         for child in children:
3189             if isinstance(child, Leaf):
3190                 if child.type == token.NAME:
3191                     yield child.value
3192             elif child.type == syms.import_as_name:
3193                 orig_name = child.children[0]
3194                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3195                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3196                 yield orig_name.value
3197             elif child.type == syms.import_as_names:
3198                 yield from get_imports_from_children(child.children)
3199             else:
3200                 raise AssertionError("Invalid syntax parsing imports")
3201
3202     for child in node.children:
3203         if child.type != syms.simple_stmt:
3204             break
3205         first_child = child.children[0]
3206         if isinstance(first_child, Leaf):
3207             # Continue looking if we see a docstring; otherwise stop.
3208             if (
3209                 len(child.children) == 2
3210                 and first_child.type == token.STRING
3211                 and child.children[1].type == token.NEWLINE
3212             ):
3213                 continue
3214             else:
3215                 break
3216         elif first_child.type == syms.import_from:
3217             module_name = first_child.children[1]
3218             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3219                 break
3220             imports |= set(get_imports_from_children(first_child.children[3:]))
3221         else:
3222             break
3223     return imports
3224
3225
3226 def gen_python_files_in_dir(
3227     path: Path,
3228     root: Path,
3229     include: Pattern[str],
3230     exclude: Pattern[str],
3231     report: "Report",
3232 ) -> Iterator[Path]:
3233     """Generate all files under `path` whose paths are not excluded by the
3234     `exclude` regex, but are included by the `include` regex.
3235
3236     Symbolic links pointing outside of the `root` directory are ignored.
3237
3238     `report` is where output about exclusions goes.
3239     """
3240     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3241     for child in path.iterdir():
3242         try:
3243             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3244         except ValueError:
3245             if child.is_symlink():
3246                 report.path_ignored(
3247                     child, f"is a symbolic link that points outside {root}"
3248                 )
3249                 continue
3250
3251             raise
3252
3253         if child.is_dir():
3254             normalized_path += "/"
3255         exclude_match = exclude.search(normalized_path)
3256         if exclude_match and exclude_match.group(0):
3257             report.path_ignored(child, f"matches the --exclude regular expression")
3258             continue
3259
3260         if child.is_dir():
3261             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3262
3263         elif child.is_file():
3264             include_match = include.search(normalized_path)
3265             if include_match:
3266                 yield child
3267
3268
3269 @lru_cache()
3270 def find_project_root(srcs: Iterable[str]) -> Path:
3271     """Return a directory containing .git, .hg, or pyproject.toml.
3272
3273     That directory can be one of the directories passed in `srcs` or their
3274     common parent.
3275
3276     If no directory in the tree contains a marker that would specify it's the
3277     project root, the root of the file system is returned.
3278     """
3279     if not srcs:
3280         return Path("/").resolve()
3281
3282     common_base = min(Path(src).resolve() for src in srcs)
3283     if common_base.is_dir():
3284         # Append a fake file so `parents` below returns `common_base_dir`, too.
3285         common_base /= "fake-file"
3286     for directory in common_base.parents:
3287         if (directory / ".git").is_dir():
3288             return directory
3289
3290         if (directory / ".hg").is_dir():
3291             return directory
3292
3293         if (directory / "pyproject.toml").is_file():
3294             return directory
3295
3296     return directory
3297
3298
3299 @dataclass
3300 class Report:
3301     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3302
3303     check: bool = False
3304     quiet: bool = False
3305     verbose: bool = False
3306     change_count: int = 0
3307     same_count: int = 0
3308     failure_count: int = 0
3309
3310     def done(self, src: Path, changed: Changed) -> None:
3311         """Increment the counter for successful reformatting. Write out a message."""
3312         if changed is Changed.YES:
3313             reformatted = "would reformat" if self.check else "reformatted"
3314             if self.verbose or not self.quiet:
3315                 out(f"{reformatted} {src}")
3316             self.change_count += 1
3317         else:
3318             if self.verbose:
3319                 if changed is Changed.NO:
3320                     msg = f"{src} already well formatted, good job."
3321                 else:
3322                     msg = f"{src} wasn't modified on disk since last run."
3323                 out(msg, bold=False)
3324             self.same_count += 1
3325
3326     def failed(self, src: Path, message: str) -> None:
3327         """Increment the counter for failed reformatting. Write out a message."""
3328         err(f"error: cannot format {src}: {message}")
3329         self.failure_count += 1
3330
3331     def path_ignored(self, path: Path, message: str) -> None:
3332         if self.verbose:
3333             out(f"{path} ignored: {message}", bold=False)
3334
3335     @property
3336     def return_code(self) -> int:
3337         """Return the exit code that the app should use.
3338
3339         This considers the current state of changed files and failures:
3340         - if there were any failures, return 123;
3341         - if any files were changed and --check is being used, return 1;
3342         - otherwise return 0.
3343         """
3344         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3345         # 126 we have special return codes reserved by the shell.
3346         if self.failure_count:
3347             return 123
3348
3349         elif self.change_count and self.check:
3350             return 1
3351
3352         return 0
3353
3354     def __str__(self) -> str:
3355         """Render a color report of the current state.
3356
3357         Use `click.unstyle` to remove colors.
3358         """
3359         if self.check:
3360             reformatted = "would be reformatted"
3361             unchanged = "would be left unchanged"
3362             failed = "would fail to reformat"
3363         else:
3364             reformatted = "reformatted"
3365             unchanged = "left unchanged"
3366             failed = "failed to reformat"
3367         report = []
3368         if self.change_count:
3369             s = "s" if self.change_count > 1 else ""
3370             report.append(
3371                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3372             )
3373         if self.same_count:
3374             s = "s" if self.same_count > 1 else ""
3375             report.append(f"{self.same_count} file{s} {unchanged}")
3376         if self.failure_count:
3377             s = "s" if self.failure_count > 1 else ""
3378             report.append(
3379                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3380             )
3381         return ", ".join(report) + "."
3382
3383
3384 def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
3385     for feature_version in (7, 6):
3386         try:
3387             return ast3.parse(src, feature_version=feature_version)
3388         except SyntaxError:
3389             continue
3390
3391     return ast27.parse(src)
3392
3393
3394 def assert_equivalent(src: str, dst: str) -> None:
3395     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3396
3397     import traceback
3398
3399     def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
3400         """Simple visitor generating strings to compare ASTs by content."""
3401         yield f"{'  ' * depth}{node.__class__.__name__}("
3402
3403         for field in sorted(node._fields):
3404             # TypeIgnore has only one field 'lineno' which breaks this comparison
3405             if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
3406                 break
3407
3408             # Ignore str kind which is case sensitive / and ignores unicode_literals
3409             if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
3410                 continue
3411
3412             try:
3413                 value = getattr(node, field)
3414             except AttributeError:
3415                 continue
3416
3417             yield f"{'  ' * (depth+1)}{field}="
3418
3419             if isinstance(value, list):
3420                 for item in value:
3421                     # Ignore nested tuples within del statements, because we may insert
3422                     # parentheses and they change the AST.
3423                     if (
3424                         field == "targets"
3425                         and isinstance(node, (ast3.Delete, ast27.Delete))
3426                         and isinstance(item, (ast3.Tuple, ast27.Tuple))
3427                     ):
3428                         for item in item.elts:
3429                             yield from _v(item, depth + 2)
3430                     elif isinstance(item, (ast3.AST, ast27.AST)):
3431                         yield from _v(item, depth + 2)
3432
3433             elif isinstance(value, (ast3.AST, ast27.AST)):
3434                 yield from _v(value, depth + 2)
3435
3436             else:
3437                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3438
3439         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3440
3441     try:
3442         src_ast = parse_ast(src)
3443     except Exception as exc:
3444         raise AssertionError(
3445             f"cannot use --safe with this file; failed to parse source file.  "
3446             f"AST error message: {exc}"
3447         )
3448
3449     try:
3450         dst_ast = parse_ast(dst)
3451     except Exception as exc:
3452         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3453         raise AssertionError(
3454             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3455             f"Please report a bug on https://github.com/python/black/issues.  "
3456             f"This invalid output might be helpful: {log}"
3457         ) from None
3458
3459     src_ast_str = "\n".join(_v(src_ast))
3460     dst_ast_str = "\n".join(_v(dst_ast))
3461     if src_ast_str != dst_ast_str:
3462         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3463         raise AssertionError(
3464             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3465             f"the source.  "
3466             f"Please report a bug on https://github.com/python/black/issues.  "
3467             f"This diff might be helpful: {log}"
3468         ) from None
3469
3470
3471 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3472     """Raise AssertionError if `dst` reformats differently the second time."""
3473     newdst = format_str(dst, mode=mode)
3474     if dst != newdst:
3475         log = dump_to_file(
3476             diff(src, dst, "source", "first pass"),
3477             diff(dst, newdst, "first pass", "second pass"),
3478         )
3479         raise AssertionError(
3480             f"INTERNAL ERROR: Black produced different code on the second pass "
3481             f"of the formatter.  "
3482             f"Please report a bug on https://github.com/python/black/issues.  "
3483             f"This diff might be helpful: {log}"
3484         ) from None
3485
3486
3487 def dump_to_file(*output: str) -> str:
3488     """Dump `output` to a temporary file. Return path to the file."""
3489     import tempfile
3490
3491     with tempfile.NamedTemporaryFile(
3492         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3493     ) as f:
3494         for lines in output:
3495             f.write(lines)
3496             if lines and lines[-1] != "\n":
3497                 f.write("\n")
3498     return f.name
3499
3500
3501 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3502     """Return a unified diff string between strings `a` and `b`."""
3503     import difflib
3504
3505     a_lines = [line + "\n" for line in a.split("\n")]
3506     b_lines = [line + "\n" for line in b.split("\n")]
3507     return "".join(
3508         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3509     )
3510
3511
3512 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3513     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3514     err("Aborted!")
3515     for task in tasks:
3516         task.cancel()
3517
3518
3519 def shutdown(loop: BaseEventLoop) -> None:
3520     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3521     try:
3522         if sys.version_info[:2] >= (3, 7):
3523             all_tasks = asyncio.all_tasks
3524         else:
3525             all_tasks = asyncio.Task.all_tasks
3526         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3527         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3528         if not to_cancel:
3529             return
3530
3531         for task in to_cancel:
3532             task.cancel()
3533         loop.run_until_complete(
3534             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3535         )
3536     finally:
3537         # `concurrent.futures.Future` objects cannot be cancelled once they
3538         # are already running. There might be some when the `shutdown()` happened.
3539         # Silence their logger's spew about the event loop being closed.
3540         cf_logger = logging.getLogger("concurrent.futures")
3541         cf_logger.setLevel(logging.CRITICAL)
3542         loop.close()
3543
3544
3545 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3546     """Replace `regex` with `replacement` twice on `original`.
3547
3548     This is used by string normalization to perform replaces on
3549     overlapping matches.
3550     """
3551     return regex.sub(replacement, regex.sub(replacement, original))
3552
3553
3554 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3555     """Compile a regular expression string in `regex`.
3556
3557     If it contains newlines, use verbose mode.
3558     """
3559     if "\n" in regex:
3560         regex = "(?x)" + regex
3561     return re.compile(regex)
3562
3563
3564 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3565     """Like `reversed(enumerate(sequence))` if that were possible."""
3566     index = len(sequence) - 1
3567     for element in reversed(sequence):
3568         yield (index, element)
3569         index -= 1
3570
3571
3572 def enumerate_with_length(
3573     line: Line, reversed: bool = False
3574 ) -> Iterator[Tuple[Index, Leaf, int]]:
3575     """Return an enumeration of leaves with their length.
3576
3577     Stops prematurely on multiline strings and standalone comments.
3578     """
3579     op = cast(
3580         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3581         enumerate_reversed if reversed else enumerate,
3582     )
3583     for index, leaf in op(line.leaves):
3584         length = len(leaf.prefix) + len(leaf.value)
3585         if "\n" in leaf.value:
3586             return  # Multiline strings, we can't continue.
3587
3588         comment: Optional[Leaf]
3589         for comment in line.comments_after(leaf):
3590             length += len(comment.value)
3591
3592         yield index, leaf, length
3593
3594
3595 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3596     """Return True if `line` is no longer than `line_length`.
3597
3598     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3599     """
3600     if not line_str:
3601         line_str = str(line).strip("\n")
3602     return (
3603         len(line_str) <= line_length
3604         and "\n" not in line_str  # multiline strings
3605         and not line.contains_standalone_comments()
3606     )
3607
3608
3609 def can_be_split(line: Line) -> bool:
3610     """Return False if the line cannot be split *for sure*.
3611
3612     This is not an exhaustive search but a cheap heuristic that we can use to
3613     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3614     in unnecessary parentheses).
3615     """
3616     leaves = line.leaves
3617     if len(leaves) < 2:
3618         return False
3619
3620     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3621         call_count = 0
3622         dot_count = 0
3623         next = leaves[-1]
3624         for leaf in leaves[-2::-1]:
3625             if leaf.type in OPENING_BRACKETS:
3626                 if next.type not in CLOSING_BRACKETS:
3627                     return False
3628
3629                 call_count += 1
3630             elif leaf.type == token.DOT:
3631                 dot_count += 1
3632             elif leaf.type == token.NAME:
3633                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3634                     return False
3635
3636             elif leaf.type not in CLOSING_BRACKETS:
3637                 return False
3638
3639             if dot_count > 1 and call_count > 1:
3640                 return False
3641
3642     return True
3643
3644
3645 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3646     """Does `line` have a shape safe to reformat without optional parens around it?
3647
3648     Returns True for only a subset of potentially nice looking formattings but
3649     the point is to not return false positives that end up producing lines that
3650     are too long.
3651     """
3652     bt = line.bracket_tracker
3653     if not bt.delimiters:
3654         # Without delimiters the optional parentheses are useless.
3655         return True
3656
3657     max_priority = bt.max_delimiter_priority()
3658     if bt.delimiter_count_with_priority(max_priority) > 1:
3659         # With more than one delimiter of a kind the optional parentheses read better.
3660         return False
3661
3662     if max_priority == DOT_PRIORITY:
3663         # A single stranded method call doesn't require optional parentheses.
3664         return True
3665
3666     assert len(line.leaves) >= 2, "Stranded delimiter"
3667
3668     first = line.leaves[0]
3669     second = line.leaves[1]
3670     penultimate = line.leaves[-2]
3671     last = line.leaves[-1]
3672
3673     # With a single delimiter, omit if the expression starts or ends with
3674     # a bracket.
3675     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3676         remainder = False
3677         length = 4 * line.depth
3678         for _index, leaf, leaf_length in enumerate_with_length(line):
3679             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3680                 remainder = True
3681             if remainder:
3682                 length += leaf_length
3683                 if length > line_length:
3684                     break
3685
3686                 if leaf.type in OPENING_BRACKETS:
3687                     # There are brackets we can further split on.
3688                     remainder = False
3689
3690         else:
3691             # checked the entire string and line length wasn't exceeded
3692             if len(line.leaves) == _index + 1:
3693                 return True
3694
3695         # Note: we are not returning False here because a line might have *both*
3696         # a leading opening bracket and a trailing closing bracket.  If the
3697         # opening bracket doesn't match our rule, maybe the closing will.
3698
3699     if (
3700         last.type == token.RPAR
3701         or last.type == token.RBRACE
3702         or (
3703             # don't use indexing for omitting optional parentheses;
3704             # it looks weird
3705             last.type == token.RSQB
3706             and last.parent
3707             and last.parent.type != syms.trailer
3708         )
3709     ):
3710         if penultimate.type in OPENING_BRACKETS:
3711             # Empty brackets don't help.
3712             return False
3713
3714         if is_multiline_string(first):
3715             # Additional wrapping of a multiline string in this situation is
3716             # unnecessary.
3717             return True
3718
3719         length = 4 * line.depth
3720         seen_other_brackets = False
3721         for _index, leaf, leaf_length in enumerate_with_length(line):
3722             length += leaf_length
3723             if leaf is last.opening_bracket:
3724                 if seen_other_brackets or length <= line_length:
3725                     return True
3726
3727             elif leaf.type in OPENING_BRACKETS:
3728                 # There are brackets we can further split on.
3729                 seen_other_brackets = True
3730
3731     return False
3732
3733
3734 def get_cache_file(mode: FileMode) -> Path:
3735     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3736
3737
3738 def read_cache(mode: FileMode) -> Cache:
3739     """Read the cache if it exists and is well formed.
3740
3741     If it is not well formed, the call to write_cache later should resolve the issue.
3742     """
3743     cache_file = get_cache_file(mode)
3744     if not cache_file.exists():
3745         return {}
3746
3747     with cache_file.open("rb") as fobj:
3748         try:
3749             cache: Cache = pickle.load(fobj)
3750         except pickle.UnpicklingError:
3751             return {}
3752
3753     return cache
3754
3755
3756 def get_cache_info(path: Path) -> CacheInfo:
3757     """Return the information used to check if a file is already formatted or not."""
3758     stat = path.stat()
3759     return stat.st_mtime, stat.st_size
3760
3761
3762 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3763     """Split an iterable of paths in `sources` into two sets.
3764
3765     The first contains paths of files that modified on disk or are not in the
3766     cache. The other contains paths to non-modified files.
3767     """
3768     todo, done = set(), set()
3769     for src in sources:
3770         src = src.resolve()
3771         if cache.get(src) != get_cache_info(src):
3772             todo.add(src)
3773         else:
3774             done.add(src)
3775     return todo, done
3776
3777
3778 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3779     """Update the cache file."""
3780     cache_file = get_cache_file(mode)
3781     try:
3782         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3783         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3784         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3785             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3786         os.replace(f.name, cache_file)
3787     except OSError:
3788         pass
3789
3790
3791 def patch_click() -> None:
3792     """Make Click not crash.
3793
3794     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3795     default which restricts paths that it can access during the lifetime of the
3796     application.  Click refuses to work in this scenario by raising a RuntimeError.
3797
3798     In case of Black the likelihood that non-ASCII characters are going to be used in
3799     file paths is minimal since it's Python source code.  Moreover, this crash was
3800     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3801     """
3802     try:
3803         from click import core
3804         from click import _unicodefun  # type: ignore
3805     except ModuleNotFoundError:
3806         return
3807
3808     for module in (core, _unicodefun):
3809         if hasattr(module, "_verify_python3_env"):
3810             module._verify_python3_env = lambda: None
3811
3812
3813 def patched_main() -> None:
3814     freeze_support()
3815     patch_click()
3816     main()
3817
3818
3819 if __name__ == "__main__":
3820     patched_main()