]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

7823ae0afe2e809bdb2f5ca208ec9cc0565b4ced
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33 )
34
35 from appdirs import user_cache_dir
36 from attr import dataclass, Factory
37 import click
38
39 # lib2to3 fork
40 from blib2to3.pytree import Node, Leaf, type_repr
41 from blib2to3 import pygram, pytree
42 from blib2to3.pgen2 import driver, token
43 from blib2to3.pgen2.parse import ParseError
44
45
46 __version__ = "18.4a6"
47 DEFAULT_LINE_LENGTH = 88
48
49 # types
50 syms = pygram.python_symbols
51 FileContent = str
52 Encoding = str
53 Depth = int
54 NodeType = int
55 LeafID = int
56 Priority = int
57 Index = int
58 LN = Union[Leaf, Node]
59 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
60 Timestamp = float
61 FileSize = int
62 CacheInfo = Tuple[Timestamp, FileSize]
63 Cache = Dict[Path, CacheInfo]
64 out = partial(click.secho, bold=True, err=True)
65 err = partial(click.secho, fg="red", err=True)
66
67
68 class NothingChanged(UserWarning):
69     """Raised by :func:`format_file` when reformatted code is the same as source."""
70
71
72 class CannotSplit(Exception):
73     """A readable split that fits the allotted line length is impossible.
74
75     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
76     :func:`delimiter_split`.
77     """
78
79
80 class FormatError(Exception):
81     """Base exception for `# fmt: on` and `# fmt: off` handling.
82
83     It holds the number of bytes of the prefix consumed before the format
84     control comment appeared.
85     """
86
87     def __init__(self, consumed: int) -> None:
88         super().__init__(consumed)
89         self.consumed = consumed
90
91     def trim_prefix(self, leaf: Leaf) -> None:
92         leaf.prefix = leaf.prefix[self.consumed :]
93
94     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
95         """Returns a new Leaf from the consumed part of the prefix."""
96         unformatted_prefix = leaf.prefix[: self.consumed]
97         return Leaf(token.NEWLINE, unformatted_prefix)
98
99
100 class FormatOn(FormatError):
101     """Found a comment like `# fmt: on` in the file."""
102
103
104 class FormatOff(FormatError):
105     """Found a comment like `# fmt: off` in the file."""
106
107
108 class WriteBack(Enum):
109     NO = 0
110     YES = 1
111     DIFF = 2
112
113
114 class Changed(Enum):
115     NO = 0
116     CACHED = 1
117     YES = 2
118
119
120 @click.command()
121 @click.option(
122     "-l",
123     "--line-length",
124     type=int,
125     default=DEFAULT_LINE_LENGTH,
126     help="How many character per line to allow.",
127     show_default=True,
128 )
129 @click.option(
130     "--check",
131     is_flag=True,
132     help=(
133         "Don't write the files back, just return the status.  Return code 0 "
134         "means nothing would change.  Return code 1 means some files would be "
135         "reformatted.  Return code 123 means there was an internal error."
136     ),
137 )
138 @click.option(
139     "--diff",
140     is_flag=True,
141     help="Don't write the files back, just output a diff for each file on stdout.",
142 )
143 @click.option(
144     "--fast/--safe",
145     is_flag=True,
146     help="If --fast given, skip temporary sanity checks. [default: --safe]",
147 )
148 @click.option(
149     "-q",
150     "--quiet",
151     is_flag=True,
152     help=(
153         "Don't emit non-error messages to stderr. Errors are still emitted, "
154         "silence those with 2>/dev/null."
155     ),
156 )
157 @click.version_option(version=__version__)
158 @click.argument(
159     "src",
160     nargs=-1,
161     type=click.Path(
162         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
163     ),
164 )
165 @click.pass_context
166 def main(
167     ctx: click.Context,
168     line_length: int,
169     check: bool,
170     diff: bool,
171     fast: bool,
172     quiet: bool,
173     src: List[str],
174 ) -> None:
175     """The uncompromising code formatter."""
176     sources: List[Path] = []
177     for s in src:
178         p = Path(s)
179         if p.is_dir():
180             sources.extend(gen_python_files_in_dir(p))
181         elif p.is_file():
182             # if a file was explicitly given, we don't care about its extension
183             sources.append(p)
184         elif s == "-":
185             sources.append(Path("-"))
186         else:
187             err(f"invalid path: {s}")
188
189     if check and not diff:
190         write_back = WriteBack.NO
191     elif diff:
192         write_back = WriteBack.DIFF
193     else:
194         write_back = WriteBack.YES
195     report = Report(check=check, quiet=quiet)
196     if len(sources) == 0:
197         out("No paths given. Nothing to do 😴")
198         ctx.exit(0)
199         return
200
201     elif len(sources) == 1:
202         reformat_one(sources[0], line_length, fast, write_back, report)
203     else:
204         loop = asyncio.get_event_loop()
205         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
206         try:
207             loop.run_until_complete(
208                 schedule_formatting(
209                     sources, line_length, fast, write_back, report, loop, executor
210                 )
211             )
212         finally:
213             shutdown(loop)
214         if not quiet:
215             out("All done! ✨ 🍰 ✨")
216             click.echo(str(report))
217     ctx.exit(report.return_code)
218
219
220 def reformat_one(
221     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
222 ) -> None:
223     """Reformat a single file under `src` without spawning child processes.
224
225     If `quiet` is True, non-error messages are not output. `line_length`,
226     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
227     """
228     try:
229         changed = Changed.NO
230         if not src.is_file() and str(src) == "-":
231             if format_stdin_to_stdout(
232                 line_length=line_length, fast=fast, write_back=write_back
233             ):
234                 changed = Changed.YES
235         else:
236             cache: Cache = {}
237             if write_back != WriteBack.DIFF:
238                 cache = read_cache(line_length)
239                 src = src.resolve()
240                 if src in cache and cache[src] == get_cache_info(src):
241                     changed = Changed.CACHED
242             if (
243                 changed is not Changed.CACHED
244                 and format_file_in_place(
245                     src, line_length=line_length, fast=fast, write_back=write_back
246                 )
247             ):
248                 changed = Changed.YES
249             if write_back == WriteBack.YES and changed is not Changed.NO:
250                 write_cache(cache, [src], line_length)
251         report.done(src, changed)
252     except Exception as exc:
253         report.failed(src, str(exc))
254
255
256 async def schedule_formatting(
257     sources: List[Path],
258     line_length: int,
259     fast: bool,
260     write_back: WriteBack,
261     report: "Report",
262     loop: BaseEventLoop,
263     executor: Executor,
264 ) -> None:
265     """Run formatting of `sources` in parallel using the provided `executor`.
266
267     (Use ProcessPoolExecutors for actual parallelism.)
268
269     `line_length`, `write_back`, and `fast` options are passed to
270     :func:`format_file_in_place`.
271     """
272     cache: Cache = {}
273     if write_back != WriteBack.DIFF:
274         cache = read_cache(line_length)
275         sources, cached = filter_cached(cache, sources)
276         for src in cached:
277             report.done(src, Changed.CACHED)
278     cancelled = []
279     formatted = []
280     if sources:
281         lock = None
282         if write_back == WriteBack.DIFF:
283             # For diff output, we need locks to ensure we don't interleave output
284             # from different processes.
285             manager = Manager()
286             lock = manager.Lock()
287         tasks = {
288             src: loop.run_in_executor(
289                 executor, format_file_in_place, src, line_length, fast, write_back, lock
290             )
291             for src in sources
292         }
293         _task_values = list(tasks.values())
294         try:
295             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
296             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
297         except NotImplementedError:
298             # There are no good alternatives for these on Windows
299             pass
300         await asyncio.wait(_task_values)
301         for src, task in tasks.items():
302             if not task.done():
303                 report.failed(src, "timed out, cancelling")
304                 task.cancel()
305                 cancelled.append(task)
306             elif task.cancelled():
307                 cancelled.append(task)
308             elif task.exception():
309                 report.failed(src, str(task.exception()))
310             else:
311                 formatted.append(src)
312                 report.done(src, Changed.YES if task.result() else Changed.NO)
313
314     if cancelled:
315         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
316     if write_back == WriteBack.YES and formatted:
317         write_cache(cache, formatted, line_length)
318
319
320 def format_file_in_place(
321     src: Path,
322     line_length: int,
323     fast: bool,
324     write_back: WriteBack = WriteBack.NO,
325     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
326 ) -> bool:
327     """Format file under `src` path. Return True if changed.
328
329     If `write_back` is True, write reformatted code back to stdout.
330     `line_length` and `fast` options are passed to :func:`format_file_contents`.
331     """
332
333     with tokenize.open(src) as src_buffer:
334         src_contents = src_buffer.read()
335     try:
336         dst_contents = format_file_contents(
337             src_contents, line_length=line_length, fast=fast
338         )
339     except NothingChanged:
340         return False
341
342     if write_back == write_back.YES:
343         with open(src, "w", encoding=src_buffer.encoding) as f:
344             f.write(dst_contents)
345     elif write_back == write_back.DIFF:
346         src_name = f"{src}  (original)"
347         dst_name = f"{src}  (formatted)"
348         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
349         if lock:
350             lock.acquire()
351         try:
352             sys.stdout.write(diff_contents)
353         finally:
354             if lock:
355                 lock.release()
356     return True
357
358
359 def format_stdin_to_stdout(
360     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
361 ) -> bool:
362     """Format file on stdin. Return True if changed.
363
364     If `write_back` is True, write reformatted code back to stdout.
365     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
366     """
367     src = sys.stdin.read()
368     dst = src
369     try:
370         dst = format_file_contents(src, line_length=line_length, fast=fast)
371         return True
372
373     except NothingChanged:
374         return False
375
376     finally:
377         if write_back == WriteBack.YES:
378             sys.stdout.write(dst)
379         elif write_back == WriteBack.DIFF:
380             src_name = "<stdin>  (original)"
381             dst_name = "<stdin>  (formatted)"
382             sys.stdout.write(diff(src, dst, src_name, dst_name))
383
384
385 def format_file_contents(
386     src_contents: str, line_length: int, fast: bool
387 ) -> FileContent:
388     """Reformat contents a file and return new contents.
389
390     If `fast` is False, additionally confirm that the reformatted code is
391     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
392     `line_length` is passed to :func:`format_str`.
393     """
394     if src_contents.strip() == "":
395         raise NothingChanged
396
397     dst_contents = format_str(src_contents, line_length=line_length)
398     if src_contents == dst_contents:
399         raise NothingChanged
400
401     if not fast:
402         assert_equivalent(src_contents, dst_contents)
403         assert_stable(src_contents, dst_contents, line_length=line_length)
404     return dst_contents
405
406
407 def format_str(src_contents: str, line_length: int) -> FileContent:
408     """Reformat a string and return new contents.
409
410     `line_length` determines how many characters per line are allowed.
411     """
412     src_node = lib2to3_parse(src_contents)
413     dst_contents = ""
414     future_imports = get_future_imports(src_node)
415     py36 = is_python36(src_node)
416     lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
417     elt = EmptyLineTracker()
418     empty_line = Line()
419     after = 0
420     for current_line in lines.visit(src_node):
421         for _ in range(after):
422             dst_contents += str(empty_line)
423         before, after = elt.maybe_empty_lines(current_line)
424         for _ in range(before):
425             dst_contents += str(empty_line)
426         for line in split_line(current_line, line_length=line_length, py36=py36):
427             dst_contents += str(line)
428     return dst_contents
429
430
431 GRAMMARS = [
432     pygram.python_grammar_no_print_statement_no_exec_statement,
433     pygram.python_grammar_no_print_statement,
434     pygram.python_grammar,
435 ]
436
437
438 def lib2to3_parse(src_txt: str) -> Node:
439     """Given a string with source, return the lib2to3 Node."""
440     grammar = pygram.python_grammar_no_print_statement
441     if src_txt[-1] != "\n":
442         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
443         src_txt += nl
444     for grammar in GRAMMARS:
445         drv = driver.Driver(grammar, pytree.convert)
446         try:
447             result = drv.parse_string(src_txt, True)
448             break
449
450         except ParseError as pe:
451             lineno, column = pe.context[1]
452             lines = src_txt.splitlines()
453             try:
454                 faulty_line = lines[lineno - 1]
455             except IndexError:
456                 faulty_line = "<line number missing in source>"
457             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
458     else:
459         raise exc from None
460
461     if isinstance(result, Leaf):
462         result = Node(syms.file_input, [result])
463     return result
464
465
466 def lib2to3_unparse(node: Node) -> str:
467     """Given a lib2to3 node, return its string representation."""
468     code = str(node)
469     return code
470
471
472 T = TypeVar("T")
473
474
475 class Visitor(Generic[T]):
476     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
477
478     def visit(self, node: LN) -> Iterator[T]:
479         """Main method to visit `node` and its children.
480
481         It tries to find a `visit_*()` method for the given `node.type`, like
482         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
483         If no dedicated `visit_*()` method is found, chooses `visit_default()`
484         instead.
485
486         Then yields objects of type `T` from the selected visitor.
487         """
488         if node.type < 256:
489             name = token.tok_name[node.type]
490         else:
491             name = type_repr(node.type)
492         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
493
494     def visit_default(self, node: LN) -> Iterator[T]:
495         """Default `visit_*()` implementation. Recurses to children of `node`."""
496         if isinstance(node, Node):
497             for child in node.children:
498                 yield from self.visit(child)
499
500
501 @dataclass
502 class DebugVisitor(Visitor[T]):
503     tree_depth: int = 0
504
505     def visit_default(self, node: LN) -> Iterator[T]:
506         indent = " " * (2 * self.tree_depth)
507         if isinstance(node, Node):
508             _type = type_repr(node.type)
509             out(f"{indent}{_type}", fg="yellow")
510             self.tree_depth += 1
511             for child in node.children:
512                 yield from self.visit(child)
513
514             self.tree_depth -= 1
515             out(f"{indent}/{_type}", fg="yellow", bold=False)
516         else:
517             _type = token.tok_name.get(node.type, str(node.type))
518             out(f"{indent}{_type}", fg="blue", nl=False)
519             if node.prefix:
520                 # We don't have to handle prefixes for `Node` objects since
521                 # that delegates to the first child anyway.
522                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
523             out(f" {node.value!r}", fg="blue", bold=False)
524
525     @classmethod
526     def show(cls, code: str) -> None:
527         """Pretty-print the lib2to3 AST of a given string of `code`.
528
529         Convenience method for debugging.
530         """
531         v: DebugVisitor[None] = DebugVisitor()
532         list(v.visit(lib2to3_parse(code)))
533
534
535 KEYWORDS = set(keyword.kwlist)
536 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
537 FLOW_CONTROL = {"return", "raise", "break", "continue"}
538 STATEMENT = {
539     syms.if_stmt,
540     syms.while_stmt,
541     syms.for_stmt,
542     syms.try_stmt,
543     syms.except_clause,
544     syms.with_stmt,
545     syms.funcdef,
546     syms.classdef,
547 }
548 STANDALONE_COMMENT = 153
549 LOGIC_OPERATORS = {"and", "or"}
550 COMPARATORS = {
551     token.LESS,
552     token.GREATER,
553     token.EQEQUAL,
554     token.NOTEQUAL,
555     token.LESSEQUAL,
556     token.GREATEREQUAL,
557 }
558 MATH_OPERATORS = {
559     token.VBAR,
560     token.CIRCUMFLEX,
561     token.AMPER,
562     token.LEFTSHIFT,
563     token.RIGHTSHIFT,
564     token.PLUS,
565     token.MINUS,
566     token.STAR,
567     token.SLASH,
568     token.DOUBLESLASH,
569     token.PERCENT,
570     token.AT,
571     token.TILDE,
572     token.DOUBLESTAR,
573 }
574 STARS = {token.STAR, token.DOUBLESTAR}
575 VARARGS_PARENTS = {
576     syms.arglist,
577     syms.argument,  # double star in arglist
578     syms.trailer,  # single argument to call
579     syms.typedargslist,
580     syms.varargslist,  # lambdas
581 }
582 UNPACKING_PARENTS = {
583     syms.atom,  # single element of a list or set literal
584     syms.dictsetmaker,
585     syms.listmaker,
586     syms.testlist_gexp,
587 }
588 TEST_DESCENDANTS = {
589     syms.test,
590     syms.lambdef,
591     syms.or_test,
592     syms.and_test,
593     syms.not_test,
594     syms.comparison,
595     syms.star_expr,
596     syms.expr,
597     syms.xor_expr,
598     syms.and_expr,
599     syms.shift_expr,
600     syms.arith_expr,
601     syms.trailer,
602     syms.term,
603     syms.power,
604 }
605 ASSIGNMENTS = {
606     "=",
607     "+=",
608     "-=",
609     "*=",
610     "@=",
611     "/=",
612     "%=",
613     "&=",
614     "|=",
615     "^=",
616     "<<=",
617     ">>=",
618     "**=",
619     "//=",
620 }
621 COMPREHENSION_PRIORITY = 20
622 COMMA_PRIORITY = 18
623 TERNARY_PRIORITY = 16
624 LOGIC_PRIORITY = 14
625 STRING_PRIORITY = 12
626 COMPARATOR_PRIORITY = 10
627 MATH_PRIORITIES = {
628     token.VBAR: 8,
629     token.CIRCUMFLEX: 7,
630     token.AMPER: 6,
631     token.LEFTSHIFT: 5,
632     token.RIGHTSHIFT: 5,
633     token.PLUS: 4,
634     token.MINUS: 4,
635     token.STAR: 3,
636     token.SLASH: 3,
637     token.DOUBLESLASH: 3,
638     token.PERCENT: 3,
639     token.AT: 3,
640     token.TILDE: 2,
641     token.DOUBLESTAR: 1,
642 }
643
644
645 @dataclass
646 class BracketTracker:
647     """Keeps track of brackets on a line."""
648
649     depth: int = 0
650     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
651     delimiters: Dict[LeafID, Priority] = Factory(dict)
652     previous: Optional[Leaf] = None
653     _for_loop_variable: int = 0
654     _lambda_arguments: int = 0
655
656     def mark(self, leaf: Leaf) -> None:
657         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
658
659         All leaves receive an int `bracket_depth` field that stores how deep
660         within brackets a given leaf is. 0 means there are no enclosing brackets
661         that started on this line.
662
663         If a leaf is itself a closing bracket, it receives an `opening_bracket`
664         field that it forms a pair with. This is a one-directional link to
665         avoid reference cycles.
666
667         If a leaf is a delimiter (a token on which Black can split the line if
668         needed) and it's on depth 0, its `id()` is stored in the tracker's
669         `delimiters` field.
670         """
671         if leaf.type == token.COMMENT:
672             return
673
674         self.maybe_decrement_after_for_loop_variable(leaf)
675         self.maybe_decrement_after_lambda_arguments(leaf)
676         if leaf.type in CLOSING_BRACKETS:
677             self.depth -= 1
678             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
679             leaf.opening_bracket = opening_bracket
680         leaf.bracket_depth = self.depth
681         if self.depth == 0:
682             delim = is_split_before_delimiter(leaf, self.previous)
683             if delim and self.previous is not None:
684                 self.delimiters[id(self.previous)] = delim
685             else:
686                 delim = is_split_after_delimiter(leaf, self.previous)
687                 if delim:
688                     self.delimiters[id(leaf)] = delim
689         if leaf.type in OPENING_BRACKETS:
690             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
691             self.depth += 1
692         self.previous = leaf
693         self.maybe_increment_lambda_arguments(leaf)
694         self.maybe_increment_for_loop_variable(leaf)
695
696     def any_open_brackets(self) -> bool:
697         """Return True if there is an yet unmatched open bracket on the line."""
698         return bool(self.bracket_match)
699
700     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
701         """Return the highest priority of a delimiter found on the line.
702
703         Values are consistent with what `is_split_*_delimiter()` return.
704         Raises ValueError on no delimiters.
705         """
706         return max(v for k, v in self.delimiters.items() if k not in exclude)
707
708     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
709         """In a for loop, or comprehension, the variables are often unpacks.
710
711         To avoid splitting on the comma in this situation, increase the depth of
712         tokens between `for` and `in`.
713         """
714         if leaf.type == token.NAME and leaf.value == "for":
715             self.depth += 1
716             self._for_loop_variable += 1
717             return True
718
719         return False
720
721     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
722         """See `maybe_increment_for_loop_variable` above for explanation."""
723         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
724             self.depth -= 1
725             self._for_loop_variable -= 1
726             return True
727
728         return False
729
730     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
731         """In a lambda expression, there might be more than one argument.
732
733         To avoid splitting on the comma in this situation, increase the depth of
734         tokens between `lambda` and `:`.
735         """
736         if leaf.type == token.NAME and leaf.value == "lambda":
737             self.depth += 1
738             self._lambda_arguments += 1
739             return True
740
741         return False
742
743     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
744         """See `maybe_increment_lambda_arguments` above for explanation."""
745         if self._lambda_arguments and leaf.type == token.COLON:
746             self.depth -= 1
747             self._lambda_arguments -= 1
748             return True
749
750         return False
751
752     def get_open_lsqb(self) -> Optional[Leaf]:
753         """Return the most recent opening square bracket (if any)."""
754         return self.bracket_match.get((self.depth - 1, token.RSQB))
755
756
757 @dataclass
758 class Line:
759     """Holds leaves and comments. Can be printed with `str(line)`."""
760
761     depth: int = 0
762     leaves: List[Leaf] = Factory(list)
763     comments: List[Tuple[Index, Leaf]] = Factory(list)
764     bracket_tracker: BracketTracker = Factory(BracketTracker)
765     inside_brackets: bool = False
766
767     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
768         """Add a new `leaf` to the end of the line.
769
770         Unless `preformatted` is True, the `leaf` will receive a new consistent
771         whitespace prefix and metadata applied by :class:`BracketTracker`.
772         Trailing commas are maybe removed, unpacked for loop variables are
773         demoted from being delimiters.
774
775         Inline comments are put aside.
776         """
777         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
778         if not has_value:
779             return
780
781         if token.COLON == leaf.type and self.is_class_paren_empty:
782             del self.leaves[-2:]
783         if self.leaves and not preformatted:
784             # Note: at this point leaf.prefix should be empty except for
785             # imports, for which we only preserve newlines.
786             leaf.prefix += whitespace(
787                 leaf, complex_subscript=self.is_complex_subscript(leaf)
788             )
789         if self.inside_brackets or not preformatted:
790             self.bracket_tracker.mark(leaf)
791             self.maybe_remove_trailing_comma(leaf)
792         if not self.append_comment(leaf):
793             self.leaves.append(leaf)
794
795     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
796         """Like :func:`append()` but disallow invalid standalone comment structure.
797
798         Raises ValueError when any `leaf` is appended after a standalone comment
799         or when a standalone comment is not the first leaf on the line.
800         """
801         if self.bracket_tracker.depth == 0:
802             if self.is_comment:
803                 raise ValueError("cannot append to standalone comments")
804
805             if self.leaves and leaf.type == STANDALONE_COMMENT:
806                 raise ValueError(
807                     "cannot append standalone comments to a populated line"
808                 )
809
810         self.append(leaf, preformatted=preformatted)
811
812     @property
813     def is_comment(self) -> bool:
814         """Is this line a standalone comment?"""
815         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
816
817     @property
818     def is_decorator(self) -> bool:
819         """Is this line a decorator?"""
820         return bool(self) and self.leaves[0].type == token.AT
821
822     @property
823     def is_import(self) -> bool:
824         """Is this an import line?"""
825         return bool(self) and is_import(self.leaves[0])
826
827     @property
828     def is_class(self) -> bool:
829         """Is this line a class definition?"""
830         return (
831             bool(self)
832             and self.leaves[0].type == token.NAME
833             and self.leaves[0].value == "class"
834         )
835
836     @property
837     def is_def(self) -> bool:
838         """Is this a function definition? (Also returns True for async defs.)"""
839         try:
840             first_leaf = self.leaves[0]
841         except IndexError:
842             return False
843
844         try:
845             second_leaf: Optional[Leaf] = self.leaves[1]
846         except IndexError:
847             second_leaf = None
848         return (
849             (first_leaf.type == token.NAME and first_leaf.value == "def")
850             or (
851                 first_leaf.type == token.ASYNC
852                 and second_leaf is not None
853                 and second_leaf.type == token.NAME
854                 and second_leaf.value == "def"
855             )
856         )
857
858     @property
859     def is_flow_control(self) -> bool:
860         """Is this line a flow control statement?
861
862         Those are `return`, `raise`, `break`, and `continue`.
863         """
864         return (
865             bool(self)
866             and self.leaves[0].type == token.NAME
867             and self.leaves[0].value in FLOW_CONTROL
868         )
869
870     @property
871     def is_yield(self) -> bool:
872         """Is this line a yield statement?"""
873         return (
874             bool(self)
875             and self.leaves[0].type == token.NAME
876             and self.leaves[0].value == "yield"
877         )
878
879     @property
880     def is_class_paren_empty(self) -> bool:
881         """Is this a class with no base classes but using parentheses?
882
883         Those are unnecessary and should be removed.
884         """
885         return (
886             bool(self)
887             and len(self.leaves) == 4
888             and self.is_class
889             and self.leaves[2].type == token.LPAR
890             and self.leaves[2].value == "("
891             and self.leaves[3].type == token.RPAR
892             and self.leaves[3].value == ")"
893         )
894
895     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
896         """If so, needs to be split before emitting."""
897         for leaf in self.leaves:
898             if leaf.type == STANDALONE_COMMENT:
899                 if leaf.bracket_depth <= depth_limit:
900                     return True
901
902         return False
903
904     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
905         """Remove trailing comma if there is one and it's safe."""
906         if not (
907             self.leaves
908             and self.leaves[-1].type == token.COMMA
909             and closing.type in CLOSING_BRACKETS
910         ):
911             return False
912
913         if closing.type == token.RBRACE:
914             self.remove_trailing_comma()
915             return True
916
917         if closing.type == token.RSQB:
918             comma = self.leaves[-1]
919             if comma.parent and comma.parent.type == syms.listmaker:
920                 self.remove_trailing_comma()
921                 return True
922
923         # For parens let's check if it's safe to remove the comma.
924         # Imports are always safe.
925         if self.is_import:
926             self.remove_trailing_comma()
927             return True
928
929         # Otheriwsse, if the trailing one is the only one, we might mistakenly
930         # change a tuple into a different type by removing the comma.
931         depth = closing.bracket_depth + 1
932         commas = 0
933         opening = closing.opening_bracket
934         for _opening_index, leaf in enumerate(self.leaves):
935             if leaf is opening:
936                 break
937
938         else:
939             return False
940
941         for leaf in self.leaves[_opening_index + 1 :]:
942             if leaf is closing:
943                 break
944
945             bracket_depth = leaf.bracket_depth
946             if bracket_depth == depth and leaf.type == token.COMMA:
947                 commas += 1
948                 if leaf.parent and leaf.parent.type == syms.arglist:
949                     commas += 1
950                     break
951
952         if commas > 1:
953             self.remove_trailing_comma()
954             return True
955
956         return False
957
958     def append_comment(self, comment: Leaf) -> bool:
959         """Add an inline or standalone comment to the line."""
960         if (
961             comment.type == STANDALONE_COMMENT
962             and self.bracket_tracker.any_open_brackets()
963         ):
964             comment.prefix = ""
965             return False
966
967         if comment.type != token.COMMENT:
968             return False
969
970         after = len(self.leaves) - 1
971         if after == -1:
972             comment.type = STANDALONE_COMMENT
973             comment.prefix = ""
974             return False
975
976         else:
977             self.comments.append((after, comment))
978             return True
979
980     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
981         """Generate comments that should appear directly after `leaf`.
982
983         Provide a non-negative leaf `_index` to speed up the function.
984         """
985         if _index == -1:
986             for _index, _leaf in enumerate(self.leaves):
987                 if leaf is _leaf:
988                     break
989
990             else:
991                 return
992
993         for index, comment_after in self.comments:
994             if _index == index:
995                 yield comment_after
996
997     def remove_trailing_comma(self) -> None:
998         """Remove the trailing comma and moves the comments attached to it."""
999         comma_index = len(self.leaves) - 1
1000         for i in range(len(self.comments)):
1001             comment_index, comment = self.comments[i]
1002             if comment_index == comma_index:
1003                 self.comments[i] = (comma_index - 1, comment)
1004         self.leaves.pop()
1005
1006     def is_complex_subscript(self, leaf: Leaf) -> bool:
1007         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1008         open_lsqb = (
1009             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1010         )
1011         if open_lsqb is None:
1012             return False
1013
1014         subscript_start = open_lsqb.next_sibling
1015         if (
1016             isinstance(subscript_start, Node)
1017             and subscript_start.type == syms.subscriptlist
1018         ):
1019             subscript_start = child_towards(subscript_start, leaf)
1020         return (
1021             subscript_start is not None
1022             and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
1023         )
1024
1025     def __str__(self) -> str:
1026         """Render the line."""
1027         if not self:
1028             return "\n"
1029
1030         indent = "    " * self.depth
1031         leaves = iter(self.leaves)
1032         first = next(leaves)
1033         res = f"{first.prefix}{indent}{first.value}"
1034         for leaf in leaves:
1035             res += str(leaf)
1036         for _, comment in self.comments:
1037             res += str(comment)
1038         return res + "\n"
1039
1040     def __bool__(self) -> bool:
1041         """Return True if the line has leaves or comments."""
1042         return bool(self.leaves or self.comments)
1043
1044
1045 class UnformattedLines(Line):
1046     """Just like :class:`Line` but stores lines which aren't reformatted."""
1047
1048     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1049         """Just add a new `leaf` to the end of the lines.
1050
1051         The `preformatted` argument is ignored.
1052
1053         Keeps track of indentation `depth`, which is useful when the user
1054         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1055         """
1056         try:
1057             list(generate_comments(leaf))
1058         except FormatOn as f_on:
1059             self.leaves.append(f_on.leaf_from_consumed(leaf))
1060             raise
1061
1062         self.leaves.append(leaf)
1063         if leaf.type == token.INDENT:
1064             self.depth += 1
1065         elif leaf.type == token.DEDENT:
1066             self.depth -= 1
1067
1068     def __str__(self) -> str:
1069         """Render unformatted lines from leaves which were added with `append()`.
1070
1071         `depth` is not used for indentation in this case.
1072         """
1073         if not self:
1074             return "\n"
1075
1076         res = ""
1077         for leaf in self.leaves:
1078             res += str(leaf)
1079         return res
1080
1081     def append_comment(self, comment: Leaf) -> bool:
1082         """Not implemented in this class. Raises `NotImplementedError`."""
1083         raise NotImplementedError("Unformatted lines don't store comments separately.")
1084
1085     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1086         """Does nothing and returns False."""
1087         return False
1088
1089     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1090         """Does nothing and returns False."""
1091         return False
1092
1093
1094 @dataclass
1095 class EmptyLineTracker:
1096     """Provides a stateful method that returns the number of potential extra
1097     empty lines needed before and after the currently processed line.
1098
1099     Note: this tracker works on lines that haven't been split yet.  It assumes
1100     the prefix of the first leaf consists of optional newlines.  Those newlines
1101     are consumed by `maybe_empty_lines()` and included in the computation.
1102     """
1103     previous_line: Optional[Line] = None
1104     previous_after: int = 0
1105     previous_defs: List[int] = Factory(list)
1106
1107     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1108         """Return the number of extra empty lines before and after the `current_line`.
1109
1110         This is for separating `def`, `async def` and `class` with extra empty
1111         lines (two on module-level), as well as providing an extra empty line
1112         after flow control keywords to make them more prominent.
1113         """
1114         if isinstance(current_line, UnformattedLines):
1115             return 0, 0
1116
1117         before, after = self._maybe_empty_lines(current_line)
1118         before -= self.previous_after
1119         self.previous_after = after
1120         self.previous_line = current_line
1121         return before, after
1122
1123     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1124         max_allowed = 1
1125         if current_line.depth == 0:
1126             max_allowed = 2
1127         if current_line.leaves:
1128             # Consume the first leaf's extra newlines.
1129             first_leaf = current_line.leaves[0]
1130             before = first_leaf.prefix.count("\n")
1131             before = min(before, max_allowed)
1132             first_leaf.prefix = ""
1133         else:
1134             before = 0
1135         depth = current_line.depth
1136         while self.previous_defs and self.previous_defs[-1] >= depth:
1137             self.previous_defs.pop()
1138             before = 1 if depth else 2
1139         is_decorator = current_line.is_decorator
1140         if is_decorator or current_line.is_def or current_line.is_class:
1141             if not is_decorator:
1142                 self.previous_defs.append(depth)
1143             if self.previous_line is None:
1144                 # Don't insert empty lines before the first line in the file.
1145                 return 0, 0
1146
1147             if self.previous_line.is_decorator:
1148                 return 0, 0
1149
1150             if (
1151                 self.previous_line.is_comment
1152                 and self.previous_line.depth == current_line.depth
1153                 and before == 0
1154             ):
1155                 return 0, 0
1156
1157             newlines = 2
1158             if current_line.depth:
1159                 newlines -= 1
1160             return newlines, 0
1161
1162         if (
1163             self.previous_line
1164             and self.previous_line.is_import
1165             and not current_line.is_import
1166             and depth == self.previous_line.depth
1167         ):
1168             return (before or 1), 0
1169
1170         return before, 0
1171
1172
1173 @dataclass
1174 class LineGenerator(Visitor[Line]):
1175     """Generates reformatted Line objects.  Empty lines are not emitted.
1176
1177     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1178     in ways that will no longer stringify to valid Python code on the tree.
1179     """
1180     current_line: Line = Factory(Line)
1181     remove_u_prefix: bool = False
1182
1183     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1184         """Generate a line.
1185
1186         If the line is empty, only emit if it makes sense.
1187         If the line is too long, split it first and then generate.
1188
1189         If any lines were generated, set up a new current_line.
1190         """
1191         if not self.current_line:
1192             if self.current_line.__class__ == type:
1193                 self.current_line.depth += indent
1194             else:
1195                 self.current_line = type(depth=self.current_line.depth + indent)
1196             return  # Line is empty, don't emit. Creating a new one unnecessary.
1197
1198         complete_line = self.current_line
1199         self.current_line = type(depth=complete_line.depth + indent)
1200         yield complete_line
1201
1202     def visit(self, node: LN) -> Iterator[Line]:
1203         """Main method to visit `node` and its children.
1204
1205         Yields :class:`Line` objects.
1206         """
1207         if isinstance(self.current_line, UnformattedLines):
1208             # File contained `# fmt: off`
1209             yield from self.visit_unformatted(node)
1210
1211         else:
1212             yield from super().visit(node)
1213
1214     def visit_default(self, node: LN) -> Iterator[Line]:
1215         """Default `visit_*()` implementation. Recurses to children of `node`."""
1216         if isinstance(node, Leaf):
1217             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1218             try:
1219                 for comment in generate_comments(node):
1220                     if any_open_brackets:
1221                         # any comment within brackets is subject to splitting
1222                         self.current_line.append(comment)
1223                     elif comment.type == token.COMMENT:
1224                         # regular trailing comment
1225                         self.current_line.append(comment)
1226                         yield from self.line()
1227
1228                     else:
1229                         # regular standalone comment
1230                         yield from self.line()
1231
1232                         self.current_line.append(comment)
1233                         yield from self.line()
1234
1235             except FormatOff as f_off:
1236                 f_off.trim_prefix(node)
1237                 yield from self.line(type=UnformattedLines)
1238                 yield from self.visit(node)
1239
1240             except FormatOn as f_on:
1241                 # This only happens here if somebody says "fmt: on" multiple
1242                 # times in a row.
1243                 f_on.trim_prefix(node)
1244                 yield from self.visit_default(node)
1245
1246             else:
1247                 normalize_prefix(node, inside_brackets=any_open_brackets)
1248                 if node.type == token.STRING:
1249                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1250                     normalize_string_quotes(node)
1251                 if node.type not in WHITESPACE:
1252                     self.current_line.append(node)
1253         yield from super().visit_default(node)
1254
1255     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1256         """Increase indentation level, maybe yield a line."""
1257         # In blib2to3 INDENT never holds comments.
1258         yield from self.line(+1)
1259         yield from self.visit_default(node)
1260
1261     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1262         """Decrease indentation level, maybe yield a line."""
1263         # The current line might still wait for trailing comments.  At DEDENT time
1264         # there won't be any (they would be prefixes on the preceding NEWLINE).
1265         # Emit the line then.
1266         yield from self.line()
1267
1268         # While DEDENT has no value, its prefix may contain standalone comments
1269         # that belong to the current indentation level.  Get 'em.
1270         yield from self.visit_default(node)
1271
1272         # Finally, emit the dedent.
1273         yield from self.line(-1)
1274
1275     def visit_stmt(
1276         self, node: Node, keywords: Set[str], parens: Set[str]
1277     ) -> Iterator[Line]:
1278         """Visit a statement.
1279
1280         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1281         `def`, `with`, `class`, `assert` and assignments.
1282
1283         The relevant Python language `keywords` for a given statement will be
1284         NAME leaves within it. This methods puts those on a separate line.
1285
1286         `parens` holds a set of string leaf values immeditely after which
1287         invisible parens should be put.
1288         """
1289         normalize_invisible_parens(node, parens_after=parens)
1290         for child in node.children:
1291             if child.type == token.NAME and child.value in keywords:  # type: ignore
1292                 yield from self.line()
1293
1294             yield from self.visit(child)
1295
1296     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1297         """Visit a statement without nested statements."""
1298         is_suite_like = node.parent and node.parent.type in STATEMENT
1299         if is_suite_like:
1300             yield from self.line(+1)
1301             yield from self.visit_default(node)
1302             yield from self.line(-1)
1303
1304         else:
1305             yield from self.line()
1306             yield from self.visit_default(node)
1307
1308     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1309         """Visit `async def`, `async for`, `async with`."""
1310         yield from self.line()
1311
1312         children = iter(node.children)
1313         for child in children:
1314             yield from self.visit(child)
1315
1316             if child.type == token.ASYNC:
1317                 break
1318
1319         internal_stmt = next(children)
1320         for child in internal_stmt.children:
1321             yield from self.visit(child)
1322
1323     def visit_decorators(self, node: Node) -> Iterator[Line]:
1324         """Visit decorators."""
1325         for child in node.children:
1326             yield from self.line()
1327             yield from self.visit(child)
1328
1329     def visit_import_from(self, node: Node) -> Iterator[Line]:
1330         """Visit import_from and maybe put invisible parentheses.
1331
1332         This is separate from `visit_stmt` because import statements don't
1333         support arbitrary atoms and thus handling of parentheses is custom.
1334         """
1335         check_lpar = False
1336         for index, child in enumerate(node.children):
1337             if check_lpar:
1338                 if child.type == token.LPAR:
1339                     # make parentheses invisible
1340                     child.value = ""  # type: ignore
1341                     node.children[-1].value = ""  # type: ignore
1342                 else:
1343                     # insert invisible parentheses
1344                     node.insert_child(index, Leaf(token.LPAR, ""))
1345                     node.append_child(Leaf(token.RPAR, ""))
1346                 break
1347
1348             check_lpar = (
1349                 child.type == token.NAME and child.value == "import"  # type: ignore
1350             )
1351
1352         for child in node.children:
1353             yield from self.visit(child)
1354
1355     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1356         """Remove a semicolon and put the other statement on a separate line."""
1357         yield from self.line()
1358
1359     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1360         """End of file. Process outstanding comments and end with a newline."""
1361         yield from self.visit_default(leaf)
1362         yield from self.line()
1363
1364     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1365         """Used when file contained a `# fmt: off`."""
1366         if isinstance(node, Node):
1367             for child in node.children:
1368                 yield from self.visit(child)
1369
1370         else:
1371             try:
1372                 self.current_line.append(node)
1373             except FormatOn as f_on:
1374                 f_on.trim_prefix(node)
1375                 yield from self.line()
1376                 yield from self.visit(node)
1377
1378             if node.type == token.ENDMARKER:
1379                 # somebody decided not to put a final `# fmt: on`
1380                 yield from self.line()
1381
1382     def __attrs_post_init__(self) -> None:
1383         """You are in a twisty little maze of passages."""
1384         v = self.visit_stmt
1385         Ø: Set[str] = set()
1386         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1387         self.visit_if_stmt = partial(
1388             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1389         )
1390         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1391         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1392         self.visit_try_stmt = partial(
1393             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1394         )
1395         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1396         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1397         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1398         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1399         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1400         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1401         self.visit_async_funcdef = self.visit_async_stmt
1402         self.visit_decorated = self.visit_decorators
1403
1404
1405 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1406 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1407 OPENING_BRACKETS = set(BRACKET.keys())
1408 CLOSING_BRACKETS = set(BRACKET.values())
1409 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1410 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1411
1412
1413 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1414     """Return whitespace prefix if needed for the given `leaf`.
1415
1416     `complex_subscript` signals whether the given leaf is part of a subscription
1417     which has non-trivial arguments, like arithmetic expressions or function calls.
1418     """
1419     NO = ""
1420     SPACE = " "
1421     DOUBLESPACE = "  "
1422     t = leaf.type
1423     p = leaf.parent
1424     v = leaf.value
1425     if t in ALWAYS_NO_SPACE:
1426         return NO
1427
1428     if t == token.COMMENT:
1429         return DOUBLESPACE
1430
1431     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1432     if (
1433         t == token.COLON
1434         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1435     ):
1436         return NO
1437
1438     prev = leaf.prev_sibling
1439     if not prev:
1440         prevp = preceding_leaf(p)
1441         if not prevp or prevp.type in OPENING_BRACKETS:
1442             return NO
1443
1444         if t == token.COLON:
1445             if prevp.type == token.COLON:
1446                 return NO
1447
1448             elif prevp.type != token.COMMA and not complex_subscript:
1449                 return NO
1450
1451             return SPACE
1452
1453         if prevp.type == token.EQUAL:
1454             if prevp.parent:
1455                 if prevp.parent.type in {
1456                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1457                 }:
1458                     return NO
1459
1460                 elif prevp.parent.type == syms.typedargslist:
1461                     # A bit hacky: if the equal sign has whitespace, it means we
1462                     # previously found it's a typed argument.  So, we're using
1463                     # that, too.
1464                     return prevp.prefix
1465
1466         elif prevp.type in STARS:
1467             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1468                 return NO
1469
1470         elif prevp.type == token.COLON:
1471             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1472                 return SPACE if complex_subscript else NO
1473
1474         elif (
1475             prevp.parent
1476             and prevp.parent.type == syms.factor
1477             and prevp.type in MATH_OPERATORS
1478         ):
1479             return NO
1480
1481         elif (
1482             prevp.type == token.RIGHTSHIFT
1483             and prevp.parent
1484             and prevp.parent.type == syms.shift_expr
1485             and prevp.prev_sibling
1486             and prevp.prev_sibling.type == token.NAME
1487             and prevp.prev_sibling.value == "print"  # type: ignore
1488         ):
1489             # Python 2 print chevron
1490             return NO
1491
1492     elif prev.type in OPENING_BRACKETS:
1493         return NO
1494
1495     if p.type in {syms.parameters, syms.arglist}:
1496         # untyped function signatures or calls
1497         if not prev or prev.type != token.COMMA:
1498             return NO
1499
1500     elif p.type == syms.varargslist:
1501         # lambdas
1502         if prev and prev.type != token.COMMA:
1503             return NO
1504
1505     elif p.type == syms.typedargslist:
1506         # typed function signatures
1507         if not prev:
1508             return NO
1509
1510         if t == token.EQUAL:
1511             if prev.type != syms.tname:
1512                 return NO
1513
1514         elif prev.type == token.EQUAL:
1515             # A bit hacky: if the equal sign has whitespace, it means we
1516             # previously found it's a typed argument.  So, we're using that, too.
1517             return prev.prefix
1518
1519         elif prev.type != token.COMMA:
1520             return NO
1521
1522     elif p.type == syms.tname:
1523         # type names
1524         if not prev:
1525             prevp = preceding_leaf(p)
1526             if not prevp or prevp.type != token.COMMA:
1527                 return NO
1528
1529     elif p.type == syms.trailer:
1530         # attributes and calls
1531         if t == token.LPAR or t == token.RPAR:
1532             return NO
1533
1534         if not prev:
1535             if t == token.DOT:
1536                 prevp = preceding_leaf(p)
1537                 if not prevp or prevp.type != token.NUMBER:
1538                     return NO
1539
1540             elif t == token.LSQB:
1541                 return NO
1542
1543         elif prev.type != token.COMMA:
1544             return NO
1545
1546     elif p.type == syms.argument:
1547         # single argument
1548         if t == token.EQUAL:
1549             return NO
1550
1551         if not prev:
1552             prevp = preceding_leaf(p)
1553             if not prevp or prevp.type == token.LPAR:
1554                 return NO
1555
1556         elif prev.type in {token.EQUAL} | STARS:
1557             return NO
1558
1559     elif p.type == syms.decorator:
1560         # decorators
1561         return NO
1562
1563     elif p.type == syms.dotted_name:
1564         if prev:
1565             return NO
1566
1567         prevp = preceding_leaf(p)
1568         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1569             return NO
1570
1571     elif p.type == syms.classdef:
1572         if t == token.LPAR:
1573             return NO
1574
1575         if prev and prev.type == token.LPAR:
1576             return NO
1577
1578     elif p.type in {syms.subscript, syms.sliceop}:
1579         # indexing
1580         if not prev:
1581             assert p.parent is not None, "subscripts are always parented"
1582             if p.parent.type == syms.subscriptlist:
1583                 return SPACE
1584
1585             return NO
1586
1587         elif not complex_subscript:
1588             return NO
1589
1590     elif p.type == syms.atom:
1591         if prev and t == token.DOT:
1592             # dots, but not the first one.
1593             return NO
1594
1595     elif p.type == syms.dictsetmaker:
1596         # dict unpacking
1597         if prev and prev.type == token.DOUBLESTAR:
1598             return NO
1599
1600     elif p.type in {syms.factor, syms.star_expr}:
1601         # unary ops
1602         if not prev:
1603             prevp = preceding_leaf(p)
1604             if not prevp or prevp.type in OPENING_BRACKETS:
1605                 return NO
1606
1607             prevp_parent = prevp.parent
1608             assert prevp_parent is not None
1609             if (
1610                 prevp.type == token.COLON
1611                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1612             ):
1613                 return NO
1614
1615             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1616                 return NO
1617
1618         elif t == token.NAME or t == token.NUMBER:
1619             return NO
1620
1621     elif p.type == syms.import_from:
1622         if t == token.DOT:
1623             if prev and prev.type == token.DOT:
1624                 return NO
1625
1626         elif t == token.NAME:
1627             if v == "import":
1628                 return SPACE
1629
1630             if prev and prev.type == token.DOT:
1631                 return NO
1632
1633     elif p.type == syms.sliceop:
1634         return NO
1635
1636     return SPACE
1637
1638
1639 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1640     """Return the first leaf that precedes `node`, if any."""
1641     while node:
1642         res = node.prev_sibling
1643         if res:
1644             if isinstance(res, Leaf):
1645                 return res
1646
1647             try:
1648                 return list(res.leaves())[-1]
1649
1650             except IndexError:
1651                 return None
1652
1653         node = node.parent
1654     return None
1655
1656
1657 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1658     """Return the child of `ancestor` that contains `descendant`."""
1659     node: Optional[LN] = descendant
1660     while node and node.parent != ancestor:
1661         node = node.parent
1662     return node
1663
1664
1665 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1666     """Return the priority of the `leaf` delimiter, given a line break after it.
1667
1668     The delimiter priorities returned here are from those delimiters that would
1669     cause a line break after themselves.
1670
1671     Higher numbers are higher priority.
1672     """
1673     if leaf.type == token.COMMA:
1674         return COMMA_PRIORITY
1675
1676     return 0
1677
1678
1679 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1680     """Return the priority of the `leaf` delimiter, given a line before after it.
1681
1682     The delimiter priorities returned here are from those delimiters that would
1683     cause a line break before themselves.
1684
1685     Higher numbers are higher priority.
1686     """
1687     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1688         # * and ** might also be MATH_OPERATORS but in this case they are not.
1689         # Don't treat them as a delimiter.
1690         return 0
1691
1692     if (
1693         leaf.type in MATH_OPERATORS
1694         and leaf.parent
1695         and leaf.parent.type not in {syms.factor, syms.star_expr}
1696     ):
1697         return MATH_PRIORITIES[leaf.type]
1698
1699     if leaf.type in COMPARATORS:
1700         return COMPARATOR_PRIORITY
1701
1702     if (
1703         leaf.type == token.STRING
1704         and previous is not None
1705         and previous.type == token.STRING
1706     ):
1707         return STRING_PRIORITY
1708
1709     if (
1710         leaf.type == token.NAME
1711         and leaf.value == "for"
1712         and leaf.parent
1713         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1714     ):
1715         return COMPREHENSION_PRIORITY
1716
1717     if (
1718         leaf.type == token.NAME
1719         and leaf.value == "if"
1720         and leaf.parent
1721         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1722     ):
1723         return COMPREHENSION_PRIORITY
1724
1725     if (
1726         leaf.type == token.NAME
1727         and leaf.value in {"if", "else"}
1728         and leaf.parent
1729         and leaf.parent.type == syms.test
1730     ):
1731         return TERNARY_PRIORITY
1732
1733     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1734         return LOGIC_PRIORITY
1735
1736     return 0
1737
1738
1739 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1740     """Clean the prefix of the `leaf` and generate comments from it, if any.
1741
1742     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1743     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1744     move because it does away with modifying the grammar to include all the
1745     possible places in which comments can be placed.
1746
1747     The sad consequence for us though is that comments don't "belong" anywhere.
1748     This is why this function generates simple parentless Leaf objects for
1749     comments.  We simply don't know what the correct parent should be.
1750
1751     No matter though, we can live without this.  We really only need to
1752     differentiate between inline and standalone comments.  The latter don't
1753     share the line with any code.
1754
1755     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1756     are emitted with a fake STANDALONE_COMMENT token identifier.
1757     """
1758     p = leaf.prefix
1759     if not p:
1760         return
1761
1762     if "#" not in p:
1763         return
1764
1765     consumed = 0
1766     nlines = 0
1767     for index, line in enumerate(p.split("\n")):
1768         consumed += len(line) + 1  # adding the length of the split '\n'
1769         line = line.lstrip()
1770         if not line:
1771             nlines += 1
1772         if not line.startswith("#"):
1773             continue
1774
1775         if index == 0 and leaf.type != token.ENDMARKER:
1776             comment_type = token.COMMENT  # simple trailing comment
1777         else:
1778             comment_type = STANDALONE_COMMENT
1779         comment = make_comment(line)
1780         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1781
1782         if comment in {"# fmt: on", "# yapf: enable"}:
1783             raise FormatOn(consumed)
1784
1785         if comment in {"# fmt: off", "# yapf: disable"}:
1786             if comment_type == STANDALONE_COMMENT:
1787                 raise FormatOff(consumed)
1788
1789             prev = preceding_leaf(leaf)
1790             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1791                 raise FormatOff(consumed)
1792
1793         nlines = 0
1794
1795
1796 def make_comment(content: str) -> str:
1797     """Return a consistently formatted comment from the given `content` string.
1798
1799     All comments (except for "##", "#!", "#:") should have a single space between
1800     the hash sign and the content.
1801
1802     If `content` didn't start with a hash sign, one is provided.
1803     """
1804     content = content.rstrip()
1805     if not content:
1806         return "#"
1807
1808     if content[0] == "#":
1809         content = content[1:]
1810     if content and content[0] not in " !:#":
1811         content = " " + content
1812     return "#" + content
1813
1814
1815 def split_line(
1816     line: Line, line_length: int, inner: bool = False, py36: bool = False
1817 ) -> Iterator[Line]:
1818     """Split a `line` into potentially many lines.
1819
1820     They should fit in the allotted `line_length` but might not be able to.
1821     `inner` signifies that there were a pair of brackets somewhere around the
1822     current `line`, possibly transitively. This means we can fallback to splitting
1823     by delimiters if the LHS/RHS don't yield any results.
1824
1825     If `py36` is True, splitting may generate syntax that is only compatible
1826     with Python 3.6 and later.
1827     """
1828     if isinstance(line, UnformattedLines) or line.is_comment:
1829         yield line
1830         return
1831
1832     line_str = str(line).strip("\n")
1833     if is_line_short_enough(line, line_length=line_length, line_str=line_str):
1834         yield line
1835         return
1836
1837     split_funcs: List[SplitFunc]
1838     if line.is_def:
1839         split_funcs = [left_hand_split]
1840     elif line.is_import:
1841         split_funcs = [explode_split]
1842     else:
1843
1844         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1845             for omit in generate_trailers_to_omit(line, line_length):
1846                 lines = list(right_hand_split(line, py36, omit=omit))
1847                 if is_line_short_enough(lines[0], line_length=line_length):
1848                     yield from lines
1849                     return
1850
1851             # All splits failed, best effort split with no omits.
1852             yield from right_hand_split(line, py36)
1853
1854         if line.inside_brackets:
1855             split_funcs = [delimiter_split, standalone_comment_split, rhs]
1856         else:
1857             split_funcs = [rhs]
1858     for split_func in split_funcs:
1859         # We are accumulating lines in `result` because we might want to abort
1860         # mission and return the original line in the end, or attempt a different
1861         # split altogether.
1862         result: List[Line] = []
1863         try:
1864             for l in split_func(line, py36):
1865                 if str(l).strip("\n") == line_str:
1866                     raise CannotSplit("Split function returned an unchanged result")
1867
1868                 result.extend(
1869                     split_line(l, line_length=line_length, inner=True, py36=py36)
1870                 )
1871         except CannotSplit as cs:
1872             continue
1873
1874         else:
1875             yield from result
1876             break
1877
1878     else:
1879         yield line
1880
1881
1882 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1883     """Split line into many lines, starting with the first matching bracket pair.
1884
1885     Note: this usually looks weird, only use this for function definitions.
1886     Prefer RHS otherwise.  This is why this function is not symmetrical with
1887     :func:`right_hand_split` which also handles optional parentheses.
1888     """
1889     head = Line(depth=line.depth)
1890     body = Line(depth=line.depth + 1, inside_brackets=True)
1891     tail = Line(depth=line.depth)
1892     tail_leaves: List[Leaf] = []
1893     body_leaves: List[Leaf] = []
1894     head_leaves: List[Leaf] = []
1895     current_leaves = head_leaves
1896     matching_bracket = None
1897     for leaf in line.leaves:
1898         if (
1899             current_leaves is body_leaves
1900             and leaf.type in CLOSING_BRACKETS
1901             and leaf.opening_bracket is matching_bracket
1902         ):
1903             current_leaves = tail_leaves if body_leaves else head_leaves
1904         current_leaves.append(leaf)
1905         if current_leaves is head_leaves:
1906             if leaf.type in OPENING_BRACKETS:
1907                 matching_bracket = leaf
1908                 current_leaves = body_leaves
1909     # Since body is a new indent level, remove spurious leading whitespace.
1910     if body_leaves:
1911         normalize_prefix(body_leaves[0], inside_brackets=True)
1912     # Build the new lines.
1913     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1914         for leaf in leaves:
1915             result.append(leaf, preformatted=True)
1916             for comment_after in line.comments_after(leaf):
1917                 result.append(comment_after, preformatted=True)
1918     bracket_split_succeeded_or_raise(head, body, tail)
1919     for result in (head, body, tail):
1920         if result:
1921             yield result
1922
1923
1924 def right_hand_split(
1925     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1926 ) -> Iterator[Line]:
1927     """Split line into many lines, starting with the last matching bracket pair.
1928
1929     If the split was by optional parentheses, attempt splitting without them, too.
1930     `omit` is a collection of closing bracket IDs that shouldn't be considered for
1931     this split.
1932     """
1933     head = Line(depth=line.depth)
1934     body = Line(depth=line.depth + 1, inside_brackets=True)
1935     tail = Line(depth=line.depth)
1936     tail_leaves: List[Leaf] = []
1937     body_leaves: List[Leaf] = []
1938     head_leaves: List[Leaf] = []
1939     current_leaves = tail_leaves
1940     opening_bracket = None
1941     closing_bracket = None
1942     for leaf in reversed(line.leaves):
1943         if current_leaves is body_leaves:
1944             if leaf is opening_bracket:
1945                 current_leaves = head_leaves if body_leaves else tail_leaves
1946         current_leaves.append(leaf)
1947         if current_leaves is tail_leaves:
1948             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1949                 opening_bracket = leaf.opening_bracket
1950                 closing_bracket = leaf
1951                 current_leaves = body_leaves
1952     tail_leaves.reverse()
1953     body_leaves.reverse()
1954     head_leaves.reverse()
1955     # Since body is a new indent level, remove spurious leading whitespace.
1956     if body_leaves:
1957         normalize_prefix(body_leaves[0], inside_brackets=True)
1958     elif not head_leaves:
1959         # No `head` and no `body` means the split failed. `tail` has all content.
1960         raise CannotSplit("No brackets found")
1961
1962     # Build the new lines.
1963     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1964         for leaf in leaves:
1965             result.append(leaf, preformatted=True)
1966             for comment_after in line.comments_after(leaf):
1967                 result.append(comment_after, preformatted=True)
1968     bracket_split_succeeded_or_raise(head, body, tail)
1969     assert opening_bracket and closing_bracket
1970     if (
1971         # the opening bracket is an optional paren
1972         opening_bracket.type == token.LPAR
1973         and not opening_bracket.value
1974         # the closing bracket is an optional paren
1975         and closing_bracket.type == token.RPAR
1976         and not closing_bracket.value
1977         # there are no delimiters or standalone comments in the body
1978         and not body.bracket_tracker.delimiters
1979         and not line.contains_standalone_comments(0)
1980         # and it's not an import (optional parens are the only thing we can split
1981         # on in this case; attempting a split without them is a waste of time)
1982         and not line.is_import
1983     ):
1984         omit = {id(closing_bracket), *omit}
1985         try:
1986             yield from right_hand_split(line, py36=py36, omit=omit)
1987             return
1988         except CannotSplit:
1989             pass
1990
1991     ensure_visible(opening_bracket)
1992     ensure_visible(closing_bracket)
1993     for result in (head, body, tail):
1994         if result:
1995             yield result
1996
1997
1998 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1999     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2000
2001     Do nothing otherwise.
2002
2003     A left- or right-hand split is based on a pair of brackets. Content before
2004     (and including) the opening bracket is left on one line, content inside the
2005     brackets is put on a separate line, and finally content starting with and
2006     following the closing bracket is put on a separate line.
2007
2008     Those are called `head`, `body`, and `tail`, respectively. If the split
2009     produced the same line (all content in `head`) or ended up with an empty `body`
2010     and the `tail` is just the closing bracket, then it's considered failed.
2011     """
2012     tail_len = len(str(tail).strip())
2013     if not body:
2014         if tail_len == 0:
2015             raise CannotSplit("Splitting brackets produced the same line")
2016
2017         elif tail_len < 3:
2018             raise CannotSplit(
2019                 f"Splitting brackets on an empty body to save "
2020                 f"{tail_len} characters is not worth it"
2021             )
2022
2023
2024 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2025     """Normalize prefix of the first leaf in every line returned by `split_func`.
2026
2027     This is a decorator over relevant split functions.
2028     """
2029
2030     @wraps(split_func)
2031     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2032         for l in split_func(line, py36):
2033             normalize_prefix(l.leaves[0], inside_brackets=True)
2034             yield l
2035
2036     return split_wrapper
2037
2038
2039 @dont_increase_indentation
2040 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2041     """Split according to delimiters of the highest priority.
2042
2043     If `py36` is True, the split will add trailing commas also in function
2044     signatures that contain `*` and `**`.
2045     """
2046     try:
2047         last_leaf = line.leaves[-1]
2048     except IndexError:
2049         raise CannotSplit("Line empty")
2050
2051     delimiters = line.bracket_tracker.delimiters
2052     try:
2053         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
2054             exclude={id(last_leaf)}
2055         )
2056     except ValueError:
2057         raise CannotSplit("No delimiters found")
2058
2059     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2060     lowest_depth = sys.maxsize
2061     trailing_comma_safe = True
2062
2063     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2064         """Append `leaf` to current line or to new line if appending impossible."""
2065         nonlocal current_line
2066         try:
2067             current_line.append_safe(leaf, preformatted=True)
2068         except ValueError as ve:
2069             yield current_line
2070
2071             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2072             current_line.append(leaf)
2073
2074     for index, leaf in enumerate(line.leaves):
2075         yield from append_to_line(leaf)
2076
2077         for comment_after in line.comments_after(leaf, index):
2078             yield from append_to_line(comment_after)
2079
2080         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2081         if (
2082             leaf.bracket_depth == lowest_depth
2083             and is_vararg(leaf, within=VARARGS_PARENTS)
2084         ):
2085             trailing_comma_safe = trailing_comma_safe and py36
2086         leaf_priority = delimiters.get(id(leaf))
2087         if leaf_priority == delimiter_priority:
2088             yield current_line
2089
2090             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2091     if current_line:
2092         if (
2093             trailing_comma_safe
2094             and delimiter_priority == COMMA_PRIORITY
2095             and current_line.leaves[-1].type != token.COMMA
2096             and current_line.leaves[-1].type != STANDALONE_COMMENT
2097         ):
2098             current_line.append(Leaf(token.COMMA, ","))
2099         yield current_line
2100
2101
2102 @dont_increase_indentation
2103 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2104     """Split standalone comments from the rest of the line."""
2105     if not line.contains_standalone_comments(0):
2106         raise CannotSplit("Line does not have any standalone comments")
2107
2108     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2109
2110     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2111         """Append `leaf` to current line or to new line if appending impossible."""
2112         nonlocal current_line
2113         try:
2114             current_line.append_safe(leaf, preformatted=True)
2115         except ValueError as ve:
2116             yield current_line
2117
2118             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2119             current_line.append(leaf)
2120
2121     for index, leaf in enumerate(line.leaves):
2122         yield from append_to_line(leaf)
2123
2124         for comment_after in line.comments_after(leaf, index):
2125             yield from append_to_line(comment_after)
2126
2127     if current_line:
2128         yield current_line
2129
2130
2131 def explode_split(
2132     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2133 ) -> Iterator[Line]:
2134     """Split by rightmost bracket and immediately split contents by a delimiter."""
2135     new_lines = list(right_hand_split(line, py36, omit))
2136     if len(new_lines) != 3:
2137         yield from new_lines
2138         return
2139
2140     yield new_lines[0]
2141
2142     try:
2143         yield from delimiter_split(new_lines[1], py36)
2144
2145     except CannotSplit:
2146         yield new_lines[1]
2147
2148     yield new_lines[2]
2149
2150
2151 def is_import(leaf: Leaf) -> bool:
2152     """Return True if the given leaf starts an import statement."""
2153     p = leaf.parent
2154     t = leaf.type
2155     v = leaf.value
2156     return bool(
2157         t == token.NAME
2158         and (
2159             (v == "import" and p and p.type == syms.import_name)
2160             or (v == "from" and p and p.type == syms.import_from)
2161         )
2162     )
2163
2164
2165 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2166     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2167     else.
2168
2169     Note: don't use backslashes for formatting or you'll lose your voting rights.
2170     """
2171     if not inside_brackets:
2172         spl = leaf.prefix.split("#")
2173         if "\\" not in spl[0]:
2174             nl_count = spl[-1].count("\n")
2175             if len(spl) > 1:
2176                 nl_count -= 1
2177             leaf.prefix = "\n" * nl_count
2178             return
2179
2180     leaf.prefix = ""
2181
2182
2183 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2184     """Make all string prefixes lowercase.
2185
2186     If remove_u_prefix is given, also removes any u prefix from the string.
2187
2188     Note: Mutates its argument.
2189     """
2190     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2191     assert match is not None, f"failed to match string {leaf.value!r}"
2192     orig_prefix = match.group(1)
2193     new_prefix = orig_prefix.lower()
2194     if remove_u_prefix:
2195         new_prefix = new_prefix.replace("u", "")
2196     leaf.value = f"{new_prefix}{match.group(2)}"
2197
2198
2199 def normalize_string_quotes(leaf: Leaf) -> None:
2200     """Prefer double quotes but only if it doesn't cause more escaping.
2201
2202     Adds or removes backslashes as appropriate. Doesn't parse and fix
2203     strings nested in f-strings (yet).
2204
2205     Note: Mutates its argument.
2206     """
2207     value = leaf.value.lstrip("furbFURB")
2208     if value[:3] == '"""':
2209         return
2210
2211     elif value[:3] == "'''":
2212         orig_quote = "'''"
2213         new_quote = '"""'
2214     elif value[0] == '"':
2215         orig_quote = '"'
2216         new_quote = "'"
2217     else:
2218         orig_quote = "'"
2219         new_quote = '"'
2220     first_quote_pos = leaf.value.find(orig_quote)
2221     if first_quote_pos == -1:
2222         return  # There's an internal error
2223
2224     prefix = leaf.value[:first_quote_pos]
2225     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2226     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2227     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2228     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2229     if "r" in prefix.casefold():
2230         if unescaped_new_quote.search(body):
2231             # There's at least one unescaped new_quote in this raw string
2232             # so converting is impossible
2233             return
2234
2235         # Do not introduce or remove backslashes in raw strings
2236         new_body = body
2237     else:
2238         # remove unnecessary quotes
2239         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2240         if body != new_body:
2241             # Consider the string without unnecessary quotes as the original
2242             body = new_body
2243             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2244         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2245         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2246     if new_quote == '"""' and new_body[-1] == '"':
2247         # edge case:
2248         new_body = new_body[:-1] + '\\"'
2249     orig_escape_count = body.count("\\")
2250     new_escape_count = new_body.count("\\")
2251     if new_escape_count > orig_escape_count:
2252         return  # Do not introduce more escaping
2253
2254     if new_escape_count == orig_escape_count and orig_quote == '"':
2255         return  # Prefer double quotes
2256
2257     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2258
2259
2260 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2261     """Make existing optional parentheses invisible or create new ones.
2262
2263     `parens_after` is a set of string leaf values immeditely after which parens
2264     should be put.
2265
2266     Standardizes on visible parentheses for single-element tuples, and keeps
2267     existing visible parentheses for other tuples and generator expressions.
2268     """
2269     check_lpar = False
2270     for child in list(node.children):
2271         if check_lpar:
2272             if child.type == syms.atom:
2273                 maybe_make_parens_invisible_in_atom(child)
2274             elif is_one_tuple(child):
2275                 # wrap child in visible parentheses
2276                 lpar = Leaf(token.LPAR, "(")
2277                 rpar = Leaf(token.RPAR, ")")
2278                 index = child.remove() or 0
2279                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2280             else:
2281                 # wrap child in invisible parentheses
2282                 lpar = Leaf(token.LPAR, "")
2283                 rpar = Leaf(token.RPAR, "")
2284                 index = child.remove() or 0
2285                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2286
2287         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2288
2289
2290 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2291     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2292     if (
2293         node.type != syms.atom
2294         or is_empty_tuple(node)
2295         or is_one_tuple(node)
2296         or is_yield(node)
2297         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2298     ):
2299         return False
2300
2301     first = node.children[0]
2302     last = node.children[-1]
2303     if first.type == token.LPAR and last.type == token.RPAR:
2304         # make parentheses invisible
2305         first.value = ""  # type: ignore
2306         last.value = ""  # type: ignore
2307         if len(node.children) > 1:
2308             maybe_make_parens_invisible_in_atom(node.children[1])
2309         return True
2310
2311     return False
2312
2313
2314 def is_empty_tuple(node: LN) -> bool:
2315     """Return True if `node` holds an empty tuple."""
2316     return (
2317         node.type == syms.atom
2318         and len(node.children) == 2
2319         and node.children[0].type == token.LPAR
2320         and node.children[1].type == token.RPAR
2321     )
2322
2323
2324 def is_one_tuple(node: LN) -> bool:
2325     """Return True if `node` holds a tuple with one element, with or without parens."""
2326     if node.type == syms.atom:
2327         if len(node.children) != 3:
2328             return False
2329
2330         lpar, gexp, rpar = node.children
2331         if not (
2332             lpar.type == token.LPAR
2333             and gexp.type == syms.testlist_gexp
2334             and rpar.type == token.RPAR
2335         ):
2336             return False
2337
2338         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2339
2340     return (
2341         node.type in IMPLICIT_TUPLE
2342         and len(node.children) == 2
2343         and node.children[1].type == token.COMMA
2344     )
2345
2346
2347 def is_yield(node: LN) -> bool:
2348     """Return True if `node` holds a `yield` or `yield from` expression."""
2349     if node.type == syms.yield_expr:
2350         return True
2351
2352     if node.type == token.NAME and node.value == "yield":  # type: ignore
2353         return True
2354
2355     if node.type != syms.atom:
2356         return False
2357
2358     if len(node.children) != 3:
2359         return False
2360
2361     lpar, expr, rpar = node.children
2362     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2363         return is_yield(expr)
2364
2365     return False
2366
2367
2368 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2369     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2370
2371     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2372     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2373     extended iterable unpacking (PEP 3132) and additional unpacking
2374     generalizations (PEP 448).
2375     """
2376     if leaf.type not in STARS or not leaf.parent:
2377         return False
2378
2379     p = leaf.parent
2380     if p.type == syms.star_expr:
2381         # Star expressions are also used as assignment targets in extended
2382         # iterable unpacking (PEP 3132).  See what its parent is instead.
2383         if not p.parent:
2384             return False
2385
2386         p = p.parent
2387
2388     return p.type in within
2389
2390
2391 def max_delimiter_priority_in_atom(node: LN) -> int:
2392     """Return maximum delimiter priority inside `node`.
2393
2394     This is specific to atoms with contents contained in a pair of parentheses.
2395     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2396     """
2397     if node.type != syms.atom:
2398         return 0
2399
2400     first = node.children[0]
2401     last = node.children[-1]
2402     if not (first.type == token.LPAR and last.type == token.RPAR):
2403         return 0
2404
2405     bt = BracketTracker()
2406     for c in node.children[1:-1]:
2407         if isinstance(c, Leaf):
2408             bt.mark(c)
2409         else:
2410             for leaf in c.leaves():
2411                 bt.mark(leaf)
2412     try:
2413         return bt.max_delimiter_priority()
2414
2415     except ValueError:
2416         return 0
2417
2418
2419 def ensure_visible(leaf: Leaf) -> None:
2420     """Make sure parentheses are visible.
2421
2422     They could be invisible as part of some statements (see
2423     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2424     """
2425     if leaf.type == token.LPAR:
2426         leaf.value = "("
2427     elif leaf.type == token.RPAR:
2428         leaf.value = ")"
2429
2430
2431 def is_python36(node: Node) -> bool:
2432     """Return True if the current file is using Python 3.6+ features.
2433
2434     Currently looking for:
2435     - f-strings; and
2436     - trailing commas after * or ** in function signatures and calls.
2437     """
2438     for n in node.pre_order():
2439         if n.type == token.STRING:
2440             value_head = n.value[:2]  # type: ignore
2441             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2442                 return True
2443
2444         elif (
2445             n.type in {syms.typedargslist, syms.arglist}
2446             and n.children
2447             and n.children[-1].type == token.COMMA
2448         ):
2449             for ch in n.children:
2450                 if ch.type in STARS:
2451                     return True
2452
2453                 if ch.type == syms.argument:
2454                     for argch in ch.children:
2455                         if argch.type in STARS:
2456                             return True
2457
2458     return False
2459
2460
2461 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2462     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2463
2464     Brackets can be omitted if the entire trailer up to and including
2465     a preceding closing bracket fits in one line.
2466
2467     Yielded sets are cumulative (contain results of previous yields, too).  First
2468     set is empty.
2469     """
2470
2471     omit: Set[LeafID] = set()
2472     yield omit
2473
2474     length = 4 * line.depth
2475     opening_bracket = None
2476     closing_bracket = None
2477     optional_brackets: Set[LeafID] = set()
2478     inner_brackets: Set[LeafID] = set()
2479     for index, leaf in enumerate_reversed(line.leaves):
2480         length += len(leaf.prefix) + len(leaf.value)
2481         if length > line_length:
2482             break
2483
2484         comment: Optional[Leaf]
2485         for comment in line.comments_after(leaf, index):
2486             if "\n" in comment.prefix:
2487                 break  # Oops, standalone comment!
2488
2489             length += len(comment.value)
2490         else:
2491             comment = None
2492         if comment is not None:
2493             break  # There was a standalone comment, we can't continue.
2494
2495         optional_brackets.discard(id(leaf))
2496         if opening_bracket:
2497             if leaf is opening_bracket:
2498                 opening_bracket = None
2499             elif leaf.type in CLOSING_BRACKETS:
2500                 inner_brackets.add(id(leaf))
2501         elif leaf.type in CLOSING_BRACKETS:
2502             if not leaf.value:
2503                 optional_brackets.add(id(opening_bracket))
2504                 continue
2505
2506             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2507                 # Empty brackets would fail a split so treat them as "inner"
2508                 # brackets (e.g. only add them to the `omit` set if another
2509                 # pair of brackets was good enough.
2510                 inner_brackets.add(id(leaf))
2511                 continue
2512
2513             opening_bracket = leaf.opening_bracket
2514             if closing_bracket:
2515                 omit.add(id(closing_bracket))
2516                 omit.update(inner_brackets)
2517                 inner_brackets.clear()
2518                 yield omit
2519             closing_bracket = leaf
2520
2521
2522 def get_future_imports(node: Node) -> Set[str]:
2523     """Return a set of __future__ imports in the file."""
2524     imports = set()
2525     for child in node.children:
2526         if child.type != syms.simple_stmt:
2527             break
2528         first_child = child.children[0]
2529         if isinstance(first_child, Leaf):
2530             # Continue looking if we see a docstring; otherwise stop.
2531             if (
2532                 len(child.children) == 2
2533                 and first_child.type == token.STRING
2534                 and child.children[1].type == token.NEWLINE
2535             ):
2536                 continue
2537             else:
2538                 break
2539         elif first_child.type == syms.import_from:
2540             module_name = first_child.children[1]
2541             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2542                 break
2543             for import_from_child in first_child.children[3:]:
2544                 if isinstance(import_from_child, Leaf):
2545                     if import_from_child.type == token.NAME:
2546                         imports.add(import_from_child.value)
2547                 else:
2548                     assert import_from_child.type == syms.import_as_names
2549                     for leaf in import_from_child.children:
2550                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2551                             imports.add(leaf.value)
2552         else:
2553             break
2554     return imports
2555
2556
2557 PYTHON_EXTENSIONS = {".py"}
2558 BLACKLISTED_DIRECTORIES = {
2559     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2560 }
2561
2562
2563 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2564     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2565     and have one of the PYTHON_EXTENSIONS.
2566     """
2567     for child in path.iterdir():
2568         if child.is_dir():
2569             if child.name in BLACKLISTED_DIRECTORIES:
2570                 continue
2571
2572             yield from gen_python_files_in_dir(child)
2573
2574         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2575             yield child
2576
2577
2578 @dataclass
2579 class Report:
2580     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2581     check: bool = False
2582     quiet: bool = False
2583     change_count: int = 0
2584     same_count: int = 0
2585     failure_count: int = 0
2586
2587     def done(self, src: Path, changed: Changed) -> None:
2588         """Increment the counter for successful reformatting. Write out a message."""
2589         if changed is Changed.YES:
2590             reformatted = "would reformat" if self.check else "reformatted"
2591             if not self.quiet:
2592                 out(f"{reformatted} {src}")
2593             self.change_count += 1
2594         else:
2595             if not self.quiet:
2596                 if changed is Changed.NO:
2597                     msg = f"{src} already well formatted, good job."
2598                 else:
2599                     msg = f"{src} wasn't modified on disk since last run."
2600                 out(msg, bold=False)
2601             self.same_count += 1
2602
2603     def failed(self, src: Path, message: str) -> None:
2604         """Increment the counter for failed reformatting. Write out a message."""
2605         err(f"error: cannot format {src}: {message}")
2606         self.failure_count += 1
2607
2608     @property
2609     def return_code(self) -> int:
2610         """Return the exit code that the app should use.
2611
2612         This considers the current state of changed files and failures:
2613         - if there were any failures, return 123;
2614         - if any files were changed and --check is being used, return 1;
2615         - otherwise return 0.
2616         """
2617         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2618         # 126 we have special returncodes reserved by the shell.
2619         if self.failure_count:
2620             return 123
2621
2622         elif self.change_count and self.check:
2623             return 1
2624
2625         return 0
2626
2627     def __str__(self) -> str:
2628         """Render a color report of the current state.
2629
2630         Use `click.unstyle` to remove colors.
2631         """
2632         if self.check:
2633             reformatted = "would be reformatted"
2634             unchanged = "would be left unchanged"
2635             failed = "would fail to reformat"
2636         else:
2637             reformatted = "reformatted"
2638             unchanged = "left unchanged"
2639             failed = "failed to reformat"
2640         report = []
2641         if self.change_count:
2642             s = "s" if self.change_count > 1 else ""
2643             report.append(
2644                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2645             )
2646         if self.same_count:
2647             s = "s" if self.same_count > 1 else ""
2648             report.append(f"{self.same_count} file{s} {unchanged}")
2649         if self.failure_count:
2650             s = "s" if self.failure_count > 1 else ""
2651             report.append(
2652                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2653             )
2654         return ", ".join(report) + "."
2655
2656
2657 def assert_equivalent(src: str, dst: str) -> None:
2658     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2659
2660     import ast
2661     import traceback
2662
2663     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2664         """Simple visitor generating strings to compare ASTs by content."""
2665         yield f"{'  ' * depth}{node.__class__.__name__}("
2666
2667         for field in sorted(node._fields):
2668             try:
2669                 value = getattr(node, field)
2670             except AttributeError:
2671                 continue
2672
2673             yield f"{'  ' * (depth+1)}{field}="
2674
2675             if isinstance(value, list):
2676                 for item in value:
2677                     if isinstance(item, ast.AST):
2678                         yield from _v(item, depth + 2)
2679
2680             elif isinstance(value, ast.AST):
2681                 yield from _v(value, depth + 2)
2682
2683             else:
2684                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2685
2686         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2687
2688     try:
2689         src_ast = ast.parse(src)
2690     except Exception as exc:
2691         major, minor = sys.version_info[:2]
2692         raise AssertionError(
2693             f"cannot use --safe with this file; failed to parse source file "
2694             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2695             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2696         )
2697
2698     try:
2699         dst_ast = ast.parse(dst)
2700     except Exception as exc:
2701         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2702         raise AssertionError(
2703             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2704             f"Please report a bug on https://github.com/ambv/black/issues.  "
2705             f"This invalid output might be helpful: {log}"
2706         ) from None
2707
2708     src_ast_str = "\n".join(_v(src_ast))
2709     dst_ast_str = "\n".join(_v(dst_ast))
2710     if src_ast_str != dst_ast_str:
2711         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2712         raise AssertionError(
2713             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2714             f"the source.  "
2715             f"Please report a bug on https://github.com/ambv/black/issues.  "
2716             f"This diff might be helpful: {log}"
2717         ) from None
2718
2719
2720 def assert_stable(src: str, dst: str, line_length: int) -> None:
2721     """Raise AssertionError if `dst` reformats differently the second time."""
2722     newdst = format_str(dst, line_length=line_length)
2723     if dst != newdst:
2724         log = dump_to_file(
2725             diff(src, dst, "source", "first pass"),
2726             diff(dst, newdst, "first pass", "second pass"),
2727         )
2728         raise AssertionError(
2729             f"INTERNAL ERROR: Black produced different code on the second pass "
2730             f"of the formatter.  "
2731             f"Please report a bug on https://github.com/ambv/black/issues.  "
2732             f"This diff might be helpful: {log}"
2733         ) from None
2734
2735
2736 def dump_to_file(*output: str) -> str:
2737     """Dump `output` to a temporary file. Return path to the file."""
2738     import tempfile
2739
2740     with tempfile.NamedTemporaryFile(
2741         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2742     ) as f:
2743         for lines in output:
2744             f.write(lines)
2745             if lines and lines[-1] != "\n":
2746                 f.write("\n")
2747     return f.name
2748
2749
2750 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2751     """Return a unified diff string between strings `a` and `b`."""
2752     import difflib
2753
2754     a_lines = [line + "\n" for line in a.split("\n")]
2755     b_lines = [line + "\n" for line in b.split("\n")]
2756     return "".join(
2757         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2758     )
2759
2760
2761 def cancel(tasks: List[asyncio.Task]) -> None:
2762     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2763     err("Aborted!")
2764     for task in tasks:
2765         task.cancel()
2766
2767
2768 def shutdown(loop: BaseEventLoop) -> None:
2769     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2770     try:
2771         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2772         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2773         if not to_cancel:
2774             return
2775
2776         for task in to_cancel:
2777             task.cancel()
2778         loop.run_until_complete(
2779             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2780         )
2781     finally:
2782         # `concurrent.futures.Future` objects cannot be cancelled once they
2783         # are already running. There might be some when the `shutdown()` happened.
2784         # Silence their logger's spew about the event loop being closed.
2785         cf_logger = logging.getLogger("concurrent.futures")
2786         cf_logger.setLevel(logging.CRITICAL)
2787         loop.close()
2788
2789
2790 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2791     """Replace `regex` with `replacement` twice on `original`.
2792
2793     This is used by string normalization to perform replaces on
2794     overlapping matches.
2795     """
2796     return regex.sub(replacement, regex.sub(replacement, original))
2797
2798
2799 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
2800     """Like `reversed(enumerate(sequence))` if that were possible."""
2801     index = len(sequence) - 1
2802     for element in reversed(sequence):
2803         yield (index, element)
2804         index -= 1
2805
2806
2807 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
2808     """Return True if `line` is no longer than `line_length`.
2809
2810     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
2811     """
2812     if not line_str:
2813         line_str = str(line).strip("\n")
2814     return (
2815         len(line_str) <= line_length
2816         and "\n" not in line_str  # multiline strings
2817         and not line.contains_standalone_comments()
2818     )
2819
2820
2821 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2822
2823
2824 def get_cache_file(line_length: int) -> Path:
2825     return CACHE_DIR / f"cache.{line_length}.pickle"
2826
2827
2828 def read_cache(line_length: int) -> Cache:
2829     """Read the cache if it exists and is well formed.
2830
2831     If it is not well formed, the call to write_cache later should resolve the issue.
2832     """
2833     cache_file = get_cache_file(line_length)
2834     if not cache_file.exists():
2835         return {}
2836
2837     with cache_file.open("rb") as fobj:
2838         try:
2839             cache: Cache = pickle.load(fobj)
2840         except pickle.UnpicklingError:
2841             return {}
2842
2843     return cache
2844
2845
2846 def get_cache_info(path: Path) -> CacheInfo:
2847     """Return the information used to check if a file is already formatted or not."""
2848     stat = path.stat()
2849     return stat.st_mtime, stat.st_size
2850
2851
2852 def filter_cached(
2853     cache: Cache, sources: Iterable[Path]
2854 ) -> Tuple[List[Path], List[Path]]:
2855     """Split a list of paths into two.
2856
2857     The first list contains paths of files that modified on disk or are not in the
2858     cache. The other list contains paths to non-modified files.
2859     """
2860     todo, done = [], []
2861     for src in sources:
2862         src = src.resolve()
2863         if cache.get(src) != get_cache_info(src):
2864             todo.append(src)
2865         else:
2866             done.append(src)
2867     return todo, done
2868
2869
2870 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2871     """Update the cache file."""
2872     cache_file = get_cache_file(line_length)
2873     try:
2874         if not CACHE_DIR.exists():
2875             CACHE_DIR.mkdir(parents=True)
2876         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2877         with cache_file.open("wb") as fobj:
2878             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2879     except OSError:
2880         pass
2881
2882
2883 if __name__ == "__main__":
2884     main()