]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

85cb45bded5cd7794d6d69bc2f4adea22aa10dda
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum, Flag
6 from functools import lru_cache, partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tokenize
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generator,
24     Generic,
25     Iterable,
26     Iterator,
27     List,
28     Optional,
29     Pattern,
30     Sequence,
31     Set,
32     Tuple,
33     TypeVar,
34     Union,
35     cast,
36 )
37
38 from appdirs import user_cache_dir
39 from attr import dataclass, Factory
40 import click
41 import toml
42
43 # lib2to3 fork
44 from blib2to3.pytree import Node, Leaf, type_repr
45 from blib2to3 import pygram, pytree
46 from blib2to3.pgen2 import driver, token
47 from blib2to3.pgen2.parse import ParseError
48
49
50 __version__ = "18.6b4"
51 DEFAULT_LINE_LENGTH = 88
52 DEFAULT_EXCLUDES = (
53     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
54 )
55 DEFAULT_INCLUDES = r"\.pyi?$"
56 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
57
58
59 # types
60 FileContent = str
61 Encoding = str
62 NewLine = str
63 Depth = int
64 NodeType = int
65 LeafID = int
66 Priority = int
67 Index = int
68 LN = Union[Leaf, Node]
69 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
70 Timestamp = float
71 FileSize = int
72 CacheInfo = Tuple[Timestamp, FileSize]
73 Cache = Dict[Path, CacheInfo]
74 out = partial(click.secho, bold=True, err=True)
75 err = partial(click.secho, fg="red", err=True)
76
77 pygram.initialize(CACHE_DIR)
78 syms = pygram.python_symbols
79
80
81 class NothingChanged(UserWarning):
82     """Raised by :func:`format_file` when reformatted code is the same as source."""
83
84
85 class CannotSplit(Exception):
86     """A readable split that fits the allotted line length is impossible.
87
88     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
89     :func:`delimiter_split`.
90     """
91
92
93 class WriteBack(Enum):
94     NO = 0
95     YES = 1
96     DIFF = 2
97     CHECK = 3
98
99     @classmethod
100     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
101         if check and not diff:
102             return cls.CHECK
103
104         return cls.DIFF if diff else cls.YES
105
106
107 class Changed(Enum):
108     NO = 0
109     CACHED = 1
110     YES = 2
111
112
113 class FileMode(Flag):
114     AUTO_DETECT = 0
115     PYTHON36 = 1
116     PYI = 2
117     NO_STRING_NORMALIZATION = 4
118
119     @classmethod
120     def from_configuration(
121         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
122     ) -> "FileMode":
123         mode = cls.AUTO_DETECT
124         if py36:
125             mode |= cls.PYTHON36
126         if pyi:
127             mode |= cls.PYI
128         if skip_string_normalization:
129             mode |= cls.NO_STRING_NORMALIZATION
130         return mode
131
132
133 def read_pyproject_toml(
134     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
135 ) -> Optional[str]:
136     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
137
138     Returns the path to a successfully found and read configuration file, None
139     otherwise.
140     """
141     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
142     if not value:
143         root = find_project_root(ctx.params.get("src", ()))
144         path = root / "pyproject.toml"
145         if path.is_file():
146             value = str(path)
147         else:
148             return None
149
150     try:
151         pyproject_toml = toml.load(value)
152         config = pyproject_toml.get("tool", {}).get("black", {})
153     except (toml.TomlDecodeError, OSError) as e:
154         raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
155
156     if not config:
157         return None
158
159     if ctx.default_map is None:
160         ctx.default_map = {}
161     ctx.default_map.update(  # type: ignore  # bad types in .pyi
162         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
163     )
164     return value
165
166
167 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
168 @click.option(
169     "-l",
170     "--line-length",
171     type=int,
172     default=DEFAULT_LINE_LENGTH,
173     help="How many characters per line to allow.",
174     show_default=True,
175 )
176 @click.option(
177     "--py36",
178     is_flag=True,
179     help=(
180         "Allow using Python 3.6-only syntax on all input files.  This will put "
181         "trailing commas in function signatures and calls also after *args and "
182         "**kwargs.  [default: per-file auto-detection]"
183     ),
184 )
185 @click.option(
186     "--pyi",
187     is_flag=True,
188     help=(
189         "Format all input files like typing stubs regardless of file extension "
190         "(useful when piping source on standard input)."
191     ),
192 )
193 @click.option(
194     "-S",
195     "--skip-string-normalization",
196     is_flag=True,
197     help="Don't normalize string quotes or prefixes.",
198 )
199 @click.option(
200     "--check",
201     is_flag=True,
202     help=(
203         "Don't write the files back, just return the status.  Return code 0 "
204         "means nothing would change.  Return code 1 means some files would be "
205         "reformatted.  Return code 123 means there was an internal error."
206     ),
207 )
208 @click.option(
209     "--diff",
210     is_flag=True,
211     help="Don't write the files back, just output a diff for each file on stdout.",
212 )
213 @click.option(
214     "--fast/--safe",
215     is_flag=True,
216     help="If --fast given, skip temporary sanity checks. [default: --safe]",
217 )
218 @click.option(
219     "--include",
220     type=str,
221     default=DEFAULT_INCLUDES,
222     help=(
223         "A regular expression that matches files and directories that should be "
224         "included on recursive searches.  An empty value means all files are "
225         "included regardless of the name.  Use forward slashes for directories on "
226         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
227         "later."
228     ),
229     show_default=True,
230 )
231 @click.option(
232     "--exclude",
233     type=str,
234     default=DEFAULT_EXCLUDES,
235     help=(
236         "A regular expression that matches files and directories that should be "
237         "excluded on recursive searches.  An empty value means no paths are excluded. "
238         "Use forward slashes for directories on all platforms (Windows, too).  "
239         "Exclusions are calculated first, inclusions later."
240     ),
241     show_default=True,
242 )
243 @click.option(
244     "-q",
245     "--quiet",
246     is_flag=True,
247     help=(
248         "Don't emit non-error messages to stderr. Errors are still emitted, "
249         "silence those with 2>/dev/null."
250     ),
251 )
252 @click.option(
253     "-v",
254     "--verbose",
255     is_flag=True,
256     help=(
257         "Also emit messages to stderr about files that were not changed or were "
258         "ignored due to --exclude=."
259     ),
260 )
261 @click.version_option(version=__version__)
262 @click.argument(
263     "src",
264     nargs=-1,
265     type=click.Path(
266         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
267     ),
268     is_eager=True,
269 )
270 @click.option(
271     "--config",
272     type=click.Path(
273         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
274     ),
275     is_eager=True,
276     callback=read_pyproject_toml,
277     help="Read configuration from PATH.",
278 )
279 @click.pass_context
280 def main(
281     ctx: click.Context,
282     line_length: int,
283     check: bool,
284     diff: bool,
285     fast: bool,
286     pyi: bool,
287     py36: bool,
288     skip_string_normalization: bool,
289     quiet: bool,
290     verbose: bool,
291     include: str,
292     exclude: str,
293     src: Tuple[str],
294     config: Optional[str],
295 ) -> None:
296     """The uncompromising code formatter."""
297     write_back = WriteBack.from_configuration(check=check, diff=diff)
298     mode = FileMode.from_configuration(
299         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
300     )
301     if config and verbose:
302         out(f"Using configuration from {config}.", bold=False, fg="blue")
303     try:
304         include_regex = re_compile_maybe_verbose(include)
305     except re.error:
306         err(f"Invalid regular expression for include given: {include!r}")
307         ctx.exit(2)
308     try:
309         exclude_regex = re_compile_maybe_verbose(exclude)
310     except re.error:
311         err(f"Invalid regular expression for exclude given: {exclude!r}")
312         ctx.exit(2)
313     report = Report(check=check, quiet=quiet, verbose=verbose)
314     root = find_project_root(src)
315     sources: Set[Path] = set()
316     for s in src:
317         p = Path(s)
318         if p.is_dir():
319             sources.update(
320                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
321             )
322         elif p.is_file() or s == "-":
323             # if a file was explicitly given, we don't care about its extension
324             sources.add(p)
325         else:
326             err(f"invalid path: {s}")
327     if len(sources) == 0:
328         if verbose or not quiet:
329             out("No paths given. Nothing to do 😴")
330         ctx.exit(0)
331
332     if len(sources) == 1:
333         reformat_one(
334             src=sources.pop(),
335             line_length=line_length,
336             fast=fast,
337             write_back=write_back,
338             mode=mode,
339             report=report,
340         )
341     else:
342         loop = asyncio.get_event_loop()
343         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
344         try:
345             loop.run_until_complete(
346                 schedule_formatting(
347                     sources=sources,
348                     line_length=line_length,
349                     fast=fast,
350                     write_back=write_back,
351                     mode=mode,
352                     report=report,
353                     loop=loop,
354                     executor=executor,
355                 )
356             )
357         finally:
358             shutdown(loop)
359     if verbose or not quiet:
360         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
361         out(f"All done! {bang}")
362         click.secho(str(report), err=True)
363     ctx.exit(report.return_code)
364
365
366 def reformat_one(
367     src: Path,
368     line_length: int,
369     fast: bool,
370     write_back: WriteBack,
371     mode: FileMode,
372     report: "Report",
373 ) -> None:
374     """Reformat a single file under `src` without spawning child processes.
375
376     If `quiet` is True, non-error messages are not output. `line_length`,
377     `write_back`, `fast` and `pyi` options are passed to
378     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
379     """
380     try:
381         changed = Changed.NO
382         if not src.is_file() and str(src) == "-":
383             if format_stdin_to_stdout(
384                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
385             ):
386                 changed = Changed.YES
387         else:
388             cache: Cache = {}
389             if write_back != WriteBack.DIFF:
390                 cache = read_cache(line_length, mode)
391                 res_src = src.resolve()
392                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
393                     changed = Changed.CACHED
394             if changed is not Changed.CACHED and format_file_in_place(
395                 src,
396                 line_length=line_length,
397                 fast=fast,
398                 write_back=write_back,
399                 mode=mode,
400             ):
401                 changed = Changed.YES
402             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
403                 write_back is WriteBack.CHECK and changed is Changed.NO
404             ):
405                 write_cache(cache, [src], line_length, mode)
406         report.done(src, changed)
407     except Exception as exc:
408         report.failed(src, str(exc))
409
410
411 async def schedule_formatting(
412     sources: Set[Path],
413     line_length: int,
414     fast: bool,
415     write_back: WriteBack,
416     mode: FileMode,
417     report: "Report",
418     loop: BaseEventLoop,
419     executor: Executor,
420 ) -> None:
421     """Run formatting of `sources` in parallel using the provided `executor`.
422
423     (Use ProcessPoolExecutors for actual parallelism.)
424
425     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
426     :func:`format_file_in_place`.
427     """
428     cache: Cache = {}
429     if write_back != WriteBack.DIFF:
430         cache = read_cache(line_length, mode)
431         sources, cached = filter_cached(cache, sources)
432         for src in sorted(cached):
433             report.done(src, Changed.CACHED)
434     if not sources:
435         return
436
437     cancelled = []
438     sources_to_cache = []
439     lock = None
440     if write_back == WriteBack.DIFF:
441         # For diff output, we need locks to ensure we don't interleave output
442         # from different processes.
443         manager = Manager()
444         lock = manager.Lock()
445     tasks = {
446         loop.run_in_executor(
447             executor,
448             format_file_in_place,
449             src,
450             line_length,
451             fast,
452             write_back,
453             mode,
454             lock,
455         ): src
456         for src in sorted(sources)
457     }
458     pending: Iterable[asyncio.Task] = tasks.keys()
459     try:
460         loop.add_signal_handler(signal.SIGINT, cancel, pending)
461         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
462     except NotImplementedError:
463         # There are no good alternatives for these on Windows.
464         pass
465     while pending:
466         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
467         for task in done:
468             src = tasks.pop(task)
469             if task.cancelled():
470                 cancelled.append(task)
471             elif task.exception():
472                 report.failed(src, str(task.exception()))
473             else:
474                 changed = Changed.YES if task.result() else Changed.NO
475                 # If the file was written back or was successfully checked as
476                 # well-formatted, store this information in the cache.
477                 if write_back is WriteBack.YES or (
478                     write_back is WriteBack.CHECK and changed is Changed.NO
479                 ):
480                     sources_to_cache.append(src)
481                 report.done(src, changed)
482     if cancelled:
483         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
484     if sources_to_cache:
485         write_cache(cache, sources_to_cache, line_length, mode)
486
487
488 def format_file_in_place(
489     src: Path,
490     line_length: int,
491     fast: bool,
492     write_back: WriteBack = WriteBack.NO,
493     mode: FileMode = FileMode.AUTO_DETECT,
494     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
495 ) -> bool:
496     """Format file under `src` path. Return True if changed.
497
498     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
499     code to the file.
500     `line_length` and `fast` options are passed to :func:`format_file_contents`.
501     """
502     if src.suffix == ".pyi":
503         mode |= FileMode.PYI
504
505     then = datetime.utcfromtimestamp(src.stat().st_mtime)
506     with open(src, "rb") as buf:
507         src_contents, encoding, newline = decode_bytes(buf.read())
508     try:
509         dst_contents = format_file_contents(
510             src_contents, line_length=line_length, fast=fast, mode=mode
511         )
512     except NothingChanged:
513         return False
514
515     if write_back == write_back.YES:
516         with open(src, "w", encoding=encoding, newline=newline) as f:
517             f.write(dst_contents)
518     elif write_back == write_back.DIFF:
519         now = datetime.utcnow()
520         src_name = f"{src}\t{then} +0000"
521         dst_name = f"{src}\t{now} +0000"
522         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
523         if lock:
524             lock.acquire()
525         try:
526             f = io.TextIOWrapper(
527                 sys.stdout.buffer,
528                 encoding=encoding,
529                 newline=newline,
530                 write_through=True,
531             )
532             f.write(diff_contents)
533             f.detach()
534         finally:
535             if lock:
536                 lock.release()
537     return True
538
539
540 def format_stdin_to_stdout(
541     line_length: int,
542     fast: bool,
543     write_back: WriteBack = WriteBack.NO,
544     mode: FileMode = FileMode.AUTO_DETECT,
545 ) -> bool:
546     """Format file on stdin. Return True if changed.
547
548     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
549     write a diff to stdout.
550     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
551     :func:`format_file_contents`.
552     """
553     then = datetime.utcnow()
554     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
555     dst = src
556     try:
557         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
558         return True
559
560     except NothingChanged:
561         return False
562
563     finally:
564         f = io.TextIOWrapper(
565             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
566         )
567         if write_back == WriteBack.YES:
568             f.write(dst)
569         elif write_back == WriteBack.DIFF:
570             now = datetime.utcnow()
571             src_name = f"STDIN\t{then} +0000"
572             dst_name = f"STDOUT\t{now} +0000"
573             f.write(diff(src, dst, src_name, dst_name))
574         f.detach()
575
576
577 def format_file_contents(
578     src_contents: str,
579     *,
580     line_length: int,
581     fast: bool,
582     mode: FileMode = FileMode.AUTO_DETECT,
583 ) -> FileContent:
584     """Reformat contents a file and return new contents.
585
586     If `fast` is False, additionally confirm that the reformatted code is
587     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
588     `line_length` is passed to :func:`format_str`.
589     """
590     if src_contents.strip() == "":
591         raise NothingChanged
592
593     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
594     if src_contents == dst_contents:
595         raise NothingChanged
596
597     if not fast:
598         assert_equivalent(src_contents, dst_contents)
599         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
600     return dst_contents
601
602
603 def format_str(
604     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
605 ) -> FileContent:
606     """Reformat a string and return new contents.
607
608     `line_length` determines how many characters per line are allowed.
609     """
610     src_node = lib2to3_parse(src_contents)
611     dst_contents = ""
612     future_imports = get_future_imports(src_node)
613     is_pyi = bool(mode & FileMode.PYI)
614     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
615     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
616     normalize_fmt_off(src_node)
617     lines = LineGenerator(
618         remove_u_prefix=py36 or "unicode_literals" in future_imports,
619         is_pyi=is_pyi,
620         normalize_strings=normalize_strings,
621         allow_underscores=py36,
622     )
623     elt = EmptyLineTracker(is_pyi=is_pyi)
624     empty_line = Line()
625     after = 0
626     for current_line in lines.visit(src_node):
627         for _ in range(after):
628             dst_contents += str(empty_line)
629         before, after = elt.maybe_empty_lines(current_line)
630         for _ in range(before):
631             dst_contents += str(empty_line)
632         for line in split_line(current_line, line_length=line_length, py36=py36):
633             dst_contents += str(line)
634     return dst_contents
635
636
637 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
638     """Return a tuple of (decoded_contents, encoding, newline).
639
640     `newline` is either CRLF or LF but `decoded_contents` is decoded with
641     universal newlines (i.e. only contains LF).
642     """
643     srcbuf = io.BytesIO(src)
644     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
645     if not lines:
646         return "", encoding, "\n"
647
648     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
649     srcbuf.seek(0)
650     with io.TextIOWrapper(srcbuf, encoding) as tiow:
651         return tiow.read(), encoding, newline
652
653
654 GRAMMARS = [
655     pygram.python_grammar_no_print_statement_no_exec_statement,
656     pygram.python_grammar_no_print_statement,
657     pygram.python_grammar,
658 ]
659
660
661 def lib2to3_parse(src_txt: str) -> Node:
662     """Given a string with source, return the lib2to3 Node."""
663     grammar = pygram.python_grammar_no_print_statement
664     if src_txt[-1:] != "\n":
665         src_txt += "\n"
666     for grammar in GRAMMARS:
667         drv = driver.Driver(grammar, pytree.convert)
668         try:
669             result = drv.parse_string(src_txt, True)
670             break
671
672         except ParseError as pe:
673             lineno, column = pe.context[1]
674             lines = src_txt.splitlines()
675             try:
676                 faulty_line = lines[lineno - 1]
677             except IndexError:
678                 faulty_line = "<line number missing in source>"
679             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
680     else:
681         raise exc from None
682
683     if isinstance(result, Leaf):
684         result = Node(syms.file_input, [result])
685     return result
686
687
688 def lib2to3_unparse(node: Node) -> str:
689     """Given a lib2to3 node, return its string representation."""
690     code = str(node)
691     return code
692
693
694 T = TypeVar("T")
695
696
697 class Visitor(Generic[T]):
698     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
699
700     def visit(self, node: LN) -> Iterator[T]:
701         """Main method to visit `node` and its children.
702
703         It tries to find a `visit_*()` method for the given `node.type`, like
704         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
705         If no dedicated `visit_*()` method is found, chooses `visit_default()`
706         instead.
707
708         Then yields objects of type `T` from the selected visitor.
709         """
710         if node.type < 256:
711             name = token.tok_name[node.type]
712         else:
713             name = type_repr(node.type)
714         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
715
716     def visit_default(self, node: LN) -> Iterator[T]:
717         """Default `visit_*()` implementation. Recurses to children of `node`."""
718         if isinstance(node, Node):
719             for child in node.children:
720                 yield from self.visit(child)
721
722
723 @dataclass
724 class DebugVisitor(Visitor[T]):
725     tree_depth: int = 0
726
727     def visit_default(self, node: LN) -> Iterator[T]:
728         indent = " " * (2 * self.tree_depth)
729         if isinstance(node, Node):
730             _type = type_repr(node.type)
731             out(f"{indent}{_type}", fg="yellow")
732             self.tree_depth += 1
733             for child in node.children:
734                 yield from self.visit(child)
735
736             self.tree_depth -= 1
737             out(f"{indent}/{_type}", fg="yellow", bold=False)
738         else:
739             _type = token.tok_name.get(node.type, str(node.type))
740             out(f"{indent}{_type}", fg="blue", nl=False)
741             if node.prefix:
742                 # We don't have to handle prefixes for `Node` objects since
743                 # that delegates to the first child anyway.
744                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
745             out(f" {node.value!r}", fg="blue", bold=False)
746
747     @classmethod
748     def show(cls, code: Union[str, Leaf, Node]) -> None:
749         """Pretty-print the lib2to3 AST of a given string of `code`.
750
751         Convenience method for debugging.
752         """
753         v: DebugVisitor[None] = DebugVisitor()
754         if isinstance(code, str):
755             code = lib2to3_parse(code)
756         list(v.visit(code))
757
758
759 KEYWORDS = set(keyword.kwlist)
760 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
761 FLOW_CONTROL = {"return", "raise", "break", "continue"}
762 STATEMENT = {
763     syms.if_stmt,
764     syms.while_stmt,
765     syms.for_stmt,
766     syms.try_stmt,
767     syms.except_clause,
768     syms.with_stmt,
769     syms.funcdef,
770     syms.classdef,
771 }
772 STANDALONE_COMMENT = 153
773 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
774 LOGIC_OPERATORS = {"and", "or"}
775 COMPARATORS = {
776     token.LESS,
777     token.GREATER,
778     token.EQEQUAL,
779     token.NOTEQUAL,
780     token.LESSEQUAL,
781     token.GREATEREQUAL,
782 }
783 MATH_OPERATORS = {
784     token.VBAR,
785     token.CIRCUMFLEX,
786     token.AMPER,
787     token.LEFTSHIFT,
788     token.RIGHTSHIFT,
789     token.PLUS,
790     token.MINUS,
791     token.STAR,
792     token.SLASH,
793     token.DOUBLESLASH,
794     token.PERCENT,
795     token.AT,
796     token.TILDE,
797     token.DOUBLESTAR,
798 }
799 STARS = {token.STAR, token.DOUBLESTAR}
800 VARARGS_PARENTS = {
801     syms.arglist,
802     syms.argument,  # double star in arglist
803     syms.trailer,  # single argument to call
804     syms.typedargslist,
805     syms.varargslist,  # lambdas
806 }
807 UNPACKING_PARENTS = {
808     syms.atom,  # single element of a list or set literal
809     syms.dictsetmaker,
810     syms.listmaker,
811     syms.testlist_gexp,
812     syms.testlist_star_expr,
813 }
814 TEST_DESCENDANTS = {
815     syms.test,
816     syms.lambdef,
817     syms.or_test,
818     syms.and_test,
819     syms.not_test,
820     syms.comparison,
821     syms.star_expr,
822     syms.expr,
823     syms.xor_expr,
824     syms.and_expr,
825     syms.shift_expr,
826     syms.arith_expr,
827     syms.trailer,
828     syms.term,
829     syms.power,
830 }
831 ASSIGNMENTS = {
832     "=",
833     "+=",
834     "-=",
835     "*=",
836     "@=",
837     "/=",
838     "%=",
839     "&=",
840     "|=",
841     "^=",
842     "<<=",
843     ">>=",
844     "**=",
845     "//=",
846 }
847 COMPREHENSION_PRIORITY = 20
848 COMMA_PRIORITY = 18
849 TERNARY_PRIORITY = 16
850 LOGIC_PRIORITY = 14
851 STRING_PRIORITY = 12
852 COMPARATOR_PRIORITY = 10
853 MATH_PRIORITIES = {
854     token.VBAR: 9,
855     token.CIRCUMFLEX: 8,
856     token.AMPER: 7,
857     token.LEFTSHIFT: 6,
858     token.RIGHTSHIFT: 6,
859     token.PLUS: 5,
860     token.MINUS: 5,
861     token.STAR: 4,
862     token.SLASH: 4,
863     token.DOUBLESLASH: 4,
864     token.PERCENT: 4,
865     token.AT: 4,
866     token.TILDE: 3,
867     token.DOUBLESTAR: 2,
868 }
869 DOT_PRIORITY = 1
870
871
872 @dataclass
873 class BracketTracker:
874     """Keeps track of brackets on a line."""
875
876     depth: int = 0
877     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
878     delimiters: Dict[LeafID, Priority] = Factory(dict)
879     previous: Optional[Leaf] = None
880     _for_loop_variable: int = 0
881     _lambda_arguments: int = 0
882
883     def mark(self, leaf: Leaf) -> None:
884         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
885
886         All leaves receive an int `bracket_depth` field that stores how deep
887         within brackets a given leaf is. 0 means there are no enclosing brackets
888         that started on this line.
889
890         If a leaf is itself a closing bracket, it receives an `opening_bracket`
891         field that it forms a pair with. This is a one-directional link to
892         avoid reference cycles.
893
894         If a leaf is a delimiter (a token on which Black can split the line if
895         needed) and it's on depth 0, its `id()` is stored in the tracker's
896         `delimiters` field.
897         """
898         if leaf.type == token.COMMENT:
899             return
900
901         self.maybe_decrement_after_for_loop_variable(leaf)
902         self.maybe_decrement_after_lambda_arguments(leaf)
903         if leaf.type in CLOSING_BRACKETS:
904             self.depth -= 1
905             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
906             leaf.opening_bracket = opening_bracket
907         leaf.bracket_depth = self.depth
908         if self.depth == 0:
909             delim = is_split_before_delimiter(leaf, self.previous)
910             if delim and self.previous is not None:
911                 self.delimiters[id(self.previous)] = delim
912             else:
913                 delim = is_split_after_delimiter(leaf, self.previous)
914                 if delim:
915                     self.delimiters[id(leaf)] = delim
916         if leaf.type in OPENING_BRACKETS:
917             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
918             self.depth += 1
919         self.previous = leaf
920         self.maybe_increment_lambda_arguments(leaf)
921         self.maybe_increment_for_loop_variable(leaf)
922
923     def any_open_brackets(self) -> bool:
924         """Return True if there is an yet unmatched open bracket on the line."""
925         return bool(self.bracket_match)
926
927     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
928         """Return the highest priority of a delimiter found on the line.
929
930         Values are consistent with what `is_split_*_delimiter()` return.
931         Raises ValueError on no delimiters.
932         """
933         return max(v for k, v in self.delimiters.items() if k not in exclude)
934
935     def delimiter_count_with_priority(self, priority: int = 0) -> int:
936         """Return the number of delimiters with the given `priority`.
937
938         If no `priority` is passed, defaults to max priority on the line.
939         """
940         if not self.delimiters:
941             return 0
942
943         priority = priority or self.max_delimiter_priority()
944         return sum(1 for p in self.delimiters.values() if p == priority)
945
946     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
947         """In a for loop, or comprehension, the variables are often unpacks.
948
949         To avoid splitting on the comma in this situation, increase the depth of
950         tokens between `for` and `in`.
951         """
952         if leaf.type == token.NAME and leaf.value == "for":
953             self.depth += 1
954             self._for_loop_variable += 1
955             return True
956
957         return False
958
959     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
960         """See `maybe_increment_for_loop_variable` above for explanation."""
961         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
962             self.depth -= 1
963             self._for_loop_variable -= 1
964             return True
965
966         return False
967
968     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
969         """In a lambda expression, there might be more than one argument.
970
971         To avoid splitting on the comma in this situation, increase the depth of
972         tokens between `lambda` and `:`.
973         """
974         if leaf.type == token.NAME and leaf.value == "lambda":
975             self.depth += 1
976             self._lambda_arguments += 1
977             return True
978
979         return False
980
981     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
982         """See `maybe_increment_lambda_arguments` above for explanation."""
983         if self._lambda_arguments and leaf.type == token.COLON:
984             self.depth -= 1
985             self._lambda_arguments -= 1
986             return True
987
988         return False
989
990     def get_open_lsqb(self) -> Optional[Leaf]:
991         """Return the most recent opening square bracket (if any)."""
992         return self.bracket_match.get((self.depth - 1, token.RSQB))
993
994
995 @dataclass
996 class Line:
997     """Holds leaves and comments. Can be printed with `str(line)`."""
998
999     depth: int = 0
1000     leaves: List[Leaf] = Factory(list)
1001     comments: List[Tuple[Index, Leaf]] = Factory(list)
1002     bracket_tracker: BracketTracker = Factory(BracketTracker)
1003     inside_brackets: bool = False
1004     should_explode: bool = False
1005
1006     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1007         """Add a new `leaf` to the end of the line.
1008
1009         Unless `preformatted` is True, the `leaf` will receive a new consistent
1010         whitespace prefix and metadata applied by :class:`BracketTracker`.
1011         Trailing commas are maybe removed, unpacked for loop variables are
1012         demoted from being delimiters.
1013
1014         Inline comments are put aside.
1015         """
1016         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1017         if not has_value:
1018             return
1019
1020         if token.COLON == leaf.type and self.is_class_paren_empty:
1021             del self.leaves[-2:]
1022         if self.leaves and not preformatted:
1023             # Note: at this point leaf.prefix should be empty except for
1024             # imports, for which we only preserve newlines.
1025             leaf.prefix += whitespace(
1026                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1027             )
1028         if self.inside_brackets or not preformatted:
1029             self.bracket_tracker.mark(leaf)
1030             self.maybe_remove_trailing_comma(leaf)
1031         if not self.append_comment(leaf):
1032             self.leaves.append(leaf)
1033
1034     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1035         """Like :func:`append()` but disallow invalid standalone comment structure.
1036
1037         Raises ValueError when any `leaf` is appended after a standalone comment
1038         or when a standalone comment is not the first leaf on the line.
1039         """
1040         if self.bracket_tracker.depth == 0:
1041             if self.is_comment:
1042                 raise ValueError("cannot append to standalone comments")
1043
1044             if self.leaves and leaf.type == STANDALONE_COMMENT:
1045                 raise ValueError(
1046                     "cannot append standalone comments to a populated line"
1047                 )
1048
1049         self.append(leaf, preformatted=preformatted)
1050
1051     @property
1052     def is_comment(self) -> bool:
1053         """Is this line a standalone comment?"""
1054         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1055
1056     @property
1057     def is_decorator(self) -> bool:
1058         """Is this line a decorator?"""
1059         return bool(self) and self.leaves[0].type == token.AT
1060
1061     @property
1062     def is_import(self) -> bool:
1063         """Is this an import line?"""
1064         return bool(self) and is_import(self.leaves[0])
1065
1066     @property
1067     def is_class(self) -> bool:
1068         """Is this line a class definition?"""
1069         return (
1070             bool(self)
1071             and self.leaves[0].type == token.NAME
1072             and self.leaves[0].value == "class"
1073         )
1074
1075     @property
1076     def is_stub_class(self) -> bool:
1077         """Is this line a class definition with a body consisting only of "..."?"""
1078         return self.is_class and self.leaves[-3:] == [
1079             Leaf(token.DOT, ".") for _ in range(3)
1080         ]
1081
1082     @property
1083     def is_def(self) -> bool:
1084         """Is this a function definition? (Also returns True for async defs.)"""
1085         try:
1086             first_leaf = self.leaves[0]
1087         except IndexError:
1088             return False
1089
1090         try:
1091             second_leaf: Optional[Leaf] = self.leaves[1]
1092         except IndexError:
1093             second_leaf = None
1094         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1095             first_leaf.type == token.ASYNC
1096             and second_leaf is not None
1097             and second_leaf.type == token.NAME
1098             and second_leaf.value == "def"
1099         )
1100
1101     @property
1102     def is_class_paren_empty(self) -> bool:
1103         """Is this a class with no base classes but using parentheses?
1104
1105         Those are unnecessary and should be removed.
1106         """
1107         return (
1108             bool(self)
1109             and len(self.leaves) == 4
1110             and self.is_class
1111             and self.leaves[2].type == token.LPAR
1112             and self.leaves[2].value == "("
1113             and self.leaves[3].type == token.RPAR
1114             and self.leaves[3].value == ")"
1115         )
1116
1117     @property
1118     def is_triple_quoted_string(self) -> bool:
1119         """Is the line a triple quoted string?"""
1120         return (
1121             bool(self)
1122             and self.leaves[0].type == token.STRING
1123             and self.leaves[0].value.startswith(('"""', "'''"))
1124         )
1125
1126     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1127         """If so, needs to be split before emitting."""
1128         for leaf in self.leaves:
1129             if leaf.type == STANDALONE_COMMENT:
1130                 if leaf.bracket_depth <= depth_limit:
1131                     return True
1132
1133         return False
1134
1135     def contains_multiline_strings(self) -> bool:
1136         for leaf in self.leaves:
1137             if is_multiline_string(leaf):
1138                 return True
1139
1140         return False
1141
1142     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1143         """Remove trailing comma if there is one and it's safe."""
1144         if not (
1145             self.leaves
1146             and self.leaves[-1].type == token.COMMA
1147             and closing.type in CLOSING_BRACKETS
1148         ):
1149             return False
1150
1151         if closing.type == token.RBRACE:
1152             self.remove_trailing_comma()
1153             return True
1154
1155         if closing.type == token.RSQB:
1156             comma = self.leaves[-1]
1157             if comma.parent and comma.parent.type == syms.listmaker:
1158                 self.remove_trailing_comma()
1159                 return True
1160
1161         # For parens let's check if it's safe to remove the comma.
1162         # Imports are always safe.
1163         if self.is_import:
1164             self.remove_trailing_comma()
1165             return True
1166
1167         # Otherwise, if the trailing one is the only one, we might mistakenly
1168         # change a tuple into a different type by removing the comma.
1169         depth = closing.bracket_depth + 1
1170         commas = 0
1171         opening = closing.opening_bracket
1172         for _opening_index, leaf in enumerate(self.leaves):
1173             if leaf is opening:
1174                 break
1175
1176         else:
1177             return False
1178
1179         for leaf in self.leaves[_opening_index + 1 :]:
1180             if leaf is closing:
1181                 break
1182
1183             bracket_depth = leaf.bracket_depth
1184             if bracket_depth == depth and leaf.type == token.COMMA:
1185                 commas += 1
1186                 if leaf.parent and leaf.parent.type == syms.arglist:
1187                     commas += 1
1188                     break
1189
1190         if commas > 1:
1191             self.remove_trailing_comma()
1192             return True
1193
1194         return False
1195
1196     def append_comment(self, comment: Leaf) -> bool:
1197         """Add an inline or standalone comment to the line."""
1198         if (
1199             comment.type == STANDALONE_COMMENT
1200             and self.bracket_tracker.any_open_brackets()
1201         ):
1202             comment.prefix = ""
1203             return False
1204
1205         if comment.type != token.COMMENT:
1206             return False
1207
1208         after = len(self.leaves) - 1
1209         if after == -1:
1210             comment.type = STANDALONE_COMMENT
1211             comment.prefix = ""
1212             return False
1213
1214         else:
1215             self.comments.append((after, comment))
1216             return True
1217
1218     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1219         """Generate comments that should appear directly after `leaf`.
1220
1221         Provide a non-negative leaf `_index` to speed up the function.
1222         """
1223         if not self.comments:
1224             return
1225
1226         if _index == -1:
1227             for _index, _leaf in enumerate(self.leaves):
1228                 if leaf is _leaf:
1229                     break
1230
1231             else:
1232                 return
1233
1234         for index, comment_after in self.comments:
1235             if _index == index:
1236                 yield comment_after
1237
1238     def remove_trailing_comma(self) -> None:
1239         """Remove the trailing comma and moves the comments attached to it."""
1240         comma_index = len(self.leaves) - 1
1241         for i in range(len(self.comments)):
1242             comment_index, comment = self.comments[i]
1243             if comment_index == comma_index:
1244                 self.comments[i] = (comma_index - 1, comment)
1245         self.leaves.pop()
1246
1247     def is_complex_subscript(self, leaf: Leaf) -> bool:
1248         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1249         open_lsqb = self.bracket_tracker.get_open_lsqb()
1250         if open_lsqb is None:
1251             return False
1252
1253         subscript_start = open_lsqb.next_sibling
1254
1255         if isinstance(subscript_start, Node):
1256             if subscript_start.type == syms.listmaker:
1257                 return False
1258
1259             if subscript_start.type == syms.subscriptlist:
1260                 subscript_start = child_towards(subscript_start, leaf)
1261         return subscript_start is not None and any(
1262             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1263         )
1264
1265     def __str__(self) -> str:
1266         """Render the line."""
1267         if not self:
1268             return "\n"
1269
1270         indent = "    " * self.depth
1271         leaves = iter(self.leaves)
1272         first = next(leaves)
1273         res = f"{first.prefix}{indent}{first.value}"
1274         for leaf in leaves:
1275             res += str(leaf)
1276         for _, comment in self.comments:
1277             res += str(comment)
1278         return res + "\n"
1279
1280     def __bool__(self) -> bool:
1281         """Return True if the line has leaves or comments."""
1282         return bool(self.leaves or self.comments)
1283
1284
1285 @dataclass
1286 class EmptyLineTracker:
1287     """Provides a stateful method that returns the number of potential extra
1288     empty lines needed before and after the currently processed line.
1289
1290     Note: this tracker works on lines that haven't been split yet.  It assumes
1291     the prefix of the first leaf consists of optional newlines.  Those newlines
1292     are consumed by `maybe_empty_lines()` and included in the computation.
1293     """
1294
1295     is_pyi: bool = False
1296     previous_line: Optional[Line] = None
1297     previous_after: int = 0
1298     previous_defs: List[int] = Factory(list)
1299
1300     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1301         """Return the number of extra empty lines before and after the `current_line`.
1302
1303         This is for separating `def`, `async def` and `class` with extra empty
1304         lines (two on module-level).
1305         """
1306         before, after = self._maybe_empty_lines(current_line)
1307         before -= self.previous_after
1308         self.previous_after = after
1309         self.previous_line = current_line
1310         return before, after
1311
1312     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1313         max_allowed = 1
1314         if current_line.depth == 0:
1315             max_allowed = 1 if self.is_pyi else 2
1316         if current_line.leaves:
1317             # Consume the first leaf's extra newlines.
1318             first_leaf = current_line.leaves[0]
1319             before = first_leaf.prefix.count("\n")
1320             before = min(before, max_allowed)
1321             first_leaf.prefix = ""
1322         else:
1323             before = 0
1324         depth = current_line.depth
1325         while self.previous_defs and self.previous_defs[-1] >= depth:
1326             self.previous_defs.pop()
1327             if self.is_pyi:
1328                 before = 0 if depth else 1
1329             else:
1330                 before = 1 if depth else 2
1331         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1332             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1333
1334         if (
1335             self.previous_line
1336             and self.previous_line.is_import
1337             and not current_line.is_import
1338             and depth == self.previous_line.depth
1339         ):
1340             return (before or 1), 0
1341
1342         if (
1343             self.previous_line
1344             and self.previous_line.is_class
1345             and current_line.is_triple_quoted_string
1346         ):
1347             return before, 1
1348
1349         return before, 0
1350
1351     def _maybe_empty_lines_for_class_or_def(
1352         self, current_line: Line, before: int
1353     ) -> Tuple[int, int]:
1354         if not current_line.is_decorator:
1355             self.previous_defs.append(current_line.depth)
1356         if self.previous_line is None:
1357             # Don't insert empty lines before the first line in the file.
1358             return 0, 0
1359
1360         if self.previous_line.is_decorator:
1361             return 0, 0
1362
1363         if self.previous_line.depth < current_line.depth and (
1364             self.previous_line.is_class or self.previous_line.is_def
1365         ):
1366             return 0, 0
1367
1368         if (
1369             self.previous_line.is_comment
1370             and self.previous_line.depth == current_line.depth
1371             and before == 0
1372         ):
1373             return 0, 0
1374
1375         if self.is_pyi:
1376             if self.previous_line.depth > current_line.depth:
1377                 newlines = 1
1378             elif current_line.is_class or self.previous_line.is_class:
1379                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1380                     # No blank line between classes with an empty body
1381                     newlines = 0
1382                 else:
1383                     newlines = 1
1384             elif current_line.is_def and not self.previous_line.is_def:
1385                 # Blank line between a block of functions and a block of non-functions
1386                 newlines = 1
1387             else:
1388                 newlines = 0
1389         else:
1390             newlines = 2
1391         if current_line.depth and newlines:
1392             newlines -= 1
1393         return newlines, 0
1394
1395
1396 @dataclass
1397 class LineGenerator(Visitor[Line]):
1398     """Generates reformatted Line objects.  Empty lines are not emitted.
1399
1400     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1401     in ways that will no longer stringify to valid Python code on the tree.
1402     """
1403
1404     is_pyi: bool = False
1405     normalize_strings: bool = True
1406     current_line: Line = Factory(Line)
1407     remove_u_prefix: bool = False
1408     allow_underscores: bool = False
1409
1410     def line(self, indent: int = 0) -> Iterator[Line]:
1411         """Generate a line.
1412
1413         If the line is empty, only emit if it makes sense.
1414         If the line is too long, split it first and then generate.
1415
1416         If any lines were generated, set up a new current_line.
1417         """
1418         if not self.current_line:
1419             self.current_line.depth += indent
1420             return  # Line is empty, don't emit. Creating a new one unnecessary.
1421
1422         complete_line = self.current_line
1423         self.current_line = Line(depth=complete_line.depth + indent)
1424         yield complete_line
1425
1426     def visit_default(self, node: LN) -> Iterator[Line]:
1427         """Default `visit_*()` implementation. Recurses to children of `node`."""
1428         if isinstance(node, Leaf):
1429             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1430             for comment in generate_comments(node):
1431                 if any_open_brackets:
1432                     # any comment within brackets is subject to splitting
1433                     self.current_line.append(comment)
1434                 elif comment.type == token.COMMENT:
1435                     # regular trailing comment
1436                     self.current_line.append(comment)
1437                     yield from self.line()
1438
1439                 else:
1440                     # regular standalone comment
1441                     yield from self.line()
1442
1443                     self.current_line.append(comment)
1444                     yield from self.line()
1445
1446             normalize_prefix(node, inside_brackets=any_open_brackets)
1447             if self.normalize_strings and node.type == token.STRING:
1448                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1449                 normalize_string_quotes(node)
1450             if node.type == token.NUMBER:
1451                 normalize_numeric_literal(node, self.allow_underscores)
1452             if node.type not in WHITESPACE:
1453                 self.current_line.append(node)
1454         yield from super().visit_default(node)
1455
1456     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1457         """Increase indentation level, maybe yield a line."""
1458         # In blib2to3 INDENT never holds comments.
1459         yield from self.line(+1)
1460         yield from self.visit_default(node)
1461
1462     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1463         """Decrease indentation level, maybe yield a line."""
1464         # The current line might still wait for trailing comments.  At DEDENT time
1465         # there won't be any (they would be prefixes on the preceding NEWLINE).
1466         # Emit the line then.
1467         yield from self.line()
1468
1469         # While DEDENT has no value, its prefix may contain standalone comments
1470         # that belong to the current indentation level.  Get 'em.
1471         yield from self.visit_default(node)
1472
1473         # Finally, emit the dedent.
1474         yield from self.line(-1)
1475
1476     def visit_stmt(
1477         self, node: Node, keywords: Set[str], parens: Set[str]
1478     ) -> Iterator[Line]:
1479         """Visit a statement.
1480
1481         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1482         `def`, `with`, `class`, `assert` and assignments.
1483
1484         The relevant Python language `keywords` for a given statement will be
1485         NAME leaves within it. This methods puts those on a separate line.
1486
1487         `parens` holds a set of string leaf values immediately after which
1488         invisible parens should be put.
1489         """
1490         normalize_invisible_parens(node, parens_after=parens)
1491         for child in node.children:
1492             if child.type == token.NAME and child.value in keywords:  # type: ignore
1493                 yield from self.line()
1494
1495             yield from self.visit(child)
1496
1497     def visit_suite(self, node: Node) -> Iterator[Line]:
1498         """Visit a suite."""
1499         if self.is_pyi and is_stub_suite(node):
1500             yield from self.visit(node.children[2])
1501         else:
1502             yield from self.visit_default(node)
1503
1504     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1505         """Visit a statement without nested statements."""
1506         is_suite_like = node.parent and node.parent.type in STATEMENT
1507         if is_suite_like:
1508             if self.is_pyi and is_stub_body(node):
1509                 yield from self.visit_default(node)
1510             else:
1511                 yield from self.line(+1)
1512                 yield from self.visit_default(node)
1513                 yield from self.line(-1)
1514
1515         else:
1516             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1517                 yield from self.line()
1518             yield from self.visit_default(node)
1519
1520     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1521         """Visit `async def`, `async for`, `async with`."""
1522         yield from self.line()
1523
1524         children = iter(node.children)
1525         for child in children:
1526             yield from self.visit(child)
1527
1528             if child.type == token.ASYNC:
1529                 break
1530
1531         internal_stmt = next(children)
1532         for child in internal_stmt.children:
1533             yield from self.visit(child)
1534
1535     def visit_decorators(self, node: Node) -> Iterator[Line]:
1536         """Visit decorators."""
1537         for child in node.children:
1538             yield from self.line()
1539             yield from self.visit(child)
1540
1541     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1542         """Remove a semicolon and put the other statement on a separate line."""
1543         yield from self.line()
1544
1545     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1546         """End of file. Process outstanding comments and end with a newline."""
1547         yield from self.visit_default(leaf)
1548         yield from self.line()
1549
1550     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1551         if not self.current_line.bracket_tracker.any_open_brackets():
1552             yield from self.line()
1553         yield from self.visit_default(leaf)
1554
1555     def __attrs_post_init__(self) -> None:
1556         """You are in a twisty little maze of passages."""
1557         v = self.visit_stmt
1558         Ø: Set[str] = set()
1559         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1560         self.visit_if_stmt = partial(
1561             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1562         )
1563         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1564         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1565         self.visit_try_stmt = partial(
1566             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1567         )
1568         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1569         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1570         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1571         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1572         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1573         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1574         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1575         self.visit_async_funcdef = self.visit_async_stmt
1576         self.visit_decorated = self.visit_decorators
1577
1578
1579 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1580 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1581 OPENING_BRACKETS = set(BRACKET.keys())
1582 CLOSING_BRACKETS = set(BRACKET.values())
1583 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1584 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1585
1586
1587 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1588     """Return whitespace prefix if needed for the given `leaf`.
1589
1590     `complex_subscript` signals whether the given leaf is part of a subscription
1591     which has non-trivial arguments, like arithmetic expressions or function calls.
1592     """
1593     NO = ""
1594     SPACE = " "
1595     DOUBLESPACE = "  "
1596     t = leaf.type
1597     p = leaf.parent
1598     v = leaf.value
1599     if t in ALWAYS_NO_SPACE:
1600         return NO
1601
1602     if t == token.COMMENT:
1603         return DOUBLESPACE
1604
1605     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1606     if t == token.COLON and p.type not in {
1607         syms.subscript,
1608         syms.subscriptlist,
1609         syms.sliceop,
1610     }:
1611         return NO
1612
1613     prev = leaf.prev_sibling
1614     if not prev:
1615         prevp = preceding_leaf(p)
1616         if not prevp or prevp.type in OPENING_BRACKETS:
1617             return NO
1618
1619         if t == token.COLON:
1620             if prevp.type == token.COLON:
1621                 return NO
1622
1623             elif prevp.type != token.COMMA and not complex_subscript:
1624                 return NO
1625
1626             return SPACE
1627
1628         if prevp.type == token.EQUAL:
1629             if prevp.parent:
1630                 if prevp.parent.type in {
1631                     syms.arglist,
1632                     syms.argument,
1633                     syms.parameters,
1634                     syms.varargslist,
1635                 }:
1636                     return NO
1637
1638                 elif prevp.parent.type == syms.typedargslist:
1639                     # A bit hacky: if the equal sign has whitespace, it means we
1640                     # previously found it's a typed argument.  So, we're using
1641                     # that, too.
1642                     return prevp.prefix
1643
1644         elif prevp.type in STARS:
1645             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1646                 return NO
1647
1648         elif prevp.type == token.COLON:
1649             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1650                 return SPACE if complex_subscript else NO
1651
1652         elif (
1653             prevp.parent
1654             and prevp.parent.type == syms.factor
1655             and prevp.type in MATH_OPERATORS
1656         ):
1657             return NO
1658
1659         elif (
1660             prevp.type == token.RIGHTSHIFT
1661             and prevp.parent
1662             and prevp.parent.type == syms.shift_expr
1663             and prevp.prev_sibling
1664             and prevp.prev_sibling.type == token.NAME
1665             and prevp.prev_sibling.value == "print"  # type: ignore
1666         ):
1667             # Python 2 print chevron
1668             return NO
1669
1670     elif prev.type in OPENING_BRACKETS:
1671         return NO
1672
1673     if p.type in {syms.parameters, syms.arglist}:
1674         # untyped function signatures or calls
1675         if not prev or prev.type != token.COMMA:
1676             return NO
1677
1678     elif p.type == syms.varargslist:
1679         # lambdas
1680         if prev and prev.type != token.COMMA:
1681             return NO
1682
1683     elif p.type == syms.typedargslist:
1684         # typed function signatures
1685         if not prev:
1686             return NO
1687
1688         if t == token.EQUAL:
1689             if prev.type != syms.tname:
1690                 return NO
1691
1692         elif prev.type == token.EQUAL:
1693             # A bit hacky: if the equal sign has whitespace, it means we
1694             # previously found it's a typed argument.  So, we're using that, too.
1695             return prev.prefix
1696
1697         elif prev.type != token.COMMA:
1698             return NO
1699
1700     elif p.type == syms.tname:
1701         # type names
1702         if not prev:
1703             prevp = preceding_leaf(p)
1704             if not prevp or prevp.type != token.COMMA:
1705                 return NO
1706
1707     elif p.type == syms.trailer:
1708         # attributes and calls
1709         if t == token.LPAR or t == token.RPAR:
1710             return NO
1711
1712         if not prev:
1713             if t == token.DOT:
1714                 prevp = preceding_leaf(p)
1715                 if not prevp or prevp.type != token.NUMBER:
1716                     return NO
1717
1718             elif t == token.LSQB:
1719                 return NO
1720
1721         elif prev.type != token.COMMA:
1722             return NO
1723
1724     elif p.type == syms.argument:
1725         # single argument
1726         if t == token.EQUAL:
1727             return NO
1728
1729         if not prev:
1730             prevp = preceding_leaf(p)
1731             if not prevp or prevp.type == token.LPAR:
1732                 return NO
1733
1734         elif prev.type in {token.EQUAL} | STARS:
1735             return NO
1736
1737     elif p.type == syms.decorator:
1738         # decorators
1739         return NO
1740
1741     elif p.type == syms.dotted_name:
1742         if prev:
1743             return NO
1744
1745         prevp = preceding_leaf(p)
1746         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1747             return NO
1748
1749     elif p.type == syms.classdef:
1750         if t == token.LPAR:
1751             return NO
1752
1753         if prev and prev.type == token.LPAR:
1754             return NO
1755
1756     elif p.type in {syms.subscript, syms.sliceop}:
1757         # indexing
1758         if not prev:
1759             assert p.parent is not None, "subscripts are always parented"
1760             if p.parent.type == syms.subscriptlist:
1761                 return SPACE
1762
1763             return NO
1764
1765         elif not complex_subscript:
1766             return NO
1767
1768     elif p.type == syms.atom:
1769         if prev and t == token.DOT:
1770             # dots, but not the first one.
1771             return NO
1772
1773     elif p.type == syms.dictsetmaker:
1774         # dict unpacking
1775         if prev and prev.type == token.DOUBLESTAR:
1776             return NO
1777
1778     elif p.type in {syms.factor, syms.star_expr}:
1779         # unary ops
1780         if not prev:
1781             prevp = preceding_leaf(p)
1782             if not prevp or prevp.type in OPENING_BRACKETS:
1783                 return NO
1784
1785             prevp_parent = prevp.parent
1786             assert prevp_parent is not None
1787             if prevp.type == token.COLON and prevp_parent.type in {
1788                 syms.subscript,
1789                 syms.sliceop,
1790             }:
1791                 return NO
1792
1793             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1794                 return NO
1795
1796         elif t in {token.NAME, token.NUMBER, token.STRING}:
1797             return NO
1798
1799     elif p.type == syms.import_from:
1800         if t == token.DOT:
1801             if prev and prev.type == token.DOT:
1802                 return NO
1803
1804         elif t == token.NAME:
1805             if v == "import":
1806                 return SPACE
1807
1808             if prev and prev.type == token.DOT:
1809                 return NO
1810
1811     elif p.type == syms.sliceop:
1812         return NO
1813
1814     return SPACE
1815
1816
1817 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1818     """Return the first leaf that precedes `node`, if any."""
1819     while node:
1820         res = node.prev_sibling
1821         if res:
1822             if isinstance(res, Leaf):
1823                 return res
1824
1825             try:
1826                 return list(res.leaves())[-1]
1827
1828             except IndexError:
1829                 return None
1830
1831         node = node.parent
1832     return None
1833
1834
1835 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1836     """Return the child of `ancestor` that contains `descendant`."""
1837     node: Optional[LN] = descendant
1838     while node and node.parent != ancestor:
1839         node = node.parent
1840     return node
1841
1842
1843 def container_of(leaf: Leaf) -> LN:
1844     """Return `leaf` or one of its ancestors that is the topmost container of it.
1845
1846     By "container" we mean a node where `leaf` is the very first child.
1847     """
1848     same_prefix = leaf.prefix
1849     container: LN = leaf
1850     while container:
1851         parent = container.parent
1852         if parent is None:
1853             break
1854
1855         if parent.children[0].prefix != same_prefix:
1856             break
1857
1858         if parent.type == syms.file_input:
1859             break
1860
1861         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1862             break
1863
1864         container = parent
1865     return container
1866
1867
1868 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1869     """Return the priority of the `leaf` delimiter, given a line break after it.
1870
1871     The delimiter priorities returned here are from those delimiters that would
1872     cause a line break after themselves.
1873
1874     Higher numbers are higher priority.
1875     """
1876     if leaf.type == token.COMMA:
1877         return COMMA_PRIORITY
1878
1879     return 0
1880
1881
1882 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1883     """Return the priority of the `leaf` delimiter, given a line before after it.
1884
1885     The delimiter priorities returned here are from those delimiters that would
1886     cause a line break before themselves.
1887
1888     Higher numbers are higher priority.
1889     """
1890     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1891         # * and ** might also be MATH_OPERATORS but in this case they are not.
1892         # Don't treat them as a delimiter.
1893         return 0
1894
1895     if (
1896         leaf.type == token.DOT
1897         and leaf.parent
1898         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1899         and (previous is None or previous.type in CLOSING_BRACKETS)
1900     ):
1901         return DOT_PRIORITY
1902
1903     if (
1904         leaf.type in MATH_OPERATORS
1905         and leaf.parent
1906         and leaf.parent.type not in {syms.factor, syms.star_expr}
1907     ):
1908         return MATH_PRIORITIES[leaf.type]
1909
1910     if leaf.type in COMPARATORS:
1911         return COMPARATOR_PRIORITY
1912
1913     if (
1914         leaf.type == token.STRING
1915         and previous is not None
1916         and previous.type == token.STRING
1917     ):
1918         return STRING_PRIORITY
1919
1920     if leaf.type != token.NAME:
1921         return 0
1922
1923     if (
1924         leaf.value == "for"
1925         and leaf.parent
1926         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1927     ):
1928         return COMPREHENSION_PRIORITY
1929
1930     if (
1931         leaf.value == "if"
1932         and leaf.parent
1933         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1934     ):
1935         return COMPREHENSION_PRIORITY
1936
1937     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1938         return TERNARY_PRIORITY
1939
1940     if leaf.value == "is":
1941         return COMPARATOR_PRIORITY
1942
1943     if (
1944         leaf.value == "in"
1945         and leaf.parent
1946         and leaf.parent.type in {syms.comp_op, syms.comparison}
1947         and not (
1948             previous is not None
1949             and previous.type == token.NAME
1950             and previous.value == "not"
1951         )
1952     ):
1953         return COMPARATOR_PRIORITY
1954
1955     if (
1956         leaf.value == "not"
1957         and leaf.parent
1958         and leaf.parent.type == syms.comp_op
1959         and not (
1960             previous is not None
1961             and previous.type == token.NAME
1962             and previous.value == "is"
1963         )
1964     ):
1965         return COMPARATOR_PRIORITY
1966
1967     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1968         return LOGIC_PRIORITY
1969
1970     return 0
1971
1972
1973 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
1974 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
1975
1976
1977 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1978     """Clean the prefix of the `leaf` and generate comments from it, if any.
1979
1980     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1981     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1982     move because it does away with modifying the grammar to include all the
1983     possible places in which comments can be placed.
1984
1985     The sad consequence for us though is that comments don't "belong" anywhere.
1986     This is why this function generates simple parentless Leaf objects for
1987     comments.  We simply don't know what the correct parent should be.
1988
1989     No matter though, we can live without this.  We really only need to
1990     differentiate between inline and standalone comments.  The latter don't
1991     share the line with any code.
1992
1993     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1994     are emitted with a fake STANDALONE_COMMENT token identifier.
1995     """
1996     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
1997         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
1998
1999
2000 @dataclass
2001 class ProtoComment:
2002     type: int  # token.COMMENT or STANDALONE_COMMENT
2003     value: str  # content of the comment
2004     newlines: int  # how many newlines before the comment
2005     consumed: int  # how many characters of the original leaf's prefix did we consume
2006
2007
2008 @lru_cache(maxsize=4096)
2009 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2010     result: List[ProtoComment] = []
2011     if not prefix or "#" not in prefix:
2012         return result
2013
2014     consumed = 0
2015     nlines = 0
2016     for index, line in enumerate(prefix.split("\n")):
2017         consumed += len(line) + 1  # adding the length of the split '\n'
2018         line = line.lstrip()
2019         if not line:
2020             nlines += 1
2021         if not line.startswith("#"):
2022             continue
2023
2024         if index == 0 and not is_endmarker:
2025             comment_type = token.COMMENT  # simple trailing comment
2026         else:
2027             comment_type = STANDALONE_COMMENT
2028         comment = make_comment(line)
2029         result.append(
2030             ProtoComment(
2031                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2032             )
2033         )
2034         nlines = 0
2035     return result
2036
2037
2038 def make_comment(content: str) -> str:
2039     """Return a consistently formatted comment from the given `content` string.
2040
2041     All comments (except for "##", "#!", "#:") should have a single space between
2042     the hash sign and the content.
2043
2044     If `content` didn't start with a hash sign, one is provided.
2045     """
2046     content = content.rstrip()
2047     if not content:
2048         return "#"
2049
2050     if content[0] == "#":
2051         content = content[1:]
2052     if content and content[0] not in " !:#":
2053         content = " " + content
2054     return "#" + content
2055
2056
2057 def split_line(
2058     line: Line, line_length: int, inner: bool = False, py36: bool = False
2059 ) -> Iterator[Line]:
2060     """Split a `line` into potentially many lines.
2061
2062     They should fit in the allotted `line_length` but might not be able to.
2063     `inner` signifies that there were a pair of brackets somewhere around the
2064     current `line`, possibly transitively. This means we can fallback to splitting
2065     by delimiters if the LHS/RHS don't yield any results.
2066
2067     If `py36` is True, splitting may generate syntax that is only compatible
2068     with Python 3.6 and later.
2069     """
2070     if line.is_comment:
2071         yield line
2072         return
2073
2074     line_str = str(line).strip("\n")
2075     if not line.should_explode and is_line_short_enough(
2076         line, line_length=line_length, line_str=line_str
2077     ):
2078         yield line
2079         return
2080
2081     split_funcs: List[SplitFunc]
2082     if line.is_def:
2083         split_funcs = [left_hand_split]
2084     else:
2085
2086         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2087             for omit in generate_trailers_to_omit(line, line_length):
2088                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2089                 if is_line_short_enough(lines[0], line_length=line_length):
2090                     yield from lines
2091                     return
2092
2093             # All splits failed, best effort split with no omits.
2094             # This mostly happens to multiline strings that are by definition
2095             # reported as not fitting a single line.
2096             yield from right_hand_split(line, py36)
2097
2098         if line.inside_brackets:
2099             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2100         else:
2101             split_funcs = [rhs]
2102     for split_func in split_funcs:
2103         # We are accumulating lines in `result` because we might want to abort
2104         # mission and return the original line in the end, or attempt a different
2105         # split altogether.
2106         result: List[Line] = []
2107         try:
2108             for l in split_func(line, py36):
2109                 if str(l).strip("\n") == line_str:
2110                     raise CannotSplit("Split function returned an unchanged result")
2111
2112                 result.extend(
2113                     split_line(l, line_length=line_length, inner=True, py36=py36)
2114                 )
2115         except CannotSplit as cs:
2116             continue
2117
2118         else:
2119             yield from result
2120             break
2121
2122     else:
2123         yield line
2124
2125
2126 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2127     """Split line into many lines, starting with the first matching bracket pair.
2128
2129     Note: this usually looks weird, only use this for function definitions.
2130     Prefer RHS otherwise.  This is why this function is not symmetrical with
2131     :func:`right_hand_split` which also handles optional parentheses.
2132     """
2133     head = Line(depth=line.depth)
2134     body = Line(depth=line.depth + 1, inside_brackets=True)
2135     tail = Line(depth=line.depth)
2136     tail_leaves: List[Leaf] = []
2137     body_leaves: List[Leaf] = []
2138     head_leaves: List[Leaf] = []
2139     current_leaves = head_leaves
2140     matching_bracket = None
2141     for leaf in line.leaves:
2142         if (
2143             current_leaves is body_leaves
2144             and leaf.type in CLOSING_BRACKETS
2145             and leaf.opening_bracket is matching_bracket
2146         ):
2147             current_leaves = tail_leaves if body_leaves else head_leaves
2148         current_leaves.append(leaf)
2149         if current_leaves is head_leaves:
2150             if leaf.type in OPENING_BRACKETS:
2151                 matching_bracket = leaf
2152                 current_leaves = body_leaves
2153     # Since body is a new indent level, remove spurious leading whitespace.
2154     if body_leaves:
2155         normalize_prefix(body_leaves[0], inside_brackets=True)
2156     # Build the new lines.
2157     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2158         for leaf in leaves:
2159             result.append(leaf, preformatted=True)
2160             for comment_after in line.comments_after(leaf):
2161                 result.append(comment_after, preformatted=True)
2162     bracket_split_succeeded_or_raise(head, body, tail)
2163     for result in (head, body, tail):
2164         if result:
2165             yield result
2166
2167
2168 def right_hand_split(
2169     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2170 ) -> Iterator[Line]:
2171     """Split line into many lines, starting with the last matching bracket pair.
2172
2173     If the split was by optional parentheses, attempt splitting without them, too.
2174     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2175     this split.
2176
2177     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2178     """
2179     head = Line(depth=line.depth)
2180     body = Line(depth=line.depth + 1, inside_brackets=True)
2181     tail = Line(depth=line.depth)
2182     tail_leaves: List[Leaf] = []
2183     body_leaves: List[Leaf] = []
2184     head_leaves: List[Leaf] = []
2185     current_leaves = tail_leaves
2186     opening_bracket = None
2187     closing_bracket = None
2188     for leaf in reversed(line.leaves):
2189         if current_leaves is body_leaves:
2190             if leaf is opening_bracket:
2191                 current_leaves = head_leaves if body_leaves else tail_leaves
2192         current_leaves.append(leaf)
2193         if current_leaves is tail_leaves:
2194             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2195                 opening_bracket = leaf.opening_bracket
2196                 closing_bracket = leaf
2197                 current_leaves = body_leaves
2198     tail_leaves.reverse()
2199     body_leaves.reverse()
2200     head_leaves.reverse()
2201     # Since body is a new indent level, remove spurious leading whitespace.
2202     if body_leaves:
2203         normalize_prefix(body_leaves[0], inside_brackets=True)
2204     if not head_leaves:
2205         # No `head` means the split failed. Either `tail` has all content or
2206         # the matching `opening_bracket` wasn't available on `line` anymore.
2207         raise CannotSplit("No brackets found")
2208
2209     # Build the new lines.
2210     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2211         for leaf in leaves:
2212             result.append(leaf, preformatted=True)
2213             for comment_after in line.comments_after(leaf):
2214                 result.append(comment_after, preformatted=True)
2215     assert opening_bracket and closing_bracket
2216     body.should_explode = should_explode(body, opening_bracket)
2217     bracket_split_succeeded_or_raise(head, body, tail)
2218     if (
2219         # the body shouldn't be exploded
2220         not body.should_explode
2221         # the opening bracket is an optional paren
2222         and opening_bracket.type == token.LPAR
2223         and not opening_bracket.value
2224         # the closing bracket is an optional paren
2225         and closing_bracket.type == token.RPAR
2226         and not closing_bracket.value
2227         # it's not an import (optional parens are the only thing we can split on
2228         # in this case; attempting a split without them is a waste of time)
2229         and not line.is_import
2230         # there are no standalone comments in the body
2231         and not body.contains_standalone_comments(0)
2232         # and we can actually remove the parens
2233         and can_omit_invisible_parens(body, line_length)
2234     ):
2235         omit = {id(closing_bracket), *omit}
2236         try:
2237             yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2238             return
2239
2240         except CannotSplit:
2241             if not (
2242                 can_be_split(body)
2243                 or is_line_short_enough(body, line_length=line_length)
2244             ):
2245                 raise CannotSplit(
2246                     "Splitting failed, body is still too long and can't be split."
2247                 )
2248
2249             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2250                 raise CannotSplit(
2251                     "The current optional pair of parentheses is bound to fail to "
2252                     "satisfy the splitting algorithm because the head or the tail "
2253                     "contains multiline strings which by definition never fit one "
2254                     "line."
2255                 )
2256
2257     ensure_visible(opening_bracket)
2258     ensure_visible(closing_bracket)
2259     for result in (head, body, tail):
2260         if result:
2261             yield result
2262
2263
2264 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2265     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2266
2267     Do nothing otherwise.
2268
2269     A left- or right-hand split is based on a pair of brackets. Content before
2270     (and including) the opening bracket is left on one line, content inside the
2271     brackets is put on a separate line, and finally content starting with and
2272     following the closing bracket is put on a separate line.
2273
2274     Those are called `head`, `body`, and `tail`, respectively. If the split
2275     produced the same line (all content in `head`) or ended up with an empty `body`
2276     and the `tail` is just the closing bracket, then it's considered failed.
2277     """
2278     tail_len = len(str(tail).strip())
2279     if not body:
2280         if tail_len == 0:
2281             raise CannotSplit("Splitting brackets produced the same line")
2282
2283         elif tail_len < 3:
2284             raise CannotSplit(
2285                 f"Splitting brackets on an empty body to save "
2286                 f"{tail_len} characters is not worth it"
2287             )
2288
2289
2290 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2291     """Normalize prefix of the first leaf in every line returned by `split_func`.
2292
2293     This is a decorator over relevant split functions.
2294     """
2295
2296     @wraps(split_func)
2297     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2298         for l in split_func(line, py36):
2299             normalize_prefix(l.leaves[0], inside_brackets=True)
2300             yield l
2301
2302     return split_wrapper
2303
2304
2305 @dont_increase_indentation
2306 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2307     """Split according to delimiters of the highest priority.
2308
2309     If `py36` is True, the split will add trailing commas also in function
2310     signatures that contain `*` and `**`.
2311     """
2312     try:
2313         last_leaf = line.leaves[-1]
2314     except IndexError:
2315         raise CannotSplit("Line empty")
2316
2317     bt = line.bracket_tracker
2318     try:
2319         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2320     except ValueError:
2321         raise CannotSplit("No delimiters found")
2322
2323     if delimiter_priority == DOT_PRIORITY:
2324         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2325             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2326
2327     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2328     lowest_depth = sys.maxsize
2329     trailing_comma_safe = True
2330
2331     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2332         """Append `leaf` to current line or to new line if appending impossible."""
2333         nonlocal current_line
2334         try:
2335             current_line.append_safe(leaf, preformatted=True)
2336         except ValueError as ve:
2337             yield current_line
2338
2339             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2340             current_line.append(leaf)
2341
2342     for index, leaf in enumerate(line.leaves):
2343         yield from append_to_line(leaf)
2344
2345         for comment_after in line.comments_after(leaf, index):
2346             yield from append_to_line(comment_after)
2347
2348         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2349         if leaf.bracket_depth == lowest_depth and is_vararg(
2350             leaf, within=VARARGS_PARENTS
2351         ):
2352             trailing_comma_safe = trailing_comma_safe and py36
2353         leaf_priority = bt.delimiters.get(id(leaf))
2354         if leaf_priority == delimiter_priority:
2355             yield current_line
2356
2357             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2358     if current_line:
2359         if (
2360             trailing_comma_safe
2361             and delimiter_priority == COMMA_PRIORITY
2362             and current_line.leaves[-1].type != token.COMMA
2363             and current_line.leaves[-1].type != STANDALONE_COMMENT
2364         ):
2365             current_line.append(Leaf(token.COMMA, ","))
2366         yield current_line
2367
2368
2369 @dont_increase_indentation
2370 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2371     """Split standalone comments from the rest of the line."""
2372     if not line.contains_standalone_comments(0):
2373         raise CannotSplit("Line does not have any standalone comments")
2374
2375     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2376
2377     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2378         """Append `leaf` to current line or to new line if appending impossible."""
2379         nonlocal current_line
2380         try:
2381             current_line.append_safe(leaf, preformatted=True)
2382         except ValueError as ve:
2383             yield current_line
2384
2385             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2386             current_line.append(leaf)
2387
2388     for index, leaf in enumerate(line.leaves):
2389         yield from append_to_line(leaf)
2390
2391         for comment_after in line.comments_after(leaf, index):
2392             yield from append_to_line(comment_after)
2393
2394     if current_line:
2395         yield current_line
2396
2397
2398 def is_import(leaf: Leaf) -> bool:
2399     """Return True if the given leaf starts an import statement."""
2400     p = leaf.parent
2401     t = leaf.type
2402     v = leaf.value
2403     return bool(
2404         t == token.NAME
2405         and (
2406             (v == "import" and p and p.type == syms.import_name)
2407             or (v == "from" and p and p.type == syms.import_from)
2408         )
2409     )
2410
2411
2412 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2413     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2414     else.
2415
2416     Note: don't use backslashes for formatting or you'll lose your voting rights.
2417     """
2418     if not inside_brackets:
2419         spl = leaf.prefix.split("#")
2420         if "\\" not in spl[0]:
2421             nl_count = spl[-1].count("\n")
2422             if len(spl) > 1:
2423                 nl_count -= 1
2424             leaf.prefix = "\n" * nl_count
2425             return
2426
2427     leaf.prefix = ""
2428
2429
2430 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2431     """Make all string prefixes lowercase.
2432
2433     If remove_u_prefix is given, also removes any u prefix from the string.
2434
2435     Note: Mutates its argument.
2436     """
2437     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2438     assert match is not None, f"failed to match string {leaf.value!r}"
2439     orig_prefix = match.group(1)
2440     new_prefix = orig_prefix.lower()
2441     if remove_u_prefix:
2442         new_prefix = new_prefix.replace("u", "")
2443     leaf.value = f"{new_prefix}{match.group(2)}"
2444
2445
2446 def normalize_string_quotes(leaf: Leaf) -> None:
2447     """Prefer double quotes but only if it doesn't cause more escaping.
2448
2449     Adds or removes backslashes as appropriate. Doesn't parse and fix
2450     strings nested in f-strings (yet).
2451
2452     Note: Mutates its argument.
2453     """
2454     value = leaf.value.lstrip("furbFURB")
2455     if value[:3] == '"""':
2456         return
2457
2458     elif value[:3] == "'''":
2459         orig_quote = "'''"
2460         new_quote = '"""'
2461     elif value[0] == '"':
2462         orig_quote = '"'
2463         new_quote = "'"
2464     else:
2465         orig_quote = "'"
2466         new_quote = '"'
2467     first_quote_pos = leaf.value.find(orig_quote)
2468     if first_quote_pos == -1:
2469         return  # There's an internal error
2470
2471     prefix = leaf.value[:first_quote_pos]
2472     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2473     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2474     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2475     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2476     if "r" in prefix.casefold():
2477         if unescaped_new_quote.search(body):
2478             # There's at least one unescaped new_quote in this raw string
2479             # so converting is impossible
2480             return
2481
2482         # Do not introduce or remove backslashes in raw strings
2483         new_body = body
2484     else:
2485         # remove unnecessary escapes
2486         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2487         if body != new_body:
2488             # Consider the string without unnecessary escapes as the original
2489             body = new_body
2490             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2491         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2492         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2493     if "f" in prefix.casefold():
2494         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2495         for m in matches:
2496             if "\\" in str(m):
2497                 # Do not introduce backslashes in interpolated expressions
2498                 return
2499     if new_quote == '"""' and new_body[-1:] == '"':
2500         # edge case:
2501         new_body = new_body[:-1] + '\\"'
2502     orig_escape_count = body.count("\\")
2503     new_escape_count = new_body.count("\\")
2504     if new_escape_count > orig_escape_count:
2505         return  # Do not introduce more escaping
2506
2507     if new_escape_count == orig_escape_count and orig_quote == '"':
2508         return  # Prefer double quotes
2509
2510     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2511
2512
2513 def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
2514     """Normalizes numeric (float, int, and complex) literals.
2515
2516     All letters used in the representation are normalized to lowercase, long number
2517     literals are split using underscores.
2518     """
2519     text = leaf.value.lower()
2520     if text.startswith(("0o", "0x", "0b")):
2521         # Leave octal, hex, and binary literals alone.
2522         pass
2523     elif "e" in text:
2524         before, after = text.split("e")
2525         sign = ""
2526         if after.startswith("-"):
2527             after = after[1:]
2528             sign = "-"
2529         elif after.startswith("+"):
2530             after = after[1:]
2531         before = format_float_or_int_string(before, allow_underscores)
2532         after = format_int_string(after, allow_underscores)
2533         text = f"{before}e{sign}{after}"
2534     elif text.endswith(("j", "l")):
2535         number = text[:-1]
2536         suffix = text[-1]
2537         text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
2538     else:
2539         text = format_float_or_int_string(text, allow_underscores)
2540     leaf.value = text
2541
2542
2543 def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
2544     """Formats a float string like "1.0"."""
2545     if "." not in text:
2546         return format_int_string(text, allow_underscores)
2547
2548     before, after = text.split(".")
2549     before = format_int_string(before, allow_underscores) if before else "0"
2550     after = format_int_string(after, allow_underscores) if after else "0"
2551     return f"{before}.{after}"
2552
2553
2554 def format_int_string(text: str, allow_underscores: bool) -> str:
2555     """Normalizes underscores in a string to e.g. 1_000_000.
2556
2557     Input must be a string of at least six digits and optional underscores.
2558     """
2559     if not allow_underscores:
2560         return text
2561
2562     text = text.replace("_", "")
2563     if len(text) <= 6:
2564         # No underscores for numbers <= 6 digits long.
2565         return text
2566
2567     # Avoid removing leading zeros, which are important if we're formatting
2568     # part of a number like "0.001".
2569     return format(int("1" + text), "3_")[1:].lstrip("_")
2570
2571
2572 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2573     """Make existing optional parentheses invisible or create new ones.
2574
2575     `parens_after` is a set of string leaf values immeditely after which parens
2576     should be put.
2577
2578     Standardizes on visible parentheses for single-element tuples, and keeps
2579     existing visible parentheses for other tuples and generator expressions.
2580     """
2581     for pc in list_comments(node.prefix, is_endmarker=False):
2582         if pc.value in FMT_OFF:
2583             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2584             return
2585
2586     check_lpar = False
2587     for index, child in enumerate(list(node.children)):
2588         if check_lpar:
2589             if child.type == syms.atom:
2590                 if maybe_make_parens_invisible_in_atom(child):
2591                     lpar = Leaf(token.LPAR, "")
2592                     rpar = Leaf(token.RPAR, "")
2593                     index = child.remove() or 0
2594                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2595             elif is_one_tuple(child):
2596                 # wrap child in visible parentheses
2597                 lpar = Leaf(token.LPAR, "(")
2598                 rpar = Leaf(token.RPAR, ")")
2599                 child.remove()
2600                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2601             elif node.type == syms.import_from:
2602                 # "import from" nodes store parentheses directly as part of
2603                 # the statement
2604                 if child.type == token.LPAR:
2605                     # make parentheses invisible
2606                     child.value = ""  # type: ignore
2607                     node.children[-1].value = ""  # type: ignore
2608                 elif child.type != token.STAR:
2609                     # insert invisible parentheses
2610                     node.insert_child(index, Leaf(token.LPAR, ""))
2611                     node.append_child(Leaf(token.RPAR, ""))
2612                 break
2613
2614             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2615                 # wrap child in invisible parentheses
2616                 lpar = Leaf(token.LPAR, "")
2617                 rpar = Leaf(token.RPAR, "")
2618                 index = child.remove() or 0
2619                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2620
2621         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2622
2623
2624 def normalize_fmt_off(node: Node) -> None:
2625     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2626     try_again = True
2627     while try_again:
2628         try_again = convert_one_fmt_off_pair(node)
2629
2630
2631 def convert_one_fmt_off_pair(node: Node) -> bool:
2632     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2633
2634     Returns True if a pair was converted.
2635     """
2636     for leaf in node.leaves():
2637         previous_consumed = 0
2638         for comment in list_comments(leaf.prefix, is_endmarker=False):
2639             if comment.value in FMT_OFF:
2640                 # We only want standalone comments. If there's no previous leaf or
2641                 # the previous leaf is indentation, it's a standalone comment in
2642                 # disguise.
2643                 if comment.type != STANDALONE_COMMENT:
2644                     prev = preceding_leaf(leaf)
2645                     if prev and prev.type not in WHITESPACE:
2646                         continue
2647
2648                 ignored_nodes = list(generate_ignored_nodes(leaf))
2649                 if not ignored_nodes:
2650                     continue
2651
2652                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2653                 parent = first.parent
2654                 prefix = first.prefix
2655                 first.prefix = prefix[comment.consumed :]
2656                 hidden_value = (
2657                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2658                 )
2659                 if hidden_value.endswith("\n"):
2660                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2661                     # leaf (possibly followed by a DEDENT).
2662                     hidden_value = hidden_value[:-1]
2663                 first_idx = None
2664                 for ignored in ignored_nodes:
2665                     index = ignored.remove()
2666                     if first_idx is None:
2667                         first_idx = index
2668                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2669                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2670                 parent.insert_child(
2671                     first_idx,
2672                     Leaf(
2673                         STANDALONE_COMMENT,
2674                         hidden_value,
2675                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2676                     ),
2677                 )
2678                 return True
2679
2680             previous_consumed = comment.consumed
2681
2682     return False
2683
2684
2685 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2686     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2687
2688     Stops at the end of the block.
2689     """
2690     container: Optional[LN] = container_of(leaf)
2691     while container is not None and container.type != token.ENDMARKER:
2692         for comment in list_comments(container.prefix, is_endmarker=False):
2693             if comment.value in FMT_ON:
2694                 return
2695
2696         yield container
2697
2698         container = container.next_sibling
2699
2700
2701 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2702     """If it's safe, make the parens in the atom `node` invisible, recursively.
2703
2704     Returns whether the node should itself be wrapped in invisible parentheses.
2705
2706     """
2707     if (
2708         node.type != syms.atom
2709         or is_empty_tuple(node)
2710         or is_one_tuple(node)
2711         or is_yield(node)
2712         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2713     ):
2714         return False
2715
2716     first = node.children[0]
2717     last = node.children[-1]
2718     if first.type == token.LPAR and last.type == token.RPAR:
2719         # make parentheses invisible
2720         first.value = ""  # type: ignore
2721         last.value = ""  # type: ignore
2722         if len(node.children) > 1:
2723             maybe_make_parens_invisible_in_atom(node.children[1])
2724         return False
2725
2726     return True
2727
2728
2729 def is_empty_tuple(node: LN) -> bool:
2730     """Return True if `node` holds an empty tuple."""
2731     return (
2732         node.type == syms.atom
2733         and len(node.children) == 2
2734         and node.children[0].type == token.LPAR
2735         and node.children[1].type == token.RPAR
2736     )
2737
2738
2739 def is_one_tuple(node: LN) -> bool:
2740     """Return True if `node` holds a tuple with one element, with or without parens."""
2741     if node.type == syms.atom:
2742         if len(node.children) != 3:
2743             return False
2744
2745         lpar, gexp, rpar = node.children
2746         if not (
2747             lpar.type == token.LPAR
2748             and gexp.type == syms.testlist_gexp
2749             and rpar.type == token.RPAR
2750         ):
2751             return False
2752
2753         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2754
2755     return (
2756         node.type in IMPLICIT_TUPLE
2757         and len(node.children) == 2
2758         and node.children[1].type == token.COMMA
2759     )
2760
2761
2762 def is_yield(node: LN) -> bool:
2763     """Return True if `node` holds a `yield` or `yield from` expression."""
2764     if node.type == syms.yield_expr:
2765         return True
2766
2767     if node.type == token.NAME and node.value == "yield":  # type: ignore
2768         return True
2769
2770     if node.type != syms.atom:
2771         return False
2772
2773     if len(node.children) != 3:
2774         return False
2775
2776     lpar, expr, rpar = node.children
2777     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2778         return is_yield(expr)
2779
2780     return False
2781
2782
2783 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2784     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2785
2786     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2787     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2788     extended iterable unpacking (PEP 3132) and additional unpacking
2789     generalizations (PEP 448).
2790     """
2791     if leaf.type not in STARS or not leaf.parent:
2792         return False
2793
2794     p = leaf.parent
2795     if p.type == syms.star_expr:
2796         # Star expressions are also used as assignment targets in extended
2797         # iterable unpacking (PEP 3132).  See what its parent is instead.
2798         if not p.parent:
2799             return False
2800
2801         p = p.parent
2802
2803     return p.type in within
2804
2805
2806 def is_multiline_string(leaf: Leaf) -> bool:
2807     """Return True if `leaf` is a multiline string that actually spans many lines."""
2808     value = leaf.value.lstrip("furbFURB")
2809     return value[:3] in {'"""', "'''"} and "\n" in value
2810
2811
2812 def is_stub_suite(node: Node) -> bool:
2813     """Return True if `node` is a suite with a stub body."""
2814     if (
2815         len(node.children) != 4
2816         or node.children[0].type != token.NEWLINE
2817         or node.children[1].type != token.INDENT
2818         or node.children[3].type != token.DEDENT
2819     ):
2820         return False
2821
2822     return is_stub_body(node.children[2])
2823
2824
2825 def is_stub_body(node: LN) -> bool:
2826     """Return True if `node` is a simple statement containing an ellipsis."""
2827     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2828         return False
2829
2830     if len(node.children) != 2:
2831         return False
2832
2833     child = node.children[0]
2834     return (
2835         child.type == syms.atom
2836         and len(child.children) == 3
2837         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2838     )
2839
2840
2841 def max_delimiter_priority_in_atom(node: LN) -> int:
2842     """Return maximum delimiter priority inside `node`.
2843
2844     This is specific to atoms with contents contained in a pair of parentheses.
2845     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2846     """
2847     if node.type != syms.atom:
2848         return 0
2849
2850     first = node.children[0]
2851     last = node.children[-1]
2852     if not (first.type == token.LPAR and last.type == token.RPAR):
2853         return 0
2854
2855     bt = BracketTracker()
2856     for c in node.children[1:-1]:
2857         if isinstance(c, Leaf):
2858             bt.mark(c)
2859         else:
2860             for leaf in c.leaves():
2861                 bt.mark(leaf)
2862     try:
2863         return bt.max_delimiter_priority()
2864
2865     except ValueError:
2866         return 0
2867
2868
2869 def ensure_visible(leaf: Leaf) -> None:
2870     """Make sure parentheses are visible.
2871
2872     They could be invisible as part of some statements (see
2873     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2874     """
2875     if leaf.type == token.LPAR:
2876         leaf.value = "("
2877     elif leaf.type == token.RPAR:
2878         leaf.value = ")"
2879
2880
2881 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2882     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2883     if not (
2884         opening_bracket.parent
2885         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2886         and opening_bracket.value in "[{("
2887     ):
2888         return False
2889
2890     try:
2891         last_leaf = line.leaves[-1]
2892         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2893         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2894     except (IndexError, ValueError):
2895         return False
2896
2897     return max_priority == COMMA_PRIORITY
2898
2899
2900 def is_python36(node: Node) -> bool:
2901     """Return True if the current file is using Python 3.6+ features.
2902
2903     Currently looking for:
2904     - f-strings;
2905     - underscores in numeric literals; and
2906     - trailing commas after * or ** in function signatures and calls.
2907     """
2908     for n in node.pre_order():
2909         if n.type == token.STRING:
2910             value_head = n.value[:2]  # type: ignore
2911             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2912                 return True
2913
2914         elif n.type == token.NUMBER:
2915             if "_" in n.value:  # type: ignore
2916                 return True
2917
2918         elif (
2919             n.type in {syms.typedargslist, syms.arglist}
2920             and n.children
2921             and n.children[-1].type == token.COMMA
2922         ):
2923             for ch in n.children:
2924                 if ch.type in STARS:
2925                     return True
2926
2927                 if ch.type == syms.argument:
2928                     for argch in ch.children:
2929                         if argch.type in STARS:
2930                             return True
2931
2932     return False
2933
2934
2935 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2936     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2937
2938     Brackets can be omitted if the entire trailer up to and including
2939     a preceding closing bracket fits in one line.
2940
2941     Yielded sets are cumulative (contain results of previous yields, too).  First
2942     set is empty.
2943     """
2944
2945     omit: Set[LeafID] = set()
2946     yield omit
2947
2948     length = 4 * line.depth
2949     opening_bracket = None
2950     closing_bracket = None
2951     optional_brackets: Set[LeafID] = set()
2952     inner_brackets: Set[LeafID] = set()
2953     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2954         length += leaf_length
2955         if length > line_length:
2956             break
2957
2958         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2959         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2960             break
2961
2962         optional_brackets.discard(id(leaf))
2963         if opening_bracket:
2964             if leaf is opening_bracket:
2965                 opening_bracket = None
2966             elif leaf.type in CLOSING_BRACKETS:
2967                 inner_brackets.add(id(leaf))
2968         elif leaf.type in CLOSING_BRACKETS:
2969             if not leaf.value:
2970                 optional_brackets.add(id(opening_bracket))
2971                 continue
2972
2973             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2974                 # Empty brackets would fail a split so treat them as "inner"
2975                 # brackets (e.g. only add them to the `omit` set if another
2976                 # pair of brackets was good enough.
2977                 inner_brackets.add(id(leaf))
2978                 continue
2979
2980             opening_bracket = leaf.opening_bracket
2981             if closing_bracket:
2982                 omit.add(id(closing_bracket))
2983                 omit.update(inner_brackets)
2984                 inner_brackets.clear()
2985                 yield omit
2986             closing_bracket = leaf
2987
2988
2989 def get_future_imports(node: Node) -> Set[str]:
2990     """Return a set of __future__ imports in the file."""
2991     imports: Set[str] = set()
2992
2993     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
2994         for child in children:
2995             if isinstance(child, Leaf):
2996                 if child.type == token.NAME:
2997                     yield child.value
2998             elif child.type == syms.import_as_name:
2999                 orig_name = child.children[0]
3000                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3001                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3002                 yield orig_name.value
3003             elif child.type == syms.import_as_names:
3004                 yield from get_imports_from_children(child.children)
3005             else:
3006                 assert False, "Invalid syntax parsing imports"
3007
3008     for child in node.children:
3009         if child.type != syms.simple_stmt:
3010             break
3011         first_child = child.children[0]
3012         if isinstance(first_child, Leaf):
3013             # Continue looking if we see a docstring; otherwise stop.
3014             if (
3015                 len(child.children) == 2
3016                 and first_child.type == token.STRING
3017                 and child.children[1].type == token.NEWLINE
3018             ):
3019                 continue
3020             else:
3021                 break
3022         elif first_child.type == syms.import_from:
3023             module_name = first_child.children[1]
3024             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3025                 break
3026             imports |= set(get_imports_from_children(first_child.children[3:]))
3027         else:
3028             break
3029     return imports
3030
3031
3032 def gen_python_files_in_dir(
3033     path: Path,
3034     root: Path,
3035     include: Pattern[str],
3036     exclude: Pattern[str],
3037     report: "Report",
3038 ) -> Iterator[Path]:
3039     """Generate all files under `path` whose paths are not excluded by the
3040     `exclude` regex, but are included by the `include` regex.
3041
3042     Symbolic links pointing outside of the `root` directory are ignored.
3043
3044     `report` is where output about exclusions goes.
3045     """
3046     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3047     for child in path.iterdir():
3048         try:
3049             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3050         except ValueError:
3051             if child.is_symlink():
3052                 report.path_ignored(
3053                     child, f"is a symbolic link that points outside {root}"
3054                 )
3055                 continue
3056
3057             raise
3058
3059         if child.is_dir():
3060             normalized_path += "/"
3061         exclude_match = exclude.search(normalized_path)
3062         if exclude_match and exclude_match.group(0):
3063             report.path_ignored(child, f"matches the --exclude regular expression")
3064             continue
3065
3066         if child.is_dir():
3067             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3068
3069         elif child.is_file():
3070             include_match = include.search(normalized_path)
3071             if include_match:
3072                 yield child
3073
3074
3075 @lru_cache()
3076 def find_project_root(srcs: Iterable[str]) -> Path:
3077     """Return a directory containing .git, .hg, or pyproject.toml.
3078
3079     That directory can be one of the directories passed in `srcs` or their
3080     common parent.
3081
3082     If no directory in the tree contains a marker that would specify it's the
3083     project root, the root of the file system is returned.
3084     """
3085     if not srcs:
3086         return Path("/").resolve()
3087
3088     common_base = min(Path(src).resolve() for src in srcs)
3089     if common_base.is_dir():
3090         # Append a fake file so `parents` below returns `common_base_dir`, too.
3091         common_base /= "fake-file"
3092     for directory in common_base.parents:
3093         if (directory / ".git").is_dir():
3094             return directory
3095
3096         if (directory / ".hg").is_dir():
3097             return directory
3098
3099         if (directory / "pyproject.toml").is_file():
3100             return directory
3101
3102     return directory
3103
3104
3105 @dataclass
3106 class Report:
3107     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3108
3109     check: bool = False
3110     quiet: bool = False
3111     verbose: bool = False
3112     change_count: int = 0
3113     same_count: int = 0
3114     failure_count: int = 0
3115
3116     def done(self, src: Path, changed: Changed) -> None:
3117         """Increment the counter for successful reformatting. Write out a message."""
3118         if changed is Changed.YES:
3119             reformatted = "would reformat" if self.check else "reformatted"
3120             if self.verbose or not self.quiet:
3121                 out(f"{reformatted} {src}")
3122             self.change_count += 1
3123         else:
3124             if self.verbose:
3125                 if changed is Changed.NO:
3126                     msg = f"{src} already well formatted, good job."
3127                 else:
3128                     msg = f"{src} wasn't modified on disk since last run."
3129                 out(msg, bold=False)
3130             self.same_count += 1
3131
3132     def failed(self, src: Path, message: str) -> None:
3133         """Increment the counter for failed reformatting. Write out a message."""
3134         err(f"error: cannot format {src}: {message}")
3135         self.failure_count += 1
3136
3137     def path_ignored(self, path: Path, message: str) -> None:
3138         if self.verbose:
3139             out(f"{path} ignored: {message}", bold=False)
3140
3141     @property
3142     def return_code(self) -> int:
3143         """Return the exit code that the app should use.
3144
3145         This considers the current state of changed files and failures:
3146         - if there were any failures, return 123;
3147         - if any files were changed and --check is being used, return 1;
3148         - otherwise return 0.
3149         """
3150         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3151         # 126 we have special return codes reserved by the shell.
3152         if self.failure_count:
3153             return 123
3154
3155         elif self.change_count and self.check:
3156             return 1
3157
3158         return 0
3159
3160     def __str__(self) -> str:
3161         """Render a color report of the current state.
3162
3163         Use `click.unstyle` to remove colors.
3164         """
3165         if self.check:
3166             reformatted = "would be reformatted"
3167             unchanged = "would be left unchanged"
3168             failed = "would fail to reformat"
3169         else:
3170             reformatted = "reformatted"
3171             unchanged = "left unchanged"
3172             failed = "failed to reformat"
3173         report = []
3174         if self.change_count:
3175             s = "s" if self.change_count > 1 else ""
3176             report.append(
3177                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3178             )
3179         if self.same_count:
3180             s = "s" if self.same_count > 1 else ""
3181             report.append(f"{self.same_count} file{s} {unchanged}")
3182         if self.failure_count:
3183             s = "s" if self.failure_count > 1 else ""
3184             report.append(
3185                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3186             )
3187         return ", ".join(report) + "."
3188
3189
3190 def assert_equivalent(src: str, dst: str) -> None:
3191     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3192
3193     import ast
3194     import traceback
3195
3196     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3197         """Simple visitor generating strings to compare ASTs by content."""
3198         yield f"{'  ' * depth}{node.__class__.__name__}("
3199
3200         for field in sorted(node._fields):
3201             try:
3202                 value = getattr(node, field)
3203             except AttributeError:
3204                 continue
3205
3206             yield f"{'  ' * (depth+1)}{field}="
3207
3208             if isinstance(value, list):
3209                 for item in value:
3210                     if isinstance(item, ast.AST):
3211                         yield from _v(item, depth + 2)
3212
3213             elif isinstance(value, ast.AST):
3214                 yield from _v(value, depth + 2)
3215
3216             else:
3217                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3218
3219         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3220
3221     try:
3222         src_ast = ast.parse(src)
3223     except Exception as exc:
3224         major, minor = sys.version_info[:2]
3225         raise AssertionError(
3226             f"cannot use --safe with this file; failed to parse source file "
3227             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3228             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3229         )
3230
3231     try:
3232         dst_ast = ast.parse(dst)
3233     except Exception as exc:
3234         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3235         raise AssertionError(
3236             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3237             f"Please report a bug on https://github.com/ambv/black/issues.  "
3238             f"This invalid output might be helpful: {log}"
3239         ) from None
3240
3241     src_ast_str = "\n".join(_v(src_ast))
3242     dst_ast_str = "\n".join(_v(dst_ast))
3243     if src_ast_str != dst_ast_str:
3244         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3245         raise AssertionError(
3246             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3247             f"the source.  "
3248             f"Please report a bug on https://github.com/ambv/black/issues.  "
3249             f"This diff might be helpful: {log}"
3250         ) from None
3251
3252
3253 def assert_stable(
3254     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3255 ) -> None:
3256     """Raise AssertionError if `dst` reformats differently the second time."""
3257     newdst = format_str(dst, line_length=line_length, mode=mode)
3258     if dst != newdst:
3259         log = dump_to_file(
3260             diff(src, dst, "source", "first pass"),
3261             diff(dst, newdst, "first pass", "second pass"),
3262         )
3263         raise AssertionError(
3264             f"INTERNAL ERROR: Black produced different code on the second pass "
3265             f"of the formatter.  "
3266             f"Please report a bug on https://github.com/ambv/black/issues.  "
3267             f"This diff might be helpful: {log}"
3268         ) from None
3269
3270
3271 def dump_to_file(*output: str) -> str:
3272     """Dump `output` to a temporary file. Return path to the file."""
3273     import tempfile
3274
3275     with tempfile.NamedTemporaryFile(
3276         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3277     ) as f:
3278         for lines in output:
3279             f.write(lines)
3280             if lines and lines[-1] != "\n":
3281                 f.write("\n")
3282     return f.name
3283
3284
3285 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3286     """Return a unified diff string between strings `a` and `b`."""
3287     import difflib
3288
3289     a_lines = [line + "\n" for line in a.split("\n")]
3290     b_lines = [line + "\n" for line in b.split("\n")]
3291     return "".join(
3292         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3293     )
3294
3295
3296 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3297     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3298     err("Aborted!")
3299     for task in tasks:
3300         task.cancel()
3301
3302
3303 def shutdown(loop: BaseEventLoop) -> None:
3304     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3305     try:
3306         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3307         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3308         if not to_cancel:
3309             return
3310
3311         for task in to_cancel:
3312             task.cancel()
3313         loop.run_until_complete(
3314             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3315         )
3316     finally:
3317         # `concurrent.futures.Future` objects cannot be cancelled once they
3318         # are already running. There might be some when the `shutdown()` happened.
3319         # Silence their logger's spew about the event loop being closed.
3320         cf_logger = logging.getLogger("concurrent.futures")
3321         cf_logger.setLevel(logging.CRITICAL)
3322         loop.close()
3323
3324
3325 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3326     """Replace `regex` with `replacement` twice on `original`.
3327
3328     This is used by string normalization to perform replaces on
3329     overlapping matches.
3330     """
3331     return regex.sub(replacement, regex.sub(replacement, original))
3332
3333
3334 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3335     """Compile a regular expression string in `regex`.
3336
3337     If it contains newlines, use verbose mode.
3338     """
3339     if "\n" in regex:
3340         regex = "(?x)" + regex
3341     return re.compile(regex)
3342
3343
3344 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3345     """Like `reversed(enumerate(sequence))` if that were possible."""
3346     index = len(sequence) - 1
3347     for element in reversed(sequence):
3348         yield (index, element)
3349         index -= 1
3350
3351
3352 def enumerate_with_length(
3353     line: Line, reversed: bool = False
3354 ) -> Iterator[Tuple[Index, Leaf, int]]:
3355     """Return an enumeration of leaves with their length.
3356
3357     Stops prematurely on multiline strings and standalone comments.
3358     """
3359     op = cast(
3360         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3361         enumerate_reversed if reversed else enumerate,
3362     )
3363     for index, leaf in op(line.leaves):
3364         length = len(leaf.prefix) + len(leaf.value)
3365         if "\n" in leaf.value:
3366             return  # Multiline strings, we can't continue.
3367
3368         comment: Optional[Leaf]
3369         for comment in line.comments_after(leaf, index):
3370             length += len(comment.value)
3371
3372         yield index, leaf, length
3373
3374
3375 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3376     """Return True if `line` is no longer than `line_length`.
3377
3378     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3379     """
3380     if not line_str:
3381         line_str = str(line).strip("\n")
3382     return (
3383         len(line_str) <= line_length
3384         and "\n" not in line_str  # multiline strings
3385         and not line.contains_standalone_comments()
3386     )
3387
3388
3389 def can_be_split(line: Line) -> bool:
3390     """Return False if the line cannot be split *for sure*.
3391
3392     This is not an exhaustive search but a cheap heuristic that we can use to
3393     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3394     in unnecessary parentheses).
3395     """
3396     leaves = line.leaves
3397     if len(leaves) < 2:
3398         return False
3399
3400     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3401         call_count = 0
3402         dot_count = 0
3403         next = leaves[-1]
3404         for leaf in leaves[-2::-1]:
3405             if leaf.type in OPENING_BRACKETS:
3406                 if next.type not in CLOSING_BRACKETS:
3407                     return False
3408
3409                 call_count += 1
3410             elif leaf.type == token.DOT:
3411                 dot_count += 1
3412             elif leaf.type == token.NAME:
3413                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3414                     return False
3415
3416             elif leaf.type not in CLOSING_BRACKETS:
3417                 return False
3418
3419             if dot_count > 1 and call_count > 1:
3420                 return False
3421
3422     return True
3423
3424
3425 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3426     """Does `line` have a shape safe to reformat without optional parens around it?
3427
3428     Returns True for only a subset of potentially nice looking formattings but
3429     the point is to not return false positives that end up producing lines that
3430     are too long.
3431     """
3432     bt = line.bracket_tracker
3433     if not bt.delimiters:
3434         # Without delimiters the optional parentheses are useless.
3435         return True
3436
3437     max_priority = bt.max_delimiter_priority()
3438     if bt.delimiter_count_with_priority(max_priority) > 1:
3439         # With more than one delimiter of a kind the optional parentheses read better.
3440         return False
3441
3442     if max_priority == DOT_PRIORITY:
3443         # A single stranded method call doesn't require optional parentheses.
3444         return True
3445
3446     assert len(line.leaves) >= 2, "Stranded delimiter"
3447
3448     first = line.leaves[0]
3449     second = line.leaves[1]
3450     penultimate = line.leaves[-2]
3451     last = line.leaves[-1]
3452
3453     # With a single delimiter, omit if the expression starts or ends with
3454     # a bracket.
3455     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3456         remainder = False
3457         length = 4 * line.depth
3458         for _index, leaf, leaf_length in enumerate_with_length(line):
3459             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3460                 remainder = True
3461             if remainder:
3462                 length += leaf_length
3463                 if length > line_length:
3464                     break
3465
3466                 if leaf.type in OPENING_BRACKETS:
3467                     # There are brackets we can further split on.
3468                     remainder = False
3469
3470         else:
3471             # checked the entire string and line length wasn't exceeded
3472             if len(line.leaves) == _index + 1:
3473                 return True
3474
3475         # Note: we are not returning False here because a line might have *both*
3476         # a leading opening bracket and a trailing closing bracket.  If the
3477         # opening bracket doesn't match our rule, maybe the closing will.
3478
3479     if (
3480         last.type == token.RPAR
3481         or last.type == token.RBRACE
3482         or (
3483             # don't use indexing for omitting optional parentheses;
3484             # it looks weird
3485             last.type == token.RSQB
3486             and last.parent
3487             and last.parent.type != syms.trailer
3488         )
3489     ):
3490         if penultimate.type in OPENING_BRACKETS:
3491             # Empty brackets don't help.
3492             return False
3493
3494         if is_multiline_string(first):
3495             # Additional wrapping of a multiline string in this situation is
3496             # unnecessary.
3497             return True
3498
3499         length = 4 * line.depth
3500         seen_other_brackets = False
3501         for _index, leaf, leaf_length in enumerate_with_length(line):
3502             length += leaf_length
3503             if leaf is last.opening_bracket:
3504                 if seen_other_brackets or length <= line_length:
3505                     return True
3506
3507             elif leaf.type in OPENING_BRACKETS:
3508                 # There are brackets we can further split on.
3509                 seen_other_brackets = True
3510
3511     return False
3512
3513
3514 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3515     return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
3516
3517
3518 def read_cache(line_length: int, mode: FileMode) -> Cache:
3519     """Read the cache if it exists and is well formed.
3520
3521     If it is not well formed, the call to write_cache later should resolve the issue.
3522     """
3523     cache_file = get_cache_file(line_length, mode)
3524     if not cache_file.exists():
3525         return {}
3526
3527     with cache_file.open("rb") as fobj:
3528         try:
3529             cache: Cache = pickle.load(fobj)
3530         except pickle.UnpicklingError:
3531             return {}
3532
3533     return cache
3534
3535
3536 def get_cache_info(path: Path) -> CacheInfo:
3537     """Return the information used to check if a file is already formatted or not."""
3538     stat = path.stat()
3539     return stat.st_mtime, stat.st_size
3540
3541
3542 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3543     """Split an iterable of paths in `sources` into two sets.
3544
3545     The first contains paths of files that modified on disk or are not in the
3546     cache. The other contains paths to non-modified files.
3547     """
3548     todo, done = set(), set()
3549     for src in sources:
3550         src = src.resolve()
3551         if cache.get(src) != get_cache_info(src):
3552             todo.add(src)
3553         else:
3554             done.add(src)
3555     return todo, done
3556
3557
3558 def write_cache(
3559     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3560 ) -> None:
3561     """Update the cache file."""
3562     cache_file = get_cache_file(line_length, mode)
3563     try:
3564         if not CACHE_DIR.exists():
3565             CACHE_DIR.mkdir(parents=True)
3566         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3567         with cache_file.open("wb") as fobj:
3568             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3569     except OSError:
3570         pass
3571
3572
3573 def patch_click() -> None:
3574     """Make Click not crash.
3575
3576     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3577     default which restricts paths that it can access during the lifetime of the
3578     application.  Click refuses to work in this scenario by raising a RuntimeError.
3579
3580     In case of Black the likelihood that non-ASCII characters are going to be used in
3581     file paths is minimal since it's Python source code.  Moreover, this crash was
3582     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3583     """
3584     try:
3585         from click import core
3586         from click import _unicodefun  # type: ignore
3587     except ModuleNotFoundError:
3588         return
3589
3590     for module in (core, _unicodefun):
3591         if hasattr(module, "_verify_python3_env"):
3592             module._verify_python3_env = lambda: None
3593
3594
3595 if __name__ == "__main__":
3596     patch_click()
3597     main()