]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

add changelog entry for #468
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum, Flag
6 from functools import lru_cache, partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tokenize
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generator,
24     Generic,
25     Iterable,
26     Iterator,
27     List,
28     Optional,
29     Pattern,
30     Sequence,
31     Set,
32     Tuple,
33     TypeVar,
34     Union,
35     cast,
36 )
37
38 from appdirs import user_cache_dir
39 from attr import dataclass, Factory
40 import click
41 import toml
42
43 # lib2to3 fork
44 from blib2to3.pytree import Node, Leaf, type_repr
45 from blib2to3 import pygram, pytree
46 from blib2to3.pgen2 import driver, token
47 from blib2to3.pgen2.parse import ParseError
48
49
50 __version__ = "18.6b4"
51 DEFAULT_LINE_LENGTH = 88
52 DEFAULT_EXCLUDES = (
53     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
54 )
55 DEFAULT_INCLUDES = r"\.pyi?$"
56 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
57
58
59 # types
60 FileContent = str
61 Encoding = str
62 NewLine = str
63 Depth = int
64 NodeType = int
65 LeafID = int
66 Priority = int
67 Index = int
68 LN = Union[Leaf, Node]
69 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
70 Timestamp = float
71 FileSize = int
72 CacheInfo = Tuple[Timestamp, FileSize]
73 Cache = Dict[Path, CacheInfo]
74 out = partial(click.secho, bold=True, err=True)
75 err = partial(click.secho, fg="red", err=True)
76
77 pygram.initialize(CACHE_DIR)
78 syms = pygram.python_symbols
79
80
81 class NothingChanged(UserWarning):
82     """Raised by :func:`format_file` when reformatted code is the same as source."""
83
84
85 class CannotSplit(Exception):
86     """A readable split that fits the allotted line length is impossible.
87
88     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
89     :func:`delimiter_split`.
90     """
91
92
93 class WriteBack(Enum):
94     NO = 0
95     YES = 1
96     DIFF = 2
97     CHECK = 3
98
99     @classmethod
100     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
101         if check and not diff:
102             return cls.CHECK
103
104         return cls.DIFF if diff else cls.YES
105
106
107 class Changed(Enum):
108     NO = 0
109     CACHED = 1
110     YES = 2
111
112
113 class FileMode(Flag):
114     AUTO_DETECT = 0
115     PYTHON36 = 1
116     PYI = 2
117     NO_STRING_NORMALIZATION = 4
118
119     @classmethod
120     def from_configuration(
121         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
122     ) -> "FileMode":
123         mode = cls.AUTO_DETECT
124         if py36:
125             mode |= cls.PYTHON36
126         if pyi:
127             mode |= cls.PYI
128         if skip_string_normalization:
129             mode |= cls.NO_STRING_NORMALIZATION
130         return mode
131
132
133 def read_pyproject_toml(
134     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
135 ) -> Optional[str]:
136     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
137
138     Returns the path to a successfully found and read configuration file, None
139     otherwise.
140     """
141     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
142     if not value:
143         root = find_project_root(ctx.params.get("src", ()))
144         path = root / "pyproject.toml"
145         if path.is_file():
146             value = str(path)
147         else:
148             return None
149
150     try:
151         pyproject_toml = toml.load(value)
152         config = pyproject_toml.get("tool", {}).get("black", {})
153     except (toml.TomlDecodeError, OSError) as e:
154         raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
155
156     if not config:
157         return None
158
159     if ctx.default_map is None:
160         ctx.default_map = {}
161     ctx.default_map.update(  # type: ignore  # bad types in .pyi
162         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
163     )
164     return value
165
166
167 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
168 @click.option(
169     "-l",
170     "--line-length",
171     type=int,
172     default=DEFAULT_LINE_LENGTH,
173     help="How many characters per line to allow.",
174     show_default=True,
175 )
176 @click.option(
177     "--py36",
178     is_flag=True,
179     help=(
180         "Allow using Python 3.6-only syntax on all input files.  This will put "
181         "trailing commas in function signatures and calls also after *args and "
182         "**kwargs.  [default: per-file auto-detection]"
183     ),
184 )
185 @click.option(
186     "--pyi",
187     is_flag=True,
188     help=(
189         "Format all input files like typing stubs regardless of file extension "
190         "(useful when piping source on standard input)."
191     ),
192 )
193 @click.option(
194     "-S",
195     "--skip-string-normalization",
196     is_flag=True,
197     help="Don't normalize string quotes or prefixes.",
198 )
199 @click.option(
200     "--check",
201     is_flag=True,
202     help=(
203         "Don't write the files back, just return the status.  Return code 0 "
204         "means nothing would change.  Return code 1 means some files would be "
205         "reformatted.  Return code 123 means there was an internal error."
206     ),
207 )
208 @click.option(
209     "--diff",
210     is_flag=True,
211     help="Don't write the files back, just output a diff for each file on stdout.",
212 )
213 @click.option(
214     "--fast/--safe",
215     is_flag=True,
216     help="If --fast given, skip temporary sanity checks. [default: --safe]",
217 )
218 @click.option(
219     "--include",
220     type=str,
221     default=DEFAULT_INCLUDES,
222     help=(
223         "A regular expression that matches files and directories that should be "
224         "included on recursive searches.  An empty value means all files are "
225         "included regardless of the name.  Use forward slashes for directories on "
226         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
227         "later."
228     ),
229     show_default=True,
230 )
231 @click.option(
232     "--exclude",
233     type=str,
234     default=DEFAULT_EXCLUDES,
235     help=(
236         "A regular expression that matches files and directories that should be "
237         "excluded on recursive searches.  An empty value means no paths are excluded. "
238         "Use forward slashes for directories on all platforms (Windows, too).  "
239         "Exclusions are calculated first, inclusions later."
240     ),
241     show_default=True,
242 )
243 @click.option(
244     "-q",
245     "--quiet",
246     is_flag=True,
247     help=(
248         "Don't emit non-error messages to stderr. Errors are still emitted, "
249         "silence those with 2>/dev/null."
250     ),
251 )
252 @click.option(
253     "-v",
254     "--verbose",
255     is_flag=True,
256     help=(
257         "Also emit messages to stderr about files that were not changed or were "
258         "ignored due to --exclude=."
259     ),
260 )
261 @click.version_option(version=__version__)
262 @click.argument(
263     "src",
264     nargs=-1,
265     type=click.Path(
266         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
267     ),
268     is_eager=True,
269 )
270 @click.option(
271     "--config",
272     type=click.Path(
273         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
274     ),
275     is_eager=True,
276     callback=read_pyproject_toml,
277     help="Read configuration from PATH.",
278 )
279 @click.pass_context
280 def main(
281     ctx: click.Context,
282     line_length: int,
283     check: bool,
284     diff: bool,
285     fast: bool,
286     pyi: bool,
287     py36: bool,
288     skip_string_normalization: bool,
289     quiet: bool,
290     verbose: bool,
291     include: str,
292     exclude: str,
293     src: Tuple[str],
294     config: Optional[str],
295 ) -> None:
296     """The uncompromising code formatter."""
297     write_back = WriteBack.from_configuration(check=check, diff=diff)
298     mode = FileMode.from_configuration(
299         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
300     )
301     if config and verbose:
302         out(f"Using configuration from {config}.", bold=False, fg="blue")
303     try:
304         include_regex = re_compile_maybe_verbose(include)
305     except re.error:
306         err(f"Invalid regular expression for include given: {include!r}")
307         ctx.exit(2)
308     try:
309         exclude_regex = re_compile_maybe_verbose(exclude)
310     except re.error:
311         err(f"Invalid regular expression for exclude given: {exclude!r}")
312         ctx.exit(2)
313     report = Report(check=check, quiet=quiet, verbose=verbose)
314     root = find_project_root(src)
315     sources: Set[Path] = set()
316     for s in src:
317         p = Path(s)
318         if p.is_dir():
319             sources.update(
320                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
321             )
322         elif p.is_file() or s == "-":
323             # if a file was explicitly given, we don't care about its extension
324             sources.add(p)
325         else:
326             err(f"invalid path: {s}")
327     if len(sources) == 0:
328         if verbose or not quiet:
329             out("No paths given. Nothing to do 😴")
330         ctx.exit(0)
331
332     if len(sources) == 1:
333         reformat_one(
334             src=sources.pop(),
335             line_length=line_length,
336             fast=fast,
337             write_back=write_back,
338             mode=mode,
339             report=report,
340         )
341     else:
342         loop = asyncio.get_event_loop()
343         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
344         try:
345             loop.run_until_complete(
346                 schedule_formatting(
347                     sources=sources,
348                     line_length=line_length,
349                     fast=fast,
350                     write_back=write_back,
351                     mode=mode,
352                     report=report,
353                     loop=loop,
354                     executor=executor,
355                 )
356             )
357         finally:
358             shutdown(loop)
359     if verbose or not quiet:
360         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
361         out(f"All done! {bang}")
362         click.secho(str(report), err=True)
363     ctx.exit(report.return_code)
364
365
366 def reformat_one(
367     src: Path,
368     line_length: int,
369     fast: bool,
370     write_back: WriteBack,
371     mode: FileMode,
372     report: "Report",
373 ) -> None:
374     """Reformat a single file under `src` without spawning child processes.
375
376     If `quiet` is True, non-error messages are not output. `line_length`,
377     `write_back`, `fast` and `pyi` options are passed to
378     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
379     """
380     try:
381         changed = Changed.NO
382         if not src.is_file() and str(src) == "-":
383             if format_stdin_to_stdout(
384                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
385             ):
386                 changed = Changed.YES
387         else:
388             cache: Cache = {}
389             if write_back != WriteBack.DIFF:
390                 cache = read_cache(line_length, mode)
391                 res_src = src.resolve()
392                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
393                     changed = Changed.CACHED
394             if changed is not Changed.CACHED and format_file_in_place(
395                 src,
396                 line_length=line_length,
397                 fast=fast,
398                 write_back=write_back,
399                 mode=mode,
400             ):
401                 changed = Changed.YES
402             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
403                 write_back is WriteBack.CHECK and changed is Changed.NO
404             ):
405                 write_cache(cache, [src], line_length, mode)
406         report.done(src, changed)
407     except Exception as exc:
408         report.failed(src, str(exc))
409
410
411 async def schedule_formatting(
412     sources: Set[Path],
413     line_length: int,
414     fast: bool,
415     write_back: WriteBack,
416     mode: FileMode,
417     report: "Report",
418     loop: BaseEventLoop,
419     executor: Executor,
420 ) -> None:
421     """Run formatting of `sources` in parallel using the provided `executor`.
422
423     (Use ProcessPoolExecutors for actual parallelism.)
424
425     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
426     :func:`format_file_in_place`.
427     """
428     cache: Cache = {}
429     if write_back != WriteBack.DIFF:
430         cache = read_cache(line_length, mode)
431         sources, cached = filter_cached(cache, sources)
432         for src in sorted(cached):
433             report.done(src, Changed.CACHED)
434     if not sources:
435         return
436
437     cancelled = []
438     sources_to_cache = []
439     lock = None
440     if write_back == WriteBack.DIFF:
441         # For diff output, we need locks to ensure we don't interleave output
442         # from different processes.
443         manager = Manager()
444         lock = manager.Lock()
445     tasks = {
446         loop.run_in_executor(
447             executor,
448             format_file_in_place,
449             src,
450             line_length,
451             fast,
452             write_back,
453             mode,
454             lock,
455         ): src
456         for src in sorted(sources)
457     }
458     pending: Iterable[asyncio.Task] = tasks.keys()
459     try:
460         loop.add_signal_handler(signal.SIGINT, cancel, pending)
461         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
462     except NotImplementedError:
463         # There are no good alternatives for these on Windows.
464         pass
465     while pending:
466         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
467         for task in done:
468             src = tasks.pop(task)
469             if task.cancelled():
470                 cancelled.append(task)
471             elif task.exception():
472                 report.failed(src, str(task.exception()))
473             else:
474                 changed = Changed.YES if task.result() else Changed.NO
475                 # If the file was written back or was successfully checked as
476                 # well-formatted, store this information in the cache.
477                 if write_back is WriteBack.YES or (
478                     write_back is WriteBack.CHECK and changed is Changed.NO
479                 ):
480                     sources_to_cache.append(src)
481                 report.done(src, changed)
482     if cancelled:
483         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
484     if sources_to_cache:
485         write_cache(cache, sources_to_cache, line_length, mode)
486
487
488 def format_file_in_place(
489     src: Path,
490     line_length: int,
491     fast: bool,
492     write_back: WriteBack = WriteBack.NO,
493     mode: FileMode = FileMode.AUTO_DETECT,
494     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
495 ) -> bool:
496     """Format file under `src` path. Return True if changed.
497
498     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
499     code to the file.
500     `line_length` and `fast` options are passed to :func:`format_file_contents`.
501     """
502     if src.suffix == ".pyi":
503         mode |= FileMode.PYI
504
505     then = datetime.utcfromtimestamp(src.stat().st_mtime)
506     with open(src, "rb") as buf:
507         src_contents, encoding, newline = decode_bytes(buf.read())
508     try:
509         dst_contents = format_file_contents(
510             src_contents, line_length=line_length, fast=fast, mode=mode
511         )
512     except NothingChanged:
513         return False
514
515     if write_back == write_back.YES:
516         with open(src, "w", encoding=encoding, newline=newline) as f:
517             f.write(dst_contents)
518     elif write_back == write_back.DIFF:
519         now = datetime.utcnow()
520         src_name = f"{src}\t{then} +0000"
521         dst_name = f"{src}\t{now} +0000"
522         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
523         if lock:
524             lock.acquire()
525         try:
526             f = io.TextIOWrapper(
527                 sys.stdout.buffer,
528                 encoding=encoding,
529                 newline=newline,
530                 write_through=True,
531             )
532             f.write(diff_contents)
533             f.detach()
534         finally:
535             if lock:
536                 lock.release()
537     return True
538
539
540 def format_stdin_to_stdout(
541     line_length: int,
542     fast: bool,
543     write_back: WriteBack = WriteBack.NO,
544     mode: FileMode = FileMode.AUTO_DETECT,
545 ) -> bool:
546     """Format file on stdin. Return True if changed.
547
548     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
549     write a diff to stdout.
550     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
551     :func:`format_file_contents`.
552     """
553     then = datetime.utcnow()
554     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
555     dst = src
556     try:
557         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
558         return True
559
560     except NothingChanged:
561         return False
562
563     finally:
564         f = io.TextIOWrapper(
565             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
566         )
567         if write_back == WriteBack.YES:
568             f.write(dst)
569         elif write_back == WriteBack.DIFF:
570             now = datetime.utcnow()
571             src_name = f"STDIN\t{then} +0000"
572             dst_name = f"STDOUT\t{now} +0000"
573             f.write(diff(src, dst, src_name, dst_name))
574         f.detach()
575
576
577 def format_file_contents(
578     src_contents: str,
579     *,
580     line_length: int,
581     fast: bool,
582     mode: FileMode = FileMode.AUTO_DETECT,
583 ) -> FileContent:
584     """Reformat contents a file and return new contents.
585
586     If `fast` is False, additionally confirm that the reformatted code is
587     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
588     `line_length` is passed to :func:`format_str`.
589     """
590     if src_contents.strip() == "":
591         raise NothingChanged
592
593     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
594     if src_contents == dst_contents:
595         raise NothingChanged
596
597     if not fast:
598         assert_equivalent(src_contents, dst_contents)
599         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
600     return dst_contents
601
602
603 def format_str(
604     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
605 ) -> FileContent:
606     """Reformat a string and return new contents.
607
608     `line_length` determines how many characters per line are allowed.
609     """
610     src_node = lib2to3_parse(src_contents)
611     dst_contents = ""
612     future_imports = get_future_imports(src_node)
613     is_pyi = bool(mode & FileMode.PYI)
614     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
615     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
616     normalize_fmt_off(src_node)
617     lines = LineGenerator(
618         remove_u_prefix=py36 or "unicode_literals" in future_imports,
619         is_pyi=is_pyi,
620         normalize_strings=normalize_strings,
621         allow_underscores=py36,
622     )
623     elt = EmptyLineTracker(is_pyi=is_pyi)
624     empty_line = Line()
625     after = 0
626     for current_line in lines.visit(src_node):
627         for _ in range(after):
628             dst_contents += str(empty_line)
629         before, after = elt.maybe_empty_lines(current_line)
630         for _ in range(before):
631             dst_contents += str(empty_line)
632         for line in split_line(current_line, line_length=line_length, py36=py36):
633             dst_contents += str(line)
634     return dst_contents
635
636
637 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
638     """Return a tuple of (decoded_contents, encoding, newline).
639
640     `newline` is either CRLF or LF but `decoded_contents` is decoded with
641     universal newlines (i.e. only contains LF).
642     """
643     srcbuf = io.BytesIO(src)
644     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
645     if not lines:
646         return "", encoding, "\n"
647
648     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
649     srcbuf.seek(0)
650     with io.TextIOWrapper(srcbuf, encoding) as tiow:
651         return tiow.read(), encoding, newline
652
653
654 GRAMMARS = [
655     pygram.python_grammar_no_print_statement_no_exec_statement,
656     pygram.python_grammar_no_print_statement,
657     pygram.python_grammar,
658 ]
659
660
661 def lib2to3_parse(src_txt: str) -> Node:
662     """Given a string with source, return the lib2to3 Node."""
663     grammar = pygram.python_grammar_no_print_statement
664     if src_txt[-1:] != "\n":
665         src_txt += "\n"
666     for grammar in GRAMMARS:
667         drv = driver.Driver(grammar, pytree.convert)
668         try:
669             result = drv.parse_string(src_txt, True)
670             break
671
672         except ParseError as pe:
673             lineno, column = pe.context[1]
674             lines = src_txt.splitlines()
675             try:
676                 faulty_line = lines[lineno - 1]
677             except IndexError:
678                 faulty_line = "<line number missing in source>"
679             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
680     else:
681         raise exc from None
682
683     if isinstance(result, Leaf):
684         result = Node(syms.file_input, [result])
685     return result
686
687
688 def lib2to3_unparse(node: Node) -> str:
689     """Given a lib2to3 node, return its string representation."""
690     code = str(node)
691     return code
692
693
694 T = TypeVar("T")
695
696
697 class Visitor(Generic[T]):
698     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
699
700     def visit(self, node: LN) -> Iterator[T]:
701         """Main method to visit `node` and its children.
702
703         It tries to find a `visit_*()` method for the given `node.type`, like
704         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
705         If no dedicated `visit_*()` method is found, chooses `visit_default()`
706         instead.
707
708         Then yields objects of type `T` from the selected visitor.
709         """
710         if node.type < 256:
711             name = token.tok_name[node.type]
712         else:
713             name = type_repr(node.type)
714         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
715
716     def visit_default(self, node: LN) -> Iterator[T]:
717         """Default `visit_*()` implementation. Recurses to children of `node`."""
718         if isinstance(node, Node):
719             for child in node.children:
720                 yield from self.visit(child)
721
722
723 @dataclass
724 class DebugVisitor(Visitor[T]):
725     tree_depth: int = 0
726
727     def visit_default(self, node: LN) -> Iterator[T]:
728         indent = " " * (2 * self.tree_depth)
729         if isinstance(node, Node):
730             _type = type_repr(node.type)
731             out(f"{indent}{_type}", fg="yellow")
732             self.tree_depth += 1
733             for child in node.children:
734                 yield from self.visit(child)
735
736             self.tree_depth -= 1
737             out(f"{indent}/{_type}", fg="yellow", bold=False)
738         else:
739             _type = token.tok_name.get(node.type, str(node.type))
740             out(f"{indent}{_type}", fg="blue", nl=False)
741             if node.prefix:
742                 # We don't have to handle prefixes for `Node` objects since
743                 # that delegates to the first child anyway.
744                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
745             out(f" {node.value!r}", fg="blue", bold=False)
746
747     @classmethod
748     def show(cls, code: Union[str, Leaf, Node]) -> None:
749         """Pretty-print the lib2to3 AST of a given string of `code`.
750
751         Convenience method for debugging.
752         """
753         v: DebugVisitor[None] = DebugVisitor()
754         if isinstance(code, str):
755             code = lib2to3_parse(code)
756         list(v.visit(code))
757
758
759 KEYWORDS = set(keyword.kwlist)
760 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
761 FLOW_CONTROL = {"return", "raise", "break", "continue"}
762 STATEMENT = {
763     syms.if_stmt,
764     syms.while_stmt,
765     syms.for_stmt,
766     syms.try_stmt,
767     syms.except_clause,
768     syms.with_stmt,
769     syms.funcdef,
770     syms.classdef,
771 }
772 STANDALONE_COMMENT = 153
773 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
774 LOGIC_OPERATORS = {"and", "or"}
775 COMPARATORS = {
776     token.LESS,
777     token.GREATER,
778     token.EQEQUAL,
779     token.NOTEQUAL,
780     token.LESSEQUAL,
781     token.GREATEREQUAL,
782 }
783 MATH_OPERATORS = {
784     token.VBAR,
785     token.CIRCUMFLEX,
786     token.AMPER,
787     token.LEFTSHIFT,
788     token.RIGHTSHIFT,
789     token.PLUS,
790     token.MINUS,
791     token.STAR,
792     token.SLASH,
793     token.DOUBLESLASH,
794     token.PERCENT,
795     token.AT,
796     token.TILDE,
797     token.DOUBLESTAR,
798 }
799 STARS = {token.STAR, token.DOUBLESTAR}
800 VARARGS_PARENTS = {
801     syms.arglist,
802     syms.argument,  # double star in arglist
803     syms.trailer,  # single argument to call
804     syms.typedargslist,
805     syms.varargslist,  # lambdas
806 }
807 UNPACKING_PARENTS = {
808     syms.atom,  # single element of a list or set literal
809     syms.dictsetmaker,
810     syms.listmaker,
811     syms.testlist_gexp,
812     syms.testlist_star_expr,
813 }
814 TEST_DESCENDANTS = {
815     syms.test,
816     syms.lambdef,
817     syms.or_test,
818     syms.and_test,
819     syms.not_test,
820     syms.comparison,
821     syms.star_expr,
822     syms.expr,
823     syms.xor_expr,
824     syms.and_expr,
825     syms.shift_expr,
826     syms.arith_expr,
827     syms.trailer,
828     syms.term,
829     syms.power,
830 }
831 ASSIGNMENTS = {
832     "=",
833     "+=",
834     "-=",
835     "*=",
836     "@=",
837     "/=",
838     "%=",
839     "&=",
840     "|=",
841     "^=",
842     "<<=",
843     ">>=",
844     "**=",
845     "//=",
846 }
847 COMPREHENSION_PRIORITY = 20
848 COMMA_PRIORITY = 18
849 TERNARY_PRIORITY = 16
850 LOGIC_PRIORITY = 14
851 STRING_PRIORITY = 12
852 COMPARATOR_PRIORITY = 10
853 MATH_PRIORITIES = {
854     token.VBAR: 9,
855     token.CIRCUMFLEX: 8,
856     token.AMPER: 7,
857     token.LEFTSHIFT: 6,
858     token.RIGHTSHIFT: 6,
859     token.PLUS: 5,
860     token.MINUS: 5,
861     token.STAR: 4,
862     token.SLASH: 4,
863     token.DOUBLESLASH: 4,
864     token.PERCENT: 4,
865     token.AT: 4,
866     token.TILDE: 3,
867     token.DOUBLESTAR: 2,
868 }
869 DOT_PRIORITY = 1
870
871
872 @dataclass
873 class BracketTracker:
874     """Keeps track of brackets on a line."""
875
876     depth: int = 0
877     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
878     delimiters: Dict[LeafID, Priority] = Factory(dict)
879     previous: Optional[Leaf] = None
880     _for_loop_depths: List[int] = Factory(list)
881     _lambda_argument_depths: List[int] = Factory(list)
882
883     def mark(self, leaf: Leaf) -> None:
884         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
885
886         All leaves receive an int `bracket_depth` field that stores how deep
887         within brackets a given leaf is. 0 means there are no enclosing brackets
888         that started on this line.
889
890         If a leaf is itself a closing bracket, it receives an `opening_bracket`
891         field that it forms a pair with. This is a one-directional link to
892         avoid reference cycles.
893
894         If a leaf is a delimiter (a token on which Black can split the line if
895         needed) and it's on depth 0, its `id()` is stored in the tracker's
896         `delimiters` field.
897         """
898         if leaf.type == token.COMMENT:
899             return
900
901         self.maybe_decrement_after_for_loop_variable(leaf)
902         self.maybe_decrement_after_lambda_arguments(leaf)
903         if leaf.type in CLOSING_BRACKETS:
904             self.depth -= 1
905             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
906             leaf.opening_bracket = opening_bracket
907         leaf.bracket_depth = self.depth
908         if self.depth == 0:
909             delim = is_split_before_delimiter(leaf, self.previous)
910             if delim and self.previous is not None:
911                 self.delimiters[id(self.previous)] = delim
912             else:
913                 delim = is_split_after_delimiter(leaf, self.previous)
914                 if delim:
915                     self.delimiters[id(leaf)] = delim
916         if leaf.type in OPENING_BRACKETS:
917             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
918             self.depth += 1
919         self.previous = leaf
920         self.maybe_increment_lambda_arguments(leaf)
921         self.maybe_increment_for_loop_variable(leaf)
922
923     def any_open_brackets(self) -> bool:
924         """Return True if there is an yet unmatched open bracket on the line."""
925         return bool(self.bracket_match)
926
927     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
928         """Return the highest priority of a delimiter found on the line.
929
930         Values are consistent with what `is_split_*_delimiter()` return.
931         Raises ValueError on no delimiters.
932         """
933         return max(v for k, v in self.delimiters.items() if k not in exclude)
934
935     def delimiter_count_with_priority(self, priority: int = 0) -> int:
936         """Return the number of delimiters with the given `priority`.
937
938         If no `priority` is passed, defaults to max priority on the line.
939         """
940         if not self.delimiters:
941             return 0
942
943         priority = priority or self.max_delimiter_priority()
944         return sum(1 for p in self.delimiters.values() if p == priority)
945
946     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
947         """In a for loop, or comprehension, the variables are often unpacks.
948
949         To avoid splitting on the comma in this situation, increase the depth of
950         tokens between `for` and `in`.
951         """
952         if leaf.type == token.NAME and leaf.value == "for":
953             self.depth += 1
954             self._for_loop_depths.append(self.depth)
955             return True
956
957         return False
958
959     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
960         """See `maybe_increment_for_loop_variable` above for explanation."""
961         if (
962             self._for_loop_depths
963             and self._for_loop_depths[-1] == self.depth
964             and leaf.type == token.NAME
965             and leaf.value == "in"
966         ):
967             self.depth -= 1
968             self._for_loop_depths.pop()
969             return True
970
971         return False
972
973     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
974         """In a lambda expression, there might be more than one argument.
975
976         To avoid splitting on the comma in this situation, increase the depth of
977         tokens between `lambda` and `:`.
978         """
979         if leaf.type == token.NAME and leaf.value == "lambda":
980             self.depth += 1
981             self._lambda_argument_depths.append(self.depth)
982             return True
983
984         return False
985
986     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
987         """See `maybe_increment_lambda_arguments` above for explanation."""
988         if (
989             self._lambda_argument_depths
990             and self._lambda_argument_depths[-1] == self.depth
991             and leaf.type == token.COLON
992         ):
993             self.depth -= 1
994             self._lambda_argument_depths.pop()
995             return True
996
997         return False
998
999     def get_open_lsqb(self) -> Optional[Leaf]:
1000         """Return the most recent opening square bracket (if any)."""
1001         return self.bracket_match.get((self.depth - 1, token.RSQB))
1002
1003
1004 @dataclass
1005 class Line:
1006     """Holds leaves and comments. Can be printed with `str(line)`."""
1007
1008     depth: int = 0
1009     leaves: List[Leaf] = Factory(list)
1010     comments: List[Tuple[Index, Leaf]] = Factory(list)
1011     bracket_tracker: BracketTracker = Factory(BracketTracker)
1012     inside_brackets: bool = False
1013     should_explode: bool = False
1014
1015     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1016         """Add a new `leaf` to the end of the line.
1017
1018         Unless `preformatted` is True, the `leaf` will receive a new consistent
1019         whitespace prefix and metadata applied by :class:`BracketTracker`.
1020         Trailing commas are maybe removed, unpacked for loop variables are
1021         demoted from being delimiters.
1022
1023         Inline comments are put aside.
1024         """
1025         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1026         if not has_value:
1027             return
1028
1029         if token.COLON == leaf.type and self.is_class_paren_empty:
1030             del self.leaves[-2:]
1031         if self.leaves and not preformatted:
1032             # Note: at this point leaf.prefix should be empty except for
1033             # imports, for which we only preserve newlines.
1034             leaf.prefix += whitespace(
1035                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1036             )
1037         if self.inside_brackets or not preformatted:
1038             self.bracket_tracker.mark(leaf)
1039             self.maybe_remove_trailing_comma(leaf)
1040         if not self.append_comment(leaf):
1041             self.leaves.append(leaf)
1042
1043     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1044         """Like :func:`append()` but disallow invalid standalone comment structure.
1045
1046         Raises ValueError when any `leaf` is appended after a standalone comment
1047         or when a standalone comment is not the first leaf on the line.
1048         """
1049         if self.bracket_tracker.depth == 0:
1050             if self.is_comment:
1051                 raise ValueError("cannot append to standalone comments")
1052
1053             if self.leaves and leaf.type == STANDALONE_COMMENT:
1054                 raise ValueError(
1055                     "cannot append standalone comments to a populated line"
1056                 )
1057
1058         self.append(leaf, preformatted=preformatted)
1059
1060     @property
1061     def is_comment(self) -> bool:
1062         """Is this line a standalone comment?"""
1063         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1064
1065     @property
1066     def is_decorator(self) -> bool:
1067         """Is this line a decorator?"""
1068         return bool(self) and self.leaves[0].type == token.AT
1069
1070     @property
1071     def is_import(self) -> bool:
1072         """Is this an import line?"""
1073         return bool(self) and is_import(self.leaves[0])
1074
1075     @property
1076     def is_class(self) -> bool:
1077         """Is this line a class definition?"""
1078         return (
1079             bool(self)
1080             and self.leaves[0].type == token.NAME
1081             and self.leaves[0].value == "class"
1082         )
1083
1084     @property
1085     def is_stub_class(self) -> bool:
1086         """Is this line a class definition with a body consisting only of "..."?"""
1087         return self.is_class and self.leaves[-3:] == [
1088             Leaf(token.DOT, ".") for _ in range(3)
1089         ]
1090
1091     @property
1092     def is_def(self) -> bool:
1093         """Is this a function definition? (Also returns True for async defs.)"""
1094         try:
1095             first_leaf = self.leaves[0]
1096         except IndexError:
1097             return False
1098
1099         try:
1100             second_leaf: Optional[Leaf] = self.leaves[1]
1101         except IndexError:
1102             second_leaf = None
1103         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1104             first_leaf.type == token.ASYNC
1105             and second_leaf is not None
1106             and second_leaf.type == token.NAME
1107             and second_leaf.value == "def"
1108         )
1109
1110     @property
1111     def is_class_paren_empty(self) -> bool:
1112         """Is this a class with no base classes but using parentheses?
1113
1114         Those are unnecessary and should be removed.
1115         """
1116         return (
1117             bool(self)
1118             and len(self.leaves) == 4
1119             and self.is_class
1120             and self.leaves[2].type == token.LPAR
1121             and self.leaves[2].value == "("
1122             and self.leaves[3].type == token.RPAR
1123             and self.leaves[3].value == ")"
1124         )
1125
1126     @property
1127     def is_triple_quoted_string(self) -> bool:
1128         """Is the line a triple quoted string?"""
1129         return (
1130             bool(self)
1131             and self.leaves[0].type == token.STRING
1132             and self.leaves[0].value.startswith(('"""', "'''"))
1133         )
1134
1135     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1136         """If so, needs to be split before emitting."""
1137         for leaf in self.leaves:
1138             if leaf.type == STANDALONE_COMMENT:
1139                 if leaf.bracket_depth <= depth_limit:
1140                     return True
1141
1142         return False
1143
1144     def contains_multiline_strings(self) -> bool:
1145         for leaf in self.leaves:
1146             if is_multiline_string(leaf):
1147                 return True
1148
1149         return False
1150
1151     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1152         """Remove trailing comma if there is one and it's safe."""
1153         if not (
1154             self.leaves
1155             and self.leaves[-1].type == token.COMMA
1156             and closing.type in CLOSING_BRACKETS
1157         ):
1158             return False
1159
1160         if closing.type == token.RBRACE:
1161             self.remove_trailing_comma()
1162             return True
1163
1164         if closing.type == token.RSQB:
1165             comma = self.leaves[-1]
1166             if comma.parent and comma.parent.type == syms.listmaker:
1167                 self.remove_trailing_comma()
1168                 return True
1169
1170         # For parens let's check if it's safe to remove the comma.
1171         # Imports are always safe.
1172         if self.is_import:
1173             self.remove_trailing_comma()
1174             return True
1175
1176         # Otherwise, if the trailing one is the only one, we might mistakenly
1177         # change a tuple into a different type by removing the comma.
1178         depth = closing.bracket_depth + 1
1179         commas = 0
1180         opening = closing.opening_bracket
1181         for _opening_index, leaf in enumerate(self.leaves):
1182             if leaf is opening:
1183                 break
1184
1185         else:
1186             return False
1187
1188         for leaf in self.leaves[_opening_index + 1 :]:
1189             if leaf is closing:
1190                 break
1191
1192             bracket_depth = leaf.bracket_depth
1193             if bracket_depth == depth and leaf.type == token.COMMA:
1194                 commas += 1
1195                 if leaf.parent and leaf.parent.type == syms.arglist:
1196                     commas += 1
1197                     break
1198
1199         if commas > 1:
1200             self.remove_trailing_comma()
1201             return True
1202
1203         return False
1204
1205     def append_comment(self, comment: Leaf) -> bool:
1206         """Add an inline or standalone comment to the line."""
1207         if (
1208             comment.type == STANDALONE_COMMENT
1209             and self.bracket_tracker.any_open_brackets()
1210         ):
1211             comment.prefix = ""
1212             return False
1213
1214         if comment.type != token.COMMENT:
1215             return False
1216
1217         after = len(self.leaves) - 1
1218         if after == -1:
1219             comment.type = STANDALONE_COMMENT
1220             comment.prefix = ""
1221             return False
1222
1223         else:
1224             self.comments.append((after, comment))
1225             return True
1226
1227     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1228         """Generate comments that should appear directly after `leaf`.
1229
1230         Provide a non-negative leaf `_index` to speed up the function.
1231         """
1232         if not self.comments:
1233             return
1234
1235         if _index == -1:
1236             for _index, _leaf in enumerate(self.leaves):
1237                 if leaf is _leaf:
1238                     break
1239
1240             else:
1241                 return
1242
1243         for index, comment_after in self.comments:
1244             if _index == index:
1245                 yield comment_after
1246
1247     def remove_trailing_comma(self) -> None:
1248         """Remove the trailing comma and moves the comments attached to it."""
1249         comma_index = len(self.leaves) - 1
1250         for i in range(len(self.comments)):
1251             comment_index, comment = self.comments[i]
1252             if comment_index == comma_index:
1253                 self.comments[i] = (comma_index - 1, comment)
1254         self.leaves.pop()
1255
1256     def is_complex_subscript(self, leaf: Leaf) -> bool:
1257         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1258         open_lsqb = self.bracket_tracker.get_open_lsqb()
1259         if open_lsqb is None:
1260             return False
1261
1262         subscript_start = open_lsqb.next_sibling
1263
1264         if isinstance(subscript_start, Node):
1265             if subscript_start.type == syms.listmaker:
1266                 return False
1267
1268             if subscript_start.type == syms.subscriptlist:
1269                 subscript_start = child_towards(subscript_start, leaf)
1270         return subscript_start is not None and any(
1271             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1272         )
1273
1274     def __str__(self) -> str:
1275         """Render the line."""
1276         if not self:
1277             return "\n"
1278
1279         indent = "    " * self.depth
1280         leaves = iter(self.leaves)
1281         first = next(leaves)
1282         res = f"{first.prefix}{indent}{first.value}"
1283         for leaf in leaves:
1284             res += str(leaf)
1285         for _, comment in self.comments:
1286             res += str(comment)
1287         return res + "\n"
1288
1289     def __bool__(self) -> bool:
1290         """Return True if the line has leaves or comments."""
1291         return bool(self.leaves or self.comments)
1292
1293
1294 @dataclass
1295 class EmptyLineTracker:
1296     """Provides a stateful method that returns the number of potential extra
1297     empty lines needed before and after the currently processed line.
1298
1299     Note: this tracker works on lines that haven't been split yet.  It assumes
1300     the prefix of the first leaf consists of optional newlines.  Those newlines
1301     are consumed by `maybe_empty_lines()` and included in the computation.
1302     """
1303
1304     is_pyi: bool = False
1305     previous_line: Optional[Line] = None
1306     previous_after: int = 0
1307     previous_defs: List[int] = Factory(list)
1308
1309     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1310         """Return the number of extra empty lines before and after the `current_line`.
1311
1312         This is for separating `def`, `async def` and `class` with extra empty
1313         lines (two on module-level).
1314         """
1315         before, after = self._maybe_empty_lines(current_line)
1316         before -= self.previous_after
1317         self.previous_after = after
1318         self.previous_line = current_line
1319         return before, after
1320
1321     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1322         max_allowed = 1
1323         if current_line.depth == 0:
1324             max_allowed = 1 if self.is_pyi else 2
1325         if current_line.leaves:
1326             # Consume the first leaf's extra newlines.
1327             first_leaf = current_line.leaves[0]
1328             before = first_leaf.prefix.count("\n")
1329             before = min(before, max_allowed)
1330             first_leaf.prefix = ""
1331         else:
1332             before = 0
1333         depth = current_line.depth
1334         while self.previous_defs and self.previous_defs[-1] >= depth:
1335             self.previous_defs.pop()
1336             if self.is_pyi:
1337                 before = 0 if depth else 1
1338             else:
1339                 before = 1 if depth else 2
1340         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1341             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1342
1343         if (
1344             self.previous_line
1345             and self.previous_line.is_import
1346             and not current_line.is_import
1347             and depth == self.previous_line.depth
1348         ):
1349             return (before or 1), 0
1350
1351         if (
1352             self.previous_line
1353             and self.previous_line.is_class
1354             and current_line.is_triple_quoted_string
1355         ):
1356             return before, 1
1357
1358         return before, 0
1359
1360     def _maybe_empty_lines_for_class_or_def(
1361         self, current_line: Line, before: int
1362     ) -> Tuple[int, int]:
1363         if not current_line.is_decorator:
1364             self.previous_defs.append(current_line.depth)
1365         if self.previous_line is None:
1366             # Don't insert empty lines before the first line in the file.
1367             return 0, 0
1368
1369         if self.previous_line.is_decorator:
1370             return 0, 0
1371
1372         if self.previous_line.depth < current_line.depth and (
1373             self.previous_line.is_class or self.previous_line.is_def
1374         ):
1375             return 0, 0
1376
1377         if (
1378             self.previous_line.is_comment
1379             and self.previous_line.depth == current_line.depth
1380             and before == 0
1381         ):
1382             return 0, 0
1383
1384         if self.is_pyi:
1385             if self.previous_line.depth > current_line.depth:
1386                 newlines = 1
1387             elif current_line.is_class or self.previous_line.is_class:
1388                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1389                     # No blank line between classes with an empty body
1390                     newlines = 0
1391                 else:
1392                     newlines = 1
1393             elif current_line.is_def and not self.previous_line.is_def:
1394                 # Blank line between a block of functions and a block of non-functions
1395                 newlines = 1
1396             else:
1397                 newlines = 0
1398         else:
1399             newlines = 2
1400         if current_line.depth and newlines:
1401             newlines -= 1
1402         return newlines, 0
1403
1404
1405 @dataclass
1406 class LineGenerator(Visitor[Line]):
1407     """Generates reformatted Line objects.  Empty lines are not emitted.
1408
1409     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1410     in ways that will no longer stringify to valid Python code on the tree.
1411     """
1412
1413     is_pyi: bool = False
1414     normalize_strings: bool = True
1415     current_line: Line = Factory(Line)
1416     remove_u_prefix: bool = False
1417     allow_underscores: bool = False
1418
1419     def line(self, indent: int = 0) -> Iterator[Line]:
1420         """Generate a line.
1421
1422         If the line is empty, only emit if it makes sense.
1423         If the line is too long, split it first and then generate.
1424
1425         If any lines were generated, set up a new current_line.
1426         """
1427         if not self.current_line:
1428             self.current_line.depth += indent
1429             return  # Line is empty, don't emit. Creating a new one unnecessary.
1430
1431         complete_line = self.current_line
1432         self.current_line = Line(depth=complete_line.depth + indent)
1433         yield complete_line
1434
1435     def visit_default(self, node: LN) -> Iterator[Line]:
1436         """Default `visit_*()` implementation. Recurses to children of `node`."""
1437         if isinstance(node, Leaf):
1438             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1439             for comment in generate_comments(node):
1440                 if any_open_brackets:
1441                     # any comment within brackets is subject to splitting
1442                     self.current_line.append(comment)
1443                 elif comment.type == token.COMMENT:
1444                     # regular trailing comment
1445                     self.current_line.append(comment)
1446                     yield from self.line()
1447
1448                 else:
1449                     # regular standalone comment
1450                     yield from self.line()
1451
1452                     self.current_line.append(comment)
1453                     yield from self.line()
1454
1455             normalize_prefix(node, inside_brackets=any_open_brackets)
1456             if self.normalize_strings and node.type == token.STRING:
1457                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1458                 normalize_string_quotes(node)
1459             if node.type == token.NUMBER:
1460                 normalize_numeric_literal(node, self.allow_underscores)
1461             if node.type not in WHITESPACE:
1462                 self.current_line.append(node)
1463         yield from super().visit_default(node)
1464
1465     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1466         """Increase indentation level, maybe yield a line."""
1467         # In blib2to3 INDENT never holds comments.
1468         yield from self.line(+1)
1469         yield from self.visit_default(node)
1470
1471     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1472         """Decrease indentation level, maybe yield a line."""
1473         # The current line might still wait for trailing comments.  At DEDENT time
1474         # there won't be any (they would be prefixes on the preceding NEWLINE).
1475         # Emit the line then.
1476         yield from self.line()
1477
1478         # While DEDENT has no value, its prefix may contain standalone comments
1479         # that belong to the current indentation level.  Get 'em.
1480         yield from self.visit_default(node)
1481
1482         # Finally, emit the dedent.
1483         yield from self.line(-1)
1484
1485     def visit_stmt(
1486         self, node: Node, keywords: Set[str], parens: Set[str]
1487     ) -> Iterator[Line]:
1488         """Visit a statement.
1489
1490         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1491         `def`, `with`, `class`, `assert` and assignments.
1492
1493         The relevant Python language `keywords` for a given statement will be
1494         NAME leaves within it. This methods puts those on a separate line.
1495
1496         `parens` holds a set of string leaf values immediately after which
1497         invisible parens should be put.
1498         """
1499         normalize_invisible_parens(node, parens_after=parens)
1500         for child in node.children:
1501             if child.type == token.NAME and child.value in keywords:  # type: ignore
1502                 yield from self.line()
1503
1504             yield from self.visit(child)
1505
1506     def visit_suite(self, node: Node) -> Iterator[Line]:
1507         """Visit a suite."""
1508         if self.is_pyi and is_stub_suite(node):
1509             yield from self.visit(node.children[2])
1510         else:
1511             yield from self.visit_default(node)
1512
1513     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1514         """Visit a statement without nested statements."""
1515         is_suite_like = node.parent and node.parent.type in STATEMENT
1516         if is_suite_like:
1517             if self.is_pyi and is_stub_body(node):
1518                 yield from self.visit_default(node)
1519             else:
1520                 yield from self.line(+1)
1521                 yield from self.visit_default(node)
1522                 yield from self.line(-1)
1523
1524         else:
1525             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1526                 yield from self.line()
1527             yield from self.visit_default(node)
1528
1529     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1530         """Visit `async def`, `async for`, `async with`."""
1531         yield from self.line()
1532
1533         children = iter(node.children)
1534         for child in children:
1535             yield from self.visit(child)
1536
1537             if child.type == token.ASYNC:
1538                 break
1539
1540         internal_stmt = next(children)
1541         for child in internal_stmt.children:
1542             yield from self.visit(child)
1543
1544     def visit_decorators(self, node: Node) -> Iterator[Line]:
1545         """Visit decorators."""
1546         for child in node.children:
1547             yield from self.line()
1548             yield from self.visit(child)
1549
1550     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1551         """Remove a semicolon and put the other statement on a separate line."""
1552         yield from self.line()
1553
1554     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1555         """End of file. Process outstanding comments and end with a newline."""
1556         yield from self.visit_default(leaf)
1557         yield from self.line()
1558
1559     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1560         if not self.current_line.bracket_tracker.any_open_brackets():
1561             yield from self.line()
1562         yield from self.visit_default(leaf)
1563
1564     def __attrs_post_init__(self) -> None:
1565         """You are in a twisty little maze of passages."""
1566         v = self.visit_stmt
1567         Ø: Set[str] = set()
1568         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1569         self.visit_if_stmt = partial(
1570             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1571         )
1572         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1573         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1574         self.visit_try_stmt = partial(
1575             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1576         )
1577         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1578         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1579         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1580         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1581         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1582         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1583         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1584         self.visit_async_funcdef = self.visit_async_stmt
1585         self.visit_decorated = self.visit_decorators
1586
1587
1588 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1589 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1590 OPENING_BRACKETS = set(BRACKET.keys())
1591 CLOSING_BRACKETS = set(BRACKET.values())
1592 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1593 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1594
1595
1596 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1597     """Return whitespace prefix if needed for the given `leaf`.
1598
1599     `complex_subscript` signals whether the given leaf is part of a subscription
1600     which has non-trivial arguments, like arithmetic expressions or function calls.
1601     """
1602     NO = ""
1603     SPACE = " "
1604     DOUBLESPACE = "  "
1605     t = leaf.type
1606     p = leaf.parent
1607     v = leaf.value
1608     if t in ALWAYS_NO_SPACE:
1609         return NO
1610
1611     if t == token.COMMENT:
1612         return DOUBLESPACE
1613
1614     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1615     if t == token.COLON and p.type not in {
1616         syms.subscript,
1617         syms.subscriptlist,
1618         syms.sliceop,
1619     }:
1620         return NO
1621
1622     prev = leaf.prev_sibling
1623     if not prev:
1624         prevp = preceding_leaf(p)
1625         if not prevp or prevp.type in OPENING_BRACKETS:
1626             return NO
1627
1628         if t == token.COLON:
1629             if prevp.type == token.COLON:
1630                 return NO
1631
1632             elif prevp.type != token.COMMA and not complex_subscript:
1633                 return NO
1634
1635             return SPACE
1636
1637         if prevp.type == token.EQUAL:
1638             if prevp.parent:
1639                 if prevp.parent.type in {
1640                     syms.arglist,
1641                     syms.argument,
1642                     syms.parameters,
1643                     syms.varargslist,
1644                 }:
1645                     return NO
1646
1647                 elif prevp.parent.type == syms.typedargslist:
1648                     # A bit hacky: if the equal sign has whitespace, it means we
1649                     # previously found it's a typed argument.  So, we're using
1650                     # that, too.
1651                     return prevp.prefix
1652
1653         elif prevp.type in STARS:
1654             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1655                 return NO
1656
1657         elif prevp.type == token.COLON:
1658             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1659                 return SPACE if complex_subscript else NO
1660
1661         elif (
1662             prevp.parent
1663             and prevp.parent.type == syms.factor
1664             and prevp.type in MATH_OPERATORS
1665         ):
1666             return NO
1667
1668         elif (
1669             prevp.type == token.RIGHTSHIFT
1670             and prevp.parent
1671             and prevp.parent.type == syms.shift_expr
1672             and prevp.prev_sibling
1673             and prevp.prev_sibling.type == token.NAME
1674             and prevp.prev_sibling.value == "print"  # type: ignore
1675         ):
1676             # Python 2 print chevron
1677             return NO
1678
1679     elif prev.type in OPENING_BRACKETS:
1680         return NO
1681
1682     if p.type in {syms.parameters, syms.arglist}:
1683         # untyped function signatures or calls
1684         if not prev or prev.type != token.COMMA:
1685             return NO
1686
1687     elif p.type == syms.varargslist:
1688         # lambdas
1689         if prev and prev.type != token.COMMA:
1690             return NO
1691
1692     elif p.type == syms.typedargslist:
1693         # typed function signatures
1694         if not prev:
1695             return NO
1696
1697         if t == token.EQUAL:
1698             if prev.type != syms.tname:
1699                 return NO
1700
1701         elif prev.type == token.EQUAL:
1702             # A bit hacky: if the equal sign has whitespace, it means we
1703             # previously found it's a typed argument.  So, we're using that, too.
1704             return prev.prefix
1705
1706         elif prev.type != token.COMMA:
1707             return NO
1708
1709     elif p.type == syms.tname:
1710         # type names
1711         if not prev:
1712             prevp = preceding_leaf(p)
1713             if not prevp or prevp.type != token.COMMA:
1714                 return NO
1715
1716     elif p.type == syms.trailer:
1717         # attributes and calls
1718         if t == token.LPAR or t == token.RPAR:
1719             return NO
1720
1721         if not prev:
1722             if t == token.DOT:
1723                 prevp = preceding_leaf(p)
1724                 if not prevp or prevp.type != token.NUMBER:
1725                     return NO
1726
1727             elif t == token.LSQB:
1728                 return NO
1729
1730         elif prev.type != token.COMMA:
1731             return NO
1732
1733     elif p.type == syms.argument:
1734         # single argument
1735         if t == token.EQUAL:
1736             return NO
1737
1738         if not prev:
1739             prevp = preceding_leaf(p)
1740             if not prevp or prevp.type == token.LPAR:
1741                 return NO
1742
1743         elif prev.type in {token.EQUAL} | STARS:
1744             return NO
1745
1746     elif p.type == syms.decorator:
1747         # decorators
1748         return NO
1749
1750     elif p.type == syms.dotted_name:
1751         if prev:
1752             return NO
1753
1754         prevp = preceding_leaf(p)
1755         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1756             return NO
1757
1758     elif p.type == syms.classdef:
1759         if t == token.LPAR:
1760             return NO
1761
1762         if prev and prev.type == token.LPAR:
1763             return NO
1764
1765     elif p.type in {syms.subscript, syms.sliceop}:
1766         # indexing
1767         if not prev:
1768             assert p.parent is not None, "subscripts are always parented"
1769             if p.parent.type == syms.subscriptlist:
1770                 return SPACE
1771
1772             return NO
1773
1774         elif not complex_subscript:
1775             return NO
1776
1777     elif p.type == syms.atom:
1778         if prev and t == token.DOT:
1779             # dots, but not the first one.
1780             return NO
1781
1782     elif p.type == syms.dictsetmaker:
1783         # dict unpacking
1784         if prev and prev.type == token.DOUBLESTAR:
1785             return NO
1786
1787     elif p.type in {syms.factor, syms.star_expr}:
1788         # unary ops
1789         if not prev:
1790             prevp = preceding_leaf(p)
1791             if not prevp or prevp.type in OPENING_BRACKETS:
1792                 return NO
1793
1794             prevp_parent = prevp.parent
1795             assert prevp_parent is not None
1796             if prevp.type == token.COLON and prevp_parent.type in {
1797                 syms.subscript,
1798                 syms.sliceop,
1799             }:
1800                 return NO
1801
1802             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1803                 return NO
1804
1805         elif t in {token.NAME, token.NUMBER, token.STRING}:
1806             return NO
1807
1808     elif p.type == syms.import_from:
1809         if t == token.DOT:
1810             if prev and prev.type == token.DOT:
1811                 return NO
1812
1813         elif t == token.NAME:
1814             if v == "import":
1815                 return SPACE
1816
1817             if prev and prev.type == token.DOT:
1818                 return NO
1819
1820     elif p.type == syms.sliceop:
1821         return NO
1822
1823     return SPACE
1824
1825
1826 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1827     """Return the first leaf that precedes `node`, if any."""
1828     while node:
1829         res = node.prev_sibling
1830         if res:
1831             if isinstance(res, Leaf):
1832                 return res
1833
1834             try:
1835                 return list(res.leaves())[-1]
1836
1837             except IndexError:
1838                 return None
1839
1840         node = node.parent
1841     return None
1842
1843
1844 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1845     """Return the child of `ancestor` that contains `descendant`."""
1846     node: Optional[LN] = descendant
1847     while node and node.parent != ancestor:
1848         node = node.parent
1849     return node
1850
1851
1852 def container_of(leaf: Leaf) -> LN:
1853     """Return `leaf` or one of its ancestors that is the topmost container of it.
1854
1855     By "container" we mean a node where `leaf` is the very first child.
1856     """
1857     same_prefix = leaf.prefix
1858     container: LN = leaf
1859     while container:
1860         parent = container.parent
1861         if parent is None:
1862             break
1863
1864         if parent.children[0].prefix != same_prefix:
1865             break
1866
1867         if parent.type == syms.file_input:
1868             break
1869
1870         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1871             break
1872
1873         container = parent
1874     return container
1875
1876
1877 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1878     """Return the priority of the `leaf` delimiter, given a line break after it.
1879
1880     The delimiter priorities returned here are from those delimiters that would
1881     cause a line break after themselves.
1882
1883     Higher numbers are higher priority.
1884     """
1885     if leaf.type == token.COMMA:
1886         return COMMA_PRIORITY
1887
1888     return 0
1889
1890
1891 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1892     """Return the priority of the `leaf` delimiter, given a line before after it.
1893
1894     The delimiter priorities returned here are from those delimiters that would
1895     cause a line break before themselves.
1896
1897     Higher numbers are higher priority.
1898     """
1899     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1900         # * and ** might also be MATH_OPERATORS but in this case they are not.
1901         # Don't treat them as a delimiter.
1902         return 0
1903
1904     if (
1905         leaf.type == token.DOT
1906         and leaf.parent
1907         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1908         and (previous is None or previous.type in CLOSING_BRACKETS)
1909     ):
1910         return DOT_PRIORITY
1911
1912     if (
1913         leaf.type in MATH_OPERATORS
1914         and leaf.parent
1915         and leaf.parent.type not in {syms.factor, syms.star_expr}
1916     ):
1917         return MATH_PRIORITIES[leaf.type]
1918
1919     if leaf.type in COMPARATORS:
1920         return COMPARATOR_PRIORITY
1921
1922     if (
1923         leaf.type == token.STRING
1924         and previous is not None
1925         and previous.type == token.STRING
1926     ):
1927         return STRING_PRIORITY
1928
1929     if leaf.type != token.NAME:
1930         return 0
1931
1932     if (
1933         leaf.value == "for"
1934         and leaf.parent
1935         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1936     ):
1937         return COMPREHENSION_PRIORITY
1938
1939     if (
1940         leaf.value == "if"
1941         and leaf.parent
1942         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1943     ):
1944         return COMPREHENSION_PRIORITY
1945
1946     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1947         return TERNARY_PRIORITY
1948
1949     if leaf.value == "is":
1950         return COMPARATOR_PRIORITY
1951
1952     if (
1953         leaf.value == "in"
1954         and leaf.parent
1955         and leaf.parent.type in {syms.comp_op, syms.comparison}
1956         and not (
1957             previous is not None
1958             and previous.type == token.NAME
1959             and previous.value == "not"
1960         )
1961     ):
1962         return COMPARATOR_PRIORITY
1963
1964     if (
1965         leaf.value == "not"
1966         and leaf.parent
1967         and leaf.parent.type == syms.comp_op
1968         and not (
1969             previous is not None
1970             and previous.type == token.NAME
1971             and previous.value == "is"
1972         )
1973     ):
1974         return COMPARATOR_PRIORITY
1975
1976     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1977         return LOGIC_PRIORITY
1978
1979     return 0
1980
1981
1982 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
1983 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
1984
1985
1986 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1987     """Clean the prefix of the `leaf` and generate comments from it, if any.
1988
1989     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1990     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1991     move because it does away with modifying the grammar to include all the
1992     possible places in which comments can be placed.
1993
1994     The sad consequence for us though is that comments don't "belong" anywhere.
1995     This is why this function generates simple parentless Leaf objects for
1996     comments.  We simply don't know what the correct parent should be.
1997
1998     No matter though, we can live without this.  We really only need to
1999     differentiate between inline and standalone comments.  The latter don't
2000     share the line with any code.
2001
2002     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2003     are emitted with a fake STANDALONE_COMMENT token identifier.
2004     """
2005     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2006         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2007
2008
2009 @dataclass
2010 class ProtoComment:
2011     type: int  # token.COMMENT or STANDALONE_COMMENT
2012     value: str  # content of the comment
2013     newlines: int  # how many newlines before the comment
2014     consumed: int  # how many characters of the original leaf's prefix did we consume
2015
2016
2017 @lru_cache(maxsize=4096)
2018 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2019     result: List[ProtoComment] = []
2020     if not prefix or "#" not in prefix:
2021         return result
2022
2023     consumed = 0
2024     nlines = 0
2025     for index, line in enumerate(prefix.split("\n")):
2026         consumed += len(line) + 1  # adding the length of the split '\n'
2027         line = line.lstrip()
2028         if not line:
2029             nlines += 1
2030         if not line.startswith("#"):
2031             continue
2032
2033         if index == 0 and not is_endmarker:
2034             comment_type = token.COMMENT  # simple trailing comment
2035         else:
2036             comment_type = STANDALONE_COMMENT
2037         comment = make_comment(line)
2038         result.append(
2039             ProtoComment(
2040                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2041             )
2042         )
2043         nlines = 0
2044     return result
2045
2046
2047 def make_comment(content: str) -> str:
2048     """Return a consistently formatted comment from the given `content` string.
2049
2050     All comments (except for "##", "#!", "#:") should have a single space between
2051     the hash sign and the content.
2052
2053     If `content` didn't start with a hash sign, one is provided.
2054     """
2055     content = content.rstrip()
2056     if not content:
2057         return "#"
2058
2059     if content[0] == "#":
2060         content = content[1:]
2061     if content and content[0] not in " !:#":
2062         content = " " + content
2063     return "#" + content
2064
2065
2066 def split_line(
2067     line: Line, line_length: int, inner: bool = False, py36: bool = False
2068 ) -> Iterator[Line]:
2069     """Split a `line` into potentially many lines.
2070
2071     They should fit in the allotted `line_length` but might not be able to.
2072     `inner` signifies that there were a pair of brackets somewhere around the
2073     current `line`, possibly transitively. This means we can fallback to splitting
2074     by delimiters if the LHS/RHS don't yield any results.
2075
2076     If `py36` is True, splitting may generate syntax that is only compatible
2077     with Python 3.6 and later.
2078     """
2079     if line.is_comment:
2080         yield line
2081         return
2082
2083     line_str = str(line).strip("\n")
2084     if not line.should_explode and is_line_short_enough(
2085         line, line_length=line_length, line_str=line_str
2086     ):
2087         yield line
2088         return
2089
2090     split_funcs: List[SplitFunc]
2091     if line.is_def:
2092         split_funcs = [left_hand_split]
2093     else:
2094
2095         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2096             for omit in generate_trailers_to_omit(line, line_length):
2097                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2098                 if is_line_short_enough(lines[0], line_length=line_length):
2099                     yield from lines
2100                     return
2101
2102             # All splits failed, best effort split with no omits.
2103             # This mostly happens to multiline strings that are by definition
2104             # reported as not fitting a single line.
2105             yield from right_hand_split(line, py36)
2106
2107         if line.inside_brackets:
2108             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2109         else:
2110             split_funcs = [rhs]
2111     for split_func in split_funcs:
2112         # We are accumulating lines in `result` because we might want to abort
2113         # mission and return the original line in the end, or attempt a different
2114         # split altogether.
2115         result: List[Line] = []
2116         try:
2117             for l in split_func(line, py36):
2118                 if str(l).strip("\n") == line_str:
2119                     raise CannotSplit("Split function returned an unchanged result")
2120
2121                 result.extend(
2122                     split_line(l, line_length=line_length, inner=True, py36=py36)
2123                 )
2124         except CannotSplit as cs:
2125             continue
2126
2127         else:
2128             yield from result
2129             break
2130
2131     else:
2132         yield line
2133
2134
2135 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2136     """Split line into many lines, starting with the first matching bracket pair.
2137
2138     Note: this usually looks weird, only use this for function definitions.
2139     Prefer RHS otherwise.  This is why this function is not symmetrical with
2140     :func:`right_hand_split` which also handles optional parentheses.
2141     """
2142     head = Line(depth=line.depth)
2143     body = Line(depth=line.depth + 1, inside_brackets=True)
2144     tail = Line(depth=line.depth)
2145     tail_leaves: List[Leaf] = []
2146     body_leaves: List[Leaf] = []
2147     head_leaves: List[Leaf] = []
2148     current_leaves = head_leaves
2149     matching_bracket = None
2150     for leaf in line.leaves:
2151         if (
2152             current_leaves is body_leaves
2153             and leaf.type in CLOSING_BRACKETS
2154             and leaf.opening_bracket is matching_bracket
2155         ):
2156             current_leaves = tail_leaves if body_leaves else head_leaves
2157         current_leaves.append(leaf)
2158         if current_leaves is head_leaves:
2159             if leaf.type in OPENING_BRACKETS:
2160                 matching_bracket = leaf
2161                 current_leaves = body_leaves
2162     # Since body is a new indent level, remove spurious leading whitespace.
2163     if body_leaves:
2164         normalize_prefix(body_leaves[0], inside_brackets=True)
2165     # Build the new lines.
2166     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2167         for leaf in leaves:
2168             result.append(leaf, preformatted=True)
2169             for comment_after in line.comments_after(leaf):
2170                 result.append(comment_after, preformatted=True)
2171     bracket_split_succeeded_or_raise(head, body, tail)
2172     for result in (head, body, tail):
2173         if result:
2174             yield result
2175
2176
2177 def right_hand_split(
2178     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2179 ) -> Iterator[Line]:
2180     """Split line into many lines, starting with the last matching bracket pair.
2181
2182     If the split was by optional parentheses, attempt splitting without them, too.
2183     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2184     this split.
2185
2186     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2187     """
2188     head = Line(depth=line.depth)
2189     body = Line(depth=line.depth + 1, inside_brackets=True)
2190     tail = Line(depth=line.depth)
2191     tail_leaves: List[Leaf] = []
2192     body_leaves: List[Leaf] = []
2193     head_leaves: List[Leaf] = []
2194     current_leaves = tail_leaves
2195     opening_bracket = None
2196     closing_bracket = None
2197     for leaf in reversed(line.leaves):
2198         if current_leaves is body_leaves:
2199             if leaf is opening_bracket:
2200                 current_leaves = head_leaves if body_leaves else tail_leaves
2201         current_leaves.append(leaf)
2202         if current_leaves is tail_leaves:
2203             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2204                 opening_bracket = leaf.opening_bracket
2205                 closing_bracket = leaf
2206                 current_leaves = body_leaves
2207     tail_leaves.reverse()
2208     body_leaves.reverse()
2209     head_leaves.reverse()
2210     # Since body is a new indent level, remove spurious leading whitespace.
2211     if body_leaves:
2212         normalize_prefix(body_leaves[0], inside_brackets=True)
2213     if not head_leaves:
2214         # No `head` means the split failed. Either `tail` has all content or
2215         # the matching `opening_bracket` wasn't available on `line` anymore.
2216         raise CannotSplit("No brackets found")
2217
2218     # Build the new lines.
2219     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2220         for leaf in leaves:
2221             result.append(leaf, preformatted=True)
2222             for comment_after in line.comments_after(leaf):
2223                 result.append(comment_after, preformatted=True)
2224     assert opening_bracket and closing_bracket
2225     body.should_explode = should_explode(body, opening_bracket)
2226     bracket_split_succeeded_or_raise(head, body, tail)
2227     if (
2228         # the body shouldn't be exploded
2229         not body.should_explode
2230         # the opening bracket is an optional paren
2231         and opening_bracket.type == token.LPAR
2232         and not opening_bracket.value
2233         # the closing bracket is an optional paren
2234         and closing_bracket.type == token.RPAR
2235         and not closing_bracket.value
2236         # it's not an import (optional parens are the only thing we can split on
2237         # in this case; attempting a split without them is a waste of time)
2238         and not line.is_import
2239         # there are no standalone comments in the body
2240         and not body.contains_standalone_comments(0)
2241         # and we can actually remove the parens
2242         and can_omit_invisible_parens(body, line_length)
2243     ):
2244         omit = {id(closing_bracket), *omit}
2245         try:
2246             yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2247             return
2248
2249         except CannotSplit:
2250             if not (
2251                 can_be_split(body)
2252                 or is_line_short_enough(body, line_length=line_length)
2253             ):
2254                 raise CannotSplit(
2255                     "Splitting failed, body is still too long and can't be split."
2256                 )
2257
2258             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2259                 raise CannotSplit(
2260                     "The current optional pair of parentheses is bound to fail to "
2261                     "satisfy the splitting algorithm because the head or the tail "
2262                     "contains multiline strings which by definition never fit one "
2263                     "line."
2264                 )
2265
2266     ensure_visible(opening_bracket)
2267     ensure_visible(closing_bracket)
2268     for result in (head, body, tail):
2269         if result:
2270             yield result
2271
2272
2273 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2274     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2275
2276     Do nothing otherwise.
2277
2278     A left- or right-hand split is based on a pair of brackets. Content before
2279     (and including) the opening bracket is left on one line, content inside the
2280     brackets is put on a separate line, and finally content starting with and
2281     following the closing bracket is put on a separate line.
2282
2283     Those are called `head`, `body`, and `tail`, respectively. If the split
2284     produced the same line (all content in `head`) or ended up with an empty `body`
2285     and the `tail` is just the closing bracket, then it's considered failed.
2286     """
2287     tail_len = len(str(tail).strip())
2288     if not body:
2289         if tail_len == 0:
2290             raise CannotSplit("Splitting brackets produced the same line")
2291
2292         elif tail_len < 3:
2293             raise CannotSplit(
2294                 f"Splitting brackets on an empty body to save "
2295                 f"{tail_len} characters is not worth it"
2296             )
2297
2298
2299 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2300     """Normalize prefix of the first leaf in every line returned by `split_func`.
2301
2302     This is a decorator over relevant split functions.
2303     """
2304
2305     @wraps(split_func)
2306     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2307         for l in split_func(line, py36):
2308             normalize_prefix(l.leaves[0], inside_brackets=True)
2309             yield l
2310
2311     return split_wrapper
2312
2313
2314 @dont_increase_indentation
2315 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2316     """Split according to delimiters of the highest priority.
2317
2318     If `py36` is True, the split will add trailing commas also in function
2319     signatures that contain `*` and `**`.
2320     """
2321     try:
2322         last_leaf = line.leaves[-1]
2323     except IndexError:
2324         raise CannotSplit("Line empty")
2325
2326     bt = line.bracket_tracker
2327     try:
2328         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2329     except ValueError:
2330         raise CannotSplit("No delimiters found")
2331
2332     if delimiter_priority == DOT_PRIORITY:
2333         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2334             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2335
2336     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2337     lowest_depth = sys.maxsize
2338     trailing_comma_safe = True
2339
2340     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2341         """Append `leaf` to current line or to new line if appending impossible."""
2342         nonlocal current_line
2343         try:
2344             current_line.append_safe(leaf, preformatted=True)
2345         except ValueError as ve:
2346             yield current_line
2347
2348             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2349             current_line.append(leaf)
2350
2351     for index, leaf in enumerate(line.leaves):
2352         yield from append_to_line(leaf)
2353
2354         for comment_after in line.comments_after(leaf, index):
2355             yield from append_to_line(comment_after)
2356
2357         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2358         if leaf.bracket_depth == lowest_depth and is_vararg(
2359             leaf, within=VARARGS_PARENTS
2360         ):
2361             trailing_comma_safe = trailing_comma_safe and py36
2362         leaf_priority = bt.delimiters.get(id(leaf))
2363         if leaf_priority == delimiter_priority:
2364             yield current_line
2365
2366             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2367     if current_line:
2368         if (
2369             trailing_comma_safe
2370             and delimiter_priority == COMMA_PRIORITY
2371             and current_line.leaves[-1].type != token.COMMA
2372             and current_line.leaves[-1].type != STANDALONE_COMMENT
2373         ):
2374             current_line.append(Leaf(token.COMMA, ","))
2375         yield current_line
2376
2377
2378 @dont_increase_indentation
2379 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2380     """Split standalone comments from the rest of the line."""
2381     if not line.contains_standalone_comments(0):
2382         raise CannotSplit("Line does not have any standalone comments")
2383
2384     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2385
2386     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2387         """Append `leaf` to current line or to new line if appending impossible."""
2388         nonlocal current_line
2389         try:
2390             current_line.append_safe(leaf, preformatted=True)
2391         except ValueError as ve:
2392             yield current_line
2393
2394             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2395             current_line.append(leaf)
2396
2397     for index, leaf in enumerate(line.leaves):
2398         yield from append_to_line(leaf)
2399
2400         for comment_after in line.comments_after(leaf, index):
2401             yield from append_to_line(comment_after)
2402
2403     if current_line:
2404         yield current_line
2405
2406
2407 def is_import(leaf: Leaf) -> bool:
2408     """Return True if the given leaf starts an import statement."""
2409     p = leaf.parent
2410     t = leaf.type
2411     v = leaf.value
2412     return bool(
2413         t == token.NAME
2414         and (
2415             (v == "import" and p and p.type == syms.import_name)
2416             or (v == "from" and p and p.type == syms.import_from)
2417         )
2418     )
2419
2420
2421 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2422     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2423     else.
2424
2425     Note: don't use backslashes for formatting or you'll lose your voting rights.
2426     """
2427     if not inside_brackets:
2428         spl = leaf.prefix.split("#")
2429         if "\\" not in spl[0]:
2430             nl_count = spl[-1].count("\n")
2431             if len(spl) > 1:
2432                 nl_count -= 1
2433             leaf.prefix = "\n" * nl_count
2434             return
2435
2436     leaf.prefix = ""
2437
2438
2439 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2440     """Make all string prefixes lowercase.
2441
2442     If remove_u_prefix is given, also removes any u prefix from the string.
2443
2444     Note: Mutates its argument.
2445     """
2446     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2447     assert match is not None, f"failed to match string {leaf.value!r}"
2448     orig_prefix = match.group(1)
2449     new_prefix = orig_prefix.lower()
2450     if remove_u_prefix:
2451         new_prefix = new_prefix.replace("u", "")
2452     leaf.value = f"{new_prefix}{match.group(2)}"
2453
2454
2455 def normalize_string_quotes(leaf: Leaf) -> None:
2456     """Prefer double quotes but only if it doesn't cause more escaping.
2457
2458     Adds or removes backslashes as appropriate. Doesn't parse and fix
2459     strings nested in f-strings (yet).
2460
2461     Note: Mutates its argument.
2462     """
2463     value = leaf.value.lstrip("furbFURB")
2464     if value[:3] == '"""':
2465         return
2466
2467     elif value[:3] == "'''":
2468         orig_quote = "'''"
2469         new_quote = '"""'
2470     elif value[0] == '"':
2471         orig_quote = '"'
2472         new_quote = "'"
2473     else:
2474         orig_quote = "'"
2475         new_quote = '"'
2476     first_quote_pos = leaf.value.find(orig_quote)
2477     if first_quote_pos == -1:
2478         return  # There's an internal error
2479
2480     prefix = leaf.value[:first_quote_pos]
2481     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2482     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2483     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2484     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2485     if "r" in prefix.casefold():
2486         if unescaped_new_quote.search(body):
2487             # There's at least one unescaped new_quote in this raw string
2488             # so converting is impossible
2489             return
2490
2491         # Do not introduce or remove backslashes in raw strings
2492         new_body = body
2493     else:
2494         # remove unnecessary escapes
2495         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2496         if body != new_body:
2497             # Consider the string without unnecessary escapes as the original
2498             body = new_body
2499             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2500         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2501         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2502     if "f" in prefix.casefold():
2503         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2504         for m in matches:
2505             if "\\" in str(m):
2506                 # Do not introduce backslashes in interpolated expressions
2507                 return
2508     if new_quote == '"""' and new_body[-1:] == '"':
2509         # edge case:
2510         new_body = new_body[:-1] + '\\"'
2511     orig_escape_count = body.count("\\")
2512     new_escape_count = new_body.count("\\")
2513     if new_escape_count > orig_escape_count:
2514         return  # Do not introduce more escaping
2515
2516     if new_escape_count == orig_escape_count and orig_quote == '"':
2517         return  # Prefer double quotes
2518
2519     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2520
2521
2522 def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
2523     """Normalizes numeric (float, int, and complex) literals.
2524
2525     All letters used in the representation are normalized to lowercase, long number
2526     literals are split using underscores.
2527     """
2528     text = leaf.value.lower()
2529     if text.startswith(("0o", "0x", "0b")):
2530         # Leave octal, hex, and binary literals alone.
2531         pass
2532     elif "e" in text:
2533         before, after = text.split("e")
2534         sign = ""
2535         if after.startswith("-"):
2536             after = after[1:]
2537             sign = "-"
2538         elif after.startswith("+"):
2539             after = after[1:]
2540         before = format_float_or_int_string(before, allow_underscores)
2541         after = format_int_string(after, allow_underscores)
2542         text = f"{before}e{sign}{after}"
2543     elif text.endswith(("j", "l")):
2544         number = text[:-1]
2545         suffix = text[-1]
2546         text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
2547     else:
2548         text = format_float_or_int_string(text, allow_underscores)
2549     leaf.value = text
2550
2551
2552 def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
2553     """Formats a float string like "1.0"."""
2554     if "." not in text:
2555         return format_int_string(text, allow_underscores)
2556
2557     before, after = text.split(".")
2558     before = format_int_string(before, allow_underscores) if before else "0"
2559     after = format_int_string(after, allow_underscores) if after else "0"
2560     return f"{before}.{after}"
2561
2562
2563 def format_int_string(text: str, allow_underscores: bool) -> str:
2564     """Normalizes underscores in a string to e.g. 1_000_000.
2565
2566     Input must be a string of at least six digits and optional underscores.
2567     """
2568     if not allow_underscores:
2569         return text
2570
2571     text = text.replace("_", "")
2572     if len(text) <= 6:
2573         # No underscores for numbers <= 6 digits long.
2574         return text
2575
2576     # Avoid removing leading zeros, which are important if we're formatting
2577     # part of a number like "0.001".
2578     return format(int("1" + text), "3_")[1:].lstrip("_")
2579
2580
2581 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2582     """Make existing optional parentheses invisible or create new ones.
2583
2584     `parens_after` is a set of string leaf values immeditely after which parens
2585     should be put.
2586
2587     Standardizes on visible parentheses for single-element tuples, and keeps
2588     existing visible parentheses for other tuples and generator expressions.
2589     """
2590     for pc in list_comments(node.prefix, is_endmarker=False):
2591         if pc.value in FMT_OFF:
2592             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2593             return
2594
2595     check_lpar = False
2596     for index, child in enumerate(list(node.children)):
2597         if check_lpar:
2598             if child.type == syms.atom:
2599                 if maybe_make_parens_invisible_in_atom(child):
2600                     lpar = Leaf(token.LPAR, "")
2601                     rpar = Leaf(token.RPAR, "")
2602                     index = child.remove() or 0
2603                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2604             elif is_one_tuple(child):
2605                 # wrap child in visible parentheses
2606                 lpar = Leaf(token.LPAR, "(")
2607                 rpar = Leaf(token.RPAR, ")")
2608                 child.remove()
2609                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2610             elif node.type == syms.import_from:
2611                 # "import from" nodes store parentheses directly as part of
2612                 # the statement
2613                 if child.type == token.LPAR:
2614                     # make parentheses invisible
2615                     child.value = ""  # type: ignore
2616                     node.children[-1].value = ""  # type: ignore
2617                 elif child.type != token.STAR:
2618                     # insert invisible parentheses
2619                     node.insert_child(index, Leaf(token.LPAR, ""))
2620                     node.append_child(Leaf(token.RPAR, ""))
2621                 break
2622
2623             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2624                 # wrap child in invisible parentheses
2625                 lpar = Leaf(token.LPAR, "")
2626                 rpar = Leaf(token.RPAR, "")
2627                 index = child.remove() or 0
2628                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2629
2630         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2631
2632
2633 def normalize_fmt_off(node: Node) -> None:
2634     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2635     try_again = True
2636     while try_again:
2637         try_again = convert_one_fmt_off_pair(node)
2638
2639
2640 def convert_one_fmt_off_pair(node: Node) -> bool:
2641     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2642
2643     Returns True if a pair was converted.
2644     """
2645     for leaf in node.leaves():
2646         previous_consumed = 0
2647         for comment in list_comments(leaf.prefix, is_endmarker=False):
2648             if comment.value in FMT_OFF:
2649                 # We only want standalone comments. If there's no previous leaf or
2650                 # the previous leaf is indentation, it's a standalone comment in
2651                 # disguise.
2652                 if comment.type != STANDALONE_COMMENT:
2653                     prev = preceding_leaf(leaf)
2654                     if prev and prev.type not in WHITESPACE:
2655                         continue
2656
2657                 ignored_nodes = list(generate_ignored_nodes(leaf))
2658                 if not ignored_nodes:
2659                     continue
2660
2661                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2662                 parent = first.parent
2663                 prefix = first.prefix
2664                 first.prefix = prefix[comment.consumed :]
2665                 hidden_value = (
2666                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2667                 )
2668                 if hidden_value.endswith("\n"):
2669                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2670                     # leaf (possibly followed by a DEDENT).
2671                     hidden_value = hidden_value[:-1]
2672                 first_idx = None
2673                 for ignored in ignored_nodes:
2674                     index = ignored.remove()
2675                     if first_idx is None:
2676                         first_idx = index
2677                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2678                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2679                 parent.insert_child(
2680                     first_idx,
2681                     Leaf(
2682                         STANDALONE_COMMENT,
2683                         hidden_value,
2684                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2685                     ),
2686                 )
2687                 return True
2688
2689             previous_consumed = comment.consumed
2690
2691     return False
2692
2693
2694 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2695     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2696
2697     Stops at the end of the block.
2698     """
2699     container: Optional[LN] = container_of(leaf)
2700     while container is not None and container.type != token.ENDMARKER:
2701         for comment in list_comments(container.prefix, is_endmarker=False):
2702             if comment.value in FMT_ON:
2703                 return
2704
2705         yield container
2706
2707         container = container.next_sibling
2708
2709
2710 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2711     """If it's safe, make the parens in the atom `node` invisible, recursively.
2712
2713     Returns whether the node should itself be wrapped in invisible parentheses.
2714
2715     """
2716     if (
2717         node.type != syms.atom
2718         or is_empty_tuple(node)
2719         or is_one_tuple(node)
2720         or is_yield(node)
2721         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2722     ):
2723         return False
2724
2725     first = node.children[0]
2726     last = node.children[-1]
2727     if first.type == token.LPAR and last.type == token.RPAR:
2728         # make parentheses invisible
2729         first.value = ""  # type: ignore
2730         last.value = ""  # type: ignore
2731         if len(node.children) > 1:
2732             maybe_make_parens_invisible_in_atom(node.children[1])
2733         return False
2734
2735     return True
2736
2737
2738 def is_empty_tuple(node: LN) -> bool:
2739     """Return True if `node` holds an empty tuple."""
2740     return (
2741         node.type == syms.atom
2742         and len(node.children) == 2
2743         and node.children[0].type == token.LPAR
2744         and node.children[1].type == token.RPAR
2745     )
2746
2747
2748 def is_one_tuple(node: LN) -> bool:
2749     """Return True if `node` holds a tuple with one element, with or without parens."""
2750     if node.type == syms.atom:
2751         if len(node.children) != 3:
2752             return False
2753
2754         lpar, gexp, rpar = node.children
2755         if not (
2756             lpar.type == token.LPAR
2757             and gexp.type == syms.testlist_gexp
2758             and rpar.type == token.RPAR
2759         ):
2760             return False
2761
2762         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2763
2764     return (
2765         node.type in IMPLICIT_TUPLE
2766         and len(node.children) == 2
2767         and node.children[1].type == token.COMMA
2768     )
2769
2770
2771 def is_yield(node: LN) -> bool:
2772     """Return True if `node` holds a `yield` or `yield from` expression."""
2773     if node.type == syms.yield_expr:
2774         return True
2775
2776     if node.type == token.NAME and node.value == "yield":  # type: ignore
2777         return True
2778
2779     if node.type != syms.atom:
2780         return False
2781
2782     if len(node.children) != 3:
2783         return False
2784
2785     lpar, expr, rpar = node.children
2786     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2787         return is_yield(expr)
2788
2789     return False
2790
2791
2792 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2793     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2794
2795     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2796     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2797     extended iterable unpacking (PEP 3132) and additional unpacking
2798     generalizations (PEP 448).
2799     """
2800     if leaf.type not in STARS or not leaf.parent:
2801         return False
2802
2803     p = leaf.parent
2804     if p.type == syms.star_expr:
2805         # Star expressions are also used as assignment targets in extended
2806         # iterable unpacking (PEP 3132).  See what its parent is instead.
2807         if not p.parent:
2808             return False
2809
2810         p = p.parent
2811
2812     return p.type in within
2813
2814
2815 def is_multiline_string(leaf: Leaf) -> bool:
2816     """Return True if `leaf` is a multiline string that actually spans many lines."""
2817     value = leaf.value.lstrip("furbFURB")
2818     return value[:3] in {'"""', "'''"} and "\n" in value
2819
2820
2821 def is_stub_suite(node: Node) -> bool:
2822     """Return True if `node` is a suite with a stub body."""
2823     if (
2824         len(node.children) != 4
2825         or node.children[0].type != token.NEWLINE
2826         or node.children[1].type != token.INDENT
2827         or node.children[3].type != token.DEDENT
2828     ):
2829         return False
2830
2831     return is_stub_body(node.children[2])
2832
2833
2834 def is_stub_body(node: LN) -> bool:
2835     """Return True if `node` is a simple statement containing an ellipsis."""
2836     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2837         return False
2838
2839     if len(node.children) != 2:
2840         return False
2841
2842     child = node.children[0]
2843     return (
2844         child.type == syms.atom
2845         and len(child.children) == 3
2846         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2847     )
2848
2849
2850 def max_delimiter_priority_in_atom(node: LN) -> int:
2851     """Return maximum delimiter priority inside `node`.
2852
2853     This is specific to atoms with contents contained in a pair of parentheses.
2854     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2855     """
2856     if node.type != syms.atom:
2857         return 0
2858
2859     first = node.children[0]
2860     last = node.children[-1]
2861     if not (first.type == token.LPAR and last.type == token.RPAR):
2862         return 0
2863
2864     bt = BracketTracker()
2865     for c in node.children[1:-1]:
2866         if isinstance(c, Leaf):
2867             bt.mark(c)
2868         else:
2869             for leaf in c.leaves():
2870                 bt.mark(leaf)
2871     try:
2872         return bt.max_delimiter_priority()
2873
2874     except ValueError:
2875         return 0
2876
2877
2878 def ensure_visible(leaf: Leaf) -> None:
2879     """Make sure parentheses are visible.
2880
2881     They could be invisible as part of some statements (see
2882     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2883     """
2884     if leaf.type == token.LPAR:
2885         leaf.value = "("
2886     elif leaf.type == token.RPAR:
2887         leaf.value = ")"
2888
2889
2890 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2891     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2892     if not (
2893         opening_bracket.parent
2894         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2895         and opening_bracket.value in "[{("
2896     ):
2897         return False
2898
2899     try:
2900         last_leaf = line.leaves[-1]
2901         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2902         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2903     except (IndexError, ValueError):
2904         return False
2905
2906     return max_priority == COMMA_PRIORITY
2907
2908
2909 def is_python36(node: Node) -> bool:
2910     """Return True if the current file is using Python 3.6+ features.
2911
2912     Currently looking for:
2913     - f-strings;
2914     - underscores in numeric literals; and
2915     - trailing commas after * or ** in function signatures and calls.
2916     """
2917     for n in node.pre_order():
2918         if n.type == token.STRING:
2919             value_head = n.value[:2]  # type: ignore
2920             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2921                 return True
2922
2923         elif n.type == token.NUMBER:
2924             if "_" in n.value:  # type: ignore
2925                 return True
2926
2927         elif (
2928             n.type in {syms.typedargslist, syms.arglist}
2929             and n.children
2930             and n.children[-1].type == token.COMMA
2931         ):
2932             for ch in n.children:
2933                 if ch.type in STARS:
2934                     return True
2935
2936                 if ch.type == syms.argument:
2937                     for argch in ch.children:
2938                         if argch.type in STARS:
2939                             return True
2940
2941     return False
2942
2943
2944 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2945     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2946
2947     Brackets can be omitted if the entire trailer up to and including
2948     a preceding closing bracket fits in one line.
2949
2950     Yielded sets are cumulative (contain results of previous yields, too).  First
2951     set is empty.
2952     """
2953
2954     omit: Set[LeafID] = set()
2955     yield omit
2956
2957     length = 4 * line.depth
2958     opening_bracket = None
2959     closing_bracket = None
2960     optional_brackets: Set[LeafID] = set()
2961     inner_brackets: Set[LeafID] = set()
2962     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2963         length += leaf_length
2964         if length > line_length:
2965             break
2966
2967         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2968         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2969             break
2970
2971         optional_brackets.discard(id(leaf))
2972         if opening_bracket:
2973             if leaf is opening_bracket:
2974                 opening_bracket = None
2975             elif leaf.type in CLOSING_BRACKETS:
2976                 inner_brackets.add(id(leaf))
2977         elif leaf.type in CLOSING_BRACKETS:
2978             if not leaf.value:
2979                 optional_brackets.add(id(opening_bracket))
2980                 continue
2981
2982             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2983                 # Empty brackets would fail a split so treat them as "inner"
2984                 # brackets (e.g. only add them to the `omit` set if another
2985                 # pair of brackets was good enough.
2986                 inner_brackets.add(id(leaf))
2987                 continue
2988
2989             opening_bracket = leaf.opening_bracket
2990             if closing_bracket:
2991                 omit.add(id(closing_bracket))
2992                 omit.update(inner_brackets)
2993                 inner_brackets.clear()
2994                 yield omit
2995             closing_bracket = leaf
2996
2997
2998 def get_future_imports(node: Node) -> Set[str]:
2999     """Return a set of __future__ imports in the file."""
3000     imports: Set[str] = set()
3001
3002     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3003         for child in children:
3004             if isinstance(child, Leaf):
3005                 if child.type == token.NAME:
3006                     yield child.value
3007             elif child.type == syms.import_as_name:
3008                 orig_name = child.children[0]
3009                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3010                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3011                 yield orig_name.value
3012             elif child.type == syms.import_as_names:
3013                 yield from get_imports_from_children(child.children)
3014             else:
3015                 assert False, "Invalid syntax parsing imports"
3016
3017     for child in node.children:
3018         if child.type != syms.simple_stmt:
3019             break
3020         first_child = child.children[0]
3021         if isinstance(first_child, Leaf):
3022             # Continue looking if we see a docstring; otherwise stop.
3023             if (
3024                 len(child.children) == 2
3025                 and first_child.type == token.STRING
3026                 and child.children[1].type == token.NEWLINE
3027             ):
3028                 continue
3029             else:
3030                 break
3031         elif first_child.type == syms.import_from:
3032             module_name = first_child.children[1]
3033             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3034                 break
3035             imports |= set(get_imports_from_children(first_child.children[3:]))
3036         else:
3037             break
3038     return imports
3039
3040
3041 def gen_python_files_in_dir(
3042     path: Path,
3043     root: Path,
3044     include: Pattern[str],
3045     exclude: Pattern[str],
3046     report: "Report",
3047 ) -> Iterator[Path]:
3048     """Generate all files under `path` whose paths are not excluded by the
3049     `exclude` regex, but are included by the `include` regex.
3050
3051     Symbolic links pointing outside of the `root` directory are ignored.
3052
3053     `report` is where output about exclusions goes.
3054     """
3055     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3056     for child in path.iterdir():
3057         try:
3058             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3059         except ValueError:
3060             if child.is_symlink():
3061                 report.path_ignored(
3062                     child, f"is a symbolic link that points outside {root}"
3063                 )
3064                 continue
3065
3066             raise
3067
3068         if child.is_dir():
3069             normalized_path += "/"
3070         exclude_match = exclude.search(normalized_path)
3071         if exclude_match and exclude_match.group(0):
3072             report.path_ignored(child, f"matches the --exclude regular expression")
3073             continue
3074
3075         if child.is_dir():
3076             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3077
3078         elif child.is_file():
3079             include_match = include.search(normalized_path)
3080             if include_match:
3081                 yield child
3082
3083
3084 @lru_cache()
3085 def find_project_root(srcs: Iterable[str]) -> Path:
3086     """Return a directory containing .git, .hg, or pyproject.toml.
3087
3088     That directory can be one of the directories passed in `srcs` or their
3089     common parent.
3090
3091     If no directory in the tree contains a marker that would specify it's the
3092     project root, the root of the file system is returned.
3093     """
3094     if not srcs:
3095         return Path("/").resolve()
3096
3097     common_base = min(Path(src).resolve() for src in srcs)
3098     if common_base.is_dir():
3099         # Append a fake file so `parents` below returns `common_base_dir`, too.
3100         common_base /= "fake-file"
3101     for directory in common_base.parents:
3102         if (directory / ".git").is_dir():
3103             return directory
3104
3105         if (directory / ".hg").is_dir():
3106             return directory
3107
3108         if (directory / "pyproject.toml").is_file():
3109             return directory
3110
3111     return directory
3112
3113
3114 @dataclass
3115 class Report:
3116     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3117
3118     check: bool = False
3119     quiet: bool = False
3120     verbose: bool = False
3121     change_count: int = 0
3122     same_count: int = 0
3123     failure_count: int = 0
3124
3125     def done(self, src: Path, changed: Changed) -> None:
3126         """Increment the counter for successful reformatting. Write out a message."""
3127         if changed is Changed.YES:
3128             reformatted = "would reformat" if self.check else "reformatted"
3129             if self.verbose or not self.quiet:
3130                 out(f"{reformatted} {src}")
3131             self.change_count += 1
3132         else:
3133             if self.verbose:
3134                 if changed is Changed.NO:
3135                     msg = f"{src} already well formatted, good job."
3136                 else:
3137                     msg = f"{src} wasn't modified on disk since last run."
3138                 out(msg, bold=False)
3139             self.same_count += 1
3140
3141     def failed(self, src: Path, message: str) -> None:
3142         """Increment the counter for failed reformatting. Write out a message."""
3143         err(f"error: cannot format {src}: {message}")
3144         self.failure_count += 1
3145
3146     def path_ignored(self, path: Path, message: str) -> None:
3147         if self.verbose:
3148             out(f"{path} ignored: {message}", bold=False)
3149
3150     @property
3151     def return_code(self) -> int:
3152         """Return the exit code that the app should use.
3153
3154         This considers the current state of changed files and failures:
3155         - if there were any failures, return 123;
3156         - if any files were changed and --check is being used, return 1;
3157         - otherwise return 0.
3158         """
3159         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3160         # 126 we have special return codes reserved by the shell.
3161         if self.failure_count:
3162             return 123
3163
3164         elif self.change_count and self.check:
3165             return 1
3166
3167         return 0
3168
3169     def __str__(self) -> str:
3170         """Render a color report of the current state.
3171
3172         Use `click.unstyle` to remove colors.
3173         """
3174         if self.check:
3175             reformatted = "would be reformatted"
3176             unchanged = "would be left unchanged"
3177             failed = "would fail to reformat"
3178         else:
3179             reformatted = "reformatted"
3180             unchanged = "left unchanged"
3181             failed = "failed to reformat"
3182         report = []
3183         if self.change_count:
3184             s = "s" if self.change_count > 1 else ""
3185             report.append(
3186                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3187             )
3188         if self.same_count:
3189             s = "s" if self.same_count > 1 else ""
3190             report.append(f"{self.same_count} file{s} {unchanged}")
3191         if self.failure_count:
3192             s = "s" if self.failure_count > 1 else ""
3193             report.append(
3194                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3195             )
3196         return ", ".join(report) + "."
3197
3198
3199 def assert_equivalent(src: str, dst: str) -> None:
3200     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3201
3202     import ast
3203     import traceback
3204
3205     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3206         """Simple visitor generating strings to compare ASTs by content."""
3207         yield f"{'  ' * depth}{node.__class__.__name__}("
3208
3209         for field in sorted(node._fields):
3210             try:
3211                 value = getattr(node, field)
3212             except AttributeError:
3213                 continue
3214
3215             yield f"{'  ' * (depth+1)}{field}="
3216
3217             if isinstance(value, list):
3218                 for item in value:
3219                     if isinstance(item, ast.AST):
3220                         yield from _v(item, depth + 2)
3221
3222             elif isinstance(value, ast.AST):
3223                 yield from _v(value, depth + 2)
3224
3225             else:
3226                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3227
3228         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3229
3230     try:
3231         src_ast = ast.parse(src)
3232     except Exception as exc:
3233         major, minor = sys.version_info[:2]
3234         raise AssertionError(
3235             f"cannot use --safe with this file; failed to parse source file "
3236             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3237             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3238         )
3239
3240     try:
3241         dst_ast = ast.parse(dst)
3242     except Exception as exc:
3243         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3244         raise AssertionError(
3245             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3246             f"Please report a bug on https://github.com/ambv/black/issues.  "
3247             f"This invalid output might be helpful: {log}"
3248         ) from None
3249
3250     src_ast_str = "\n".join(_v(src_ast))
3251     dst_ast_str = "\n".join(_v(dst_ast))
3252     if src_ast_str != dst_ast_str:
3253         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3254         raise AssertionError(
3255             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3256             f"the source.  "
3257             f"Please report a bug on https://github.com/ambv/black/issues.  "
3258             f"This diff might be helpful: {log}"
3259         ) from None
3260
3261
3262 def assert_stable(
3263     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3264 ) -> None:
3265     """Raise AssertionError if `dst` reformats differently the second time."""
3266     newdst = format_str(dst, line_length=line_length, mode=mode)
3267     if dst != newdst:
3268         log = dump_to_file(
3269             diff(src, dst, "source", "first pass"),
3270             diff(dst, newdst, "first pass", "second pass"),
3271         )
3272         raise AssertionError(
3273             f"INTERNAL ERROR: Black produced different code on the second pass "
3274             f"of the formatter.  "
3275             f"Please report a bug on https://github.com/ambv/black/issues.  "
3276             f"This diff might be helpful: {log}"
3277         ) from None
3278
3279
3280 def dump_to_file(*output: str) -> str:
3281     """Dump `output` to a temporary file. Return path to the file."""
3282     import tempfile
3283
3284     with tempfile.NamedTemporaryFile(
3285         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3286     ) as f:
3287         for lines in output:
3288             f.write(lines)
3289             if lines and lines[-1] != "\n":
3290                 f.write("\n")
3291     return f.name
3292
3293
3294 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3295     """Return a unified diff string between strings `a` and `b`."""
3296     import difflib
3297
3298     a_lines = [line + "\n" for line in a.split("\n")]
3299     b_lines = [line + "\n" for line in b.split("\n")]
3300     return "".join(
3301         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3302     )
3303
3304
3305 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3306     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3307     err("Aborted!")
3308     for task in tasks:
3309         task.cancel()
3310
3311
3312 def shutdown(loop: BaseEventLoop) -> None:
3313     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3314     try:
3315         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3316         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3317         if not to_cancel:
3318             return
3319
3320         for task in to_cancel:
3321             task.cancel()
3322         loop.run_until_complete(
3323             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3324         )
3325     finally:
3326         # `concurrent.futures.Future` objects cannot be cancelled once they
3327         # are already running. There might be some when the `shutdown()` happened.
3328         # Silence their logger's spew about the event loop being closed.
3329         cf_logger = logging.getLogger("concurrent.futures")
3330         cf_logger.setLevel(logging.CRITICAL)
3331         loop.close()
3332
3333
3334 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3335     """Replace `regex` with `replacement` twice on `original`.
3336
3337     This is used by string normalization to perform replaces on
3338     overlapping matches.
3339     """
3340     return regex.sub(replacement, regex.sub(replacement, original))
3341
3342
3343 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3344     """Compile a regular expression string in `regex`.
3345
3346     If it contains newlines, use verbose mode.
3347     """
3348     if "\n" in regex:
3349         regex = "(?x)" + regex
3350     return re.compile(regex)
3351
3352
3353 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3354     """Like `reversed(enumerate(sequence))` if that were possible."""
3355     index = len(sequence) - 1
3356     for element in reversed(sequence):
3357         yield (index, element)
3358         index -= 1
3359
3360
3361 def enumerate_with_length(
3362     line: Line, reversed: bool = False
3363 ) -> Iterator[Tuple[Index, Leaf, int]]:
3364     """Return an enumeration of leaves with their length.
3365
3366     Stops prematurely on multiline strings and standalone comments.
3367     """
3368     op = cast(
3369         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3370         enumerate_reversed if reversed else enumerate,
3371     )
3372     for index, leaf in op(line.leaves):
3373         length = len(leaf.prefix) + len(leaf.value)
3374         if "\n" in leaf.value:
3375             return  # Multiline strings, we can't continue.
3376
3377         comment: Optional[Leaf]
3378         for comment in line.comments_after(leaf, index):
3379             length += len(comment.value)
3380
3381         yield index, leaf, length
3382
3383
3384 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3385     """Return True if `line` is no longer than `line_length`.
3386
3387     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3388     """
3389     if not line_str:
3390         line_str = str(line).strip("\n")
3391     return (
3392         len(line_str) <= line_length
3393         and "\n" not in line_str  # multiline strings
3394         and not line.contains_standalone_comments()
3395     )
3396
3397
3398 def can_be_split(line: Line) -> bool:
3399     """Return False if the line cannot be split *for sure*.
3400
3401     This is not an exhaustive search but a cheap heuristic that we can use to
3402     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3403     in unnecessary parentheses).
3404     """
3405     leaves = line.leaves
3406     if len(leaves) < 2:
3407         return False
3408
3409     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3410         call_count = 0
3411         dot_count = 0
3412         next = leaves[-1]
3413         for leaf in leaves[-2::-1]:
3414             if leaf.type in OPENING_BRACKETS:
3415                 if next.type not in CLOSING_BRACKETS:
3416                     return False
3417
3418                 call_count += 1
3419             elif leaf.type == token.DOT:
3420                 dot_count += 1
3421             elif leaf.type == token.NAME:
3422                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3423                     return False
3424
3425             elif leaf.type not in CLOSING_BRACKETS:
3426                 return False
3427
3428             if dot_count > 1 and call_count > 1:
3429                 return False
3430
3431     return True
3432
3433
3434 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3435     """Does `line` have a shape safe to reformat without optional parens around it?
3436
3437     Returns True for only a subset of potentially nice looking formattings but
3438     the point is to not return false positives that end up producing lines that
3439     are too long.
3440     """
3441     bt = line.bracket_tracker
3442     if not bt.delimiters:
3443         # Without delimiters the optional parentheses are useless.
3444         return True
3445
3446     max_priority = bt.max_delimiter_priority()
3447     if bt.delimiter_count_with_priority(max_priority) > 1:
3448         # With more than one delimiter of a kind the optional parentheses read better.
3449         return False
3450
3451     if max_priority == DOT_PRIORITY:
3452         # A single stranded method call doesn't require optional parentheses.
3453         return True
3454
3455     assert len(line.leaves) >= 2, "Stranded delimiter"
3456
3457     first = line.leaves[0]
3458     second = line.leaves[1]
3459     penultimate = line.leaves[-2]
3460     last = line.leaves[-1]
3461
3462     # With a single delimiter, omit if the expression starts or ends with
3463     # a bracket.
3464     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3465         remainder = False
3466         length = 4 * line.depth
3467         for _index, leaf, leaf_length in enumerate_with_length(line):
3468             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3469                 remainder = True
3470             if remainder:
3471                 length += leaf_length
3472                 if length > line_length:
3473                     break
3474
3475                 if leaf.type in OPENING_BRACKETS:
3476                     # There are brackets we can further split on.
3477                     remainder = False
3478
3479         else:
3480             # checked the entire string and line length wasn't exceeded
3481             if len(line.leaves) == _index + 1:
3482                 return True
3483
3484         # Note: we are not returning False here because a line might have *both*
3485         # a leading opening bracket and a trailing closing bracket.  If the
3486         # opening bracket doesn't match our rule, maybe the closing will.
3487
3488     if (
3489         last.type == token.RPAR
3490         or last.type == token.RBRACE
3491         or (
3492             # don't use indexing for omitting optional parentheses;
3493             # it looks weird
3494             last.type == token.RSQB
3495             and last.parent
3496             and last.parent.type != syms.trailer
3497         )
3498     ):
3499         if penultimate.type in OPENING_BRACKETS:
3500             # Empty brackets don't help.
3501             return False
3502
3503         if is_multiline_string(first):
3504             # Additional wrapping of a multiline string in this situation is
3505             # unnecessary.
3506             return True
3507
3508         length = 4 * line.depth
3509         seen_other_brackets = False
3510         for _index, leaf, leaf_length in enumerate_with_length(line):
3511             length += leaf_length
3512             if leaf is last.opening_bracket:
3513                 if seen_other_brackets or length <= line_length:
3514                     return True
3515
3516             elif leaf.type in OPENING_BRACKETS:
3517                 # There are brackets we can further split on.
3518                 seen_other_brackets = True
3519
3520     return False
3521
3522
3523 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3524     return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
3525
3526
3527 def read_cache(line_length: int, mode: FileMode) -> Cache:
3528     """Read the cache if it exists and is well formed.
3529
3530     If it is not well formed, the call to write_cache later should resolve the issue.
3531     """
3532     cache_file = get_cache_file(line_length, mode)
3533     if not cache_file.exists():
3534         return {}
3535
3536     with cache_file.open("rb") as fobj:
3537         try:
3538             cache: Cache = pickle.load(fobj)
3539         except pickle.UnpicklingError:
3540             return {}
3541
3542     return cache
3543
3544
3545 def get_cache_info(path: Path) -> CacheInfo:
3546     """Return the information used to check if a file is already formatted or not."""
3547     stat = path.stat()
3548     return stat.st_mtime, stat.st_size
3549
3550
3551 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3552     """Split an iterable of paths in `sources` into two sets.
3553
3554     The first contains paths of files that modified on disk or are not in the
3555     cache. The other contains paths to non-modified files.
3556     """
3557     todo, done = set(), set()
3558     for src in sources:
3559         src = src.resolve()
3560         if cache.get(src) != get_cache_info(src):
3561             todo.add(src)
3562         else:
3563             done.add(src)
3564     return todo, done
3565
3566
3567 def write_cache(
3568     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3569 ) -> None:
3570     """Update the cache file."""
3571     cache_file = get_cache_file(line_length, mode)
3572     try:
3573         if not CACHE_DIR.exists():
3574             CACHE_DIR.mkdir(parents=True)
3575         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3576         with cache_file.open("wb") as fobj:
3577             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3578     except OSError:
3579         pass
3580
3581
3582 def patch_click() -> None:
3583     """Make Click not crash.
3584
3585     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3586     default which restricts paths that it can access during the lifetime of the
3587     application.  Click refuses to work in this scenario by raising a RuntimeError.
3588
3589     In case of Black the likelihood that non-ASCII characters are going to be used in
3590     file paths is minimal since it's Python source code.  Moreover, this crash was
3591     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3592     """
3593     try:
3594         from click import core
3595         from click import _unicodefun  # type: ignore
3596     except ModuleNotFoundError:
3597         return
3598
3599     for module in (core, _unicodefun):
3600         if hasattr(module, "_verify_python3_env"):
3601             module._verify_python3_env = lambda: None
3602
3603
3604 if __name__ == "__main__":
3605     patch_click()
3606     main()