]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

6f7496ede2345e597c3b3145fc268d65c8f7b416
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum, Flag
6 from functools import lru_cache, partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tokenize
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generator,
24     Generic,
25     Iterable,
26     Iterator,
27     List,
28     Optional,
29     Pattern,
30     Sequence,
31     Set,
32     Tuple,
33     TypeVar,
34     Union,
35     cast,
36 )
37
38 from appdirs import user_cache_dir
39 from attr import dataclass, Factory
40 import click
41 import toml
42
43 # lib2to3 fork
44 from blib2to3.pytree import Node, Leaf, type_repr
45 from blib2to3 import pygram, pytree
46 from blib2to3.pgen2 import driver, token
47 from blib2to3.pgen2.parse import ParseError
48
49
50 __version__ = "18.6b4"
51 DEFAULT_LINE_LENGTH = 88
52 DEFAULT_EXCLUDES = (
53     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
54 )
55 DEFAULT_INCLUDES = r"\.pyi?$"
56 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
57
58
59 # types
60 FileContent = str
61 Encoding = str
62 NewLine = str
63 Depth = int
64 NodeType = int
65 LeafID = int
66 Priority = int
67 Index = int
68 LN = Union[Leaf, Node]
69 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
70 Timestamp = float
71 FileSize = int
72 CacheInfo = Tuple[Timestamp, FileSize]
73 Cache = Dict[Path, CacheInfo]
74 out = partial(click.secho, bold=True, err=True)
75 err = partial(click.secho, fg="red", err=True)
76
77 pygram.initialize(CACHE_DIR)
78 syms = pygram.python_symbols
79
80
81 class NothingChanged(UserWarning):
82     """Raised by :func:`format_file` when reformatted code is the same as source."""
83
84
85 class CannotSplit(Exception):
86     """A readable split that fits the allotted line length is impossible.
87
88     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
89     :func:`delimiter_split`.
90     """
91
92
93 class WriteBack(Enum):
94     NO = 0
95     YES = 1
96     DIFF = 2
97     CHECK = 3
98
99     @classmethod
100     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
101         if check and not diff:
102             return cls.CHECK
103
104         return cls.DIFF if diff else cls.YES
105
106
107 class Changed(Enum):
108     NO = 0
109     CACHED = 1
110     YES = 2
111
112
113 class FileMode(Flag):
114     AUTO_DETECT = 0
115     PYTHON36 = 1
116     PYI = 2
117     NO_STRING_NORMALIZATION = 4
118
119     @classmethod
120     def from_configuration(
121         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
122     ) -> "FileMode":
123         mode = cls.AUTO_DETECT
124         if py36:
125             mode |= cls.PYTHON36
126         if pyi:
127             mode |= cls.PYI
128         if skip_string_normalization:
129             mode |= cls.NO_STRING_NORMALIZATION
130         return mode
131
132
133 def read_pyproject_toml(
134     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
135 ) -> Optional[str]:
136     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
137
138     Returns the path to a successfully found and read configuration file, None
139     otherwise.
140     """
141     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
142     if not value:
143         root = find_project_root(ctx.params.get("src", ()))
144         path = root / "pyproject.toml"
145         if path.is_file():
146             value = str(path)
147         else:
148             return None
149
150     try:
151         pyproject_toml = toml.load(value)
152         config = pyproject_toml.get("tool", {}).get("black", {})
153     except (toml.TomlDecodeError, OSError) as e:
154         raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
155
156     if not config:
157         return None
158
159     if ctx.default_map is None:
160         ctx.default_map = {}
161     ctx.default_map.update(  # type: ignore  # bad types in .pyi
162         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
163     )
164     return value
165
166
167 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
168 @click.option(
169     "-l",
170     "--line-length",
171     type=int,
172     default=DEFAULT_LINE_LENGTH,
173     help="How many characters per line to allow.",
174     show_default=True,
175 )
176 @click.option(
177     "--py36",
178     is_flag=True,
179     help=(
180         "Allow using Python 3.6-only syntax on all input files.  This will put "
181         "trailing commas in function signatures and calls also after *args and "
182         "**kwargs.  [default: per-file auto-detection]"
183     ),
184 )
185 @click.option(
186     "--pyi",
187     is_flag=True,
188     help=(
189         "Format all input files like typing stubs regardless of file extension "
190         "(useful when piping source on standard input)."
191     ),
192 )
193 @click.option(
194     "-S",
195     "--skip-string-normalization",
196     is_flag=True,
197     help="Don't normalize string quotes or prefixes.",
198 )
199 @click.option(
200     "--check",
201     is_flag=True,
202     help=(
203         "Don't write the files back, just return the status.  Return code 0 "
204         "means nothing would change.  Return code 1 means some files would be "
205         "reformatted.  Return code 123 means there was an internal error."
206     ),
207 )
208 @click.option(
209     "--diff",
210     is_flag=True,
211     help="Don't write the files back, just output a diff for each file on stdout.",
212 )
213 @click.option(
214     "--fast/--safe",
215     is_flag=True,
216     help="If --fast given, skip temporary sanity checks. [default: --safe]",
217 )
218 @click.option(
219     "--include",
220     type=str,
221     default=DEFAULT_INCLUDES,
222     help=(
223         "A regular expression that matches files and directories that should be "
224         "included on recursive searches.  An empty value means all files are "
225         "included regardless of the name.  Use forward slashes for directories on "
226         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
227         "later."
228     ),
229     show_default=True,
230 )
231 @click.option(
232     "--exclude",
233     type=str,
234     default=DEFAULT_EXCLUDES,
235     help=(
236         "A regular expression that matches files and directories that should be "
237         "excluded on recursive searches.  An empty value means no paths are excluded. "
238         "Use forward slashes for directories on all platforms (Windows, too).  "
239         "Exclusions are calculated first, inclusions later."
240     ),
241     show_default=True,
242 )
243 @click.option(
244     "-q",
245     "--quiet",
246     is_flag=True,
247     help=(
248         "Don't emit non-error messages to stderr. Errors are still emitted, "
249         "silence those with 2>/dev/null."
250     ),
251 )
252 @click.option(
253     "-v",
254     "--verbose",
255     is_flag=True,
256     help=(
257         "Also emit messages to stderr about files that were not changed or were "
258         "ignored due to --exclude=."
259     ),
260 )
261 @click.version_option(version=__version__)
262 @click.argument(
263     "src",
264     nargs=-1,
265     type=click.Path(
266         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
267     ),
268     is_eager=True,
269 )
270 @click.option(
271     "--config",
272     type=click.Path(
273         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
274     ),
275     is_eager=True,
276     callback=read_pyproject_toml,
277     help="Read configuration from PATH.",
278 )
279 @click.pass_context
280 def main(
281     ctx: click.Context,
282     line_length: int,
283     check: bool,
284     diff: bool,
285     fast: bool,
286     pyi: bool,
287     py36: bool,
288     skip_string_normalization: bool,
289     quiet: bool,
290     verbose: bool,
291     include: str,
292     exclude: str,
293     src: Tuple[str],
294     config: Optional[str],
295 ) -> None:
296     """The uncompromising code formatter."""
297     write_back = WriteBack.from_configuration(check=check, diff=diff)
298     mode = FileMode.from_configuration(
299         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
300     )
301     if config and verbose:
302         out(f"Using configuration from {config}.", bold=False, fg="blue")
303     try:
304         include_regex = re_compile_maybe_verbose(include)
305     except re.error:
306         err(f"Invalid regular expression for include given: {include!r}")
307         ctx.exit(2)
308     try:
309         exclude_regex = re_compile_maybe_verbose(exclude)
310     except re.error:
311         err(f"Invalid regular expression for exclude given: {exclude!r}")
312         ctx.exit(2)
313     report = Report(check=check, quiet=quiet, verbose=verbose)
314     root = find_project_root(src)
315     sources: Set[Path] = set()
316     for s in src:
317         p = Path(s)
318         if p.is_dir():
319             sources.update(
320                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
321             )
322         elif p.is_file() or s == "-":
323             # if a file was explicitly given, we don't care about its extension
324             sources.add(p)
325         else:
326             err(f"invalid path: {s}")
327     if len(sources) == 0:
328         if verbose or not quiet:
329             out("No paths given. Nothing to do 😴")
330         ctx.exit(0)
331
332     if len(sources) == 1:
333         reformat_one(
334             src=sources.pop(),
335             line_length=line_length,
336             fast=fast,
337             write_back=write_back,
338             mode=mode,
339             report=report,
340         )
341     else:
342         loop = asyncio.get_event_loop()
343         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
344         try:
345             loop.run_until_complete(
346                 schedule_formatting(
347                     sources=sources,
348                     line_length=line_length,
349                     fast=fast,
350                     write_back=write_back,
351                     mode=mode,
352                     report=report,
353                     loop=loop,
354                     executor=executor,
355                 )
356             )
357         finally:
358             shutdown(loop)
359     if verbose or not quiet:
360         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
361         out(f"All done! {bang}")
362         click.secho(str(report), err=True)
363     ctx.exit(report.return_code)
364
365
366 def reformat_one(
367     src: Path,
368     line_length: int,
369     fast: bool,
370     write_back: WriteBack,
371     mode: FileMode,
372     report: "Report",
373 ) -> None:
374     """Reformat a single file under `src` without spawning child processes.
375
376     If `quiet` is True, non-error messages are not output. `line_length`,
377     `write_back`, `fast` and `pyi` options are passed to
378     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
379     """
380     try:
381         changed = Changed.NO
382         if not src.is_file() and str(src) == "-":
383             if format_stdin_to_stdout(
384                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
385             ):
386                 changed = Changed.YES
387         else:
388             cache: Cache = {}
389             if write_back != WriteBack.DIFF:
390                 cache = read_cache(line_length, mode)
391                 res_src = src.resolve()
392                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
393                     changed = Changed.CACHED
394             if changed is not Changed.CACHED and format_file_in_place(
395                 src,
396                 line_length=line_length,
397                 fast=fast,
398                 write_back=write_back,
399                 mode=mode,
400             ):
401                 changed = Changed.YES
402             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
403                 write_back is WriteBack.CHECK and changed is Changed.NO
404             ):
405                 write_cache(cache, [src], line_length, mode)
406         report.done(src, changed)
407     except Exception as exc:
408         report.failed(src, str(exc))
409
410
411 async def schedule_formatting(
412     sources: Set[Path],
413     line_length: int,
414     fast: bool,
415     write_back: WriteBack,
416     mode: FileMode,
417     report: "Report",
418     loop: BaseEventLoop,
419     executor: Executor,
420 ) -> None:
421     """Run formatting of `sources` in parallel using the provided `executor`.
422
423     (Use ProcessPoolExecutors for actual parallelism.)
424
425     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
426     :func:`format_file_in_place`.
427     """
428     cache: Cache = {}
429     if write_back != WriteBack.DIFF:
430         cache = read_cache(line_length, mode)
431         sources, cached = filter_cached(cache, sources)
432         for src in sorted(cached):
433             report.done(src, Changed.CACHED)
434     if not sources:
435         return
436
437     cancelled = []
438     sources_to_cache = []
439     lock = None
440     if write_back == WriteBack.DIFF:
441         # For diff output, we need locks to ensure we don't interleave output
442         # from different processes.
443         manager = Manager()
444         lock = manager.Lock()
445     tasks = {
446         loop.run_in_executor(
447             executor,
448             format_file_in_place,
449             src,
450             line_length,
451             fast,
452             write_back,
453             mode,
454             lock,
455         ): src
456         for src in sorted(sources)
457     }
458     pending: Iterable[asyncio.Task] = tasks.keys()
459     try:
460         loop.add_signal_handler(signal.SIGINT, cancel, pending)
461         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
462     except NotImplementedError:
463         # There are no good alternatives for these on Windows.
464         pass
465     while pending:
466         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
467         for task in done:
468             src = tasks.pop(task)
469             if task.cancelled():
470                 cancelled.append(task)
471             elif task.exception():
472                 report.failed(src, str(task.exception()))
473             else:
474                 changed = Changed.YES if task.result() else Changed.NO
475                 # If the file was written back or was successfully checked as
476                 # well-formatted, store this information in the cache.
477                 if write_back is WriteBack.YES or (
478                     write_back is WriteBack.CHECK and changed is Changed.NO
479                 ):
480                     sources_to_cache.append(src)
481                 report.done(src, changed)
482     if cancelled:
483         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
484     if sources_to_cache:
485         write_cache(cache, sources_to_cache, line_length, mode)
486
487
488 def format_file_in_place(
489     src: Path,
490     line_length: int,
491     fast: bool,
492     write_back: WriteBack = WriteBack.NO,
493     mode: FileMode = FileMode.AUTO_DETECT,
494     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
495 ) -> bool:
496     """Format file under `src` path. Return True if changed.
497
498     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
499     code to the file.
500     `line_length` and `fast` options are passed to :func:`format_file_contents`.
501     """
502     if src.suffix == ".pyi":
503         mode |= FileMode.PYI
504
505     then = datetime.utcfromtimestamp(src.stat().st_mtime)
506     with open(src, "rb") as buf:
507         src_contents, encoding, newline = decode_bytes(buf.read())
508     try:
509         dst_contents = format_file_contents(
510             src_contents, line_length=line_length, fast=fast, mode=mode
511         )
512     except NothingChanged:
513         return False
514
515     if write_back == write_back.YES:
516         with open(src, "w", encoding=encoding, newline=newline) as f:
517             f.write(dst_contents)
518     elif write_back == write_back.DIFF:
519         now = datetime.utcnow()
520         src_name = f"{src}\t{then} +0000"
521         dst_name = f"{src}\t{now} +0000"
522         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
523         if lock:
524             lock.acquire()
525         try:
526             f = io.TextIOWrapper(
527                 sys.stdout.buffer,
528                 encoding=encoding,
529                 newline=newline,
530                 write_through=True,
531             )
532             f.write(diff_contents)
533             f.detach()
534         finally:
535             if lock:
536                 lock.release()
537     return True
538
539
540 def format_stdin_to_stdout(
541     line_length: int,
542     fast: bool,
543     write_back: WriteBack = WriteBack.NO,
544     mode: FileMode = FileMode.AUTO_DETECT,
545 ) -> bool:
546     """Format file on stdin. Return True if changed.
547
548     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
549     write a diff to stdout.
550     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
551     :func:`format_file_contents`.
552     """
553     then = datetime.utcnow()
554     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
555     dst = src
556     try:
557         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
558         return True
559
560     except NothingChanged:
561         return False
562
563     finally:
564         f = io.TextIOWrapper(
565             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
566         )
567         if write_back == WriteBack.YES:
568             f.write(dst)
569         elif write_back == WriteBack.DIFF:
570             now = datetime.utcnow()
571             src_name = f"STDIN\t{then} +0000"
572             dst_name = f"STDOUT\t{now} +0000"
573             f.write(diff(src, dst, src_name, dst_name))
574         f.detach()
575
576
577 def format_file_contents(
578     src_contents: str,
579     *,
580     line_length: int,
581     fast: bool,
582     mode: FileMode = FileMode.AUTO_DETECT,
583 ) -> FileContent:
584     """Reformat contents a file and return new contents.
585
586     If `fast` is False, additionally confirm that the reformatted code is
587     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
588     `line_length` is passed to :func:`format_str`.
589     """
590     if src_contents.strip() == "":
591         raise NothingChanged
592
593     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
594     if src_contents == dst_contents:
595         raise NothingChanged
596
597     if not fast:
598         assert_equivalent(src_contents, dst_contents)
599         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
600     return dst_contents
601
602
603 def format_str(
604     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
605 ) -> FileContent:
606     """Reformat a string and return new contents.
607
608     `line_length` determines how many characters per line are allowed.
609     """
610     src_node = lib2to3_parse(src_contents)
611     dst_contents = ""
612     future_imports = get_future_imports(src_node)
613     is_pyi = bool(mode & FileMode.PYI)
614     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
615     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
616     normalize_fmt_off(src_node)
617     lines = LineGenerator(
618         remove_u_prefix=py36 or "unicode_literals" in future_imports,
619         is_pyi=is_pyi,
620         normalize_strings=normalize_strings,
621         allow_underscores=py36,
622     )
623     elt = EmptyLineTracker(is_pyi=is_pyi)
624     empty_line = Line()
625     after = 0
626     for current_line in lines.visit(src_node):
627         for _ in range(after):
628             dst_contents += str(empty_line)
629         before, after = elt.maybe_empty_lines(current_line)
630         for _ in range(before):
631             dst_contents += str(empty_line)
632         for line in split_line(current_line, line_length=line_length, py36=py36):
633             dst_contents += str(line)
634     return dst_contents
635
636
637 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
638     """Return a tuple of (decoded_contents, encoding, newline).
639
640     `newline` is either CRLF or LF but `decoded_contents` is decoded with
641     universal newlines (i.e. only contains LF).
642     """
643     srcbuf = io.BytesIO(src)
644     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
645     if not lines:
646         return "", encoding, "\n"
647
648     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
649     srcbuf.seek(0)
650     with io.TextIOWrapper(srcbuf, encoding) as tiow:
651         return tiow.read(), encoding, newline
652
653
654 GRAMMARS = [
655     pygram.python_grammar_no_print_statement_no_exec_statement,
656     pygram.python_grammar_no_print_statement,
657     pygram.python_grammar,
658 ]
659
660
661 def lib2to3_parse(src_txt: str) -> Node:
662     """Given a string with source, return the lib2to3 Node."""
663     grammar = pygram.python_grammar_no_print_statement
664     if src_txt[-1:] != "\n":
665         src_txt += "\n"
666     for grammar in GRAMMARS:
667         drv = driver.Driver(grammar, pytree.convert)
668         try:
669             result = drv.parse_string(src_txt, True)
670             break
671
672         except ParseError as pe:
673             lineno, column = pe.context[1]
674             lines = src_txt.splitlines()
675             try:
676                 faulty_line = lines[lineno - 1]
677             except IndexError:
678                 faulty_line = "<line number missing in source>"
679             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
680     else:
681         raise exc from None
682
683     if isinstance(result, Leaf):
684         result = Node(syms.file_input, [result])
685     return result
686
687
688 def lib2to3_unparse(node: Node) -> str:
689     """Given a lib2to3 node, return its string representation."""
690     code = str(node)
691     return code
692
693
694 T = TypeVar("T")
695
696
697 class Visitor(Generic[T]):
698     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
699
700     def visit(self, node: LN) -> Iterator[T]:
701         """Main method to visit `node` and its children.
702
703         It tries to find a `visit_*()` method for the given `node.type`, like
704         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
705         If no dedicated `visit_*()` method is found, chooses `visit_default()`
706         instead.
707
708         Then yields objects of type `T` from the selected visitor.
709         """
710         if node.type < 256:
711             name = token.tok_name[node.type]
712         else:
713             name = type_repr(node.type)
714         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
715
716     def visit_default(self, node: LN) -> Iterator[T]:
717         """Default `visit_*()` implementation. Recurses to children of `node`."""
718         if isinstance(node, Node):
719             for child in node.children:
720                 yield from self.visit(child)
721
722
723 @dataclass
724 class DebugVisitor(Visitor[T]):
725     tree_depth: int = 0
726
727     def visit_default(self, node: LN) -> Iterator[T]:
728         indent = " " * (2 * self.tree_depth)
729         if isinstance(node, Node):
730             _type = type_repr(node.type)
731             out(f"{indent}{_type}", fg="yellow")
732             self.tree_depth += 1
733             for child in node.children:
734                 yield from self.visit(child)
735
736             self.tree_depth -= 1
737             out(f"{indent}/{_type}", fg="yellow", bold=False)
738         else:
739             _type = token.tok_name.get(node.type, str(node.type))
740             out(f"{indent}{_type}", fg="blue", nl=False)
741             if node.prefix:
742                 # We don't have to handle prefixes for `Node` objects since
743                 # that delegates to the first child anyway.
744                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
745             out(f" {node.value!r}", fg="blue", bold=False)
746
747     @classmethod
748     def show(cls, code: Union[str, Leaf, Node]) -> None:
749         """Pretty-print the lib2to3 AST of a given string of `code`.
750
751         Convenience method for debugging.
752         """
753         v: DebugVisitor[None] = DebugVisitor()
754         if isinstance(code, str):
755             code = lib2to3_parse(code)
756         list(v.visit(code))
757
758
759 KEYWORDS = set(keyword.kwlist)
760 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
761 FLOW_CONTROL = {"return", "raise", "break", "continue"}
762 STATEMENT = {
763     syms.if_stmt,
764     syms.while_stmt,
765     syms.for_stmt,
766     syms.try_stmt,
767     syms.except_clause,
768     syms.with_stmt,
769     syms.funcdef,
770     syms.classdef,
771 }
772 STANDALONE_COMMENT = 153
773 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
774 LOGIC_OPERATORS = {"and", "or"}
775 COMPARATORS = {
776     token.LESS,
777     token.GREATER,
778     token.EQEQUAL,
779     token.NOTEQUAL,
780     token.LESSEQUAL,
781     token.GREATEREQUAL,
782 }
783 MATH_OPERATORS = {
784     token.VBAR,
785     token.CIRCUMFLEX,
786     token.AMPER,
787     token.LEFTSHIFT,
788     token.RIGHTSHIFT,
789     token.PLUS,
790     token.MINUS,
791     token.STAR,
792     token.SLASH,
793     token.DOUBLESLASH,
794     token.PERCENT,
795     token.AT,
796     token.TILDE,
797     token.DOUBLESTAR,
798 }
799 STARS = {token.STAR, token.DOUBLESTAR}
800 VARARGS_PARENTS = {
801     syms.arglist,
802     syms.argument,  # double star in arglist
803     syms.trailer,  # single argument to call
804     syms.typedargslist,
805     syms.varargslist,  # lambdas
806 }
807 UNPACKING_PARENTS = {
808     syms.atom,  # single element of a list or set literal
809     syms.dictsetmaker,
810     syms.listmaker,
811     syms.testlist_gexp,
812     syms.testlist_star_expr,
813 }
814 TEST_DESCENDANTS = {
815     syms.test,
816     syms.lambdef,
817     syms.or_test,
818     syms.and_test,
819     syms.not_test,
820     syms.comparison,
821     syms.star_expr,
822     syms.expr,
823     syms.xor_expr,
824     syms.and_expr,
825     syms.shift_expr,
826     syms.arith_expr,
827     syms.trailer,
828     syms.term,
829     syms.power,
830 }
831 ASSIGNMENTS = {
832     "=",
833     "+=",
834     "-=",
835     "*=",
836     "@=",
837     "/=",
838     "%=",
839     "&=",
840     "|=",
841     "^=",
842     "<<=",
843     ">>=",
844     "**=",
845     "//=",
846 }
847 COMPREHENSION_PRIORITY = 20
848 COMMA_PRIORITY = 18
849 TERNARY_PRIORITY = 16
850 LOGIC_PRIORITY = 14
851 STRING_PRIORITY = 12
852 COMPARATOR_PRIORITY = 10
853 MATH_PRIORITIES = {
854     token.VBAR: 9,
855     token.CIRCUMFLEX: 8,
856     token.AMPER: 7,
857     token.LEFTSHIFT: 6,
858     token.RIGHTSHIFT: 6,
859     token.PLUS: 5,
860     token.MINUS: 5,
861     token.STAR: 4,
862     token.SLASH: 4,
863     token.DOUBLESLASH: 4,
864     token.PERCENT: 4,
865     token.AT: 4,
866     token.TILDE: 3,
867     token.DOUBLESTAR: 2,
868 }
869 DOT_PRIORITY = 1
870
871
872 @dataclass
873 class BracketTracker:
874     """Keeps track of brackets on a line."""
875
876     depth: int = 0
877     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
878     delimiters: Dict[LeafID, Priority] = Factory(dict)
879     previous: Optional[Leaf] = None
880     _for_loop_depths: List[int] = Factory(list)
881     _lambda_argument_depths: List[int] = Factory(list)
882
883     def mark(self, leaf: Leaf) -> None:
884         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
885
886         All leaves receive an int `bracket_depth` field that stores how deep
887         within brackets a given leaf is. 0 means there are no enclosing brackets
888         that started on this line.
889
890         If a leaf is itself a closing bracket, it receives an `opening_bracket`
891         field that it forms a pair with. This is a one-directional link to
892         avoid reference cycles.
893
894         If a leaf is a delimiter (a token on which Black can split the line if
895         needed) and it's on depth 0, its `id()` is stored in the tracker's
896         `delimiters` field.
897         """
898         if leaf.type == token.COMMENT:
899             return
900
901         self.maybe_decrement_after_for_loop_variable(leaf)
902         self.maybe_decrement_after_lambda_arguments(leaf)
903         if leaf.type in CLOSING_BRACKETS:
904             self.depth -= 1
905             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
906             leaf.opening_bracket = opening_bracket
907         leaf.bracket_depth = self.depth
908         if self.depth == 0:
909             delim = is_split_before_delimiter(leaf, self.previous)
910             if delim and self.previous is not None:
911                 self.delimiters[id(self.previous)] = delim
912             else:
913                 delim = is_split_after_delimiter(leaf, self.previous)
914                 if delim:
915                     self.delimiters[id(leaf)] = delim
916         if leaf.type in OPENING_BRACKETS:
917             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
918             self.depth += 1
919         self.previous = leaf
920         self.maybe_increment_lambda_arguments(leaf)
921         self.maybe_increment_for_loop_variable(leaf)
922
923     def any_open_brackets(self) -> bool:
924         """Return True if there is an yet unmatched open bracket on the line."""
925         return bool(self.bracket_match)
926
927     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
928         """Return the highest priority of a delimiter found on the line.
929
930         Values are consistent with what `is_split_*_delimiter()` return.
931         Raises ValueError on no delimiters.
932         """
933         return max(v for k, v in self.delimiters.items() if k not in exclude)
934
935     def delimiter_count_with_priority(self, priority: int = 0) -> int:
936         """Return the number of delimiters with the given `priority`.
937
938         If no `priority` is passed, defaults to max priority on the line.
939         """
940         if not self.delimiters:
941             return 0
942
943         priority = priority or self.max_delimiter_priority()
944         return sum(1 for p in self.delimiters.values() if p == priority)
945
946     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
947         """In a for loop, or comprehension, the variables are often unpacks.
948
949         To avoid splitting on the comma in this situation, increase the depth of
950         tokens between `for` and `in`.
951         """
952         if leaf.type == token.NAME and leaf.value == "for":
953             self.depth += 1
954             self._for_loop_depths.append(self.depth)
955             return True
956
957         return False
958
959     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
960         """See `maybe_increment_for_loop_variable` above for explanation."""
961         if (
962             self._for_loop_depths
963             and self._for_loop_depths[-1] == self.depth
964             and leaf.type == token.NAME
965             and leaf.value == "in"
966         ):
967             self.depth -= 1
968             self._for_loop_depths.pop()
969             return True
970
971         return False
972
973     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
974         """In a lambda expression, there might be more than one argument.
975
976         To avoid splitting on the comma in this situation, increase the depth of
977         tokens between `lambda` and `:`.
978         """
979         if leaf.type == token.NAME and leaf.value == "lambda":
980             self.depth += 1
981             self._lambda_argument_depths.append(self.depth)
982             return True
983
984         return False
985
986     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
987         """See `maybe_increment_lambda_arguments` above for explanation."""
988         if (
989             self._lambda_argument_depths
990             and self._lambda_argument_depths[-1] == self.depth
991             and leaf.type == token.COLON
992         ):
993             self.depth -= 1
994             self._lambda_argument_depths.pop()
995             return True
996
997         return False
998
999     def get_open_lsqb(self) -> Optional[Leaf]:
1000         """Return the most recent opening square bracket (if any)."""
1001         return self.bracket_match.get((self.depth - 1, token.RSQB))
1002
1003
1004 @dataclass
1005 class Line:
1006     """Holds leaves and comments. Can be printed with `str(line)`."""
1007
1008     depth: int = 0
1009     leaves: List[Leaf] = Factory(list)
1010     comments: List[Tuple[Index, Leaf]] = Factory(list)
1011     bracket_tracker: BracketTracker = Factory(BracketTracker)
1012     inside_brackets: bool = False
1013     should_explode: bool = False
1014
1015     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1016         """Add a new `leaf` to the end of the line.
1017
1018         Unless `preformatted` is True, the `leaf` will receive a new consistent
1019         whitespace prefix and metadata applied by :class:`BracketTracker`.
1020         Trailing commas are maybe removed, unpacked for loop variables are
1021         demoted from being delimiters.
1022
1023         Inline comments are put aside.
1024         """
1025         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1026         if not has_value:
1027             return
1028
1029         if token.COLON == leaf.type and self.is_class_paren_empty:
1030             del self.leaves[-2:]
1031         if self.leaves and not preformatted:
1032             # Note: at this point leaf.prefix should be empty except for
1033             # imports, for which we only preserve newlines.
1034             leaf.prefix += whitespace(
1035                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1036             )
1037         if self.inside_brackets or not preformatted:
1038             self.bracket_tracker.mark(leaf)
1039             self.maybe_remove_trailing_comma(leaf)
1040         if not self.append_comment(leaf):
1041             self.leaves.append(leaf)
1042
1043     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1044         """Like :func:`append()` but disallow invalid standalone comment structure.
1045
1046         Raises ValueError when any `leaf` is appended after a standalone comment
1047         or when a standalone comment is not the first leaf on the line.
1048         """
1049         if self.bracket_tracker.depth == 0:
1050             if self.is_comment:
1051                 raise ValueError("cannot append to standalone comments")
1052
1053             if self.leaves and leaf.type == STANDALONE_COMMENT:
1054                 raise ValueError(
1055                     "cannot append standalone comments to a populated line"
1056                 )
1057
1058         self.append(leaf, preformatted=preformatted)
1059
1060     @property
1061     def is_comment(self) -> bool:
1062         """Is this line a standalone comment?"""
1063         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1064
1065     @property
1066     def is_decorator(self) -> bool:
1067         """Is this line a decorator?"""
1068         return bool(self) and self.leaves[0].type == token.AT
1069
1070     @property
1071     def is_import(self) -> bool:
1072         """Is this an import line?"""
1073         return bool(self) and is_import(self.leaves[0])
1074
1075     @property
1076     def is_class(self) -> bool:
1077         """Is this line a class definition?"""
1078         return (
1079             bool(self)
1080             and self.leaves[0].type == token.NAME
1081             and self.leaves[0].value == "class"
1082         )
1083
1084     @property
1085     def is_stub_class(self) -> bool:
1086         """Is this line a class definition with a body consisting only of "..."?"""
1087         return self.is_class and self.leaves[-3:] == [
1088             Leaf(token.DOT, ".") for _ in range(3)
1089         ]
1090
1091     @property
1092     def is_def(self) -> bool:
1093         """Is this a function definition? (Also returns True for async defs.)"""
1094         try:
1095             first_leaf = self.leaves[0]
1096         except IndexError:
1097             return False
1098
1099         try:
1100             second_leaf: Optional[Leaf] = self.leaves[1]
1101         except IndexError:
1102             second_leaf = None
1103         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1104             first_leaf.type == token.ASYNC
1105             and second_leaf is not None
1106             and second_leaf.type == token.NAME
1107             and second_leaf.value == "def"
1108         )
1109
1110     @property
1111     def is_class_paren_empty(self) -> bool:
1112         """Is this a class with no base classes but using parentheses?
1113
1114         Those are unnecessary and should be removed.
1115         """
1116         return (
1117             bool(self)
1118             and len(self.leaves) == 4
1119             and self.is_class
1120             and self.leaves[2].type == token.LPAR
1121             and self.leaves[2].value == "("
1122             and self.leaves[3].type == token.RPAR
1123             and self.leaves[3].value == ")"
1124         )
1125
1126     @property
1127     def is_triple_quoted_string(self) -> bool:
1128         """Is the line a triple quoted string?"""
1129         return (
1130             bool(self)
1131             and self.leaves[0].type == token.STRING
1132             and self.leaves[0].value.startswith(('"""', "'''"))
1133         )
1134
1135     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1136         """If so, needs to be split before emitting."""
1137         for leaf in self.leaves:
1138             if leaf.type == STANDALONE_COMMENT:
1139                 if leaf.bracket_depth <= depth_limit:
1140                     return True
1141
1142         return False
1143
1144     def contains_multiline_strings(self) -> bool:
1145         for leaf in self.leaves:
1146             if is_multiline_string(leaf):
1147                 return True
1148
1149         return False
1150
1151     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1152         """Remove trailing comma if there is one and it's safe."""
1153         if not (
1154             self.leaves
1155             and self.leaves[-1].type == token.COMMA
1156             and closing.type in CLOSING_BRACKETS
1157         ):
1158             return False
1159
1160         if closing.type == token.RBRACE:
1161             self.remove_trailing_comma()
1162             return True
1163
1164         if closing.type == token.RSQB:
1165             comma = self.leaves[-1]
1166             if comma.parent and comma.parent.type == syms.listmaker:
1167                 self.remove_trailing_comma()
1168                 return True
1169
1170         # For parens let's check if it's safe to remove the comma.
1171         # Imports are always safe.
1172         if self.is_import:
1173             self.remove_trailing_comma()
1174             return True
1175
1176         # Otherwise, if the trailing one is the only one, we might mistakenly
1177         # change a tuple into a different type by removing the comma.
1178         depth = closing.bracket_depth + 1
1179         commas = 0
1180         opening = closing.opening_bracket
1181         for _opening_index, leaf in enumerate(self.leaves):
1182             if leaf is opening:
1183                 break
1184
1185         else:
1186             return False
1187
1188         for leaf in self.leaves[_opening_index + 1 :]:
1189             if leaf is closing:
1190                 break
1191
1192             bracket_depth = leaf.bracket_depth
1193             if bracket_depth == depth and leaf.type == token.COMMA:
1194                 commas += 1
1195                 if leaf.parent and leaf.parent.type == syms.arglist:
1196                     commas += 1
1197                     break
1198
1199         if commas > 1:
1200             self.remove_trailing_comma()
1201             return True
1202
1203         return False
1204
1205     def append_comment(self, comment: Leaf) -> bool:
1206         """Add an inline or standalone comment to the line."""
1207         if (
1208             comment.type == STANDALONE_COMMENT
1209             and self.bracket_tracker.any_open_brackets()
1210         ):
1211             comment.prefix = ""
1212             return False
1213
1214         if comment.type != token.COMMENT:
1215             return False
1216
1217         after = len(self.leaves) - 1
1218         if after == -1:
1219             comment.type = STANDALONE_COMMENT
1220             comment.prefix = ""
1221             return False
1222
1223         else:
1224             self.comments.append((after, comment))
1225             return True
1226
1227     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1228         """Generate comments that should appear directly after `leaf`.
1229
1230         Provide a non-negative leaf `_index` to speed up the function.
1231         """
1232         if not self.comments:
1233             return
1234
1235         if _index == -1:
1236             for _index, _leaf in enumerate(self.leaves):
1237                 if leaf is _leaf:
1238                     break
1239
1240             else:
1241                 return
1242
1243         for index, comment_after in self.comments:
1244             if _index == index:
1245                 yield comment_after
1246
1247     def remove_trailing_comma(self) -> None:
1248         """Remove the trailing comma and moves the comments attached to it."""
1249         comma_index = len(self.leaves) - 1
1250         for i in range(len(self.comments)):
1251             comment_index, comment = self.comments[i]
1252             if comment_index == comma_index:
1253                 self.comments[i] = (comma_index - 1, comment)
1254         self.leaves.pop()
1255
1256     def is_complex_subscript(self, leaf: Leaf) -> bool:
1257         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1258         open_lsqb = self.bracket_tracker.get_open_lsqb()
1259         if open_lsqb is None:
1260             return False
1261
1262         subscript_start = open_lsqb.next_sibling
1263
1264         if isinstance(subscript_start, Node):
1265             if subscript_start.type == syms.listmaker:
1266                 return False
1267
1268             if subscript_start.type == syms.subscriptlist:
1269                 subscript_start = child_towards(subscript_start, leaf)
1270         return subscript_start is not None and any(
1271             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1272         )
1273
1274     def __str__(self) -> str:
1275         """Render the line."""
1276         if not self:
1277             return "\n"
1278
1279         indent = "    " * self.depth
1280         leaves = iter(self.leaves)
1281         first = next(leaves)
1282         res = f"{first.prefix}{indent}{first.value}"
1283         for leaf in leaves:
1284             res += str(leaf)
1285         for _, comment in self.comments:
1286             res += str(comment)
1287         return res + "\n"
1288
1289     def __bool__(self) -> bool:
1290         """Return True if the line has leaves or comments."""
1291         return bool(self.leaves or self.comments)
1292
1293
1294 @dataclass
1295 class EmptyLineTracker:
1296     """Provides a stateful method that returns the number of potential extra
1297     empty lines needed before and after the currently processed line.
1298
1299     Note: this tracker works on lines that haven't been split yet.  It assumes
1300     the prefix of the first leaf consists of optional newlines.  Those newlines
1301     are consumed by `maybe_empty_lines()` and included in the computation.
1302     """
1303
1304     is_pyi: bool = False
1305     previous_line: Optional[Line] = None
1306     previous_after: int = 0
1307     previous_defs: List[int] = Factory(list)
1308
1309     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1310         """Return the number of extra empty lines before and after the `current_line`.
1311
1312         This is for separating `def`, `async def` and `class` with extra empty
1313         lines (two on module-level).
1314         """
1315         before, after = self._maybe_empty_lines(current_line)
1316         before -= self.previous_after
1317         self.previous_after = after
1318         self.previous_line = current_line
1319         return before, after
1320
1321     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1322         max_allowed = 1
1323         if current_line.depth == 0:
1324             max_allowed = 1 if self.is_pyi else 2
1325         if current_line.leaves:
1326             # Consume the first leaf's extra newlines.
1327             first_leaf = current_line.leaves[0]
1328             before = first_leaf.prefix.count("\n")
1329             before = min(before, max_allowed)
1330             first_leaf.prefix = ""
1331         else:
1332             before = 0
1333         depth = current_line.depth
1334         while self.previous_defs and self.previous_defs[-1] >= depth:
1335             self.previous_defs.pop()
1336             if self.is_pyi:
1337                 before = 0 if depth else 1
1338             else:
1339                 before = 1 if depth else 2
1340         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1341             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1342
1343         if (
1344             self.previous_line
1345             and self.previous_line.is_import
1346             and not current_line.is_import
1347             and depth == self.previous_line.depth
1348         ):
1349             return (before or 1), 0
1350
1351         if (
1352             self.previous_line
1353             and self.previous_line.is_class
1354             and current_line.is_triple_quoted_string
1355         ):
1356             return before, 1
1357
1358         return before, 0
1359
1360     def _maybe_empty_lines_for_class_or_def(
1361         self, current_line: Line, before: int
1362     ) -> Tuple[int, int]:
1363         if not current_line.is_decorator:
1364             self.previous_defs.append(current_line.depth)
1365         if self.previous_line is None:
1366             # Don't insert empty lines before the first line in the file.
1367             return 0, 0
1368
1369         if self.previous_line.is_decorator:
1370             return 0, 0
1371
1372         if self.previous_line.depth < current_line.depth and (
1373             self.previous_line.is_class or self.previous_line.is_def
1374         ):
1375             return 0, 0
1376
1377         if (
1378             self.previous_line.is_comment
1379             and self.previous_line.depth == current_line.depth
1380             and before == 0
1381         ):
1382             return 0, 0
1383
1384         if self.is_pyi:
1385             if self.previous_line.depth > current_line.depth:
1386                 newlines = 1
1387             elif current_line.is_class or self.previous_line.is_class:
1388                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1389                     # No blank line between classes with an empty body
1390                     newlines = 0
1391                 else:
1392                     newlines = 1
1393             elif current_line.is_def and not self.previous_line.is_def:
1394                 # Blank line between a block of functions and a block of non-functions
1395                 newlines = 1
1396             else:
1397                 newlines = 0
1398         else:
1399             newlines = 2
1400         if current_line.depth and newlines:
1401             newlines -= 1
1402         return newlines, 0
1403
1404
1405 @dataclass
1406 class LineGenerator(Visitor[Line]):
1407     """Generates reformatted Line objects.  Empty lines are not emitted.
1408
1409     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1410     in ways that will no longer stringify to valid Python code on the tree.
1411     """
1412
1413     is_pyi: bool = False
1414     normalize_strings: bool = True
1415     current_line: Line = Factory(Line)
1416     remove_u_prefix: bool = False
1417     allow_underscores: bool = False
1418
1419     def line(self, indent: int = 0) -> Iterator[Line]:
1420         """Generate a line.
1421
1422         If the line is empty, only emit if it makes sense.
1423         If the line is too long, split it first and then generate.
1424
1425         If any lines were generated, set up a new current_line.
1426         """
1427         if not self.current_line:
1428             self.current_line.depth += indent
1429             return  # Line is empty, don't emit. Creating a new one unnecessary.
1430
1431         complete_line = self.current_line
1432         self.current_line = Line(depth=complete_line.depth + indent)
1433         yield complete_line
1434
1435     def visit_default(self, node: LN) -> Iterator[Line]:
1436         """Default `visit_*()` implementation. Recurses to children of `node`."""
1437         if isinstance(node, Leaf):
1438             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1439             for comment in generate_comments(node):
1440                 if any_open_brackets:
1441                     # any comment within brackets is subject to splitting
1442                     self.current_line.append(comment)
1443                 elif comment.type == token.COMMENT:
1444                     # regular trailing comment
1445                     self.current_line.append(comment)
1446                     yield from self.line()
1447
1448                 else:
1449                     # regular standalone comment
1450                     yield from self.line()
1451
1452                     self.current_line.append(comment)
1453                     yield from self.line()
1454
1455             normalize_prefix(node, inside_brackets=any_open_brackets)
1456             if self.normalize_strings and node.type == token.STRING:
1457                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1458                 normalize_string_quotes(node)
1459             if node.type == token.NUMBER:
1460                 normalize_numeric_literal(node, self.allow_underscores)
1461             if node.type not in WHITESPACE:
1462                 self.current_line.append(node)
1463         yield from super().visit_default(node)
1464
1465     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1466         """Increase indentation level, maybe yield a line."""
1467         # In blib2to3 INDENT never holds comments.
1468         yield from self.line(+1)
1469         yield from self.visit_default(node)
1470
1471     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1472         """Decrease indentation level, maybe yield a line."""
1473         # The current line might still wait for trailing comments.  At DEDENT time
1474         # there won't be any (they would be prefixes on the preceding NEWLINE).
1475         # Emit the line then.
1476         yield from self.line()
1477
1478         # While DEDENT has no value, its prefix may contain standalone comments
1479         # that belong to the current indentation level.  Get 'em.
1480         yield from self.visit_default(node)
1481
1482         # Finally, emit the dedent.
1483         yield from self.line(-1)
1484
1485     def visit_stmt(
1486         self, node: Node, keywords: Set[str], parens: Set[str]
1487     ) -> Iterator[Line]:
1488         """Visit a statement.
1489
1490         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1491         `def`, `with`, `class`, `assert` and assignments.
1492
1493         The relevant Python language `keywords` for a given statement will be
1494         NAME leaves within it. This methods puts those on a separate line.
1495
1496         `parens` holds a set of string leaf values immediately after which
1497         invisible parens should be put.
1498         """
1499         normalize_invisible_parens(node, parens_after=parens)
1500         for child in node.children:
1501             if child.type == token.NAME and child.value in keywords:  # type: ignore
1502                 yield from self.line()
1503
1504             yield from self.visit(child)
1505
1506     def visit_suite(self, node: Node) -> Iterator[Line]:
1507         """Visit a suite."""
1508         if self.is_pyi and is_stub_suite(node):
1509             yield from self.visit(node.children[2])
1510         else:
1511             yield from self.visit_default(node)
1512
1513     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1514         """Visit a statement without nested statements."""
1515         is_suite_like = node.parent and node.parent.type in STATEMENT
1516         if is_suite_like:
1517             if self.is_pyi and is_stub_body(node):
1518                 yield from self.visit_default(node)
1519             else:
1520                 yield from self.line(+1)
1521                 yield from self.visit_default(node)
1522                 yield from self.line(-1)
1523
1524         else:
1525             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1526                 yield from self.line()
1527             yield from self.visit_default(node)
1528
1529     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1530         """Visit `async def`, `async for`, `async with`."""
1531         yield from self.line()
1532
1533         children = iter(node.children)
1534         for child in children:
1535             yield from self.visit(child)
1536
1537             if child.type == token.ASYNC:
1538                 break
1539
1540         internal_stmt = next(children)
1541         for child in internal_stmt.children:
1542             yield from self.visit(child)
1543
1544     def visit_decorators(self, node: Node) -> Iterator[Line]:
1545         """Visit decorators."""
1546         for child in node.children:
1547             yield from self.line()
1548             yield from self.visit(child)
1549
1550     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1551         """Remove a semicolon and put the other statement on a separate line."""
1552         yield from self.line()
1553
1554     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1555         """End of file. Process outstanding comments and end with a newline."""
1556         yield from self.visit_default(leaf)
1557         yield from self.line()
1558
1559     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1560         if not self.current_line.bracket_tracker.any_open_brackets():
1561             yield from self.line()
1562         yield from self.visit_default(leaf)
1563
1564     def __attrs_post_init__(self) -> None:
1565         """You are in a twisty little maze of passages."""
1566         v = self.visit_stmt
1567         Ø: Set[str] = set()
1568         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1569         self.visit_if_stmt = partial(
1570             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1571         )
1572         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1573         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1574         self.visit_try_stmt = partial(
1575             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1576         )
1577         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1578         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1579         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1580         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1581         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1582         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1583         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1584         self.visit_async_funcdef = self.visit_async_stmt
1585         self.visit_decorated = self.visit_decorators
1586
1587
1588 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1589 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1590 OPENING_BRACKETS = set(BRACKET.keys())
1591 CLOSING_BRACKETS = set(BRACKET.values())
1592 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1593 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1594
1595
1596 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1597     """Return whitespace prefix if needed for the given `leaf`.
1598
1599     `complex_subscript` signals whether the given leaf is part of a subscription
1600     which has non-trivial arguments, like arithmetic expressions or function calls.
1601     """
1602     NO = ""
1603     SPACE = " "
1604     DOUBLESPACE = "  "
1605     t = leaf.type
1606     p = leaf.parent
1607     v = leaf.value
1608     if t in ALWAYS_NO_SPACE:
1609         return NO
1610
1611     if t == token.COMMENT:
1612         return DOUBLESPACE
1613
1614     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1615     if t == token.COLON and p.type not in {
1616         syms.subscript,
1617         syms.subscriptlist,
1618         syms.sliceop,
1619     }:
1620         return NO
1621
1622     prev = leaf.prev_sibling
1623     if not prev:
1624         prevp = preceding_leaf(p)
1625         if not prevp or prevp.type in OPENING_BRACKETS:
1626             return NO
1627
1628         if t == token.COLON:
1629             if prevp.type == token.COLON:
1630                 return NO
1631
1632             elif prevp.type != token.COMMA and not complex_subscript:
1633                 return NO
1634
1635             return SPACE
1636
1637         if prevp.type == token.EQUAL:
1638             if prevp.parent:
1639                 if prevp.parent.type in {
1640                     syms.arglist,
1641                     syms.argument,
1642                     syms.parameters,
1643                     syms.varargslist,
1644                 }:
1645                     return NO
1646
1647                 elif prevp.parent.type == syms.typedargslist:
1648                     # A bit hacky: if the equal sign has whitespace, it means we
1649                     # previously found it's a typed argument.  So, we're using
1650                     # that, too.
1651                     return prevp.prefix
1652
1653         elif prevp.type in STARS:
1654             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1655                 return NO
1656
1657         elif prevp.type == token.COLON:
1658             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1659                 return SPACE if complex_subscript else NO
1660
1661         elif (
1662             prevp.parent
1663             and prevp.parent.type == syms.factor
1664             and prevp.type in MATH_OPERATORS
1665         ):
1666             return NO
1667
1668         elif (
1669             prevp.type == token.RIGHTSHIFT
1670             and prevp.parent
1671             and prevp.parent.type == syms.shift_expr
1672             and prevp.prev_sibling
1673             and prevp.prev_sibling.type == token.NAME
1674             and prevp.prev_sibling.value == "print"  # type: ignore
1675         ):
1676             # Python 2 print chevron
1677             return NO
1678
1679     elif prev.type in OPENING_BRACKETS:
1680         return NO
1681
1682     if p.type in {syms.parameters, syms.arglist}:
1683         # untyped function signatures or calls
1684         if not prev or prev.type != token.COMMA:
1685             return NO
1686
1687     elif p.type == syms.varargslist:
1688         # lambdas
1689         if prev and prev.type != token.COMMA:
1690             return NO
1691
1692     elif p.type == syms.typedargslist:
1693         # typed function signatures
1694         if not prev:
1695             return NO
1696
1697         if t == token.EQUAL:
1698             if prev.type != syms.tname:
1699                 return NO
1700
1701         elif prev.type == token.EQUAL:
1702             # A bit hacky: if the equal sign has whitespace, it means we
1703             # previously found it's a typed argument.  So, we're using that, too.
1704             return prev.prefix
1705
1706         elif prev.type != token.COMMA:
1707             return NO
1708
1709     elif p.type == syms.tname:
1710         # type names
1711         if not prev:
1712             prevp = preceding_leaf(p)
1713             if not prevp or prevp.type != token.COMMA:
1714                 return NO
1715
1716     elif p.type == syms.trailer:
1717         # attributes and calls
1718         if t == token.LPAR or t == token.RPAR:
1719             return NO
1720
1721         if not prev:
1722             if t == token.DOT:
1723                 prevp = preceding_leaf(p)
1724                 if not prevp or prevp.type != token.NUMBER:
1725                     return NO
1726
1727             elif t == token.LSQB:
1728                 return NO
1729
1730         elif prev.type != token.COMMA:
1731             return NO
1732
1733     elif p.type == syms.argument:
1734         # single argument
1735         if t == token.EQUAL:
1736             return NO
1737
1738         if not prev:
1739             prevp = preceding_leaf(p)
1740             if not prevp or prevp.type == token.LPAR:
1741                 return NO
1742
1743         elif prev.type in {token.EQUAL} | STARS:
1744             return NO
1745
1746     elif p.type == syms.decorator:
1747         # decorators
1748         return NO
1749
1750     elif p.type == syms.dotted_name:
1751         if prev:
1752             return NO
1753
1754         prevp = preceding_leaf(p)
1755         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1756             return NO
1757
1758     elif p.type == syms.classdef:
1759         if t == token.LPAR:
1760             return NO
1761
1762         if prev and prev.type == token.LPAR:
1763             return NO
1764
1765     elif p.type in {syms.subscript, syms.sliceop}:
1766         # indexing
1767         if not prev:
1768             assert p.parent is not None, "subscripts are always parented"
1769             if p.parent.type == syms.subscriptlist:
1770                 return SPACE
1771
1772             return NO
1773
1774         elif not complex_subscript:
1775             return NO
1776
1777     elif p.type == syms.atom:
1778         if prev and t == token.DOT:
1779             # dots, but not the first one.
1780             return NO
1781
1782     elif p.type == syms.dictsetmaker:
1783         # dict unpacking
1784         if prev and prev.type == token.DOUBLESTAR:
1785             return NO
1786
1787     elif p.type in {syms.factor, syms.star_expr}:
1788         # unary ops
1789         if not prev:
1790             prevp = preceding_leaf(p)
1791             if not prevp or prevp.type in OPENING_BRACKETS:
1792                 return NO
1793
1794             prevp_parent = prevp.parent
1795             assert prevp_parent is not None
1796             if prevp.type == token.COLON and prevp_parent.type in {
1797                 syms.subscript,
1798                 syms.sliceop,
1799             }:
1800                 return NO
1801
1802             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1803                 return NO
1804
1805         elif t in {token.NAME, token.NUMBER, token.STRING}:
1806             return NO
1807
1808     elif p.type == syms.import_from:
1809         if t == token.DOT:
1810             if prev and prev.type == token.DOT:
1811                 return NO
1812
1813         elif t == token.NAME:
1814             if v == "import":
1815                 return SPACE
1816
1817             if prev and prev.type == token.DOT:
1818                 return NO
1819
1820     elif p.type == syms.sliceop:
1821         return NO
1822
1823     return SPACE
1824
1825
1826 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1827     """Return the first leaf that precedes `node`, if any."""
1828     while node:
1829         res = node.prev_sibling
1830         if res:
1831             if isinstance(res, Leaf):
1832                 return res
1833
1834             try:
1835                 return list(res.leaves())[-1]
1836
1837             except IndexError:
1838                 return None
1839
1840         node = node.parent
1841     return None
1842
1843
1844 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1845     """Return the child of `ancestor` that contains `descendant`."""
1846     node: Optional[LN] = descendant
1847     while node and node.parent != ancestor:
1848         node = node.parent
1849     return node
1850
1851
1852 def container_of(leaf: Leaf) -> LN:
1853     """Return `leaf` or one of its ancestors that is the topmost container of it.
1854
1855     By "container" we mean a node where `leaf` is the very first child.
1856     """
1857     same_prefix = leaf.prefix
1858     container: LN = leaf
1859     while container:
1860         parent = container.parent
1861         if parent is None:
1862             break
1863
1864         if parent.children[0].prefix != same_prefix:
1865             break
1866
1867         if parent.type == syms.file_input:
1868             break
1869
1870         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1871             break
1872
1873         container = parent
1874     return container
1875
1876
1877 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1878     """Return the priority of the `leaf` delimiter, given a line break after it.
1879
1880     The delimiter priorities returned here are from those delimiters that would
1881     cause a line break after themselves.
1882
1883     Higher numbers are higher priority.
1884     """
1885     if leaf.type == token.COMMA:
1886         return COMMA_PRIORITY
1887
1888     return 0
1889
1890
1891 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1892     """Return the priority of the `leaf` delimiter, given a line break before it.
1893
1894     The delimiter priorities returned here are from those delimiters that would
1895     cause a line break before themselves.
1896
1897     Higher numbers are higher priority.
1898     """
1899     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1900         # * and ** might also be MATH_OPERATORS but in this case they are not.
1901         # Don't treat them as a delimiter.
1902         return 0
1903
1904     if (
1905         leaf.type == token.DOT
1906         and leaf.parent
1907         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1908         and (previous is None or previous.type in CLOSING_BRACKETS)
1909     ):
1910         return DOT_PRIORITY
1911
1912     if (
1913         leaf.type in MATH_OPERATORS
1914         and leaf.parent
1915         and leaf.parent.type not in {syms.factor, syms.star_expr}
1916     ):
1917         return MATH_PRIORITIES[leaf.type]
1918
1919     if leaf.type in COMPARATORS:
1920         return COMPARATOR_PRIORITY
1921
1922     if (
1923         leaf.type == token.STRING
1924         and previous is not None
1925         and previous.type == token.STRING
1926     ):
1927         return STRING_PRIORITY
1928
1929     if leaf.type not in {token.NAME, token.ASYNC}:
1930         return 0
1931
1932     if (
1933         leaf.value == "for"
1934         and leaf.parent
1935         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1936         or leaf.type == token.ASYNC
1937     ):
1938         if (
1939             not isinstance(leaf.prev_sibling, Leaf)
1940             or leaf.prev_sibling.value != "async"
1941         ):
1942             return COMPREHENSION_PRIORITY
1943
1944     if (
1945         leaf.value == "if"
1946         and leaf.parent
1947         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1948     ):
1949         return COMPREHENSION_PRIORITY
1950
1951     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1952         return TERNARY_PRIORITY
1953
1954     if leaf.value == "is":
1955         return COMPARATOR_PRIORITY
1956
1957     if (
1958         leaf.value == "in"
1959         and leaf.parent
1960         and leaf.parent.type in {syms.comp_op, syms.comparison}
1961         and not (
1962             previous is not None
1963             and previous.type == token.NAME
1964             and previous.value == "not"
1965         )
1966     ):
1967         return COMPARATOR_PRIORITY
1968
1969     if (
1970         leaf.value == "not"
1971         and leaf.parent
1972         and leaf.parent.type == syms.comp_op
1973         and not (
1974             previous is not None
1975             and previous.type == token.NAME
1976             and previous.value == "is"
1977         )
1978     ):
1979         return COMPARATOR_PRIORITY
1980
1981     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1982         return LOGIC_PRIORITY
1983
1984     return 0
1985
1986
1987 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
1988 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
1989
1990
1991 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1992     """Clean the prefix of the `leaf` and generate comments from it, if any.
1993
1994     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1995     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1996     move because it does away with modifying the grammar to include all the
1997     possible places in which comments can be placed.
1998
1999     The sad consequence for us though is that comments don't "belong" anywhere.
2000     This is why this function generates simple parentless Leaf objects for
2001     comments.  We simply don't know what the correct parent should be.
2002
2003     No matter though, we can live without this.  We really only need to
2004     differentiate between inline and standalone comments.  The latter don't
2005     share the line with any code.
2006
2007     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2008     are emitted with a fake STANDALONE_COMMENT token identifier.
2009     """
2010     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2011         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2012
2013
2014 @dataclass
2015 class ProtoComment:
2016     type: int  # token.COMMENT or STANDALONE_COMMENT
2017     value: str  # content of the comment
2018     newlines: int  # how many newlines before the comment
2019     consumed: int  # how many characters of the original leaf's prefix did we consume
2020
2021
2022 @lru_cache(maxsize=4096)
2023 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2024     result: List[ProtoComment] = []
2025     if not prefix or "#" not in prefix:
2026         return result
2027
2028     consumed = 0
2029     nlines = 0
2030     for index, line in enumerate(prefix.split("\n")):
2031         consumed += len(line) + 1  # adding the length of the split '\n'
2032         line = line.lstrip()
2033         if not line:
2034             nlines += 1
2035         if not line.startswith("#"):
2036             continue
2037
2038         if index == 0 and not is_endmarker:
2039             comment_type = token.COMMENT  # simple trailing comment
2040         else:
2041             comment_type = STANDALONE_COMMENT
2042         comment = make_comment(line)
2043         result.append(
2044             ProtoComment(
2045                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2046             )
2047         )
2048         nlines = 0
2049     return result
2050
2051
2052 def make_comment(content: str) -> str:
2053     """Return a consistently formatted comment from the given `content` string.
2054
2055     All comments (except for "##", "#!", "#:") should have a single space between
2056     the hash sign and the content.
2057
2058     If `content` didn't start with a hash sign, one is provided.
2059     """
2060     content = content.rstrip()
2061     if not content:
2062         return "#"
2063
2064     if content[0] == "#":
2065         content = content[1:]
2066     if content and content[0] not in " !:#":
2067         content = " " + content
2068     return "#" + content
2069
2070
2071 def split_line(
2072     line: Line, line_length: int, inner: bool = False, py36: bool = False
2073 ) -> Iterator[Line]:
2074     """Split a `line` into potentially many lines.
2075
2076     They should fit in the allotted `line_length` but might not be able to.
2077     `inner` signifies that there were a pair of brackets somewhere around the
2078     current `line`, possibly transitively. This means we can fallback to splitting
2079     by delimiters if the LHS/RHS don't yield any results.
2080
2081     If `py36` is True, splitting may generate syntax that is only compatible
2082     with Python 3.6 and later.
2083     """
2084     if line.is_comment:
2085         yield line
2086         return
2087
2088     line_str = str(line).strip("\n")
2089     if not line.should_explode and is_line_short_enough(
2090         line, line_length=line_length, line_str=line_str
2091     ):
2092         yield line
2093         return
2094
2095     split_funcs: List[SplitFunc]
2096     if line.is_def:
2097         split_funcs = [left_hand_split]
2098     else:
2099
2100         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2101             for omit in generate_trailers_to_omit(line, line_length):
2102                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2103                 if is_line_short_enough(lines[0], line_length=line_length):
2104                     yield from lines
2105                     return
2106
2107             # All splits failed, best effort split with no omits.
2108             # This mostly happens to multiline strings that are by definition
2109             # reported as not fitting a single line.
2110             yield from right_hand_split(line, py36)
2111
2112         if line.inside_brackets:
2113             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2114         else:
2115             split_funcs = [rhs]
2116     for split_func in split_funcs:
2117         # We are accumulating lines in `result` because we might want to abort
2118         # mission and return the original line in the end, or attempt a different
2119         # split altogether.
2120         result: List[Line] = []
2121         try:
2122             for l in split_func(line, py36):
2123                 if str(l).strip("\n") == line_str:
2124                     raise CannotSplit("Split function returned an unchanged result")
2125
2126                 result.extend(
2127                     split_line(l, line_length=line_length, inner=True, py36=py36)
2128                 )
2129         except CannotSplit as cs:
2130             continue
2131
2132         else:
2133             yield from result
2134             break
2135
2136     else:
2137         yield line
2138
2139
2140 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2141     """Split line into many lines, starting with the first matching bracket pair.
2142
2143     Note: this usually looks weird, only use this for function definitions.
2144     Prefer RHS otherwise.  This is why this function is not symmetrical with
2145     :func:`right_hand_split` which also handles optional parentheses.
2146     """
2147     head = Line(depth=line.depth)
2148     body = Line(depth=line.depth + 1, inside_brackets=True)
2149     tail = Line(depth=line.depth)
2150     tail_leaves: List[Leaf] = []
2151     body_leaves: List[Leaf] = []
2152     head_leaves: List[Leaf] = []
2153     current_leaves = head_leaves
2154     matching_bracket = None
2155     for leaf in line.leaves:
2156         if (
2157             current_leaves is body_leaves
2158             and leaf.type in CLOSING_BRACKETS
2159             and leaf.opening_bracket is matching_bracket
2160         ):
2161             current_leaves = tail_leaves if body_leaves else head_leaves
2162         current_leaves.append(leaf)
2163         if current_leaves is head_leaves:
2164             if leaf.type in OPENING_BRACKETS:
2165                 matching_bracket = leaf
2166                 current_leaves = body_leaves
2167     # Since body is a new indent level, remove spurious leading whitespace.
2168     if body_leaves:
2169         normalize_prefix(body_leaves[0], inside_brackets=True)
2170     # Build the new lines.
2171     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2172         for leaf in leaves:
2173             result.append(leaf, preformatted=True)
2174             for comment_after in line.comments_after(leaf):
2175                 result.append(comment_after, preformatted=True)
2176     bracket_split_succeeded_or_raise(head, body, tail)
2177     for result in (head, body, tail):
2178         if result:
2179             yield result
2180
2181
2182 def right_hand_split(
2183     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2184 ) -> Iterator[Line]:
2185     """Split line into many lines, starting with the last matching bracket pair.
2186
2187     If the split was by optional parentheses, attempt splitting without them, too.
2188     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2189     this split.
2190
2191     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2192     """
2193     head = Line(depth=line.depth)
2194     body = Line(depth=line.depth + 1, inside_brackets=True)
2195     tail = Line(depth=line.depth)
2196     tail_leaves: List[Leaf] = []
2197     body_leaves: List[Leaf] = []
2198     head_leaves: List[Leaf] = []
2199     current_leaves = tail_leaves
2200     opening_bracket = None
2201     closing_bracket = None
2202     for leaf in reversed(line.leaves):
2203         if current_leaves is body_leaves:
2204             if leaf is opening_bracket:
2205                 current_leaves = head_leaves if body_leaves else tail_leaves
2206         current_leaves.append(leaf)
2207         if current_leaves is tail_leaves:
2208             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2209                 opening_bracket = leaf.opening_bracket
2210                 closing_bracket = leaf
2211                 current_leaves = body_leaves
2212     tail_leaves.reverse()
2213     body_leaves.reverse()
2214     head_leaves.reverse()
2215     # Since body is a new indent level, remove spurious leading whitespace.
2216     if body_leaves:
2217         normalize_prefix(body_leaves[0], inside_brackets=True)
2218     if not head_leaves:
2219         # No `head` means the split failed. Either `tail` has all content or
2220         # the matching `opening_bracket` wasn't available on `line` anymore.
2221         raise CannotSplit("No brackets found")
2222
2223     # Build the new lines.
2224     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2225         for leaf in leaves:
2226             result.append(leaf, preformatted=True)
2227             for comment_after in line.comments_after(leaf):
2228                 result.append(comment_after, preformatted=True)
2229     assert opening_bracket and closing_bracket
2230     body.should_explode = should_explode(body, opening_bracket)
2231     bracket_split_succeeded_or_raise(head, body, tail)
2232     if (
2233         # the body shouldn't be exploded
2234         not body.should_explode
2235         # the opening bracket is an optional paren
2236         and opening_bracket.type == token.LPAR
2237         and not opening_bracket.value
2238         # the closing bracket is an optional paren
2239         and closing_bracket.type == token.RPAR
2240         and not closing_bracket.value
2241         # it's not an import (optional parens are the only thing we can split on
2242         # in this case; attempting a split without them is a waste of time)
2243         and not line.is_import
2244         # there are no standalone comments in the body
2245         and not body.contains_standalone_comments(0)
2246         # and we can actually remove the parens
2247         and can_omit_invisible_parens(body, line_length)
2248     ):
2249         omit = {id(closing_bracket), *omit}
2250         try:
2251             yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2252             return
2253
2254         except CannotSplit:
2255             if not (
2256                 can_be_split(body)
2257                 or is_line_short_enough(body, line_length=line_length)
2258             ):
2259                 raise CannotSplit(
2260                     "Splitting failed, body is still too long and can't be split."
2261                 )
2262
2263             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2264                 raise CannotSplit(
2265                     "The current optional pair of parentheses is bound to fail to "
2266                     "satisfy the splitting algorithm because the head or the tail "
2267                     "contains multiline strings which by definition never fit one "
2268                     "line."
2269                 )
2270
2271     ensure_visible(opening_bracket)
2272     ensure_visible(closing_bracket)
2273     for result in (head, body, tail):
2274         if result:
2275             yield result
2276
2277
2278 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2279     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2280
2281     Do nothing otherwise.
2282
2283     A left- or right-hand split is based on a pair of brackets. Content before
2284     (and including) the opening bracket is left on one line, content inside the
2285     brackets is put on a separate line, and finally content starting with and
2286     following the closing bracket is put on a separate line.
2287
2288     Those are called `head`, `body`, and `tail`, respectively. If the split
2289     produced the same line (all content in `head`) or ended up with an empty `body`
2290     and the `tail` is just the closing bracket, then it's considered failed.
2291     """
2292     tail_len = len(str(tail).strip())
2293     if not body:
2294         if tail_len == 0:
2295             raise CannotSplit("Splitting brackets produced the same line")
2296
2297         elif tail_len < 3:
2298             raise CannotSplit(
2299                 f"Splitting brackets on an empty body to save "
2300                 f"{tail_len} characters is not worth it"
2301             )
2302
2303
2304 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2305     """Normalize prefix of the first leaf in every line returned by `split_func`.
2306
2307     This is a decorator over relevant split functions.
2308     """
2309
2310     @wraps(split_func)
2311     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2312         for l in split_func(line, py36):
2313             normalize_prefix(l.leaves[0], inside_brackets=True)
2314             yield l
2315
2316     return split_wrapper
2317
2318
2319 @dont_increase_indentation
2320 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2321     """Split according to delimiters of the highest priority.
2322
2323     If `py36` is True, the split will add trailing commas also in function
2324     signatures that contain `*` and `**`.
2325     """
2326     try:
2327         last_leaf = line.leaves[-1]
2328     except IndexError:
2329         raise CannotSplit("Line empty")
2330
2331     bt = line.bracket_tracker
2332     try:
2333         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2334     except ValueError:
2335         raise CannotSplit("No delimiters found")
2336
2337     if delimiter_priority == DOT_PRIORITY:
2338         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2339             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2340
2341     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2342     lowest_depth = sys.maxsize
2343     trailing_comma_safe = True
2344
2345     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2346         """Append `leaf` to current line or to new line if appending impossible."""
2347         nonlocal current_line
2348         try:
2349             current_line.append_safe(leaf, preformatted=True)
2350         except ValueError as ve:
2351             yield current_line
2352
2353             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2354             current_line.append(leaf)
2355
2356     for index, leaf in enumerate(line.leaves):
2357         yield from append_to_line(leaf)
2358
2359         for comment_after in line.comments_after(leaf, index):
2360             yield from append_to_line(comment_after)
2361
2362         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2363         if leaf.bracket_depth == lowest_depth and is_vararg(
2364             leaf, within=VARARGS_PARENTS
2365         ):
2366             trailing_comma_safe = trailing_comma_safe and py36
2367         leaf_priority = bt.delimiters.get(id(leaf))
2368         if leaf_priority == delimiter_priority:
2369             yield current_line
2370
2371             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2372     if current_line:
2373         if (
2374             trailing_comma_safe
2375             and delimiter_priority == COMMA_PRIORITY
2376             and current_line.leaves[-1].type != token.COMMA
2377             and current_line.leaves[-1].type != STANDALONE_COMMENT
2378         ):
2379             current_line.append(Leaf(token.COMMA, ","))
2380         yield current_line
2381
2382
2383 @dont_increase_indentation
2384 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2385     """Split standalone comments from the rest of the line."""
2386     if not line.contains_standalone_comments(0):
2387         raise CannotSplit("Line does not have any standalone comments")
2388
2389     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2390
2391     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2392         """Append `leaf` to current line or to new line if appending impossible."""
2393         nonlocal current_line
2394         try:
2395             current_line.append_safe(leaf, preformatted=True)
2396         except ValueError as ve:
2397             yield current_line
2398
2399             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2400             current_line.append(leaf)
2401
2402     for index, leaf in enumerate(line.leaves):
2403         yield from append_to_line(leaf)
2404
2405         for comment_after in line.comments_after(leaf, index):
2406             yield from append_to_line(comment_after)
2407
2408     if current_line:
2409         yield current_line
2410
2411
2412 def is_import(leaf: Leaf) -> bool:
2413     """Return True if the given leaf starts an import statement."""
2414     p = leaf.parent
2415     t = leaf.type
2416     v = leaf.value
2417     return bool(
2418         t == token.NAME
2419         and (
2420             (v == "import" and p and p.type == syms.import_name)
2421             or (v == "from" and p and p.type == syms.import_from)
2422         )
2423     )
2424
2425
2426 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2427     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2428     else.
2429
2430     Note: don't use backslashes for formatting or you'll lose your voting rights.
2431     """
2432     if not inside_brackets:
2433         spl = leaf.prefix.split("#")
2434         if "\\" not in spl[0]:
2435             nl_count = spl[-1].count("\n")
2436             if len(spl) > 1:
2437                 nl_count -= 1
2438             leaf.prefix = "\n" * nl_count
2439             return
2440
2441     leaf.prefix = ""
2442
2443
2444 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2445     """Make all string prefixes lowercase.
2446
2447     If remove_u_prefix is given, also removes any u prefix from the string.
2448
2449     Note: Mutates its argument.
2450     """
2451     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2452     assert match is not None, f"failed to match string {leaf.value!r}"
2453     orig_prefix = match.group(1)
2454     new_prefix = orig_prefix.lower()
2455     if remove_u_prefix:
2456         new_prefix = new_prefix.replace("u", "")
2457     leaf.value = f"{new_prefix}{match.group(2)}"
2458
2459
2460 def normalize_string_quotes(leaf: Leaf) -> None:
2461     """Prefer double quotes but only if it doesn't cause more escaping.
2462
2463     Adds or removes backslashes as appropriate. Doesn't parse and fix
2464     strings nested in f-strings (yet).
2465
2466     Note: Mutates its argument.
2467     """
2468     value = leaf.value.lstrip("furbFURB")
2469     if value[:3] == '"""':
2470         return
2471
2472     elif value[:3] == "'''":
2473         orig_quote = "'''"
2474         new_quote = '"""'
2475     elif value[0] == '"':
2476         orig_quote = '"'
2477         new_quote = "'"
2478     else:
2479         orig_quote = "'"
2480         new_quote = '"'
2481     first_quote_pos = leaf.value.find(orig_quote)
2482     if first_quote_pos == -1:
2483         return  # There's an internal error
2484
2485     prefix = leaf.value[:first_quote_pos]
2486     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2487     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2488     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2489     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2490     if "r" in prefix.casefold():
2491         if unescaped_new_quote.search(body):
2492             # There's at least one unescaped new_quote in this raw string
2493             # so converting is impossible
2494             return
2495
2496         # Do not introduce or remove backslashes in raw strings
2497         new_body = body
2498     else:
2499         # remove unnecessary escapes
2500         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2501         if body != new_body:
2502             # Consider the string without unnecessary escapes as the original
2503             body = new_body
2504             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2505         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2506         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2507     if "f" in prefix.casefold():
2508         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2509         for m in matches:
2510             if "\\" in str(m):
2511                 # Do not introduce backslashes in interpolated expressions
2512                 return
2513     if new_quote == '"""' and new_body[-1:] == '"':
2514         # edge case:
2515         new_body = new_body[:-1] + '\\"'
2516     orig_escape_count = body.count("\\")
2517     new_escape_count = new_body.count("\\")
2518     if new_escape_count > orig_escape_count:
2519         return  # Do not introduce more escaping
2520
2521     if new_escape_count == orig_escape_count and orig_quote == '"':
2522         return  # Prefer double quotes
2523
2524     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2525
2526
2527 def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
2528     """Normalizes numeric (float, int, and complex) literals.
2529
2530     All letters used in the representation are normalized to lowercase (except
2531     in Python 2 long literals), and long number literals are split using underscores.
2532     """
2533     text = leaf.value.lower()
2534     if text.startswith(("0o", "0x", "0b")):
2535         # Leave octal, hex, and binary literals alone.
2536         pass
2537     elif "e" in text:
2538         before, after = text.split("e")
2539         sign = ""
2540         if after.startswith("-"):
2541             after = after[1:]
2542             sign = "-"
2543         elif after.startswith("+"):
2544             after = after[1:]
2545         before = format_float_or_int_string(before, allow_underscores)
2546         after = format_int_string(after, allow_underscores)
2547         text = f"{before}e{sign}{after}"
2548     elif text.endswith(("j", "l")):
2549         number = text[:-1]
2550         suffix = text[-1]
2551         # Capitalize in "2L" because "l" looks too similar to "1".
2552         if suffix == "l":
2553             suffix = "L"
2554         text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
2555     else:
2556         text = format_float_or_int_string(text, allow_underscores)
2557     leaf.value = text
2558
2559
2560 def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
2561     """Formats a float string like "1.0"."""
2562     if "." not in text:
2563         return format_int_string(text, allow_underscores)
2564
2565     before, after = text.split(".")
2566     before = format_int_string(before, allow_underscores) if before else "0"
2567     if after:
2568         after = format_int_string(after, allow_underscores, count_from_end=False)
2569     else:
2570         after = "0"
2571     return f"{before}.{after}"
2572
2573
2574 def format_int_string(
2575     text: str, allow_underscores: bool, count_from_end: bool = True
2576 ) -> str:
2577     """Normalizes underscores in a string to e.g. 1_000_000.
2578
2579     Input must be a string of digits and optional underscores.
2580     If count_from_end is False, we add underscores after groups of three digits
2581     counting from the beginning instead of the end of the strings. This is used
2582     for the fractional part of float literals.
2583     """
2584     if not allow_underscores:
2585         return text
2586
2587     text = text.replace("_", "")
2588     if len(text) <= 6:
2589         # No underscores for numbers <= 6 digits long.
2590         return text
2591
2592     if count_from_end:
2593         # Avoid removing leading zeros, which are important if we're formatting
2594         # part of a number like "0.001".
2595         return format(int("1" + text), "3_")[1:].lstrip("_")
2596     else:
2597         return "_".join(text[i : i + 3] for i in range(0, len(text), 3))
2598
2599
2600 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2601     """Make existing optional parentheses invisible or create new ones.
2602
2603     `parens_after` is a set of string leaf values immeditely after which parens
2604     should be put.
2605
2606     Standardizes on visible parentheses for single-element tuples, and keeps
2607     existing visible parentheses for other tuples and generator expressions.
2608     """
2609     for pc in list_comments(node.prefix, is_endmarker=False):
2610         if pc.value in FMT_OFF:
2611             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2612             return
2613
2614     check_lpar = False
2615     for index, child in enumerate(list(node.children)):
2616         if check_lpar:
2617             if child.type == syms.atom:
2618                 if maybe_make_parens_invisible_in_atom(child):
2619                     lpar = Leaf(token.LPAR, "")
2620                     rpar = Leaf(token.RPAR, "")
2621                     index = child.remove() or 0
2622                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2623             elif is_one_tuple(child):
2624                 # wrap child in visible parentheses
2625                 lpar = Leaf(token.LPAR, "(")
2626                 rpar = Leaf(token.RPAR, ")")
2627                 child.remove()
2628                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2629             elif node.type == syms.import_from:
2630                 # "import from" nodes store parentheses directly as part of
2631                 # the statement
2632                 if child.type == token.LPAR:
2633                     # make parentheses invisible
2634                     child.value = ""  # type: ignore
2635                     node.children[-1].value = ""  # type: ignore
2636                 elif child.type != token.STAR:
2637                     # insert invisible parentheses
2638                     node.insert_child(index, Leaf(token.LPAR, ""))
2639                     node.append_child(Leaf(token.RPAR, ""))
2640                 break
2641
2642             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2643                 # wrap child in invisible parentheses
2644                 lpar = Leaf(token.LPAR, "")
2645                 rpar = Leaf(token.RPAR, "")
2646                 index = child.remove() or 0
2647                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2648
2649         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2650
2651
2652 def normalize_fmt_off(node: Node) -> None:
2653     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2654     try_again = True
2655     while try_again:
2656         try_again = convert_one_fmt_off_pair(node)
2657
2658
2659 def convert_one_fmt_off_pair(node: Node) -> bool:
2660     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2661
2662     Returns True if a pair was converted.
2663     """
2664     for leaf in node.leaves():
2665         previous_consumed = 0
2666         for comment in list_comments(leaf.prefix, is_endmarker=False):
2667             if comment.value in FMT_OFF:
2668                 # We only want standalone comments. If there's no previous leaf or
2669                 # the previous leaf is indentation, it's a standalone comment in
2670                 # disguise.
2671                 if comment.type != STANDALONE_COMMENT:
2672                     prev = preceding_leaf(leaf)
2673                     if prev and prev.type not in WHITESPACE:
2674                         continue
2675
2676                 ignored_nodes = list(generate_ignored_nodes(leaf))
2677                 if not ignored_nodes:
2678                     continue
2679
2680                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2681                 parent = first.parent
2682                 prefix = first.prefix
2683                 first.prefix = prefix[comment.consumed :]
2684                 hidden_value = (
2685                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2686                 )
2687                 if hidden_value.endswith("\n"):
2688                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2689                     # leaf (possibly followed by a DEDENT).
2690                     hidden_value = hidden_value[:-1]
2691                 first_idx = None
2692                 for ignored in ignored_nodes:
2693                     index = ignored.remove()
2694                     if first_idx is None:
2695                         first_idx = index
2696                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2697                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2698                 parent.insert_child(
2699                     first_idx,
2700                     Leaf(
2701                         STANDALONE_COMMENT,
2702                         hidden_value,
2703                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2704                     ),
2705                 )
2706                 return True
2707
2708             previous_consumed = comment.consumed
2709
2710     return False
2711
2712
2713 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2714     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2715
2716     Stops at the end of the block.
2717     """
2718     container: Optional[LN] = container_of(leaf)
2719     while container is not None and container.type != token.ENDMARKER:
2720         for comment in list_comments(container.prefix, is_endmarker=False):
2721             if comment.value in FMT_ON:
2722                 return
2723
2724         yield container
2725
2726         container = container.next_sibling
2727
2728
2729 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2730     """If it's safe, make the parens in the atom `node` invisible, recursively.
2731
2732     Returns whether the node should itself be wrapped in invisible parentheses.
2733
2734     """
2735     if (
2736         node.type != syms.atom
2737         or is_empty_tuple(node)
2738         or is_one_tuple(node)
2739         or is_yield(node)
2740         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2741     ):
2742         return False
2743
2744     first = node.children[0]
2745     last = node.children[-1]
2746     if first.type == token.LPAR and last.type == token.RPAR:
2747         # make parentheses invisible
2748         first.value = ""  # type: ignore
2749         last.value = ""  # type: ignore
2750         if len(node.children) > 1:
2751             maybe_make_parens_invisible_in_atom(node.children[1])
2752         return False
2753
2754     return True
2755
2756
2757 def is_empty_tuple(node: LN) -> bool:
2758     """Return True if `node` holds an empty tuple."""
2759     return (
2760         node.type == syms.atom
2761         and len(node.children) == 2
2762         and node.children[0].type == token.LPAR
2763         and node.children[1].type == token.RPAR
2764     )
2765
2766
2767 def is_one_tuple(node: LN) -> bool:
2768     """Return True if `node` holds a tuple with one element, with or without parens."""
2769     if node.type == syms.atom:
2770         if len(node.children) != 3:
2771             return False
2772
2773         lpar, gexp, rpar = node.children
2774         if not (
2775             lpar.type == token.LPAR
2776             and gexp.type == syms.testlist_gexp
2777             and rpar.type == token.RPAR
2778         ):
2779             return False
2780
2781         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2782
2783     return (
2784         node.type in IMPLICIT_TUPLE
2785         and len(node.children) == 2
2786         and node.children[1].type == token.COMMA
2787     )
2788
2789
2790 def is_yield(node: LN) -> bool:
2791     """Return True if `node` holds a `yield` or `yield from` expression."""
2792     if node.type == syms.yield_expr:
2793         return True
2794
2795     if node.type == token.NAME and node.value == "yield":  # type: ignore
2796         return True
2797
2798     if node.type != syms.atom:
2799         return False
2800
2801     if len(node.children) != 3:
2802         return False
2803
2804     lpar, expr, rpar = node.children
2805     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2806         return is_yield(expr)
2807
2808     return False
2809
2810
2811 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2812     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2813
2814     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2815     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2816     extended iterable unpacking (PEP 3132) and additional unpacking
2817     generalizations (PEP 448).
2818     """
2819     if leaf.type not in STARS or not leaf.parent:
2820         return False
2821
2822     p = leaf.parent
2823     if p.type == syms.star_expr:
2824         # Star expressions are also used as assignment targets in extended
2825         # iterable unpacking (PEP 3132).  See what its parent is instead.
2826         if not p.parent:
2827             return False
2828
2829         p = p.parent
2830
2831     return p.type in within
2832
2833
2834 def is_multiline_string(leaf: Leaf) -> bool:
2835     """Return True if `leaf` is a multiline string that actually spans many lines."""
2836     value = leaf.value.lstrip("furbFURB")
2837     return value[:3] in {'"""', "'''"} and "\n" in value
2838
2839
2840 def is_stub_suite(node: Node) -> bool:
2841     """Return True if `node` is a suite with a stub body."""
2842     if (
2843         len(node.children) != 4
2844         or node.children[0].type != token.NEWLINE
2845         or node.children[1].type != token.INDENT
2846         or node.children[3].type != token.DEDENT
2847     ):
2848         return False
2849
2850     return is_stub_body(node.children[2])
2851
2852
2853 def is_stub_body(node: LN) -> bool:
2854     """Return True if `node` is a simple statement containing an ellipsis."""
2855     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2856         return False
2857
2858     if len(node.children) != 2:
2859         return False
2860
2861     child = node.children[0]
2862     return (
2863         child.type == syms.atom
2864         and len(child.children) == 3
2865         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2866     )
2867
2868
2869 def max_delimiter_priority_in_atom(node: LN) -> int:
2870     """Return maximum delimiter priority inside `node`.
2871
2872     This is specific to atoms with contents contained in a pair of parentheses.
2873     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2874     """
2875     if node.type != syms.atom:
2876         return 0
2877
2878     first = node.children[0]
2879     last = node.children[-1]
2880     if not (first.type == token.LPAR and last.type == token.RPAR):
2881         return 0
2882
2883     bt = BracketTracker()
2884     for c in node.children[1:-1]:
2885         if isinstance(c, Leaf):
2886             bt.mark(c)
2887         else:
2888             for leaf in c.leaves():
2889                 bt.mark(leaf)
2890     try:
2891         return bt.max_delimiter_priority()
2892
2893     except ValueError:
2894         return 0
2895
2896
2897 def ensure_visible(leaf: Leaf) -> None:
2898     """Make sure parentheses are visible.
2899
2900     They could be invisible as part of some statements (see
2901     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2902     """
2903     if leaf.type == token.LPAR:
2904         leaf.value = "("
2905     elif leaf.type == token.RPAR:
2906         leaf.value = ")"
2907
2908
2909 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2910     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2911     if not (
2912         opening_bracket.parent
2913         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2914         and opening_bracket.value in "[{("
2915     ):
2916         return False
2917
2918     try:
2919         last_leaf = line.leaves[-1]
2920         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2921         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2922     except (IndexError, ValueError):
2923         return False
2924
2925     return max_priority == COMMA_PRIORITY
2926
2927
2928 def is_python36(node: Node) -> bool:
2929     """Return True if the current file is using Python 3.6+ features.
2930
2931     Currently looking for:
2932     - f-strings;
2933     - underscores in numeric literals; and
2934     - trailing commas after * or ** in function signatures and calls.
2935     """
2936     for n in node.pre_order():
2937         if n.type == token.STRING:
2938             value_head = n.value[:2]  # type: ignore
2939             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2940                 return True
2941
2942         elif n.type == token.NUMBER:
2943             if "_" in n.value:  # type: ignore
2944                 return True
2945
2946         elif (
2947             n.type in {syms.typedargslist, syms.arglist}
2948             and n.children
2949             and n.children[-1].type == token.COMMA
2950         ):
2951             for ch in n.children:
2952                 if ch.type in STARS:
2953                     return True
2954
2955                 if ch.type == syms.argument:
2956                     for argch in ch.children:
2957                         if argch.type in STARS:
2958                             return True
2959
2960     return False
2961
2962
2963 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2964     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2965
2966     Brackets can be omitted if the entire trailer up to and including
2967     a preceding closing bracket fits in one line.
2968
2969     Yielded sets are cumulative (contain results of previous yields, too).  First
2970     set is empty.
2971     """
2972
2973     omit: Set[LeafID] = set()
2974     yield omit
2975
2976     length = 4 * line.depth
2977     opening_bracket = None
2978     closing_bracket = None
2979     optional_brackets: Set[LeafID] = set()
2980     inner_brackets: Set[LeafID] = set()
2981     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2982         length += leaf_length
2983         if length > line_length:
2984             break
2985
2986         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2987         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2988             break
2989
2990         optional_brackets.discard(id(leaf))
2991         if opening_bracket:
2992             if leaf is opening_bracket:
2993                 opening_bracket = None
2994             elif leaf.type in CLOSING_BRACKETS:
2995                 inner_brackets.add(id(leaf))
2996         elif leaf.type in CLOSING_BRACKETS:
2997             if not leaf.value:
2998                 optional_brackets.add(id(opening_bracket))
2999                 continue
3000
3001             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3002                 # Empty brackets would fail a split so treat them as "inner"
3003                 # brackets (e.g. only add them to the `omit` set if another
3004                 # pair of brackets was good enough.
3005                 inner_brackets.add(id(leaf))
3006                 continue
3007
3008             opening_bracket = leaf.opening_bracket
3009             if closing_bracket:
3010                 omit.add(id(closing_bracket))
3011                 omit.update(inner_brackets)
3012                 inner_brackets.clear()
3013                 yield omit
3014             closing_bracket = leaf
3015
3016
3017 def get_future_imports(node: Node) -> Set[str]:
3018     """Return a set of __future__ imports in the file."""
3019     imports: Set[str] = set()
3020
3021     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3022         for child in children:
3023             if isinstance(child, Leaf):
3024                 if child.type == token.NAME:
3025                     yield child.value
3026             elif child.type == syms.import_as_name:
3027                 orig_name = child.children[0]
3028                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3029                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3030                 yield orig_name.value
3031             elif child.type == syms.import_as_names:
3032                 yield from get_imports_from_children(child.children)
3033             else:
3034                 assert False, "Invalid syntax parsing imports"
3035
3036     for child in node.children:
3037         if child.type != syms.simple_stmt:
3038             break
3039         first_child = child.children[0]
3040         if isinstance(first_child, Leaf):
3041             # Continue looking if we see a docstring; otherwise stop.
3042             if (
3043                 len(child.children) == 2
3044                 and first_child.type == token.STRING
3045                 and child.children[1].type == token.NEWLINE
3046             ):
3047                 continue
3048             else:
3049                 break
3050         elif first_child.type == syms.import_from:
3051             module_name = first_child.children[1]
3052             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3053                 break
3054             imports |= set(get_imports_from_children(first_child.children[3:]))
3055         else:
3056             break
3057     return imports
3058
3059
3060 def gen_python_files_in_dir(
3061     path: Path,
3062     root: Path,
3063     include: Pattern[str],
3064     exclude: Pattern[str],
3065     report: "Report",
3066 ) -> Iterator[Path]:
3067     """Generate all files under `path` whose paths are not excluded by the
3068     `exclude` regex, but are included by the `include` regex.
3069
3070     Symbolic links pointing outside of the `root` directory are ignored.
3071
3072     `report` is where output about exclusions goes.
3073     """
3074     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3075     for child in path.iterdir():
3076         try:
3077             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3078         except ValueError:
3079             if child.is_symlink():
3080                 report.path_ignored(
3081                     child, f"is a symbolic link that points outside {root}"
3082                 )
3083                 continue
3084
3085             raise
3086
3087         if child.is_dir():
3088             normalized_path += "/"
3089         exclude_match = exclude.search(normalized_path)
3090         if exclude_match and exclude_match.group(0):
3091             report.path_ignored(child, f"matches the --exclude regular expression")
3092             continue
3093
3094         if child.is_dir():
3095             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3096
3097         elif child.is_file():
3098             include_match = include.search(normalized_path)
3099             if include_match:
3100                 yield child
3101
3102
3103 @lru_cache()
3104 def find_project_root(srcs: Iterable[str]) -> Path:
3105     """Return a directory containing .git, .hg, or pyproject.toml.
3106
3107     That directory can be one of the directories passed in `srcs` or their
3108     common parent.
3109
3110     If no directory in the tree contains a marker that would specify it's the
3111     project root, the root of the file system is returned.
3112     """
3113     if not srcs:
3114         return Path("/").resolve()
3115
3116     common_base = min(Path(src).resolve() for src in srcs)
3117     if common_base.is_dir():
3118         # Append a fake file so `parents` below returns `common_base_dir`, too.
3119         common_base /= "fake-file"
3120     for directory in common_base.parents:
3121         if (directory / ".git").is_dir():
3122             return directory
3123
3124         if (directory / ".hg").is_dir():
3125             return directory
3126
3127         if (directory / "pyproject.toml").is_file():
3128             return directory
3129
3130     return directory
3131
3132
3133 @dataclass
3134 class Report:
3135     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3136
3137     check: bool = False
3138     quiet: bool = False
3139     verbose: bool = False
3140     change_count: int = 0
3141     same_count: int = 0
3142     failure_count: int = 0
3143
3144     def done(self, src: Path, changed: Changed) -> None:
3145         """Increment the counter for successful reformatting. Write out a message."""
3146         if changed is Changed.YES:
3147             reformatted = "would reformat" if self.check else "reformatted"
3148             if self.verbose or not self.quiet:
3149                 out(f"{reformatted} {src}")
3150             self.change_count += 1
3151         else:
3152             if self.verbose:
3153                 if changed is Changed.NO:
3154                     msg = f"{src} already well formatted, good job."
3155                 else:
3156                     msg = f"{src} wasn't modified on disk since last run."
3157                 out(msg, bold=False)
3158             self.same_count += 1
3159
3160     def failed(self, src: Path, message: str) -> None:
3161         """Increment the counter for failed reformatting. Write out a message."""
3162         err(f"error: cannot format {src}: {message}")
3163         self.failure_count += 1
3164
3165     def path_ignored(self, path: Path, message: str) -> None:
3166         if self.verbose:
3167             out(f"{path} ignored: {message}", bold=False)
3168
3169     @property
3170     def return_code(self) -> int:
3171         """Return the exit code that the app should use.
3172
3173         This considers the current state of changed files and failures:
3174         - if there were any failures, return 123;
3175         - if any files were changed and --check is being used, return 1;
3176         - otherwise return 0.
3177         """
3178         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3179         # 126 we have special return codes reserved by the shell.
3180         if self.failure_count:
3181             return 123
3182
3183         elif self.change_count and self.check:
3184             return 1
3185
3186         return 0
3187
3188     def __str__(self) -> str:
3189         """Render a color report of the current state.
3190
3191         Use `click.unstyle` to remove colors.
3192         """
3193         if self.check:
3194             reformatted = "would be reformatted"
3195             unchanged = "would be left unchanged"
3196             failed = "would fail to reformat"
3197         else:
3198             reformatted = "reformatted"
3199             unchanged = "left unchanged"
3200             failed = "failed to reformat"
3201         report = []
3202         if self.change_count:
3203             s = "s" if self.change_count > 1 else ""
3204             report.append(
3205                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3206             )
3207         if self.same_count:
3208             s = "s" if self.same_count > 1 else ""
3209             report.append(f"{self.same_count} file{s} {unchanged}")
3210         if self.failure_count:
3211             s = "s" if self.failure_count > 1 else ""
3212             report.append(
3213                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3214             )
3215         return ", ".join(report) + "."
3216
3217
3218 def assert_equivalent(src: str, dst: str) -> None:
3219     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3220
3221     import ast
3222     import traceback
3223
3224     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3225         """Simple visitor generating strings to compare ASTs by content."""
3226         yield f"{'  ' * depth}{node.__class__.__name__}("
3227
3228         for field in sorted(node._fields):
3229             try:
3230                 value = getattr(node, field)
3231             except AttributeError:
3232                 continue
3233
3234             yield f"{'  ' * (depth+1)}{field}="
3235
3236             if isinstance(value, list):
3237                 for item in value:
3238                     if isinstance(item, ast.AST):
3239                         yield from _v(item, depth + 2)
3240
3241             elif isinstance(value, ast.AST):
3242                 yield from _v(value, depth + 2)
3243
3244             else:
3245                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3246
3247         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3248
3249     try:
3250         src_ast = ast.parse(src)
3251     except Exception as exc:
3252         major, minor = sys.version_info[:2]
3253         raise AssertionError(
3254             f"cannot use --safe with this file; failed to parse source file "
3255             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3256             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3257         )
3258
3259     try:
3260         dst_ast = ast.parse(dst)
3261     except Exception as exc:
3262         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3263         raise AssertionError(
3264             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3265             f"Please report a bug on https://github.com/ambv/black/issues.  "
3266             f"This invalid output might be helpful: {log}"
3267         ) from None
3268
3269     src_ast_str = "\n".join(_v(src_ast))
3270     dst_ast_str = "\n".join(_v(dst_ast))
3271     if src_ast_str != dst_ast_str:
3272         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3273         raise AssertionError(
3274             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3275             f"the source.  "
3276             f"Please report a bug on https://github.com/ambv/black/issues.  "
3277             f"This diff might be helpful: {log}"
3278         ) from None
3279
3280
3281 def assert_stable(
3282     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3283 ) -> None:
3284     """Raise AssertionError if `dst` reformats differently the second time."""
3285     newdst = format_str(dst, line_length=line_length, mode=mode)
3286     if dst != newdst:
3287         log = dump_to_file(
3288             diff(src, dst, "source", "first pass"),
3289             diff(dst, newdst, "first pass", "second pass"),
3290         )
3291         raise AssertionError(
3292             f"INTERNAL ERROR: Black produced different code on the second pass "
3293             f"of the formatter.  "
3294             f"Please report a bug on https://github.com/ambv/black/issues.  "
3295             f"This diff might be helpful: {log}"
3296         ) from None
3297
3298
3299 def dump_to_file(*output: str) -> str:
3300     """Dump `output` to a temporary file. Return path to the file."""
3301     import tempfile
3302
3303     with tempfile.NamedTemporaryFile(
3304         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3305     ) as f:
3306         for lines in output:
3307             f.write(lines)
3308             if lines and lines[-1] != "\n":
3309                 f.write("\n")
3310     return f.name
3311
3312
3313 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3314     """Return a unified diff string between strings `a` and `b`."""
3315     import difflib
3316
3317     a_lines = [line + "\n" for line in a.split("\n")]
3318     b_lines = [line + "\n" for line in b.split("\n")]
3319     return "".join(
3320         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3321     )
3322
3323
3324 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3325     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3326     err("Aborted!")
3327     for task in tasks:
3328         task.cancel()
3329
3330
3331 def shutdown(loop: BaseEventLoop) -> None:
3332     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3333     try:
3334         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3335         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3336         if not to_cancel:
3337             return
3338
3339         for task in to_cancel:
3340             task.cancel()
3341         loop.run_until_complete(
3342             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3343         )
3344     finally:
3345         # `concurrent.futures.Future` objects cannot be cancelled once they
3346         # are already running. There might be some when the `shutdown()` happened.
3347         # Silence their logger's spew about the event loop being closed.
3348         cf_logger = logging.getLogger("concurrent.futures")
3349         cf_logger.setLevel(logging.CRITICAL)
3350         loop.close()
3351
3352
3353 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3354     """Replace `regex` with `replacement` twice on `original`.
3355
3356     This is used by string normalization to perform replaces on
3357     overlapping matches.
3358     """
3359     return regex.sub(replacement, regex.sub(replacement, original))
3360
3361
3362 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3363     """Compile a regular expression string in `regex`.
3364
3365     If it contains newlines, use verbose mode.
3366     """
3367     if "\n" in regex:
3368         regex = "(?x)" + regex
3369     return re.compile(regex)
3370
3371
3372 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3373     """Like `reversed(enumerate(sequence))` if that were possible."""
3374     index = len(sequence) - 1
3375     for element in reversed(sequence):
3376         yield (index, element)
3377         index -= 1
3378
3379
3380 def enumerate_with_length(
3381     line: Line, reversed: bool = False
3382 ) -> Iterator[Tuple[Index, Leaf, int]]:
3383     """Return an enumeration of leaves with their length.
3384
3385     Stops prematurely on multiline strings and standalone comments.
3386     """
3387     op = cast(
3388         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3389         enumerate_reversed if reversed else enumerate,
3390     )
3391     for index, leaf in op(line.leaves):
3392         length = len(leaf.prefix) + len(leaf.value)
3393         if "\n" in leaf.value:
3394             return  # Multiline strings, we can't continue.
3395
3396         comment: Optional[Leaf]
3397         for comment in line.comments_after(leaf, index):
3398             length += len(comment.value)
3399
3400         yield index, leaf, length
3401
3402
3403 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3404     """Return True if `line` is no longer than `line_length`.
3405
3406     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3407     """
3408     if not line_str:
3409         line_str = str(line).strip("\n")
3410     return (
3411         len(line_str) <= line_length
3412         and "\n" not in line_str  # multiline strings
3413         and not line.contains_standalone_comments()
3414     )
3415
3416
3417 def can_be_split(line: Line) -> bool:
3418     """Return False if the line cannot be split *for sure*.
3419
3420     This is not an exhaustive search but a cheap heuristic that we can use to
3421     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3422     in unnecessary parentheses).
3423     """
3424     leaves = line.leaves
3425     if len(leaves) < 2:
3426         return False
3427
3428     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3429         call_count = 0
3430         dot_count = 0
3431         next = leaves[-1]
3432         for leaf in leaves[-2::-1]:
3433             if leaf.type in OPENING_BRACKETS:
3434                 if next.type not in CLOSING_BRACKETS:
3435                     return False
3436
3437                 call_count += 1
3438             elif leaf.type == token.DOT:
3439                 dot_count += 1
3440             elif leaf.type == token.NAME:
3441                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3442                     return False
3443
3444             elif leaf.type not in CLOSING_BRACKETS:
3445                 return False
3446
3447             if dot_count > 1 and call_count > 1:
3448                 return False
3449
3450     return True
3451
3452
3453 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3454     """Does `line` have a shape safe to reformat without optional parens around it?
3455
3456     Returns True for only a subset of potentially nice looking formattings but
3457     the point is to not return false positives that end up producing lines that
3458     are too long.
3459     """
3460     bt = line.bracket_tracker
3461     if not bt.delimiters:
3462         # Without delimiters the optional parentheses are useless.
3463         return True
3464
3465     max_priority = bt.max_delimiter_priority()
3466     if bt.delimiter_count_with_priority(max_priority) > 1:
3467         # With more than one delimiter of a kind the optional parentheses read better.
3468         return False
3469
3470     if max_priority == DOT_PRIORITY:
3471         # A single stranded method call doesn't require optional parentheses.
3472         return True
3473
3474     assert len(line.leaves) >= 2, "Stranded delimiter"
3475
3476     first = line.leaves[0]
3477     second = line.leaves[1]
3478     penultimate = line.leaves[-2]
3479     last = line.leaves[-1]
3480
3481     # With a single delimiter, omit if the expression starts or ends with
3482     # a bracket.
3483     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3484         remainder = False
3485         length = 4 * line.depth
3486         for _index, leaf, leaf_length in enumerate_with_length(line):
3487             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3488                 remainder = True
3489             if remainder:
3490                 length += leaf_length
3491                 if length > line_length:
3492                     break
3493
3494                 if leaf.type in OPENING_BRACKETS:
3495                     # There are brackets we can further split on.
3496                     remainder = False
3497
3498         else:
3499             # checked the entire string and line length wasn't exceeded
3500             if len(line.leaves) == _index + 1:
3501                 return True
3502
3503         # Note: we are not returning False here because a line might have *both*
3504         # a leading opening bracket and a trailing closing bracket.  If the
3505         # opening bracket doesn't match our rule, maybe the closing will.
3506
3507     if (
3508         last.type == token.RPAR
3509         or last.type == token.RBRACE
3510         or (
3511             # don't use indexing for omitting optional parentheses;
3512             # it looks weird
3513             last.type == token.RSQB
3514             and last.parent
3515             and last.parent.type != syms.trailer
3516         )
3517     ):
3518         if penultimate.type in OPENING_BRACKETS:
3519             # Empty brackets don't help.
3520             return False
3521
3522         if is_multiline_string(first):
3523             # Additional wrapping of a multiline string in this situation is
3524             # unnecessary.
3525             return True
3526
3527         length = 4 * line.depth
3528         seen_other_brackets = False
3529         for _index, leaf, leaf_length in enumerate_with_length(line):
3530             length += leaf_length
3531             if leaf is last.opening_bracket:
3532                 if seen_other_brackets or length <= line_length:
3533                     return True
3534
3535             elif leaf.type in OPENING_BRACKETS:
3536                 # There are brackets we can further split on.
3537                 seen_other_brackets = True
3538
3539     return False
3540
3541
3542 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3543     return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
3544
3545
3546 def read_cache(line_length: int, mode: FileMode) -> Cache:
3547     """Read the cache if it exists and is well formed.
3548
3549     If it is not well formed, the call to write_cache later should resolve the issue.
3550     """
3551     cache_file = get_cache_file(line_length, mode)
3552     if not cache_file.exists():
3553         return {}
3554
3555     with cache_file.open("rb") as fobj:
3556         try:
3557             cache: Cache = pickle.load(fobj)
3558         except pickle.UnpicklingError:
3559             return {}
3560
3561     return cache
3562
3563
3564 def get_cache_info(path: Path) -> CacheInfo:
3565     """Return the information used to check if a file is already formatted or not."""
3566     stat = path.stat()
3567     return stat.st_mtime, stat.st_size
3568
3569
3570 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3571     """Split an iterable of paths in `sources` into two sets.
3572
3573     The first contains paths of files that modified on disk or are not in the
3574     cache. The other contains paths to non-modified files.
3575     """
3576     todo, done = set(), set()
3577     for src in sources:
3578         src = src.resolve()
3579         if cache.get(src) != get_cache_info(src):
3580             todo.add(src)
3581         else:
3582             done.add(src)
3583     return todo, done
3584
3585
3586 def write_cache(
3587     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3588 ) -> None:
3589     """Update the cache file."""
3590     cache_file = get_cache_file(line_length, mode)
3591     try:
3592         if not CACHE_DIR.exists():
3593             CACHE_DIR.mkdir(parents=True)
3594         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3595         with cache_file.open("wb") as fobj:
3596             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3597     except OSError:
3598         pass
3599
3600
3601 def patch_click() -> None:
3602     """Make Click not crash.
3603
3604     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3605     default which restricts paths that it can access during the lifetime of the
3606     application.  Click refuses to work in this scenario by raising a RuntimeError.
3607
3608     In case of Black the likelihood that non-ASCII characters are going to be used in
3609     file paths is minimal since it's Python source code.  Moreover, this crash was
3610     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3611     """
3612     try:
3613         from click import core
3614         from click import _unicodefun  # type: ignore
3615     except ModuleNotFoundError:
3616         return
3617
3618     for module in (core, _unicodefun):
3619         if hasattr(module, "_verify_python3_env"):
3620             module._verify_python3_env = lambda: None
3621
3622
3623 if __name__ == "__main__":
3624     patch_click()
3625     main()