]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Prefer https:// links where available (#485)
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum, Flag
6 from functools import lru_cache, partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tokenize
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generator,
24     Generic,
25     Iterable,
26     Iterator,
27     List,
28     Optional,
29     Pattern,
30     Sequence,
31     Set,
32     Tuple,
33     TypeVar,
34     Union,
35     cast,
36 )
37
38 from appdirs import user_cache_dir
39 from attr import dataclass, Factory
40 import click
41 import toml
42
43 # lib2to3 fork
44 from blib2to3.pytree import Node, Leaf, type_repr
45 from blib2to3 import pygram, pytree
46 from blib2to3.pgen2 import driver, token
47 from blib2to3.pgen2.parse import ParseError
48
49
50 __version__ = "18.6b4"
51 DEFAULT_LINE_LENGTH = 88
52 DEFAULT_EXCLUDES = (
53     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
54 )
55 DEFAULT_INCLUDES = r"\.pyi?$"
56 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
57
58
59 # types
60 FileContent = str
61 Encoding = str
62 NewLine = str
63 Depth = int
64 NodeType = int
65 LeafID = int
66 Priority = int
67 Index = int
68 LN = Union[Leaf, Node]
69 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
70 Timestamp = float
71 FileSize = int
72 CacheInfo = Tuple[Timestamp, FileSize]
73 Cache = Dict[Path, CacheInfo]
74 out = partial(click.secho, bold=True, err=True)
75 err = partial(click.secho, fg="red", err=True)
76
77 pygram.initialize(CACHE_DIR)
78 syms = pygram.python_symbols
79
80
81 class NothingChanged(UserWarning):
82     """Raised by :func:`format_file` when reformatted code is the same as source."""
83
84
85 class CannotSplit(Exception):
86     """A readable split that fits the allotted line length is impossible.
87
88     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
89     :func:`delimiter_split`.
90     """
91
92
93 class WriteBack(Enum):
94     NO = 0
95     YES = 1
96     DIFF = 2
97     CHECK = 3
98
99     @classmethod
100     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
101         if check and not diff:
102             return cls.CHECK
103
104         return cls.DIFF if diff else cls.YES
105
106
107 class Changed(Enum):
108     NO = 0
109     CACHED = 1
110     YES = 2
111
112
113 class FileMode(Flag):
114     AUTO_DETECT = 0
115     PYTHON36 = 1
116     PYI = 2
117     NO_STRING_NORMALIZATION = 4
118
119     @classmethod
120     def from_configuration(
121         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
122     ) -> "FileMode":
123         mode = cls.AUTO_DETECT
124         if py36:
125             mode |= cls.PYTHON36
126         if pyi:
127             mode |= cls.PYI
128         if skip_string_normalization:
129             mode |= cls.NO_STRING_NORMALIZATION
130         return mode
131
132
133 def read_pyproject_toml(
134     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
135 ) -> Optional[str]:
136     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
137
138     Returns the path to a successfully found and read configuration file, None
139     otherwise.
140     """
141     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
142     if not value:
143         root = find_project_root(ctx.params.get("src", ()))
144         path = root / "pyproject.toml"
145         if path.is_file():
146             value = str(path)
147         else:
148             return None
149
150     try:
151         pyproject_toml = toml.load(value)
152         config = pyproject_toml.get("tool", {}).get("black", {})
153     except (toml.TomlDecodeError, OSError) as e:
154         raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
155
156     if not config:
157         return None
158
159     if ctx.default_map is None:
160         ctx.default_map = {}
161     ctx.default_map.update(  # type: ignore  # bad types in .pyi
162         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
163     )
164     return value
165
166
167 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
168 @click.option(
169     "-l",
170     "--line-length",
171     type=int,
172     default=DEFAULT_LINE_LENGTH,
173     help="How many characters per line to allow.",
174     show_default=True,
175 )
176 @click.option(
177     "--py36",
178     is_flag=True,
179     help=(
180         "Allow using Python 3.6-only syntax on all input files.  This will put "
181         "trailing commas in function signatures and calls also after *args and "
182         "**kwargs.  [default: per-file auto-detection]"
183     ),
184 )
185 @click.option(
186     "--pyi",
187     is_flag=True,
188     help=(
189         "Format all input files like typing stubs regardless of file extension "
190         "(useful when piping source on standard input)."
191     ),
192 )
193 @click.option(
194     "-S",
195     "--skip-string-normalization",
196     is_flag=True,
197     help="Don't normalize string quotes or prefixes.",
198 )
199 @click.option(
200     "--check",
201     is_flag=True,
202     help=(
203         "Don't write the files back, just return the status.  Return code 0 "
204         "means nothing would change.  Return code 1 means some files would be "
205         "reformatted.  Return code 123 means there was an internal error."
206     ),
207 )
208 @click.option(
209     "--diff",
210     is_flag=True,
211     help="Don't write the files back, just output a diff for each file on stdout.",
212 )
213 @click.option(
214     "--fast/--safe",
215     is_flag=True,
216     help="If --fast given, skip temporary sanity checks. [default: --safe]",
217 )
218 @click.option(
219     "--include",
220     type=str,
221     default=DEFAULT_INCLUDES,
222     help=(
223         "A regular expression that matches files and directories that should be "
224         "included on recursive searches.  An empty value means all files are "
225         "included regardless of the name.  Use forward slashes for directories on "
226         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
227         "later."
228     ),
229     show_default=True,
230 )
231 @click.option(
232     "--exclude",
233     type=str,
234     default=DEFAULT_EXCLUDES,
235     help=(
236         "A regular expression that matches files and directories that should be "
237         "excluded on recursive searches.  An empty value means no paths are excluded. "
238         "Use forward slashes for directories on all platforms (Windows, too).  "
239         "Exclusions are calculated first, inclusions later."
240     ),
241     show_default=True,
242 )
243 @click.option(
244     "-q",
245     "--quiet",
246     is_flag=True,
247     help=(
248         "Don't emit non-error messages to stderr. Errors are still emitted, "
249         "silence those with 2>/dev/null."
250     ),
251 )
252 @click.option(
253     "-v",
254     "--verbose",
255     is_flag=True,
256     help=(
257         "Also emit messages to stderr about files that were not changed or were "
258         "ignored due to --exclude=."
259     ),
260 )
261 @click.version_option(version=__version__)
262 @click.argument(
263     "src",
264     nargs=-1,
265     type=click.Path(
266         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
267     ),
268     is_eager=True,
269 )
270 @click.option(
271     "--config",
272     type=click.Path(
273         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
274     ),
275     is_eager=True,
276     callback=read_pyproject_toml,
277     help="Read configuration from PATH.",
278 )
279 @click.pass_context
280 def main(
281     ctx: click.Context,
282     line_length: int,
283     check: bool,
284     diff: bool,
285     fast: bool,
286     pyi: bool,
287     py36: bool,
288     skip_string_normalization: bool,
289     quiet: bool,
290     verbose: bool,
291     include: str,
292     exclude: str,
293     src: Tuple[str],
294     config: Optional[str],
295 ) -> None:
296     """The uncompromising code formatter."""
297     write_back = WriteBack.from_configuration(check=check, diff=diff)
298     mode = FileMode.from_configuration(
299         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
300     )
301     if config and verbose:
302         out(f"Using configuration from {config}.", bold=False, fg="blue")
303     try:
304         include_regex = re_compile_maybe_verbose(include)
305     except re.error:
306         err(f"Invalid regular expression for include given: {include!r}")
307         ctx.exit(2)
308     try:
309         exclude_regex = re_compile_maybe_verbose(exclude)
310     except re.error:
311         err(f"Invalid regular expression for exclude given: {exclude!r}")
312         ctx.exit(2)
313     report = Report(check=check, quiet=quiet, verbose=verbose)
314     root = find_project_root(src)
315     sources: Set[Path] = set()
316     for s in src:
317         p = Path(s)
318         if p.is_dir():
319             sources.update(
320                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
321             )
322         elif p.is_file() or s == "-":
323             # if a file was explicitly given, we don't care about its extension
324             sources.add(p)
325         else:
326             err(f"invalid path: {s}")
327     if len(sources) == 0:
328         if verbose or not quiet:
329             out("No paths given. Nothing to do 😴")
330         ctx.exit(0)
331
332     if len(sources) == 1:
333         reformat_one(
334             src=sources.pop(),
335             line_length=line_length,
336             fast=fast,
337             write_back=write_back,
338             mode=mode,
339             report=report,
340         )
341     else:
342         loop = asyncio.get_event_loop()
343         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
344         try:
345             loop.run_until_complete(
346                 schedule_formatting(
347                     sources=sources,
348                     line_length=line_length,
349                     fast=fast,
350                     write_back=write_back,
351                     mode=mode,
352                     report=report,
353                     loop=loop,
354                     executor=executor,
355                 )
356             )
357         finally:
358             shutdown(loop)
359     if verbose or not quiet:
360         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
361         out(f"All done! {bang}")
362         click.secho(str(report), err=True)
363     ctx.exit(report.return_code)
364
365
366 def reformat_one(
367     src: Path,
368     line_length: int,
369     fast: bool,
370     write_back: WriteBack,
371     mode: FileMode,
372     report: "Report",
373 ) -> None:
374     """Reformat a single file under `src` without spawning child processes.
375
376     If `quiet` is True, non-error messages are not output. `line_length`,
377     `write_back`, `fast` and `pyi` options are passed to
378     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
379     """
380     try:
381         changed = Changed.NO
382         if not src.is_file() and str(src) == "-":
383             if format_stdin_to_stdout(
384                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
385             ):
386                 changed = Changed.YES
387         else:
388             cache: Cache = {}
389             if write_back != WriteBack.DIFF:
390                 cache = read_cache(line_length, mode)
391                 res_src = src.resolve()
392                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
393                     changed = Changed.CACHED
394             if changed is not Changed.CACHED and format_file_in_place(
395                 src,
396                 line_length=line_length,
397                 fast=fast,
398                 write_back=write_back,
399                 mode=mode,
400             ):
401                 changed = Changed.YES
402             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
403                 write_back is WriteBack.CHECK and changed is Changed.NO
404             ):
405                 write_cache(cache, [src], line_length, mode)
406         report.done(src, changed)
407     except Exception as exc:
408         report.failed(src, str(exc))
409
410
411 async def schedule_formatting(
412     sources: Set[Path],
413     line_length: int,
414     fast: bool,
415     write_back: WriteBack,
416     mode: FileMode,
417     report: "Report",
418     loop: BaseEventLoop,
419     executor: Executor,
420 ) -> None:
421     """Run formatting of `sources` in parallel using the provided `executor`.
422
423     (Use ProcessPoolExecutors for actual parallelism.)
424
425     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
426     :func:`format_file_in_place`.
427     """
428     cache: Cache = {}
429     if write_back != WriteBack.DIFF:
430         cache = read_cache(line_length, mode)
431         sources, cached = filter_cached(cache, sources)
432         for src in sorted(cached):
433             report.done(src, Changed.CACHED)
434     if not sources:
435         return
436
437     cancelled = []
438     sources_to_cache = []
439     lock = None
440     if write_back == WriteBack.DIFF:
441         # For diff output, we need locks to ensure we don't interleave output
442         # from different processes.
443         manager = Manager()
444         lock = manager.Lock()
445     tasks = {
446         loop.run_in_executor(
447             executor,
448             format_file_in_place,
449             src,
450             line_length,
451             fast,
452             write_back,
453             mode,
454             lock,
455         ): src
456         for src in sorted(sources)
457     }
458     pending: Iterable[asyncio.Task] = tasks.keys()
459     try:
460         loop.add_signal_handler(signal.SIGINT, cancel, pending)
461         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
462     except NotImplementedError:
463         # There are no good alternatives for these on Windows.
464         pass
465     while pending:
466         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
467         for task in done:
468             src = tasks.pop(task)
469             if task.cancelled():
470                 cancelled.append(task)
471             elif task.exception():
472                 report.failed(src, str(task.exception()))
473             else:
474                 changed = Changed.YES if task.result() else Changed.NO
475                 # If the file was written back or was successfully checked as
476                 # well-formatted, store this information in the cache.
477                 if write_back is WriteBack.YES or (
478                     write_back is WriteBack.CHECK and changed is Changed.NO
479                 ):
480                     sources_to_cache.append(src)
481                 report.done(src, changed)
482     if cancelled:
483         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
484     if sources_to_cache:
485         write_cache(cache, sources_to_cache, line_length, mode)
486
487
488 def format_file_in_place(
489     src: Path,
490     line_length: int,
491     fast: bool,
492     write_back: WriteBack = WriteBack.NO,
493     mode: FileMode = FileMode.AUTO_DETECT,
494     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
495 ) -> bool:
496     """Format file under `src` path. Return True if changed.
497
498     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
499     code to the file.
500     `line_length` and `fast` options are passed to :func:`format_file_contents`.
501     """
502     if src.suffix == ".pyi":
503         mode |= FileMode.PYI
504
505     then = datetime.utcfromtimestamp(src.stat().st_mtime)
506     with open(src, "rb") as buf:
507         src_contents, encoding, newline = decode_bytes(buf.read())
508     try:
509         dst_contents = format_file_contents(
510             src_contents, line_length=line_length, fast=fast, mode=mode
511         )
512     except NothingChanged:
513         return False
514
515     if write_back == write_back.YES:
516         with open(src, "w", encoding=encoding, newline=newline) as f:
517             f.write(dst_contents)
518     elif write_back == write_back.DIFF:
519         now = datetime.utcnow()
520         src_name = f"{src}\t{then} +0000"
521         dst_name = f"{src}\t{now} +0000"
522         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
523         if lock:
524             lock.acquire()
525         try:
526             f = io.TextIOWrapper(
527                 sys.stdout.buffer,
528                 encoding=encoding,
529                 newline=newline,
530                 write_through=True,
531             )
532             f.write(diff_contents)
533             f.detach()
534         finally:
535             if lock:
536                 lock.release()
537     return True
538
539
540 def format_stdin_to_stdout(
541     line_length: int,
542     fast: bool,
543     write_back: WriteBack = WriteBack.NO,
544     mode: FileMode = FileMode.AUTO_DETECT,
545 ) -> bool:
546     """Format file on stdin. Return True if changed.
547
548     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
549     write a diff to stdout.
550     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
551     :func:`format_file_contents`.
552     """
553     then = datetime.utcnow()
554     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
555     dst = src
556     try:
557         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
558         return True
559
560     except NothingChanged:
561         return False
562
563     finally:
564         f = io.TextIOWrapper(
565             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
566         )
567         if write_back == WriteBack.YES:
568             f.write(dst)
569         elif write_back == WriteBack.DIFF:
570             now = datetime.utcnow()
571             src_name = f"STDIN\t{then} +0000"
572             dst_name = f"STDOUT\t{now} +0000"
573             f.write(diff(src, dst, src_name, dst_name))
574         f.detach()
575
576
577 def format_file_contents(
578     src_contents: str,
579     *,
580     line_length: int,
581     fast: bool,
582     mode: FileMode = FileMode.AUTO_DETECT,
583 ) -> FileContent:
584     """Reformat contents a file and return new contents.
585
586     If `fast` is False, additionally confirm that the reformatted code is
587     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
588     `line_length` is passed to :func:`format_str`.
589     """
590     if src_contents.strip() == "":
591         raise NothingChanged
592
593     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
594     if src_contents == dst_contents:
595         raise NothingChanged
596
597     if not fast:
598         assert_equivalent(src_contents, dst_contents)
599         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
600     return dst_contents
601
602
603 def format_str(
604     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
605 ) -> FileContent:
606     """Reformat a string and return new contents.
607
608     `line_length` determines how many characters per line are allowed.
609     """
610     src_node = lib2to3_parse(src_contents)
611     dst_contents = ""
612     future_imports = get_future_imports(src_node)
613     is_pyi = bool(mode & FileMode.PYI)
614     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
615     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
616     normalize_fmt_off(src_node)
617     lines = LineGenerator(
618         remove_u_prefix=py36 or "unicode_literals" in future_imports,
619         is_pyi=is_pyi,
620         normalize_strings=normalize_strings,
621         allow_underscores=py36,
622     )
623     elt = EmptyLineTracker(is_pyi=is_pyi)
624     empty_line = Line()
625     after = 0
626     for current_line in lines.visit(src_node):
627         for _ in range(after):
628             dst_contents += str(empty_line)
629         before, after = elt.maybe_empty_lines(current_line)
630         for _ in range(before):
631             dst_contents += str(empty_line)
632         for line in split_line(current_line, line_length=line_length, py36=py36):
633             dst_contents += str(line)
634     return dst_contents
635
636
637 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
638     """Return a tuple of (decoded_contents, encoding, newline).
639
640     `newline` is either CRLF or LF but `decoded_contents` is decoded with
641     universal newlines (i.e. only contains LF).
642     """
643     srcbuf = io.BytesIO(src)
644     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
645     if not lines:
646         return "", encoding, "\n"
647
648     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
649     srcbuf.seek(0)
650     with io.TextIOWrapper(srcbuf, encoding) as tiow:
651         return tiow.read(), encoding, newline
652
653
654 GRAMMARS = [
655     pygram.python_grammar_no_print_statement_no_exec_statement,
656     pygram.python_grammar_no_print_statement,
657     pygram.python_grammar,
658 ]
659
660
661 def lib2to3_parse(src_txt: str) -> Node:
662     """Given a string with source, return the lib2to3 Node."""
663     grammar = pygram.python_grammar_no_print_statement
664     if src_txt[-1:] != "\n":
665         src_txt += "\n"
666     for grammar in GRAMMARS:
667         drv = driver.Driver(grammar, pytree.convert)
668         try:
669             result = drv.parse_string(src_txt, True)
670             break
671
672         except ParseError as pe:
673             lineno, column = pe.context[1]
674             lines = src_txt.splitlines()
675             try:
676                 faulty_line = lines[lineno - 1]
677             except IndexError:
678                 faulty_line = "<line number missing in source>"
679             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
680     else:
681         raise exc from None
682
683     if isinstance(result, Leaf):
684         result = Node(syms.file_input, [result])
685     return result
686
687
688 def lib2to3_unparse(node: Node) -> str:
689     """Given a lib2to3 node, return its string representation."""
690     code = str(node)
691     return code
692
693
694 T = TypeVar("T")
695
696
697 class Visitor(Generic[T]):
698     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
699
700     def visit(self, node: LN) -> Iterator[T]:
701         """Main method to visit `node` and its children.
702
703         It tries to find a `visit_*()` method for the given `node.type`, like
704         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
705         If no dedicated `visit_*()` method is found, chooses `visit_default()`
706         instead.
707
708         Then yields objects of type `T` from the selected visitor.
709         """
710         if node.type < 256:
711             name = token.tok_name[node.type]
712         else:
713             name = type_repr(node.type)
714         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
715
716     def visit_default(self, node: LN) -> Iterator[T]:
717         """Default `visit_*()` implementation. Recurses to children of `node`."""
718         if isinstance(node, Node):
719             for child in node.children:
720                 yield from self.visit(child)
721
722
723 @dataclass
724 class DebugVisitor(Visitor[T]):
725     tree_depth: int = 0
726
727     def visit_default(self, node: LN) -> Iterator[T]:
728         indent = " " * (2 * self.tree_depth)
729         if isinstance(node, Node):
730             _type = type_repr(node.type)
731             out(f"{indent}{_type}", fg="yellow")
732             self.tree_depth += 1
733             for child in node.children:
734                 yield from self.visit(child)
735
736             self.tree_depth -= 1
737             out(f"{indent}/{_type}", fg="yellow", bold=False)
738         else:
739             _type = token.tok_name.get(node.type, str(node.type))
740             out(f"{indent}{_type}", fg="blue", nl=False)
741             if node.prefix:
742                 # We don't have to handle prefixes for `Node` objects since
743                 # that delegates to the first child anyway.
744                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
745             out(f" {node.value!r}", fg="blue", bold=False)
746
747     @classmethod
748     def show(cls, code: Union[str, Leaf, Node]) -> None:
749         """Pretty-print the lib2to3 AST of a given string of `code`.
750
751         Convenience method for debugging.
752         """
753         v: DebugVisitor[None] = DebugVisitor()
754         if isinstance(code, str):
755             code = lib2to3_parse(code)
756         list(v.visit(code))
757
758
759 KEYWORDS = set(keyword.kwlist)
760 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
761 FLOW_CONTROL = {"return", "raise", "break", "continue"}
762 STATEMENT = {
763     syms.if_stmt,
764     syms.while_stmt,
765     syms.for_stmt,
766     syms.try_stmt,
767     syms.except_clause,
768     syms.with_stmt,
769     syms.funcdef,
770     syms.classdef,
771 }
772 STANDALONE_COMMENT = 153
773 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
774 LOGIC_OPERATORS = {"and", "or"}
775 COMPARATORS = {
776     token.LESS,
777     token.GREATER,
778     token.EQEQUAL,
779     token.NOTEQUAL,
780     token.LESSEQUAL,
781     token.GREATEREQUAL,
782 }
783 MATH_OPERATORS = {
784     token.VBAR,
785     token.CIRCUMFLEX,
786     token.AMPER,
787     token.LEFTSHIFT,
788     token.RIGHTSHIFT,
789     token.PLUS,
790     token.MINUS,
791     token.STAR,
792     token.SLASH,
793     token.DOUBLESLASH,
794     token.PERCENT,
795     token.AT,
796     token.TILDE,
797     token.DOUBLESTAR,
798 }
799 STARS = {token.STAR, token.DOUBLESTAR}
800 VARARGS_PARENTS = {
801     syms.arglist,
802     syms.argument,  # double star in arglist
803     syms.trailer,  # single argument to call
804     syms.typedargslist,
805     syms.varargslist,  # lambdas
806 }
807 UNPACKING_PARENTS = {
808     syms.atom,  # single element of a list or set literal
809     syms.dictsetmaker,
810     syms.listmaker,
811     syms.testlist_gexp,
812     syms.testlist_star_expr,
813 }
814 TEST_DESCENDANTS = {
815     syms.test,
816     syms.lambdef,
817     syms.or_test,
818     syms.and_test,
819     syms.not_test,
820     syms.comparison,
821     syms.star_expr,
822     syms.expr,
823     syms.xor_expr,
824     syms.and_expr,
825     syms.shift_expr,
826     syms.arith_expr,
827     syms.trailer,
828     syms.term,
829     syms.power,
830 }
831 ASSIGNMENTS = {
832     "=",
833     "+=",
834     "-=",
835     "*=",
836     "@=",
837     "/=",
838     "%=",
839     "&=",
840     "|=",
841     "^=",
842     "<<=",
843     ">>=",
844     "**=",
845     "//=",
846 }
847 COMPREHENSION_PRIORITY = 20
848 COMMA_PRIORITY = 18
849 TERNARY_PRIORITY = 16
850 LOGIC_PRIORITY = 14
851 STRING_PRIORITY = 12
852 COMPARATOR_PRIORITY = 10
853 MATH_PRIORITIES = {
854     token.VBAR: 9,
855     token.CIRCUMFLEX: 8,
856     token.AMPER: 7,
857     token.LEFTSHIFT: 6,
858     token.RIGHTSHIFT: 6,
859     token.PLUS: 5,
860     token.MINUS: 5,
861     token.STAR: 4,
862     token.SLASH: 4,
863     token.DOUBLESLASH: 4,
864     token.PERCENT: 4,
865     token.AT: 4,
866     token.TILDE: 3,
867     token.DOUBLESTAR: 2,
868 }
869 DOT_PRIORITY = 1
870
871
872 @dataclass
873 class BracketTracker:
874     """Keeps track of brackets on a line."""
875
876     depth: int = 0
877     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
878     delimiters: Dict[LeafID, Priority] = Factory(dict)
879     previous: Optional[Leaf] = None
880     _for_loop_depths: List[int] = Factory(list)
881     _lambda_argument_depths: List[int] = Factory(list)
882
883     def mark(self, leaf: Leaf) -> None:
884         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
885
886         All leaves receive an int `bracket_depth` field that stores how deep
887         within brackets a given leaf is. 0 means there are no enclosing brackets
888         that started on this line.
889
890         If a leaf is itself a closing bracket, it receives an `opening_bracket`
891         field that it forms a pair with. This is a one-directional link to
892         avoid reference cycles.
893
894         If a leaf is a delimiter (a token on which Black can split the line if
895         needed) and it's on depth 0, its `id()` is stored in the tracker's
896         `delimiters` field.
897         """
898         if leaf.type == token.COMMENT:
899             return
900
901         self.maybe_decrement_after_for_loop_variable(leaf)
902         self.maybe_decrement_after_lambda_arguments(leaf)
903         if leaf.type in CLOSING_BRACKETS:
904             self.depth -= 1
905             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
906             leaf.opening_bracket = opening_bracket
907         leaf.bracket_depth = self.depth
908         if self.depth == 0:
909             delim = is_split_before_delimiter(leaf, self.previous)
910             if delim and self.previous is not None:
911                 self.delimiters[id(self.previous)] = delim
912             else:
913                 delim = is_split_after_delimiter(leaf, self.previous)
914                 if delim:
915                     self.delimiters[id(leaf)] = delim
916         if leaf.type in OPENING_BRACKETS:
917             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
918             self.depth += 1
919         self.previous = leaf
920         self.maybe_increment_lambda_arguments(leaf)
921         self.maybe_increment_for_loop_variable(leaf)
922
923     def any_open_brackets(self) -> bool:
924         """Return True if there is an yet unmatched open bracket on the line."""
925         return bool(self.bracket_match)
926
927     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
928         """Return the highest priority of a delimiter found on the line.
929
930         Values are consistent with what `is_split_*_delimiter()` return.
931         Raises ValueError on no delimiters.
932         """
933         return max(v for k, v in self.delimiters.items() if k not in exclude)
934
935     def delimiter_count_with_priority(self, priority: int = 0) -> int:
936         """Return the number of delimiters with the given `priority`.
937
938         If no `priority` is passed, defaults to max priority on the line.
939         """
940         if not self.delimiters:
941             return 0
942
943         priority = priority or self.max_delimiter_priority()
944         return sum(1 for p in self.delimiters.values() if p == priority)
945
946     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
947         """In a for loop, or comprehension, the variables are often unpacks.
948
949         To avoid splitting on the comma in this situation, increase the depth of
950         tokens between `for` and `in`.
951         """
952         if leaf.type == token.NAME and leaf.value == "for":
953             self.depth += 1
954             self._for_loop_depths.append(self.depth)
955             return True
956
957         return False
958
959     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
960         """See `maybe_increment_for_loop_variable` above for explanation."""
961         if (
962             self._for_loop_depths
963             and self._for_loop_depths[-1] == self.depth
964             and leaf.type == token.NAME
965             and leaf.value == "in"
966         ):
967             self.depth -= 1
968             self._for_loop_depths.pop()
969             return True
970
971         return False
972
973     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
974         """In a lambda expression, there might be more than one argument.
975
976         To avoid splitting on the comma in this situation, increase the depth of
977         tokens between `lambda` and `:`.
978         """
979         if leaf.type == token.NAME and leaf.value == "lambda":
980             self.depth += 1
981             self._lambda_argument_depths.append(self.depth)
982             return True
983
984         return False
985
986     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
987         """See `maybe_increment_lambda_arguments` above for explanation."""
988         if (
989             self._lambda_argument_depths
990             and self._lambda_argument_depths[-1] == self.depth
991             and leaf.type == token.COLON
992         ):
993             self.depth -= 1
994             self._lambda_argument_depths.pop()
995             return True
996
997         return False
998
999     def get_open_lsqb(self) -> Optional[Leaf]:
1000         """Return the most recent opening square bracket (if any)."""
1001         return self.bracket_match.get((self.depth - 1, token.RSQB))
1002
1003
1004 @dataclass
1005 class Line:
1006     """Holds leaves and comments. Can be printed with `str(line)`."""
1007
1008     depth: int = 0
1009     leaves: List[Leaf] = Factory(list)
1010     comments: List[Tuple[Index, Leaf]] = Factory(list)
1011     bracket_tracker: BracketTracker = Factory(BracketTracker)
1012     inside_brackets: bool = False
1013     should_explode: bool = False
1014
1015     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1016         """Add a new `leaf` to the end of the line.
1017
1018         Unless `preformatted` is True, the `leaf` will receive a new consistent
1019         whitespace prefix and metadata applied by :class:`BracketTracker`.
1020         Trailing commas are maybe removed, unpacked for loop variables are
1021         demoted from being delimiters.
1022
1023         Inline comments are put aside.
1024         """
1025         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1026         if not has_value:
1027             return
1028
1029         if token.COLON == leaf.type and self.is_class_paren_empty:
1030             del self.leaves[-2:]
1031         if self.leaves and not preformatted:
1032             # Note: at this point leaf.prefix should be empty except for
1033             # imports, for which we only preserve newlines.
1034             leaf.prefix += whitespace(
1035                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1036             )
1037         if self.inside_brackets or not preformatted:
1038             self.bracket_tracker.mark(leaf)
1039             self.maybe_remove_trailing_comma(leaf)
1040         if not self.append_comment(leaf):
1041             self.leaves.append(leaf)
1042
1043     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1044         """Like :func:`append()` but disallow invalid standalone comment structure.
1045
1046         Raises ValueError when any `leaf` is appended after a standalone comment
1047         or when a standalone comment is not the first leaf on the line.
1048         """
1049         if self.bracket_tracker.depth == 0:
1050             if self.is_comment:
1051                 raise ValueError("cannot append to standalone comments")
1052
1053             if self.leaves and leaf.type == STANDALONE_COMMENT:
1054                 raise ValueError(
1055                     "cannot append standalone comments to a populated line"
1056                 )
1057
1058         self.append(leaf, preformatted=preformatted)
1059
1060     @property
1061     def is_comment(self) -> bool:
1062         """Is this line a standalone comment?"""
1063         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1064
1065     @property
1066     def is_decorator(self) -> bool:
1067         """Is this line a decorator?"""
1068         return bool(self) and self.leaves[0].type == token.AT
1069
1070     @property
1071     def is_import(self) -> bool:
1072         """Is this an import line?"""
1073         return bool(self) and is_import(self.leaves[0])
1074
1075     @property
1076     def is_class(self) -> bool:
1077         """Is this line a class definition?"""
1078         return (
1079             bool(self)
1080             and self.leaves[0].type == token.NAME
1081             and self.leaves[0].value == "class"
1082         )
1083
1084     @property
1085     def is_stub_class(self) -> bool:
1086         """Is this line a class definition with a body consisting only of "..."?"""
1087         return self.is_class and self.leaves[-3:] == [
1088             Leaf(token.DOT, ".") for _ in range(3)
1089         ]
1090
1091     @property
1092     def is_def(self) -> bool:
1093         """Is this a function definition? (Also returns True for async defs.)"""
1094         try:
1095             first_leaf = self.leaves[0]
1096         except IndexError:
1097             return False
1098
1099         try:
1100             second_leaf: Optional[Leaf] = self.leaves[1]
1101         except IndexError:
1102             second_leaf = None
1103         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1104             first_leaf.type == token.ASYNC
1105             and second_leaf is not None
1106             and second_leaf.type == token.NAME
1107             and second_leaf.value == "def"
1108         )
1109
1110     @property
1111     def is_class_paren_empty(self) -> bool:
1112         """Is this a class with no base classes but using parentheses?
1113
1114         Those are unnecessary and should be removed.
1115         """
1116         return (
1117             bool(self)
1118             and len(self.leaves) == 4
1119             and self.is_class
1120             and self.leaves[2].type == token.LPAR
1121             and self.leaves[2].value == "("
1122             and self.leaves[3].type == token.RPAR
1123             and self.leaves[3].value == ")"
1124         )
1125
1126     @property
1127     def is_triple_quoted_string(self) -> bool:
1128         """Is the line a triple quoted string?"""
1129         return (
1130             bool(self)
1131             and self.leaves[0].type == token.STRING
1132             and self.leaves[0].value.startswith(('"""', "'''"))
1133         )
1134
1135     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1136         """If so, needs to be split before emitting."""
1137         for leaf in self.leaves:
1138             if leaf.type == STANDALONE_COMMENT:
1139                 if leaf.bracket_depth <= depth_limit:
1140                     return True
1141
1142         return False
1143
1144     def contains_multiline_strings(self) -> bool:
1145         for leaf in self.leaves:
1146             if is_multiline_string(leaf):
1147                 return True
1148
1149         return False
1150
1151     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1152         """Remove trailing comma if there is one and it's safe."""
1153         if not (
1154             self.leaves
1155             and self.leaves[-1].type == token.COMMA
1156             and closing.type in CLOSING_BRACKETS
1157         ):
1158             return False
1159
1160         if closing.type == token.RBRACE:
1161             self.remove_trailing_comma()
1162             return True
1163
1164         if closing.type == token.RSQB:
1165             comma = self.leaves[-1]
1166             if comma.parent and comma.parent.type == syms.listmaker:
1167                 self.remove_trailing_comma()
1168                 return True
1169
1170         # For parens let's check if it's safe to remove the comma.
1171         # Imports are always safe.
1172         if self.is_import:
1173             self.remove_trailing_comma()
1174             return True
1175
1176         # Otherwise, if the trailing one is the only one, we might mistakenly
1177         # change a tuple into a different type by removing the comma.
1178         depth = closing.bracket_depth + 1
1179         commas = 0
1180         opening = closing.opening_bracket
1181         for _opening_index, leaf in enumerate(self.leaves):
1182             if leaf is opening:
1183                 break
1184
1185         else:
1186             return False
1187
1188         for leaf in self.leaves[_opening_index + 1 :]:
1189             if leaf is closing:
1190                 break
1191
1192             bracket_depth = leaf.bracket_depth
1193             if bracket_depth == depth and leaf.type == token.COMMA:
1194                 commas += 1
1195                 if leaf.parent and leaf.parent.type == syms.arglist:
1196                     commas += 1
1197                     break
1198
1199         if commas > 1:
1200             self.remove_trailing_comma()
1201             return True
1202
1203         return False
1204
1205     def append_comment(self, comment: Leaf) -> bool:
1206         """Add an inline or standalone comment to the line."""
1207         if (
1208             comment.type == STANDALONE_COMMENT
1209             and self.bracket_tracker.any_open_brackets()
1210         ):
1211             comment.prefix = ""
1212             return False
1213
1214         if comment.type != token.COMMENT:
1215             return False
1216
1217         after = len(self.leaves) - 1
1218         if after == -1:
1219             comment.type = STANDALONE_COMMENT
1220             comment.prefix = ""
1221             return False
1222
1223         else:
1224             self.comments.append((after, comment))
1225             return True
1226
1227     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1228         """Generate comments that should appear directly after `leaf`.
1229
1230         Provide a non-negative leaf `_index` to speed up the function.
1231         """
1232         if not self.comments:
1233             return
1234
1235         if _index == -1:
1236             for _index, _leaf in enumerate(self.leaves):
1237                 if leaf is _leaf:
1238                     break
1239
1240             else:
1241                 return
1242
1243         for index, comment_after in self.comments:
1244             if _index == index:
1245                 yield comment_after
1246
1247     def remove_trailing_comma(self) -> None:
1248         """Remove the trailing comma and moves the comments attached to it."""
1249         comma_index = len(self.leaves) - 1
1250         for i in range(len(self.comments)):
1251             comment_index, comment = self.comments[i]
1252             if comment_index == comma_index:
1253                 self.comments[i] = (comma_index - 1, comment)
1254         self.leaves.pop()
1255
1256     def is_complex_subscript(self, leaf: Leaf) -> bool:
1257         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1258         open_lsqb = self.bracket_tracker.get_open_lsqb()
1259         if open_lsqb is None:
1260             return False
1261
1262         subscript_start = open_lsqb.next_sibling
1263
1264         if isinstance(subscript_start, Node):
1265             if subscript_start.type == syms.listmaker:
1266                 return False
1267
1268             if subscript_start.type == syms.subscriptlist:
1269                 subscript_start = child_towards(subscript_start, leaf)
1270         return subscript_start is not None and any(
1271             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1272         )
1273
1274     def __str__(self) -> str:
1275         """Render the line."""
1276         if not self:
1277             return "\n"
1278
1279         indent = "    " * self.depth
1280         leaves = iter(self.leaves)
1281         first = next(leaves)
1282         res = f"{first.prefix}{indent}{first.value}"
1283         for leaf in leaves:
1284             res += str(leaf)
1285         for _, comment in self.comments:
1286             res += str(comment)
1287         return res + "\n"
1288
1289     def __bool__(self) -> bool:
1290         """Return True if the line has leaves or comments."""
1291         return bool(self.leaves or self.comments)
1292
1293
1294 @dataclass
1295 class EmptyLineTracker:
1296     """Provides a stateful method that returns the number of potential extra
1297     empty lines needed before and after the currently processed line.
1298
1299     Note: this tracker works on lines that haven't been split yet.  It assumes
1300     the prefix of the first leaf consists of optional newlines.  Those newlines
1301     are consumed by `maybe_empty_lines()` and included in the computation.
1302     """
1303
1304     is_pyi: bool = False
1305     previous_line: Optional[Line] = None
1306     previous_after: int = 0
1307     previous_defs: List[int] = Factory(list)
1308
1309     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1310         """Return the number of extra empty lines before and after the `current_line`.
1311
1312         This is for separating `def`, `async def` and `class` with extra empty
1313         lines (two on module-level).
1314         """
1315         before, after = self._maybe_empty_lines(current_line)
1316         before -= self.previous_after
1317         self.previous_after = after
1318         self.previous_line = current_line
1319         return before, after
1320
1321     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1322         max_allowed = 1
1323         if current_line.depth == 0:
1324             max_allowed = 1 if self.is_pyi else 2
1325         if current_line.leaves:
1326             # Consume the first leaf's extra newlines.
1327             first_leaf = current_line.leaves[0]
1328             before = first_leaf.prefix.count("\n")
1329             before = min(before, max_allowed)
1330             first_leaf.prefix = ""
1331         else:
1332             before = 0
1333         depth = current_line.depth
1334         while self.previous_defs and self.previous_defs[-1] >= depth:
1335             self.previous_defs.pop()
1336             if self.is_pyi:
1337                 before = 0 if depth else 1
1338             else:
1339                 before = 1 if depth else 2
1340         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1341             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1342
1343         if (
1344             self.previous_line
1345             and self.previous_line.is_import
1346             and not current_line.is_import
1347             and depth == self.previous_line.depth
1348         ):
1349             return (before or 1), 0
1350
1351         if (
1352             self.previous_line
1353             and self.previous_line.is_class
1354             and current_line.is_triple_quoted_string
1355         ):
1356             return before, 1
1357
1358         return before, 0
1359
1360     def _maybe_empty_lines_for_class_or_def(
1361         self, current_line: Line, before: int
1362     ) -> Tuple[int, int]:
1363         if not current_line.is_decorator:
1364             self.previous_defs.append(current_line.depth)
1365         if self.previous_line is None:
1366             # Don't insert empty lines before the first line in the file.
1367             return 0, 0
1368
1369         if self.previous_line.is_decorator:
1370             return 0, 0
1371
1372         if self.previous_line.depth < current_line.depth and (
1373             self.previous_line.is_class or self.previous_line.is_def
1374         ):
1375             return 0, 0
1376
1377         if (
1378             self.previous_line.is_comment
1379             and self.previous_line.depth == current_line.depth
1380             and before == 0
1381         ):
1382             return 0, 0
1383
1384         if self.is_pyi:
1385             if self.previous_line.depth > current_line.depth:
1386                 newlines = 1
1387             elif current_line.is_class or self.previous_line.is_class:
1388                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1389                     # No blank line between classes with an empty body
1390                     newlines = 0
1391                 else:
1392                     newlines = 1
1393             elif current_line.is_def and not self.previous_line.is_def:
1394                 # Blank line between a block of functions and a block of non-functions
1395                 newlines = 1
1396             else:
1397                 newlines = 0
1398         else:
1399             newlines = 2
1400         if current_line.depth and newlines:
1401             newlines -= 1
1402         return newlines, 0
1403
1404
1405 @dataclass
1406 class LineGenerator(Visitor[Line]):
1407     """Generates reformatted Line objects.  Empty lines are not emitted.
1408
1409     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1410     in ways that will no longer stringify to valid Python code on the tree.
1411     """
1412
1413     is_pyi: bool = False
1414     normalize_strings: bool = True
1415     current_line: Line = Factory(Line)
1416     remove_u_prefix: bool = False
1417     allow_underscores: bool = False
1418
1419     def line(self, indent: int = 0) -> Iterator[Line]:
1420         """Generate a line.
1421
1422         If the line is empty, only emit if it makes sense.
1423         If the line is too long, split it first and then generate.
1424
1425         If any lines were generated, set up a new current_line.
1426         """
1427         if not self.current_line:
1428             self.current_line.depth += indent
1429             return  # Line is empty, don't emit. Creating a new one unnecessary.
1430
1431         complete_line = self.current_line
1432         self.current_line = Line(depth=complete_line.depth + indent)
1433         yield complete_line
1434
1435     def visit_default(self, node: LN) -> Iterator[Line]:
1436         """Default `visit_*()` implementation. Recurses to children of `node`."""
1437         if isinstance(node, Leaf):
1438             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1439             for comment in generate_comments(node):
1440                 if any_open_brackets:
1441                     # any comment within brackets is subject to splitting
1442                     self.current_line.append(comment)
1443                 elif comment.type == token.COMMENT:
1444                     # regular trailing comment
1445                     self.current_line.append(comment)
1446                     yield from self.line()
1447
1448                 else:
1449                     # regular standalone comment
1450                     yield from self.line()
1451
1452                     self.current_line.append(comment)
1453                     yield from self.line()
1454
1455             normalize_prefix(node, inside_brackets=any_open_brackets)
1456             if self.normalize_strings and node.type == token.STRING:
1457                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1458                 normalize_string_quotes(node)
1459             if node.type == token.NUMBER:
1460                 normalize_numeric_literal(node, self.allow_underscores)
1461             if node.type not in WHITESPACE:
1462                 self.current_line.append(node)
1463         yield from super().visit_default(node)
1464
1465     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1466         """Increase indentation level, maybe yield a line."""
1467         # In blib2to3 INDENT never holds comments.
1468         yield from self.line(+1)
1469         yield from self.visit_default(node)
1470
1471     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1472         """Decrease indentation level, maybe yield a line."""
1473         # The current line might still wait for trailing comments.  At DEDENT time
1474         # there won't be any (they would be prefixes on the preceding NEWLINE).
1475         # Emit the line then.
1476         yield from self.line()
1477
1478         # While DEDENT has no value, its prefix may contain standalone comments
1479         # that belong to the current indentation level.  Get 'em.
1480         yield from self.visit_default(node)
1481
1482         # Finally, emit the dedent.
1483         yield from self.line(-1)
1484
1485     def visit_stmt(
1486         self, node: Node, keywords: Set[str], parens: Set[str]
1487     ) -> Iterator[Line]:
1488         """Visit a statement.
1489
1490         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1491         `def`, `with`, `class`, `assert` and assignments.
1492
1493         The relevant Python language `keywords` for a given statement will be
1494         NAME leaves within it. This methods puts those on a separate line.
1495
1496         `parens` holds a set of string leaf values immediately after which
1497         invisible parens should be put.
1498         """
1499         normalize_invisible_parens(node, parens_after=parens)
1500         for child in node.children:
1501             if child.type == token.NAME and child.value in keywords:  # type: ignore
1502                 yield from self.line()
1503
1504             yield from self.visit(child)
1505
1506     def visit_suite(self, node: Node) -> Iterator[Line]:
1507         """Visit a suite."""
1508         if self.is_pyi and is_stub_suite(node):
1509             yield from self.visit(node.children[2])
1510         else:
1511             yield from self.visit_default(node)
1512
1513     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1514         """Visit a statement without nested statements."""
1515         is_suite_like = node.parent and node.parent.type in STATEMENT
1516         if is_suite_like:
1517             if self.is_pyi and is_stub_body(node):
1518                 yield from self.visit_default(node)
1519             else:
1520                 yield from self.line(+1)
1521                 yield from self.visit_default(node)
1522                 yield from self.line(-1)
1523
1524         else:
1525             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1526                 yield from self.line()
1527             yield from self.visit_default(node)
1528
1529     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1530         """Visit `async def`, `async for`, `async with`."""
1531         yield from self.line()
1532
1533         children = iter(node.children)
1534         for child in children:
1535             yield from self.visit(child)
1536
1537             if child.type == token.ASYNC:
1538                 break
1539
1540         internal_stmt = next(children)
1541         for child in internal_stmt.children:
1542             yield from self.visit(child)
1543
1544     def visit_decorators(self, node: Node) -> Iterator[Line]:
1545         """Visit decorators."""
1546         for child in node.children:
1547             yield from self.line()
1548             yield from self.visit(child)
1549
1550     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1551         """Remove a semicolon and put the other statement on a separate line."""
1552         yield from self.line()
1553
1554     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1555         """End of file. Process outstanding comments and end with a newline."""
1556         yield from self.visit_default(leaf)
1557         yield from self.line()
1558
1559     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1560         if not self.current_line.bracket_tracker.any_open_brackets():
1561             yield from self.line()
1562         yield from self.visit_default(leaf)
1563
1564     def __attrs_post_init__(self) -> None:
1565         """You are in a twisty little maze of passages."""
1566         v = self.visit_stmt
1567         Ø: Set[str] = set()
1568         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1569         self.visit_if_stmt = partial(
1570             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1571         )
1572         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1573         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1574         self.visit_try_stmt = partial(
1575             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1576         )
1577         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1578         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1579         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1580         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1581         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1582         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1583         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1584         self.visit_async_funcdef = self.visit_async_stmt
1585         self.visit_decorated = self.visit_decorators
1586
1587
1588 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1589 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1590 OPENING_BRACKETS = set(BRACKET.keys())
1591 CLOSING_BRACKETS = set(BRACKET.values())
1592 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1593 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1594
1595
1596 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1597     """Return whitespace prefix if needed for the given `leaf`.
1598
1599     `complex_subscript` signals whether the given leaf is part of a subscription
1600     which has non-trivial arguments, like arithmetic expressions or function calls.
1601     """
1602     NO = ""
1603     SPACE = " "
1604     DOUBLESPACE = "  "
1605     t = leaf.type
1606     p = leaf.parent
1607     v = leaf.value
1608     if t in ALWAYS_NO_SPACE:
1609         return NO
1610
1611     if t == token.COMMENT:
1612         return DOUBLESPACE
1613
1614     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1615     if t == token.COLON and p.type not in {
1616         syms.subscript,
1617         syms.subscriptlist,
1618         syms.sliceop,
1619     }:
1620         return NO
1621
1622     prev = leaf.prev_sibling
1623     if not prev:
1624         prevp = preceding_leaf(p)
1625         if not prevp or prevp.type in OPENING_BRACKETS:
1626             return NO
1627
1628         if t == token.COLON:
1629             if prevp.type == token.COLON:
1630                 return NO
1631
1632             elif prevp.type != token.COMMA and not complex_subscript:
1633                 return NO
1634
1635             return SPACE
1636
1637         if prevp.type == token.EQUAL:
1638             if prevp.parent:
1639                 if prevp.parent.type in {
1640                     syms.arglist,
1641                     syms.argument,
1642                     syms.parameters,
1643                     syms.varargslist,
1644                 }:
1645                     return NO
1646
1647                 elif prevp.parent.type == syms.typedargslist:
1648                     # A bit hacky: if the equal sign has whitespace, it means we
1649                     # previously found it's a typed argument.  So, we're using
1650                     # that, too.
1651                     return prevp.prefix
1652
1653         elif prevp.type in STARS:
1654             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1655                 return NO
1656
1657         elif prevp.type == token.COLON:
1658             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1659                 return SPACE if complex_subscript else NO
1660
1661         elif (
1662             prevp.parent
1663             and prevp.parent.type == syms.factor
1664             and prevp.type in MATH_OPERATORS
1665         ):
1666             return NO
1667
1668         elif (
1669             prevp.type == token.RIGHTSHIFT
1670             and prevp.parent
1671             and prevp.parent.type == syms.shift_expr
1672             and prevp.prev_sibling
1673             and prevp.prev_sibling.type == token.NAME
1674             and prevp.prev_sibling.value == "print"  # type: ignore
1675         ):
1676             # Python 2 print chevron
1677             return NO
1678
1679     elif prev.type in OPENING_BRACKETS:
1680         return NO
1681
1682     if p.type in {syms.parameters, syms.arglist}:
1683         # untyped function signatures or calls
1684         if not prev or prev.type != token.COMMA:
1685             return NO
1686
1687     elif p.type == syms.varargslist:
1688         # lambdas
1689         if prev and prev.type != token.COMMA:
1690             return NO
1691
1692     elif p.type == syms.typedargslist:
1693         # typed function signatures
1694         if not prev:
1695             return NO
1696
1697         if t == token.EQUAL:
1698             if prev.type != syms.tname:
1699                 return NO
1700
1701         elif prev.type == token.EQUAL:
1702             # A bit hacky: if the equal sign has whitespace, it means we
1703             # previously found it's a typed argument.  So, we're using that, too.
1704             return prev.prefix
1705
1706         elif prev.type != token.COMMA:
1707             return NO
1708
1709     elif p.type == syms.tname:
1710         # type names
1711         if not prev:
1712             prevp = preceding_leaf(p)
1713             if not prevp or prevp.type != token.COMMA:
1714                 return NO
1715
1716     elif p.type == syms.trailer:
1717         # attributes and calls
1718         if t == token.LPAR or t == token.RPAR:
1719             return NO
1720
1721         if not prev:
1722             if t == token.DOT:
1723                 prevp = preceding_leaf(p)
1724                 if not prevp or prevp.type != token.NUMBER:
1725                     return NO
1726
1727             elif t == token.LSQB:
1728                 return NO
1729
1730         elif prev.type != token.COMMA:
1731             return NO
1732
1733     elif p.type == syms.argument:
1734         # single argument
1735         if t == token.EQUAL:
1736             return NO
1737
1738         if not prev:
1739             prevp = preceding_leaf(p)
1740             if not prevp or prevp.type == token.LPAR:
1741                 return NO
1742
1743         elif prev.type in {token.EQUAL} | STARS:
1744             return NO
1745
1746     elif p.type == syms.decorator:
1747         # decorators
1748         return NO
1749
1750     elif p.type == syms.dotted_name:
1751         if prev:
1752             return NO
1753
1754         prevp = preceding_leaf(p)
1755         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1756             return NO
1757
1758     elif p.type == syms.classdef:
1759         if t == token.LPAR:
1760             return NO
1761
1762         if prev and prev.type == token.LPAR:
1763             return NO
1764
1765     elif p.type in {syms.subscript, syms.sliceop}:
1766         # indexing
1767         if not prev:
1768             assert p.parent is not None, "subscripts are always parented"
1769             if p.parent.type == syms.subscriptlist:
1770                 return SPACE
1771
1772             return NO
1773
1774         elif not complex_subscript:
1775             return NO
1776
1777     elif p.type == syms.atom:
1778         if prev and t == token.DOT:
1779             # dots, but not the first one.
1780             return NO
1781
1782     elif p.type == syms.dictsetmaker:
1783         # dict unpacking
1784         if prev and prev.type == token.DOUBLESTAR:
1785             return NO
1786
1787     elif p.type in {syms.factor, syms.star_expr}:
1788         # unary ops
1789         if not prev:
1790             prevp = preceding_leaf(p)
1791             if not prevp or prevp.type in OPENING_BRACKETS:
1792                 return NO
1793
1794             prevp_parent = prevp.parent
1795             assert prevp_parent is not None
1796             if prevp.type == token.COLON and prevp_parent.type in {
1797                 syms.subscript,
1798                 syms.sliceop,
1799             }:
1800                 return NO
1801
1802             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1803                 return NO
1804
1805         elif t in {token.NAME, token.NUMBER, token.STRING}:
1806             return NO
1807
1808     elif p.type == syms.import_from:
1809         if t == token.DOT:
1810             if prev and prev.type == token.DOT:
1811                 return NO
1812
1813         elif t == token.NAME:
1814             if v == "import":
1815                 return SPACE
1816
1817             if prev and prev.type == token.DOT:
1818                 return NO
1819
1820     elif p.type == syms.sliceop:
1821         return NO
1822
1823     return SPACE
1824
1825
1826 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1827     """Return the first leaf that precedes `node`, if any."""
1828     while node:
1829         res = node.prev_sibling
1830         if res:
1831             if isinstance(res, Leaf):
1832                 return res
1833
1834             try:
1835                 return list(res.leaves())[-1]
1836
1837             except IndexError:
1838                 return None
1839
1840         node = node.parent
1841     return None
1842
1843
1844 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1845     """Return the child of `ancestor` that contains `descendant`."""
1846     node: Optional[LN] = descendant
1847     while node and node.parent != ancestor:
1848         node = node.parent
1849     return node
1850
1851
1852 def container_of(leaf: Leaf) -> LN:
1853     """Return `leaf` or one of its ancestors that is the topmost container of it.
1854
1855     By "container" we mean a node where `leaf` is the very first child.
1856     """
1857     same_prefix = leaf.prefix
1858     container: LN = leaf
1859     while container:
1860         parent = container.parent
1861         if parent is None:
1862             break
1863
1864         if parent.children[0].prefix != same_prefix:
1865             break
1866
1867         if parent.type == syms.file_input:
1868             break
1869
1870         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1871             break
1872
1873         container = parent
1874     return container
1875
1876
1877 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1878     """Return the priority of the `leaf` delimiter, given a line break after it.
1879
1880     The delimiter priorities returned here are from those delimiters that would
1881     cause a line break after themselves.
1882
1883     Higher numbers are higher priority.
1884     """
1885     if leaf.type == token.COMMA:
1886         return COMMA_PRIORITY
1887
1888     return 0
1889
1890
1891 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1892     """Return the priority of the `leaf` delimiter, given a line before after it.
1893
1894     The delimiter priorities returned here are from those delimiters that would
1895     cause a line break before themselves.
1896
1897     Higher numbers are higher priority.
1898     """
1899     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1900         # * and ** might also be MATH_OPERATORS but in this case they are not.
1901         # Don't treat them as a delimiter.
1902         return 0
1903
1904     if (
1905         leaf.type == token.DOT
1906         and leaf.parent
1907         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1908         and (previous is None or previous.type in CLOSING_BRACKETS)
1909     ):
1910         return DOT_PRIORITY
1911
1912     if (
1913         leaf.type in MATH_OPERATORS
1914         and leaf.parent
1915         and leaf.parent.type not in {syms.factor, syms.star_expr}
1916     ):
1917         return MATH_PRIORITIES[leaf.type]
1918
1919     if leaf.type in COMPARATORS:
1920         return COMPARATOR_PRIORITY
1921
1922     if (
1923         leaf.type == token.STRING
1924         and previous is not None
1925         and previous.type == token.STRING
1926     ):
1927         return STRING_PRIORITY
1928
1929     if leaf.type != token.NAME:
1930         return 0
1931
1932     if (
1933         leaf.value == "for"
1934         and leaf.parent
1935         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1936     ):
1937         return COMPREHENSION_PRIORITY
1938
1939     if (
1940         leaf.value == "if"
1941         and leaf.parent
1942         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1943     ):
1944         return COMPREHENSION_PRIORITY
1945
1946     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1947         return TERNARY_PRIORITY
1948
1949     if leaf.value == "is":
1950         return COMPARATOR_PRIORITY
1951
1952     if (
1953         leaf.value == "in"
1954         and leaf.parent
1955         and leaf.parent.type in {syms.comp_op, syms.comparison}
1956         and not (
1957             previous is not None
1958             and previous.type == token.NAME
1959             and previous.value == "not"
1960         )
1961     ):
1962         return COMPARATOR_PRIORITY
1963
1964     if (
1965         leaf.value == "not"
1966         and leaf.parent
1967         and leaf.parent.type == syms.comp_op
1968         and not (
1969             previous is not None
1970             and previous.type == token.NAME
1971             and previous.value == "is"
1972         )
1973     ):
1974         return COMPARATOR_PRIORITY
1975
1976     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1977         return LOGIC_PRIORITY
1978
1979     return 0
1980
1981
1982 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
1983 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
1984
1985
1986 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1987     """Clean the prefix of the `leaf` and generate comments from it, if any.
1988
1989     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1990     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1991     move because it does away with modifying the grammar to include all the
1992     possible places in which comments can be placed.
1993
1994     The sad consequence for us though is that comments don't "belong" anywhere.
1995     This is why this function generates simple parentless Leaf objects for
1996     comments.  We simply don't know what the correct parent should be.
1997
1998     No matter though, we can live without this.  We really only need to
1999     differentiate between inline and standalone comments.  The latter don't
2000     share the line with any code.
2001
2002     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2003     are emitted with a fake STANDALONE_COMMENT token identifier.
2004     """
2005     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2006         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2007
2008
2009 @dataclass
2010 class ProtoComment:
2011     type: int  # token.COMMENT or STANDALONE_COMMENT
2012     value: str  # content of the comment
2013     newlines: int  # how many newlines before the comment
2014     consumed: int  # how many characters of the original leaf's prefix did we consume
2015
2016
2017 @lru_cache(maxsize=4096)
2018 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2019     result: List[ProtoComment] = []
2020     if not prefix or "#" not in prefix:
2021         return result
2022
2023     consumed = 0
2024     nlines = 0
2025     for index, line in enumerate(prefix.split("\n")):
2026         consumed += len(line) + 1  # adding the length of the split '\n'
2027         line = line.lstrip()
2028         if not line:
2029             nlines += 1
2030         if not line.startswith("#"):
2031             continue
2032
2033         if index == 0 and not is_endmarker:
2034             comment_type = token.COMMENT  # simple trailing comment
2035         else:
2036             comment_type = STANDALONE_COMMENT
2037         comment = make_comment(line)
2038         result.append(
2039             ProtoComment(
2040                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2041             )
2042         )
2043         nlines = 0
2044     return result
2045
2046
2047 def make_comment(content: str) -> str:
2048     """Return a consistently formatted comment from the given `content` string.
2049
2050     All comments (except for "##", "#!", "#:") should have a single space between
2051     the hash sign and the content.
2052
2053     If `content` didn't start with a hash sign, one is provided.
2054     """
2055     content = content.rstrip()
2056     if not content:
2057         return "#"
2058
2059     if content[0] == "#":
2060         content = content[1:]
2061     if content and content[0] not in " !:#":
2062         content = " " + content
2063     return "#" + content
2064
2065
2066 def split_line(
2067     line: Line, line_length: int, inner: bool = False, py36: bool = False
2068 ) -> Iterator[Line]:
2069     """Split a `line` into potentially many lines.
2070
2071     They should fit in the allotted `line_length` but might not be able to.
2072     `inner` signifies that there were a pair of brackets somewhere around the
2073     current `line`, possibly transitively. This means we can fallback to splitting
2074     by delimiters if the LHS/RHS don't yield any results.
2075
2076     If `py36` is True, splitting may generate syntax that is only compatible
2077     with Python 3.6 and later.
2078     """
2079     if line.is_comment:
2080         yield line
2081         return
2082
2083     line_str = str(line).strip("\n")
2084     if not line.should_explode and is_line_short_enough(
2085         line, line_length=line_length, line_str=line_str
2086     ):
2087         yield line
2088         return
2089
2090     split_funcs: List[SplitFunc]
2091     if line.is_def:
2092         split_funcs = [left_hand_split]
2093     else:
2094
2095         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2096             for omit in generate_trailers_to_omit(line, line_length):
2097                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2098                 if is_line_short_enough(lines[0], line_length=line_length):
2099                     yield from lines
2100                     return
2101
2102             # All splits failed, best effort split with no omits.
2103             # This mostly happens to multiline strings that are by definition
2104             # reported as not fitting a single line.
2105             yield from right_hand_split(line, py36)
2106
2107         if line.inside_brackets:
2108             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2109         else:
2110             split_funcs = [rhs]
2111     for split_func in split_funcs:
2112         # We are accumulating lines in `result` because we might want to abort
2113         # mission and return the original line in the end, or attempt a different
2114         # split altogether.
2115         result: List[Line] = []
2116         try:
2117             for l in split_func(line, py36):
2118                 if str(l).strip("\n") == line_str:
2119                     raise CannotSplit("Split function returned an unchanged result")
2120
2121                 result.extend(
2122                     split_line(l, line_length=line_length, inner=True, py36=py36)
2123                 )
2124         except CannotSplit as cs:
2125             continue
2126
2127         else:
2128             yield from result
2129             break
2130
2131     else:
2132         yield line
2133
2134
2135 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2136     """Split line into many lines, starting with the first matching bracket pair.
2137
2138     Note: this usually looks weird, only use this for function definitions.
2139     Prefer RHS otherwise.  This is why this function is not symmetrical with
2140     :func:`right_hand_split` which also handles optional parentheses.
2141     """
2142     head = Line(depth=line.depth)
2143     body = Line(depth=line.depth + 1, inside_brackets=True)
2144     tail = Line(depth=line.depth)
2145     tail_leaves: List[Leaf] = []
2146     body_leaves: List[Leaf] = []
2147     head_leaves: List[Leaf] = []
2148     current_leaves = head_leaves
2149     matching_bracket = None
2150     for leaf in line.leaves:
2151         if (
2152             current_leaves is body_leaves
2153             and leaf.type in CLOSING_BRACKETS
2154             and leaf.opening_bracket is matching_bracket
2155         ):
2156             current_leaves = tail_leaves if body_leaves else head_leaves
2157         current_leaves.append(leaf)
2158         if current_leaves is head_leaves:
2159             if leaf.type in OPENING_BRACKETS:
2160                 matching_bracket = leaf
2161                 current_leaves = body_leaves
2162     # Since body is a new indent level, remove spurious leading whitespace.
2163     if body_leaves:
2164         normalize_prefix(body_leaves[0], inside_brackets=True)
2165     # Build the new lines.
2166     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2167         for leaf in leaves:
2168             result.append(leaf, preformatted=True)
2169             for comment_after in line.comments_after(leaf):
2170                 result.append(comment_after, preformatted=True)
2171     bracket_split_succeeded_or_raise(head, body, tail)
2172     for result in (head, body, tail):
2173         if result:
2174             yield result
2175
2176
2177 def right_hand_split(
2178     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2179 ) -> Iterator[Line]:
2180     """Split line into many lines, starting with the last matching bracket pair.
2181
2182     If the split was by optional parentheses, attempt splitting without them, too.
2183     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2184     this split.
2185
2186     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2187     """
2188     head = Line(depth=line.depth)
2189     body = Line(depth=line.depth + 1, inside_brackets=True)
2190     tail = Line(depth=line.depth)
2191     tail_leaves: List[Leaf] = []
2192     body_leaves: List[Leaf] = []
2193     head_leaves: List[Leaf] = []
2194     current_leaves = tail_leaves
2195     opening_bracket = None
2196     closing_bracket = None
2197     for leaf in reversed(line.leaves):
2198         if current_leaves is body_leaves:
2199             if leaf is opening_bracket:
2200                 current_leaves = head_leaves if body_leaves else tail_leaves
2201         current_leaves.append(leaf)
2202         if current_leaves is tail_leaves:
2203             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2204                 opening_bracket = leaf.opening_bracket
2205                 closing_bracket = leaf
2206                 current_leaves = body_leaves
2207     tail_leaves.reverse()
2208     body_leaves.reverse()
2209     head_leaves.reverse()
2210     # Since body is a new indent level, remove spurious leading whitespace.
2211     if body_leaves:
2212         normalize_prefix(body_leaves[0], inside_brackets=True)
2213     if not head_leaves:
2214         # No `head` means the split failed. Either `tail` has all content or
2215         # the matching `opening_bracket` wasn't available on `line` anymore.
2216         raise CannotSplit("No brackets found")
2217
2218     # Build the new lines.
2219     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2220         for leaf in leaves:
2221             result.append(leaf, preformatted=True)
2222             for comment_after in line.comments_after(leaf):
2223                 result.append(comment_after, preformatted=True)
2224     assert opening_bracket and closing_bracket
2225     body.should_explode = should_explode(body, opening_bracket)
2226     bracket_split_succeeded_or_raise(head, body, tail)
2227     if (
2228         # the body shouldn't be exploded
2229         not body.should_explode
2230         # the opening bracket is an optional paren
2231         and opening_bracket.type == token.LPAR
2232         and not opening_bracket.value
2233         # the closing bracket is an optional paren
2234         and closing_bracket.type == token.RPAR
2235         and not closing_bracket.value
2236         # it's not an import (optional parens are the only thing we can split on
2237         # in this case; attempting a split without them is a waste of time)
2238         and not line.is_import
2239         # there are no standalone comments in the body
2240         and not body.contains_standalone_comments(0)
2241         # and we can actually remove the parens
2242         and can_omit_invisible_parens(body, line_length)
2243     ):
2244         omit = {id(closing_bracket), *omit}
2245         try:
2246             yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2247             return
2248
2249         except CannotSplit:
2250             if not (
2251                 can_be_split(body)
2252                 or is_line_short_enough(body, line_length=line_length)
2253             ):
2254                 raise CannotSplit(
2255                     "Splitting failed, body is still too long and can't be split."
2256                 )
2257
2258             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2259                 raise CannotSplit(
2260                     "The current optional pair of parentheses is bound to fail to "
2261                     "satisfy the splitting algorithm because the head or the tail "
2262                     "contains multiline strings which by definition never fit one "
2263                     "line."
2264                 )
2265
2266     ensure_visible(opening_bracket)
2267     ensure_visible(closing_bracket)
2268     for result in (head, body, tail):
2269         if result:
2270             yield result
2271
2272
2273 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2274     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2275
2276     Do nothing otherwise.
2277
2278     A left- or right-hand split is based on a pair of brackets. Content before
2279     (and including) the opening bracket is left on one line, content inside the
2280     brackets is put on a separate line, and finally content starting with and
2281     following the closing bracket is put on a separate line.
2282
2283     Those are called `head`, `body`, and `tail`, respectively. If the split
2284     produced the same line (all content in `head`) or ended up with an empty `body`
2285     and the `tail` is just the closing bracket, then it's considered failed.
2286     """
2287     tail_len = len(str(tail).strip())
2288     if not body:
2289         if tail_len == 0:
2290             raise CannotSplit("Splitting brackets produced the same line")
2291
2292         elif tail_len < 3:
2293             raise CannotSplit(
2294                 f"Splitting brackets on an empty body to save "
2295                 f"{tail_len} characters is not worth it"
2296             )
2297
2298
2299 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2300     """Normalize prefix of the first leaf in every line returned by `split_func`.
2301
2302     This is a decorator over relevant split functions.
2303     """
2304
2305     @wraps(split_func)
2306     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2307         for l in split_func(line, py36):
2308             normalize_prefix(l.leaves[0], inside_brackets=True)
2309             yield l
2310
2311     return split_wrapper
2312
2313
2314 @dont_increase_indentation
2315 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2316     """Split according to delimiters of the highest priority.
2317
2318     If `py36` is True, the split will add trailing commas also in function
2319     signatures that contain `*` and `**`.
2320     """
2321     try:
2322         last_leaf = line.leaves[-1]
2323     except IndexError:
2324         raise CannotSplit("Line empty")
2325
2326     bt = line.bracket_tracker
2327     try:
2328         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2329     except ValueError:
2330         raise CannotSplit("No delimiters found")
2331
2332     if delimiter_priority == DOT_PRIORITY:
2333         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2334             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2335
2336     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2337     lowest_depth = sys.maxsize
2338     trailing_comma_safe = True
2339
2340     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2341         """Append `leaf` to current line or to new line if appending impossible."""
2342         nonlocal current_line
2343         try:
2344             current_line.append_safe(leaf, preformatted=True)
2345         except ValueError as ve:
2346             yield current_line
2347
2348             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2349             current_line.append(leaf)
2350
2351     for index, leaf in enumerate(line.leaves):
2352         yield from append_to_line(leaf)
2353
2354         for comment_after in line.comments_after(leaf, index):
2355             yield from append_to_line(comment_after)
2356
2357         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2358         if leaf.bracket_depth == lowest_depth and is_vararg(
2359             leaf, within=VARARGS_PARENTS
2360         ):
2361             trailing_comma_safe = trailing_comma_safe and py36
2362         leaf_priority = bt.delimiters.get(id(leaf))
2363         if leaf_priority == delimiter_priority:
2364             yield current_line
2365
2366             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2367     if current_line:
2368         if (
2369             trailing_comma_safe
2370             and delimiter_priority == COMMA_PRIORITY
2371             and current_line.leaves[-1].type != token.COMMA
2372             and current_line.leaves[-1].type != STANDALONE_COMMENT
2373         ):
2374             current_line.append(Leaf(token.COMMA, ","))
2375         yield current_line
2376
2377
2378 @dont_increase_indentation
2379 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2380     """Split standalone comments from the rest of the line."""
2381     if not line.contains_standalone_comments(0):
2382         raise CannotSplit("Line does not have any standalone comments")
2383
2384     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2385
2386     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2387         """Append `leaf` to current line or to new line if appending impossible."""
2388         nonlocal current_line
2389         try:
2390             current_line.append_safe(leaf, preformatted=True)
2391         except ValueError as ve:
2392             yield current_line
2393
2394             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2395             current_line.append(leaf)
2396
2397     for index, leaf in enumerate(line.leaves):
2398         yield from append_to_line(leaf)
2399
2400         for comment_after in line.comments_after(leaf, index):
2401             yield from append_to_line(comment_after)
2402
2403     if current_line:
2404         yield current_line
2405
2406
2407 def is_import(leaf: Leaf) -> bool:
2408     """Return True if the given leaf starts an import statement."""
2409     p = leaf.parent
2410     t = leaf.type
2411     v = leaf.value
2412     return bool(
2413         t == token.NAME
2414         and (
2415             (v == "import" and p and p.type == syms.import_name)
2416             or (v == "from" and p and p.type == syms.import_from)
2417         )
2418     )
2419
2420
2421 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2422     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2423     else.
2424
2425     Note: don't use backslashes for formatting or you'll lose your voting rights.
2426     """
2427     if not inside_brackets:
2428         spl = leaf.prefix.split("#")
2429         if "\\" not in spl[0]:
2430             nl_count = spl[-1].count("\n")
2431             if len(spl) > 1:
2432                 nl_count -= 1
2433             leaf.prefix = "\n" * nl_count
2434             return
2435
2436     leaf.prefix = ""
2437
2438
2439 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2440     """Make all string prefixes lowercase.
2441
2442     If remove_u_prefix is given, also removes any u prefix from the string.
2443
2444     Note: Mutates its argument.
2445     """
2446     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2447     assert match is not None, f"failed to match string {leaf.value!r}"
2448     orig_prefix = match.group(1)
2449     new_prefix = orig_prefix.lower()
2450     if remove_u_prefix:
2451         new_prefix = new_prefix.replace("u", "")
2452     leaf.value = f"{new_prefix}{match.group(2)}"
2453
2454
2455 def normalize_string_quotes(leaf: Leaf) -> None:
2456     """Prefer double quotes but only if it doesn't cause more escaping.
2457
2458     Adds or removes backslashes as appropriate. Doesn't parse and fix
2459     strings nested in f-strings (yet).
2460
2461     Note: Mutates its argument.
2462     """
2463     value = leaf.value.lstrip("furbFURB")
2464     if value[:3] == '"""':
2465         return
2466
2467     elif value[:3] == "'''":
2468         orig_quote = "'''"
2469         new_quote = '"""'
2470     elif value[0] == '"':
2471         orig_quote = '"'
2472         new_quote = "'"
2473     else:
2474         orig_quote = "'"
2475         new_quote = '"'
2476     first_quote_pos = leaf.value.find(orig_quote)
2477     if first_quote_pos == -1:
2478         return  # There's an internal error
2479
2480     prefix = leaf.value[:first_quote_pos]
2481     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2482     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2483     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2484     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2485     if "r" in prefix.casefold():
2486         if unescaped_new_quote.search(body):
2487             # There's at least one unescaped new_quote in this raw string
2488             # so converting is impossible
2489             return
2490
2491         # Do not introduce or remove backslashes in raw strings
2492         new_body = body
2493     else:
2494         # remove unnecessary escapes
2495         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2496         if body != new_body:
2497             # Consider the string without unnecessary escapes as the original
2498             body = new_body
2499             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2500         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2501         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2502     if "f" in prefix.casefold():
2503         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2504         for m in matches:
2505             if "\\" in str(m):
2506                 # Do not introduce backslashes in interpolated expressions
2507                 return
2508     if new_quote == '"""' and new_body[-1:] == '"':
2509         # edge case:
2510         new_body = new_body[:-1] + '\\"'
2511     orig_escape_count = body.count("\\")
2512     new_escape_count = new_body.count("\\")
2513     if new_escape_count > orig_escape_count:
2514         return  # Do not introduce more escaping
2515
2516     if new_escape_count == orig_escape_count and orig_quote == '"':
2517         return  # Prefer double quotes
2518
2519     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2520
2521
2522 def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
2523     """Normalizes numeric (float, int, and complex) literals.
2524
2525     All letters used in the representation are normalized to lowercase (except
2526     in Python 2 long literals), and long number literals are split using underscores.
2527     """
2528     text = leaf.value.lower()
2529     if text.startswith(("0o", "0x", "0b")):
2530         # Leave octal, hex, and binary literals alone.
2531         pass
2532     elif "e" in text:
2533         before, after = text.split("e")
2534         sign = ""
2535         if after.startswith("-"):
2536             after = after[1:]
2537             sign = "-"
2538         elif after.startswith("+"):
2539             after = after[1:]
2540         before = format_float_or_int_string(before, allow_underscores)
2541         after = format_int_string(after, allow_underscores)
2542         text = f"{before}e{sign}{after}"
2543     elif text.endswith(("j", "l")):
2544         number = text[:-1]
2545         suffix = text[-1]
2546         # Capitalize in "2L" because "l" looks too similar to "1".
2547         if suffix == "l":
2548             suffix = "L"
2549         text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
2550     else:
2551         text = format_float_or_int_string(text, allow_underscores)
2552     leaf.value = text
2553
2554
2555 def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
2556     """Formats a float string like "1.0"."""
2557     if "." not in text:
2558         return format_int_string(text, allow_underscores)
2559
2560     before, after = text.split(".")
2561     before = format_int_string(before, allow_underscores) if before else "0"
2562     if after:
2563         after = format_int_string(after, allow_underscores, count_from_end=False)
2564     else:
2565         after = "0"
2566     return f"{before}.{after}"
2567
2568
2569 def format_int_string(
2570     text: str, allow_underscores: bool, count_from_end: bool = True
2571 ) -> str:
2572     """Normalizes underscores in a string to e.g. 1_000_000.
2573
2574     Input must be a string of digits and optional underscores.
2575     If count_from_end is False, we add underscores after groups of three digits
2576     counting from the beginning instead of the end of the strings. This is used
2577     for the fractional part of float literals.
2578     """
2579     if not allow_underscores:
2580         return text
2581
2582     text = text.replace("_", "")
2583     if len(text) <= 6:
2584         # No underscores for numbers <= 6 digits long.
2585         return text
2586
2587     if count_from_end:
2588         # Avoid removing leading zeros, which are important if we're formatting
2589         # part of a number like "0.001".
2590         return format(int("1" + text), "3_")[1:].lstrip("_")
2591     else:
2592         return "_".join(text[i : i + 3] for i in range(0, len(text), 3))
2593
2594
2595 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2596     """Make existing optional parentheses invisible or create new ones.
2597
2598     `parens_after` is a set of string leaf values immeditely after which parens
2599     should be put.
2600
2601     Standardizes on visible parentheses for single-element tuples, and keeps
2602     existing visible parentheses for other tuples and generator expressions.
2603     """
2604     for pc in list_comments(node.prefix, is_endmarker=False):
2605         if pc.value in FMT_OFF:
2606             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2607             return
2608
2609     check_lpar = False
2610     for index, child in enumerate(list(node.children)):
2611         if check_lpar:
2612             if child.type == syms.atom:
2613                 if maybe_make_parens_invisible_in_atom(child):
2614                     lpar = Leaf(token.LPAR, "")
2615                     rpar = Leaf(token.RPAR, "")
2616                     index = child.remove() or 0
2617                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2618             elif is_one_tuple(child):
2619                 # wrap child in visible parentheses
2620                 lpar = Leaf(token.LPAR, "(")
2621                 rpar = Leaf(token.RPAR, ")")
2622                 child.remove()
2623                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2624             elif node.type == syms.import_from:
2625                 # "import from" nodes store parentheses directly as part of
2626                 # the statement
2627                 if child.type == token.LPAR:
2628                     # make parentheses invisible
2629                     child.value = ""  # type: ignore
2630                     node.children[-1].value = ""  # type: ignore
2631                 elif child.type != token.STAR:
2632                     # insert invisible parentheses
2633                     node.insert_child(index, Leaf(token.LPAR, ""))
2634                     node.append_child(Leaf(token.RPAR, ""))
2635                 break
2636
2637             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2638                 # wrap child in invisible parentheses
2639                 lpar = Leaf(token.LPAR, "")
2640                 rpar = Leaf(token.RPAR, "")
2641                 index = child.remove() or 0
2642                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2643
2644         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2645
2646
2647 def normalize_fmt_off(node: Node) -> None:
2648     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2649     try_again = True
2650     while try_again:
2651         try_again = convert_one_fmt_off_pair(node)
2652
2653
2654 def convert_one_fmt_off_pair(node: Node) -> bool:
2655     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2656
2657     Returns True if a pair was converted.
2658     """
2659     for leaf in node.leaves():
2660         previous_consumed = 0
2661         for comment in list_comments(leaf.prefix, is_endmarker=False):
2662             if comment.value in FMT_OFF:
2663                 # We only want standalone comments. If there's no previous leaf or
2664                 # the previous leaf is indentation, it's a standalone comment in
2665                 # disguise.
2666                 if comment.type != STANDALONE_COMMENT:
2667                     prev = preceding_leaf(leaf)
2668                     if prev and prev.type not in WHITESPACE:
2669                         continue
2670
2671                 ignored_nodes = list(generate_ignored_nodes(leaf))
2672                 if not ignored_nodes:
2673                     continue
2674
2675                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2676                 parent = first.parent
2677                 prefix = first.prefix
2678                 first.prefix = prefix[comment.consumed :]
2679                 hidden_value = (
2680                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2681                 )
2682                 if hidden_value.endswith("\n"):
2683                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2684                     # leaf (possibly followed by a DEDENT).
2685                     hidden_value = hidden_value[:-1]
2686                 first_idx = None
2687                 for ignored in ignored_nodes:
2688                     index = ignored.remove()
2689                     if first_idx is None:
2690                         first_idx = index
2691                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2692                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2693                 parent.insert_child(
2694                     first_idx,
2695                     Leaf(
2696                         STANDALONE_COMMENT,
2697                         hidden_value,
2698                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2699                     ),
2700                 )
2701                 return True
2702
2703             previous_consumed = comment.consumed
2704
2705     return False
2706
2707
2708 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2709     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2710
2711     Stops at the end of the block.
2712     """
2713     container: Optional[LN] = container_of(leaf)
2714     while container is not None and container.type != token.ENDMARKER:
2715         for comment in list_comments(container.prefix, is_endmarker=False):
2716             if comment.value in FMT_ON:
2717                 return
2718
2719         yield container
2720
2721         container = container.next_sibling
2722
2723
2724 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2725     """If it's safe, make the parens in the atom `node` invisible, recursively.
2726
2727     Returns whether the node should itself be wrapped in invisible parentheses.
2728
2729     """
2730     if (
2731         node.type != syms.atom
2732         or is_empty_tuple(node)
2733         or is_one_tuple(node)
2734         or is_yield(node)
2735         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2736     ):
2737         return False
2738
2739     first = node.children[0]
2740     last = node.children[-1]
2741     if first.type == token.LPAR and last.type == token.RPAR:
2742         # make parentheses invisible
2743         first.value = ""  # type: ignore
2744         last.value = ""  # type: ignore
2745         if len(node.children) > 1:
2746             maybe_make_parens_invisible_in_atom(node.children[1])
2747         return False
2748
2749     return True
2750
2751
2752 def is_empty_tuple(node: LN) -> bool:
2753     """Return True if `node` holds an empty tuple."""
2754     return (
2755         node.type == syms.atom
2756         and len(node.children) == 2
2757         and node.children[0].type == token.LPAR
2758         and node.children[1].type == token.RPAR
2759     )
2760
2761
2762 def is_one_tuple(node: LN) -> bool:
2763     """Return True if `node` holds a tuple with one element, with or without parens."""
2764     if node.type == syms.atom:
2765         if len(node.children) != 3:
2766             return False
2767
2768         lpar, gexp, rpar = node.children
2769         if not (
2770             lpar.type == token.LPAR
2771             and gexp.type == syms.testlist_gexp
2772             and rpar.type == token.RPAR
2773         ):
2774             return False
2775
2776         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2777
2778     return (
2779         node.type in IMPLICIT_TUPLE
2780         and len(node.children) == 2
2781         and node.children[1].type == token.COMMA
2782     )
2783
2784
2785 def is_yield(node: LN) -> bool:
2786     """Return True if `node` holds a `yield` or `yield from` expression."""
2787     if node.type == syms.yield_expr:
2788         return True
2789
2790     if node.type == token.NAME and node.value == "yield":  # type: ignore
2791         return True
2792
2793     if node.type != syms.atom:
2794         return False
2795
2796     if len(node.children) != 3:
2797         return False
2798
2799     lpar, expr, rpar = node.children
2800     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2801         return is_yield(expr)
2802
2803     return False
2804
2805
2806 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2807     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2808
2809     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2810     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2811     extended iterable unpacking (PEP 3132) and additional unpacking
2812     generalizations (PEP 448).
2813     """
2814     if leaf.type not in STARS or not leaf.parent:
2815         return False
2816
2817     p = leaf.parent
2818     if p.type == syms.star_expr:
2819         # Star expressions are also used as assignment targets in extended
2820         # iterable unpacking (PEP 3132).  See what its parent is instead.
2821         if not p.parent:
2822             return False
2823
2824         p = p.parent
2825
2826     return p.type in within
2827
2828
2829 def is_multiline_string(leaf: Leaf) -> bool:
2830     """Return True if `leaf` is a multiline string that actually spans many lines."""
2831     value = leaf.value.lstrip("furbFURB")
2832     return value[:3] in {'"""', "'''"} and "\n" in value
2833
2834
2835 def is_stub_suite(node: Node) -> bool:
2836     """Return True if `node` is a suite with a stub body."""
2837     if (
2838         len(node.children) != 4
2839         or node.children[0].type != token.NEWLINE
2840         or node.children[1].type != token.INDENT
2841         or node.children[3].type != token.DEDENT
2842     ):
2843         return False
2844
2845     return is_stub_body(node.children[2])
2846
2847
2848 def is_stub_body(node: LN) -> bool:
2849     """Return True if `node` is a simple statement containing an ellipsis."""
2850     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2851         return False
2852
2853     if len(node.children) != 2:
2854         return False
2855
2856     child = node.children[0]
2857     return (
2858         child.type == syms.atom
2859         and len(child.children) == 3
2860         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2861     )
2862
2863
2864 def max_delimiter_priority_in_atom(node: LN) -> int:
2865     """Return maximum delimiter priority inside `node`.
2866
2867     This is specific to atoms with contents contained in a pair of parentheses.
2868     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2869     """
2870     if node.type != syms.atom:
2871         return 0
2872
2873     first = node.children[0]
2874     last = node.children[-1]
2875     if not (first.type == token.LPAR and last.type == token.RPAR):
2876         return 0
2877
2878     bt = BracketTracker()
2879     for c in node.children[1:-1]:
2880         if isinstance(c, Leaf):
2881             bt.mark(c)
2882         else:
2883             for leaf in c.leaves():
2884                 bt.mark(leaf)
2885     try:
2886         return bt.max_delimiter_priority()
2887
2888     except ValueError:
2889         return 0
2890
2891
2892 def ensure_visible(leaf: Leaf) -> None:
2893     """Make sure parentheses are visible.
2894
2895     They could be invisible as part of some statements (see
2896     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2897     """
2898     if leaf.type == token.LPAR:
2899         leaf.value = "("
2900     elif leaf.type == token.RPAR:
2901         leaf.value = ")"
2902
2903
2904 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2905     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2906     if not (
2907         opening_bracket.parent
2908         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2909         and opening_bracket.value in "[{("
2910     ):
2911         return False
2912
2913     try:
2914         last_leaf = line.leaves[-1]
2915         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2916         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2917     except (IndexError, ValueError):
2918         return False
2919
2920     return max_priority == COMMA_PRIORITY
2921
2922
2923 def is_python36(node: Node) -> bool:
2924     """Return True if the current file is using Python 3.6+ features.
2925
2926     Currently looking for:
2927     - f-strings;
2928     - underscores in numeric literals; and
2929     - trailing commas after * or ** in function signatures and calls.
2930     """
2931     for n in node.pre_order():
2932         if n.type == token.STRING:
2933             value_head = n.value[:2]  # type: ignore
2934             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2935                 return True
2936
2937         elif n.type == token.NUMBER:
2938             if "_" in n.value:  # type: ignore
2939                 return True
2940
2941         elif (
2942             n.type in {syms.typedargslist, syms.arglist}
2943             and n.children
2944             and n.children[-1].type == token.COMMA
2945         ):
2946             for ch in n.children:
2947                 if ch.type in STARS:
2948                     return True
2949
2950                 if ch.type == syms.argument:
2951                     for argch in ch.children:
2952                         if argch.type in STARS:
2953                             return True
2954
2955     return False
2956
2957
2958 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2959     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2960
2961     Brackets can be omitted if the entire trailer up to and including
2962     a preceding closing bracket fits in one line.
2963
2964     Yielded sets are cumulative (contain results of previous yields, too).  First
2965     set is empty.
2966     """
2967
2968     omit: Set[LeafID] = set()
2969     yield omit
2970
2971     length = 4 * line.depth
2972     opening_bracket = None
2973     closing_bracket = None
2974     optional_brackets: Set[LeafID] = set()
2975     inner_brackets: Set[LeafID] = set()
2976     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2977         length += leaf_length
2978         if length > line_length:
2979             break
2980
2981         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2982         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2983             break
2984
2985         optional_brackets.discard(id(leaf))
2986         if opening_bracket:
2987             if leaf is opening_bracket:
2988                 opening_bracket = None
2989             elif leaf.type in CLOSING_BRACKETS:
2990                 inner_brackets.add(id(leaf))
2991         elif leaf.type in CLOSING_BRACKETS:
2992             if not leaf.value:
2993                 optional_brackets.add(id(opening_bracket))
2994                 continue
2995
2996             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2997                 # Empty brackets would fail a split so treat them as "inner"
2998                 # brackets (e.g. only add them to the `omit` set if another
2999                 # pair of brackets was good enough.
3000                 inner_brackets.add(id(leaf))
3001                 continue
3002
3003             opening_bracket = leaf.opening_bracket
3004             if closing_bracket:
3005                 omit.add(id(closing_bracket))
3006                 omit.update(inner_brackets)
3007                 inner_brackets.clear()
3008                 yield omit
3009             closing_bracket = leaf
3010
3011
3012 def get_future_imports(node: Node) -> Set[str]:
3013     """Return a set of __future__ imports in the file."""
3014     imports: Set[str] = set()
3015
3016     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3017         for child in children:
3018             if isinstance(child, Leaf):
3019                 if child.type == token.NAME:
3020                     yield child.value
3021             elif child.type == syms.import_as_name:
3022                 orig_name = child.children[0]
3023                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3024                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3025                 yield orig_name.value
3026             elif child.type == syms.import_as_names:
3027                 yield from get_imports_from_children(child.children)
3028             else:
3029                 assert False, "Invalid syntax parsing imports"
3030
3031     for child in node.children:
3032         if child.type != syms.simple_stmt:
3033             break
3034         first_child = child.children[0]
3035         if isinstance(first_child, Leaf):
3036             # Continue looking if we see a docstring; otherwise stop.
3037             if (
3038                 len(child.children) == 2
3039                 and first_child.type == token.STRING
3040                 and child.children[1].type == token.NEWLINE
3041             ):
3042                 continue
3043             else:
3044                 break
3045         elif first_child.type == syms.import_from:
3046             module_name = first_child.children[1]
3047             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3048                 break
3049             imports |= set(get_imports_from_children(first_child.children[3:]))
3050         else:
3051             break
3052     return imports
3053
3054
3055 def gen_python_files_in_dir(
3056     path: Path,
3057     root: Path,
3058     include: Pattern[str],
3059     exclude: Pattern[str],
3060     report: "Report",
3061 ) -> Iterator[Path]:
3062     """Generate all files under `path` whose paths are not excluded by the
3063     `exclude` regex, but are included by the `include` regex.
3064
3065     Symbolic links pointing outside of the `root` directory are ignored.
3066
3067     `report` is where output about exclusions goes.
3068     """
3069     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3070     for child in path.iterdir():
3071         try:
3072             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3073         except ValueError:
3074             if child.is_symlink():
3075                 report.path_ignored(
3076                     child, f"is a symbolic link that points outside {root}"
3077                 )
3078                 continue
3079
3080             raise
3081
3082         if child.is_dir():
3083             normalized_path += "/"
3084         exclude_match = exclude.search(normalized_path)
3085         if exclude_match and exclude_match.group(0):
3086             report.path_ignored(child, f"matches the --exclude regular expression")
3087             continue
3088
3089         if child.is_dir():
3090             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3091
3092         elif child.is_file():
3093             include_match = include.search(normalized_path)
3094             if include_match:
3095                 yield child
3096
3097
3098 @lru_cache()
3099 def find_project_root(srcs: Iterable[str]) -> Path:
3100     """Return a directory containing .git, .hg, or pyproject.toml.
3101
3102     That directory can be one of the directories passed in `srcs` or their
3103     common parent.
3104
3105     If no directory in the tree contains a marker that would specify it's the
3106     project root, the root of the file system is returned.
3107     """
3108     if not srcs:
3109         return Path("/").resolve()
3110
3111     common_base = min(Path(src).resolve() for src in srcs)
3112     if common_base.is_dir():
3113         # Append a fake file so `parents` below returns `common_base_dir`, too.
3114         common_base /= "fake-file"
3115     for directory in common_base.parents:
3116         if (directory / ".git").is_dir():
3117             return directory
3118
3119         if (directory / ".hg").is_dir():
3120             return directory
3121
3122         if (directory / "pyproject.toml").is_file():
3123             return directory
3124
3125     return directory
3126
3127
3128 @dataclass
3129 class Report:
3130     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3131
3132     check: bool = False
3133     quiet: bool = False
3134     verbose: bool = False
3135     change_count: int = 0
3136     same_count: int = 0
3137     failure_count: int = 0
3138
3139     def done(self, src: Path, changed: Changed) -> None:
3140         """Increment the counter for successful reformatting. Write out a message."""
3141         if changed is Changed.YES:
3142             reformatted = "would reformat" if self.check else "reformatted"
3143             if self.verbose or not self.quiet:
3144                 out(f"{reformatted} {src}")
3145             self.change_count += 1
3146         else:
3147             if self.verbose:
3148                 if changed is Changed.NO:
3149                     msg = f"{src} already well formatted, good job."
3150                 else:
3151                     msg = f"{src} wasn't modified on disk since last run."
3152                 out(msg, bold=False)
3153             self.same_count += 1
3154
3155     def failed(self, src: Path, message: str) -> None:
3156         """Increment the counter for failed reformatting. Write out a message."""
3157         err(f"error: cannot format {src}: {message}")
3158         self.failure_count += 1
3159
3160     def path_ignored(self, path: Path, message: str) -> None:
3161         if self.verbose:
3162             out(f"{path} ignored: {message}", bold=False)
3163
3164     @property
3165     def return_code(self) -> int:
3166         """Return the exit code that the app should use.
3167
3168         This considers the current state of changed files and failures:
3169         - if there were any failures, return 123;
3170         - if any files were changed and --check is being used, return 1;
3171         - otherwise return 0.
3172         """
3173         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3174         # 126 we have special return codes reserved by the shell.
3175         if self.failure_count:
3176             return 123
3177
3178         elif self.change_count and self.check:
3179             return 1
3180
3181         return 0
3182
3183     def __str__(self) -> str:
3184         """Render a color report of the current state.
3185
3186         Use `click.unstyle` to remove colors.
3187         """
3188         if self.check:
3189             reformatted = "would be reformatted"
3190             unchanged = "would be left unchanged"
3191             failed = "would fail to reformat"
3192         else:
3193             reformatted = "reformatted"
3194             unchanged = "left unchanged"
3195             failed = "failed to reformat"
3196         report = []
3197         if self.change_count:
3198             s = "s" if self.change_count > 1 else ""
3199             report.append(
3200                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3201             )
3202         if self.same_count:
3203             s = "s" if self.same_count > 1 else ""
3204             report.append(f"{self.same_count} file{s} {unchanged}")
3205         if self.failure_count:
3206             s = "s" if self.failure_count > 1 else ""
3207             report.append(
3208                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3209             )
3210         return ", ".join(report) + "."
3211
3212
3213 def assert_equivalent(src: str, dst: str) -> None:
3214     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3215
3216     import ast
3217     import traceback
3218
3219     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3220         """Simple visitor generating strings to compare ASTs by content."""
3221         yield f"{'  ' * depth}{node.__class__.__name__}("
3222
3223         for field in sorted(node._fields):
3224             try:
3225                 value = getattr(node, field)
3226             except AttributeError:
3227                 continue
3228
3229             yield f"{'  ' * (depth+1)}{field}="
3230
3231             if isinstance(value, list):
3232                 for item in value:
3233                     if isinstance(item, ast.AST):
3234                         yield from _v(item, depth + 2)
3235
3236             elif isinstance(value, ast.AST):
3237                 yield from _v(value, depth + 2)
3238
3239             else:
3240                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3241
3242         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3243
3244     try:
3245         src_ast = ast.parse(src)
3246     except Exception as exc:
3247         major, minor = sys.version_info[:2]
3248         raise AssertionError(
3249             f"cannot use --safe with this file; failed to parse source file "
3250             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3251             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3252         )
3253
3254     try:
3255         dst_ast = ast.parse(dst)
3256     except Exception as exc:
3257         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3258         raise AssertionError(
3259             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3260             f"Please report a bug on https://github.com/ambv/black/issues.  "
3261             f"This invalid output might be helpful: {log}"
3262         ) from None
3263
3264     src_ast_str = "\n".join(_v(src_ast))
3265     dst_ast_str = "\n".join(_v(dst_ast))
3266     if src_ast_str != dst_ast_str:
3267         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3268         raise AssertionError(
3269             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3270             f"the source.  "
3271             f"Please report a bug on https://github.com/ambv/black/issues.  "
3272             f"This diff might be helpful: {log}"
3273         ) from None
3274
3275
3276 def assert_stable(
3277     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3278 ) -> None:
3279     """Raise AssertionError if `dst` reformats differently the second time."""
3280     newdst = format_str(dst, line_length=line_length, mode=mode)
3281     if dst != newdst:
3282         log = dump_to_file(
3283             diff(src, dst, "source", "first pass"),
3284             diff(dst, newdst, "first pass", "second pass"),
3285         )
3286         raise AssertionError(
3287             f"INTERNAL ERROR: Black produced different code on the second pass "
3288             f"of the formatter.  "
3289             f"Please report a bug on https://github.com/ambv/black/issues.  "
3290             f"This diff might be helpful: {log}"
3291         ) from None
3292
3293
3294 def dump_to_file(*output: str) -> str:
3295     """Dump `output` to a temporary file. Return path to the file."""
3296     import tempfile
3297
3298     with tempfile.NamedTemporaryFile(
3299         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3300     ) as f:
3301         for lines in output:
3302             f.write(lines)
3303             if lines and lines[-1] != "\n":
3304                 f.write("\n")
3305     return f.name
3306
3307
3308 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3309     """Return a unified diff string between strings `a` and `b`."""
3310     import difflib
3311
3312     a_lines = [line + "\n" for line in a.split("\n")]
3313     b_lines = [line + "\n" for line in b.split("\n")]
3314     return "".join(
3315         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3316     )
3317
3318
3319 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3320     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3321     err("Aborted!")
3322     for task in tasks:
3323         task.cancel()
3324
3325
3326 def shutdown(loop: BaseEventLoop) -> None:
3327     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3328     try:
3329         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3330         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3331         if not to_cancel:
3332             return
3333
3334         for task in to_cancel:
3335             task.cancel()
3336         loop.run_until_complete(
3337             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3338         )
3339     finally:
3340         # `concurrent.futures.Future` objects cannot be cancelled once they
3341         # are already running. There might be some when the `shutdown()` happened.
3342         # Silence their logger's spew about the event loop being closed.
3343         cf_logger = logging.getLogger("concurrent.futures")
3344         cf_logger.setLevel(logging.CRITICAL)
3345         loop.close()
3346
3347
3348 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3349     """Replace `regex` with `replacement` twice on `original`.
3350
3351     This is used by string normalization to perform replaces on
3352     overlapping matches.
3353     """
3354     return regex.sub(replacement, regex.sub(replacement, original))
3355
3356
3357 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3358     """Compile a regular expression string in `regex`.
3359
3360     If it contains newlines, use verbose mode.
3361     """
3362     if "\n" in regex:
3363         regex = "(?x)" + regex
3364     return re.compile(regex)
3365
3366
3367 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3368     """Like `reversed(enumerate(sequence))` if that were possible."""
3369     index = len(sequence) - 1
3370     for element in reversed(sequence):
3371         yield (index, element)
3372         index -= 1
3373
3374
3375 def enumerate_with_length(
3376     line: Line, reversed: bool = False
3377 ) -> Iterator[Tuple[Index, Leaf, int]]:
3378     """Return an enumeration of leaves with their length.
3379
3380     Stops prematurely on multiline strings and standalone comments.
3381     """
3382     op = cast(
3383         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3384         enumerate_reversed if reversed else enumerate,
3385     )
3386     for index, leaf in op(line.leaves):
3387         length = len(leaf.prefix) + len(leaf.value)
3388         if "\n" in leaf.value:
3389             return  # Multiline strings, we can't continue.
3390
3391         comment: Optional[Leaf]
3392         for comment in line.comments_after(leaf, index):
3393             length += len(comment.value)
3394
3395         yield index, leaf, length
3396
3397
3398 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3399     """Return True if `line` is no longer than `line_length`.
3400
3401     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3402     """
3403     if not line_str:
3404         line_str = str(line).strip("\n")
3405     return (
3406         len(line_str) <= line_length
3407         and "\n" not in line_str  # multiline strings
3408         and not line.contains_standalone_comments()
3409     )
3410
3411
3412 def can_be_split(line: Line) -> bool:
3413     """Return False if the line cannot be split *for sure*.
3414
3415     This is not an exhaustive search but a cheap heuristic that we can use to
3416     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3417     in unnecessary parentheses).
3418     """
3419     leaves = line.leaves
3420     if len(leaves) < 2:
3421         return False
3422
3423     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3424         call_count = 0
3425         dot_count = 0
3426         next = leaves[-1]
3427         for leaf in leaves[-2::-1]:
3428             if leaf.type in OPENING_BRACKETS:
3429                 if next.type not in CLOSING_BRACKETS:
3430                     return False
3431
3432                 call_count += 1
3433             elif leaf.type == token.DOT:
3434                 dot_count += 1
3435             elif leaf.type == token.NAME:
3436                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3437                     return False
3438
3439             elif leaf.type not in CLOSING_BRACKETS:
3440                 return False
3441
3442             if dot_count > 1 and call_count > 1:
3443                 return False
3444
3445     return True
3446
3447
3448 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3449     """Does `line` have a shape safe to reformat without optional parens around it?
3450
3451     Returns True for only a subset of potentially nice looking formattings but
3452     the point is to not return false positives that end up producing lines that
3453     are too long.
3454     """
3455     bt = line.bracket_tracker
3456     if not bt.delimiters:
3457         # Without delimiters the optional parentheses are useless.
3458         return True
3459
3460     max_priority = bt.max_delimiter_priority()
3461     if bt.delimiter_count_with_priority(max_priority) > 1:
3462         # With more than one delimiter of a kind the optional parentheses read better.
3463         return False
3464
3465     if max_priority == DOT_PRIORITY:
3466         # A single stranded method call doesn't require optional parentheses.
3467         return True
3468
3469     assert len(line.leaves) >= 2, "Stranded delimiter"
3470
3471     first = line.leaves[0]
3472     second = line.leaves[1]
3473     penultimate = line.leaves[-2]
3474     last = line.leaves[-1]
3475
3476     # With a single delimiter, omit if the expression starts or ends with
3477     # a bracket.
3478     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3479         remainder = False
3480         length = 4 * line.depth
3481         for _index, leaf, leaf_length in enumerate_with_length(line):
3482             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3483                 remainder = True
3484             if remainder:
3485                 length += leaf_length
3486                 if length > line_length:
3487                     break
3488
3489                 if leaf.type in OPENING_BRACKETS:
3490                     # There are brackets we can further split on.
3491                     remainder = False
3492
3493         else:
3494             # checked the entire string and line length wasn't exceeded
3495             if len(line.leaves) == _index + 1:
3496                 return True
3497
3498         # Note: we are not returning False here because a line might have *both*
3499         # a leading opening bracket and a trailing closing bracket.  If the
3500         # opening bracket doesn't match our rule, maybe the closing will.
3501
3502     if (
3503         last.type == token.RPAR
3504         or last.type == token.RBRACE
3505         or (
3506             # don't use indexing for omitting optional parentheses;
3507             # it looks weird
3508             last.type == token.RSQB
3509             and last.parent
3510             and last.parent.type != syms.trailer
3511         )
3512     ):
3513         if penultimate.type in OPENING_BRACKETS:
3514             # Empty brackets don't help.
3515             return False
3516
3517         if is_multiline_string(first):
3518             # Additional wrapping of a multiline string in this situation is
3519             # unnecessary.
3520             return True
3521
3522         length = 4 * line.depth
3523         seen_other_brackets = False
3524         for _index, leaf, leaf_length in enumerate_with_length(line):
3525             length += leaf_length
3526             if leaf is last.opening_bracket:
3527                 if seen_other_brackets or length <= line_length:
3528                     return True
3529
3530             elif leaf.type in OPENING_BRACKETS:
3531                 # There are brackets we can further split on.
3532                 seen_other_brackets = True
3533
3534     return False
3535
3536
3537 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3538     return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
3539
3540
3541 def read_cache(line_length: int, mode: FileMode) -> Cache:
3542     """Read the cache if it exists and is well formed.
3543
3544     If it is not well formed, the call to write_cache later should resolve the issue.
3545     """
3546     cache_file = get_cache_file(line_length, mode)
3547     if not cache_file.exists():
3548         return {}
3549
3550     with cache_file.open("rb") as fobj:
3551         try:
3552             cache: Cache = pickle.load(fobj)
3553         except pickle.UnpicklingError:
3554             return {}
3555
3556     return cache
3557
3558
3559 def get_cache_info(path: Path) -> CacheInfo:
3560     """Return the information used to check if a file is already formatted or not."""
3561     stat = path.stat()
3562     return stat.st_mtime, stat.st_size
3563
3564
3565 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3566     """Split an iterable of paths in `sources` into two sets.
3567
3568     The first contains paths of files that modified on disk or are not in the
3569     cache. The other contains paths to non-modified files.
3570     """
3571     todo, done = set(), set()
3572     for src in sources:
3573         src = src.resolve()
3574         if cache.get(src) != get_cache_info(src):
3575             todo.add(src)
3576         else:
3577             done.add(src)
3578     return todo, done
3579
3580
3581 def write_cache(
3582     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3583 ) -> None:
3584     """Update the cache file."""
3585     cache_file = get_cache_file(line_length, mode)
3586     try:
3587         if not CACHE_DIR.exists():
3588             CACHE_DIR.mkdir(parents=True)
3589         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3590         with cache_file.open("wb") as fobj:
3591             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3592     except OSError:
3593         pass
3594
3595
3596 def patch_click() -> None:
3597     """Make Click not crash.
3598
3599     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3600     default which restricts paths that it can access during the lifetime of the
3601     application.  Click refuses to work in this scenario by raising a RuntimeError.
3602
3603     In case of Black the likelihood that non-ASCII characters are going to be used in
3604     file paths is minimal since it's Python source code.  Moreover, this crash was
3605     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3606     """
3607     try:
3608         from click import core
3609         from click import _unicodefun  # type: ignore
3610     except ModuleNotFoundError:
3611         return
3612
3613     for module in (core, _unicodefun):
3614         if hasattr(module, "_verify_python3_env"):
3615             module._verify_python3_env = lambda: None
3616
3617
3618 if __name__ == "__main__":
3619     patch_click()
3620     main()