]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add support for pyi files (#210)
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33 )
34
35 from appdirs import user_cache_dir
36 from attr import dataclass, Factory
37 import click
38
39 # lib2to3 fork
40 from blib2to3.pytree import Node, Leaf, type_repr
41 from blib2to3 import pygram, pytree
42 from blib2to3.pgen2 import driver, token
43 from blib2to3.pgen2.parse import ParseError
44
45
46 __version__ = "18.4a6"
47 DEFAULT_LINE_LENGTH = 88
48
49 # types
50 syms = pygram.python_symbols
51 FileContent = str
52 Encoding = str
53 Depth = int
54 NodeType = int
55 LeafID = int
56 Priority = int
57 Index = int
58 LN = Union[Leaf, Node]
59 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
60 Timestamp = float
61 FileSize = int
62 CacheInfo = Tuple[Timestamp, FileSize]
63 Cache = Dict[Path, CacheInfo]
64 out = partial(click.secho, bold=True, err=True)
65 err = partial(click.secho, fg="red", err=True)
66
67
68 class NothingChanged(UserWarning):
69     """Raised by :func:`format_file` when reformatted code is the same as source."""
70
71
72 class CannotSplit(Exception):
73     """A readable split that fits the allotted line length is impossible.
74
75     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
76     :func:`delimiter_split`.
77     """
78
79
80 class FormatError(Exception):
81     """Base exception for `# fmt: on` and `# fmt: off` handling.
82
83     It holds the number of bytes of the prefix consumed before the format
84     control comment appeared.
85     """
86
87     def __init__(self, consumed: int) -> None:
88         super().__init__(consumed)
89         self.consumed = consumed
90
91     def trim_prefix(self, leaf: Leaf) -> None:
92         leaf.prefix = leaf.prefix[self.consumed :]
93
94     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
95         """Returns a new Leaf from the consumed part of the prefix."""
96         unformatted_prefix = leaf.prefix[: self.consumed]
97         return Leaf(token.NEWLINE, unformatted_prefix)
98
99
100 class FormatOn(FormatError):
101     """Found a comment like `# fmt: on` in the file."""
102
103
104 class FormatOff(FormatError):
105     """Found a comment like `# fmt: off` in the file."""
106
107
108 class WriteBack(Enum):
109     NO = 0
110     YES = 1
111     DIFF = 2
112
113
114 class Changed(Enum):
115     NO = 0
116     CACHED = 1
117     YES = 2
118
119
120 @click.command()
121 @click.option(
122     "-l",
123     "--line-length",
124     type=int,
125     default=DEFAULT_LINE_LENGTH,
126     help="How many character per line to allow.",
127     show_default=True,
128 )
129 @click.option(
130     "--check",
131     is_flag=True,
132     help=(
133         "Don't write the files back, just return the status.  Return code 0 "
134         "means nothing would change.  Return code 1 means some files would be "
135         "reformatted.  Return code 123 means there was an internal error."
136     ),
137 )
138 @click.option(
139     "--diff",
140     is_flag=True,
141     help="Don't write the files back, just output a diff for each file on stdout.",
142 )
143 @click.option(
144     "--fast/--safe",
145     is_flag=True,
146     help="If --fast given, skip temporary sanity checks. [default: --safe]",
147 )
148 @click.option(
149     "-q",
150     "--quiet",
151     is_flag=True,
152     help=(
153         "Don't emit non-error messages to stderr. Errors are still emitted, "
154         "silence those with 2>/dev/null."
155     ),
156 )
157 @click.version_option(version=__version__)
158 @click.argument(
159     "src",
160     nargs=-1,
161     type=click.Path(
162         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
163     ),
164 )
165 @click.pass_context
166 def main(
167     ctx: click.Context,
168     line_length: int,
169     check: bool,
170     diff: bool,
171     fast: bool,
172     quiet: bool,
173     src: List[str],
174 ) -> None:
175     """The uncompromising code formatter."""
176     sources: List[Path] = []
177     for s in src:
178         p = Path(s)
179         if p.is_dir():
180             sources.extend(gen_python_files_in_dir(p))
181         elif p.is_file():
182             # if a file was explicitly given, we don't care about its extension
183             sources.append(p)
184         elif s == "-":
185             sources.append(Path("-"))
186         else:
187             err(f"invalid path: {s}")
188
189     if check and not diff:
190         write_back = WriteBack.NO
191     elif diff:
192         write_back = WriteBack.DIFF
193     else:
194         write_back = WriteBack.YES
195     report = Report(check=check, quiet=quiet)
196     if len(sources) == 0:
197         out("No paths given. Nothing to do 😴")
198         ctx.exit(0)
199         return
200
201     elif len(sources) == 1:
202         reformat_one(sources[0], line_length, fast, write_back, report)
203     else:
204         loop = asyncio.get_event_loop()
205         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
206         try:
207             loop.run_until_complete(
208                 schedule_formatting(
209                     sources, line_length, fast, write_back, report, loop, executor
210                 )
211             )
212         finally:
213             shutdown(loop)
214         if not quiet:
215             out("All done! ✨ 🍰 ✨")
216             click.echo(str(report))
217     ctx.exit(report.return_code)
218
219
220 def reformat_one(
221     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
222 ) -> None:
223     """Reformat a single file under `src` without spawning child processes.
224
225     If `quiet` is True, non-error messages are not output. `line_length`,
226     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
227     """
228     try:
229         changed = Changed.NO
230         if not src.is_file() and str(src) == "-":
231             if format_stdin_to_stdout(
232                 line_length=line_length, fast=fast, write_back=write_back
233             ):
234                 changed = Changed.YES
235         else:
236             cache: Cache = {}
237             if write_back != WriteBack.DIFF:
238                 cache = read_cache(line_length)
239                 src = src.resolve()
240                 if src in cache and cache[src] == get_cache_info(src):
241                     changed = Changed.CACHED
242             if (
243                 changed is not Changed.CACHED
244                 and format_file_in_place(
245                     src, line_length=line_length, fast=fast, write_back=write_back
246                 )
247             ):
248                 changed = Changed.YES
249             if write_back == WriteBack.YES and changed is not Changed.NO:
250                 write_cache(cache, [src], line_length)
251         report.done(src, changed)
252     except Exception as exc:
253         report.failed(src, str(exc))
254
255
256 async def schedule_formatting(
257     sources: List[Path],
258     line_length: int,
259     fast: bool,
260     write_back: WriteBack,
261     report: "Report",
262     loop: BaseEventLoop,
263     executor: Executor,
264 ) -> None:
265     """Run formatting of `sources` in parallel using the provided `executor`.
266
267     (Use ProcessPoolExecutors for actual parallelism.)
268
269     `line_length`, `write_back`, and `fast` options are passed to
270     :func:`format_file_in_place`.
271     """
272     cache: Cache = {}
273     if write_back != WriteBack.DIFF:
274         cache = read_cache(line_length)
275         sources, cached = filter_cached(cache, sources)
276         for src in cached:
277             report.done(src, Changed.CACHED)
278     cancelled = []
279     formatted = []
280     if sources:
281         lock = None
282         if write_back == WriteBack.DIFF:
283             # For diff output, we need locks to ensure we don't interleave output
284             # from different processes.
285             manager = Manager()
286             lock = manager.Lock()
287         tasks = {
288             src: loop.run_in_executor(
289                 executor, format_file_in_place, src, line_length, fast, write_back, lock
290             )
291             for src in sources
292         }
293         _task_values = list(tasks.values())
294         try:
295             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
296             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
297         except NotImplementedError:
298             # There are no good alternatives for these on Windows
299             pass
300         await asyncio.wait(_task_values)
301         for src, task in tasks.items():
302             if not task.done():
303                 report.failed(src, "timed out, cancelling")
304                 task.cancel()
305                 cancelled.append(task)
306             elif task.cancelled():
307                 cancelled.append(task)
308             elif task.exception():
309                 report.failed(src, str(task.exception()))
310             else:
311                 formatted.append(src)
312                 report.done(src, Changed.YES if task.result() else Changed.NO)
313
314     if cancelled:
315         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
316     if write_back == WriteBack.YES and formatted:
317         write_cache(cache, formatted, line_length)
318
319
320 def format_file_in_place(
321     src: Path,
322     line_length: int,
323     fast: bool,
324     write_back: WriteBack = WriteBack.NO,
325     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
326 ) -> bool:
327     """Format file under `src` path. Return True if changed.
328
329     If `write_back` is True, write reformatted code back to stdout.
330     `line_length` and `fast` options are passed to :func:`format_file_contents`.
331     """
332     is_pyi = src.suffix == ".pyi"
333
334     with tokenize.open(src) as src_buffer:
335         src_contents = src_buffer.read()
336     try:
337         dst_contents = format_file_contents(
338             src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
339         )
340     except NothingChanged:
341         return False
342
343     if write_back == write_back.YES:
344         with open(src, "w", encoding=src_buffer.encoding) as f:
345             f.write(dst_contents)
346     elif write_back == write_back.DIFF:
347         src_name = f"{src}  (original)"
348         dst_name = f"{src}  (formatted)"
349         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
350         if lock:
351             lock.acquire()
352         try:
353             sys.stdout.write(diff_contents)
354         finally:
355             if lock:
356                 lock.release()
357     return True
358
359
360 def format_stdin_to_stdout(
361     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
362 ) -> bool:
363     """Format file on stdin. Return True if changed.
364
365     If `write_back` is True, write reformatted code back to stdout.
366     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
367     """
368     src = sys.stdin.read()
369     dst = src
370     try:
371         dst = format_file_contents(src, line_length=line_length, fast=fast)
372         return True
373
374     except NothingChanged:
375         return False
376
377     finally:
378         if write_back == WriteBack.YES:
379             sys.stdout.write(dst)
380         elif write_back == WriteBack.DIFF:
381             src_name = "<stdin>  (original)"
382             dst_name = "<stdin>  (formatted)"
383             sys.stdout.write(diff(src, dst, src_name, dst_name))
384
385
386 def format_file_contents(
387     src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
388 ) -> FileContent:
389     """Reformat contents a file and return new contents.
390
391     If `fast` is False, additionally confirm that the reformatted code is
392     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
393     `line_length` is passed to :func:`format_str`.
394     """
395     if src_contents.strip() == "":
396         raise NothingChanged
397
398     dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
399     if src_contents == dst_contents:
400         raise NothingChanged
401
402     if not fast:
403         assert_equivalent(src_contents, dst_contents)
404         assert_stable(
405             src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
406         )
407     return dst_contents
408
409
410 def format_str(
411     src_contents: str, line_length: int, *, is_pyi: bool = False
412 ) -> FileContent:
413     """Reformat a string and return new contents.
414
415     `line_length` determines how many characters per line are allowed.
416     """
417     src_node = lib2to3_parse(src_contents)
418     dst_contents = ""
419     future_imports = get_future_imports(src_node)
420     elt = EmptyLineTracker(is_pyi=is_pyi)
421     py36 = is_python36(src_node)
422     lines = LineGenerator(
423         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
424     )
425     empty_line = Line()
426     after = 0
427     for current_line in lines.visit(src_node):
428         for _ in range(after):
429             dst_contents += str(empty_line)
430         before, after = elt.maybe_empty_lines(current_line)
431         for _ in range(before):
432             dst_contents += str(empty_line)
433         for line in split_line(current_line, line_length=line_length, py36=py36):
434             dst_contents += str(line)
435     return dst_contents
436
437
438 GRAMMARS = [
439     pygram.python_grammar_no_print_statement_no_exec_statement,
440     pygram.python_grammar_no_print_statement,
441     pygram.python_grammar,
442 ]
443
444
445 def lib2to3_parse(src_txt: str) -> Node:
446     """Given a string with source, return the lib2to3 Node."""
447     grammar = pygram.python_grammar_no_print_statement
448     if src_txt[-1] != "\n":
449         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
450         src_txt += nl
451     for grammar in GRAMMARS:
452         drv = driver.Driver(grammar, pytree.convert)
453         try:
454             result = drv.parse_string(src_txt, True)
455             break
456
457         except ParseError as pe:
458             lineno, column = pe.context[1]
459             lines = src_txt.splitlines()
460             try:
461                 faulty_line = lines[lineno - 1]
462             except IndexError:
463                 faulty_line = "<line number missing in source>"
464             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
465     else:
466         raise exc from None
467
468     if isinstance(result, Leaf):
469         result = Node(syms.file_input, [result])
470     return result
471
472
473 def lib2to3_unparse(node: Node) -> str:
474     """Given a lib2to3 node, return its string representation."""
475     code = str(node)
476     return code
477
478
479 T = TypeVar("T")
480
481
482 class Visitor(Generic[T]):
483     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
484
485     def visit(self, node: LN) -> Iterator[T]:
486         """Main method to visit `node` and its children.
487
488         It tries to find a `visit_*()` method for the given `node.type`, like
489         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
490         If no dedicated `visit_*()` method is found, chooses `visit_default()`
491         instead.
492
493         Then yields objects of type `T` from the selected visitor.
494         """
495         if node.type < 256:
496             name = token.tok_name[node.type]
497         else:
498             name = type_repr(node.type)
499         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
500
501     def visit_default(self, node: LN) -> Iterator[T]:
502         """Default `visit_*()` implementation. Recurses to children of `node`."""
503         if isinstance(node, Node):
504             for child in node.children:
505                 yield from self.visit(child)
506
507
508 @dataclass
509 class DebugVisitor(Visitor[T]):
510     tree_depth: int = 0
511
512     def visit_default(self, node: LN) -> Iterator[T]:
513         indent = " " * (2 * self.tree_depth)
514         if isinstance(node, Node):
515             _type = type_repr(node.type)
516             out(f"{indent}{_type}", fg="yellow")
517             self.tree_depth += 1
518             for child in node.children:
519                 yield from self.visit(child)
520
521             self.tree_depth -= 1
522             out(f"{indent}/{_type}", fg="yellow", bold=False)
523         else:
524             _type = token.tok_name.get(node.type, str(node.type))
525             out(f"{indent}{_type}", fg="blue", nl=False)
526             if node.prefix:
527                 # We don't have to handle prefixes for `Node` objects since
528                 # that delegates to the first child anyway.
529                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
530             out(f" {node.value!r}", fg="blue", bold=False)
531
532     @classmethod
533     def show(cls, code: str) -> None:
534         """Pretty-print the lib2to3 AST of a given string of `code`.
535
536         Convenience method for debugging.
537         """
538         v: DebugVisitor[None] = DebugVisitor()
539         list(v.visit(lib2to3_parse(code)))
540
541
542 KEYWORDS = set(keyword.kwlist)
543 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
544 FLOW_CONTROL = {"return", "raise", "break", "continue"}
545 STATEMENT = {
546     syms.if_stmt,
547     syms.while_stmt,
548     syms.for_stmt,
549     syms.try_stmt,
550     syms.except_clause,
551     syms.with_stmt,
552     syms.funcdef,
553     syms.classdef,
554 }
555 STANDALONE_COMMENT = 153
556 LOGIC_OPERATORS = {"and", "or"}
557 COMPARATORS = {
558     token.LESS,
559     token.GREATER,
560     token.EQEQUAL,
561     token.NOTEQUAL,
562     token.LESSEQUAL,
563     token.GREATEREQUAL,
564 }
565 MATH_OPERATORS = {
566     token.VBAR,
567     token.CIRCUMFLEX,
568     token.AMPER,
569     token.LEFTSHIFT,
570     token.RIGHTSHIFT,
571     token.PLUS,
572     token.MINUS,
573     token.STAR,
574     token.SLASH,
575     token.DOUBLESLASH,
576     token.PERCENT,
577     token.AT,
578     token.TILDE,
579     token.DOUBLESTAR,
580 }
581 STARS = {token.STAR, token.DOUBLESTAR}
582 VARARGS_PARENTS = {
583     syms.arglist,
584     syms.argument,  # double star in arglist
585     syms.trailer,  # single argument to call
586     syms.typedargslist,
587     syms.varargslist,  # lambdas
588 }
589 UNPACKING_PARENTS = {
590     syms.atom,  # single element of a list or set literal
591     syms.dictsetmaker,
592     syms.listmaker,
593     syms.testlist_gexp,
594 }
595 TEST_DESCENDANTS = {
596     syms.test,
597     syms.lambdef,
598     syms.or_test,
599     syms.and_test,
600     syms.not_test,
601     syms.comparison,
602     syms.star_expr,
603     syms.expr,
604     syms.xor_expr,
605     syms.and_expr,
606     syms.shift_expr,
607     syms.arith_expr,
608     syms.trailer,
609     syms.term,
610     syms.power,
611 }
612 ASSIGNMENTS = {
613     "=",
614     "+=",
615     "-=",
616     "*=",
617     "@=",
618     "/=",
619     "%=",
620     "&=",
621     "|=",
622     "^=",
623     "<<=",
624     ">>=",
625     "**=",
626     "//=",
627 }
628 COMPREHENSION_PRIORITY = 20
629 COMMA_PRIORITY = 18
630 TERNARY_PRIORITY = 16
631 LOGIC_PRIORITY = 14
632 STRING_PRIORITY = 12
633 COMPARATOR_PRIORITY = 10
634 MATH_PRIORITIES = {
635     token.VBAR: 8,
636     token.CIRCUMFLEX: 7,
637     token.AMPER: 6,
638     token.LEFTSHIFT: 5,
639     token.RIGHTSHIFT: 5,
640     token.PLUS: 4,
641     token.MINUS: 4,
642     token.STAR: 3,
643     token.SLASH: 3,
644     token.DOUBLESLASH: 3,
645     token.PERCENT: 3,
646     token.AT: 3,
647     token.TILDE: 2,
648     token.DOUBLESTAR: 1,
649 }
650
651
652 @dataclass
653 class BracketTracker:
654     """Keeps track of brackets on a line."""
655
656     depth: int = 0
657     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
658     delimiters: Dict[LeafID, Priority] = Factory(dict)
659     previous: Optional[Leaf] = None
660     _for_loop_variable: int = 0
661     _lambda_arguments: int = 0
662
663     def mark(self, leaf: Leaf) -> None:
664         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
665
666         All leaves receive an int `bracket_depth` field that stores how deep
667         within brackets a given leaf is. 0 means there are no enclosing brackets
668         that started on this line.
669
670         If a leaf is itself a closing bracket, it receives an `opening_bracket`
671         field that it forms a pair with. This is a one-directional link to
672         avoid reference cycles.
673
674         If a leaf is a delimiter (a token on which Black can split the line if
675         needed) and it's on depth 0, its `id()` is stored in the tracker's
676         `delimiters` field.
677         """
678         if leaf.type == token.COMMENT:
679             return
680
681         self.maybe_decrement_after_for_loop_variable(leaf)
682         self.maybe_decrement_after_lambda_arguments(leaf)
683         if leaf.type in CLOSING_BRACKETS:
684             self.depth -= 1
685             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
686             leaf.opening_bracket = opening_bracket
687         leaf.bracket_depth = self.depth
688         if self.depth == 0:
689             delim = is_split_before_delimiter(leaf, self.previous)
690             if delim and self.previous is not None:
691                 self.delimiters[id(self.previous)] = delim
692             else:
693                 delim = is_split_after_delimiter(leaf, self.previous)
694                 if delim:
695                     self.delimiters[id(leaf)] = delim
696         if leaf.type in OPENING_BRACKETS:
697             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
698             self.depth += 1
699         self.previous = leaf
700         self.maybe_increment_lambda_arguments(leaf)
701         self.maybe_increment_for_loop_variable(leaf)
702
703     def any_open_brackets(self) -> bool:
704         """Return True if there is an yet unmatched open bracket on the line."""
705         return bool(self.bracket_match)
706
707     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
708         """Return the highest priority of a delimiter found on the line.
709
710         Values are consistent with what `is_split_*_delimiter()` return.
711         Raises ValueError on no delimiters.
712         """
713         return max(v for k, v in self.delimiters.items() if k not in exclude)
714
715     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
716         """In a for loop, or comprehension, the variables are often unpacks.
717
718         To avoid splitting on the comma in this situation, increase the depth of
719         tokens between `for` and `in`.
720         """
721         if leaf.type == token.NAME and leaf.value == "for":
722             self.depth += 1
723             self._for_loop_variable += 1
724             return True
725
726         return False
727
728     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
729         """See `maybe_increment_for_loop_variable` above for explanation."""
730         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
731             self.depth -= 1
732             self._for_loop_variable -= 1
733             return True
734
735         return False
736
737     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
738         """In a lambda expression, there might be more than one argument.
739
740         To avoid splitting on the comma in this situation, increase the depth of
741         tokens between `lambda` and `:`.
742         """
743         if leaf.type == token.NAME and leaf.value == "lambda":
744             self.depth += 1
745             self._lambda_arguments += 1
746             return True
747
748         return False
749
750     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
751         """See `maybe_increment_lambda_arguments` above for explanation."""
752         if self._lambda_arguments and leaf.type == token.COLON:
753             self.depth -= 1
754             self._lambda_arguments -= 1
755             return True
756
757         return False
758
759     def get_open_lsqb(self) -> Optional[Leaf]:
760         """Return the most recent opening square bracket (if any)."""
761         return self.bracket_match.get((self.depth - 1, token.RSQB))
762
763
764 @dataclass
765 class Line:
766     """Holds leaves and comments. Can be printed with `str(line)`."""
767
768     depth: int = 0
769     leaves: List[Leaf] = Factory(list)
770     comments: List[Tuple[Index, Leaf]] = Factory(list)
771     bracket_tracker: BracketTracker = Factory(BracketTracker)
772     inside_brackets: bool = False
773
774     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
775         """Add a new `leaf` to the end of the line.
776
777         Unless `preformatted` is True, the `leaf` will receive a new consistent
778         whitespace prefix and metadata applied by :class:`BracketTracker`.
779         Trailing commas are maybe removed, unpacked for loop variables are
780         demoted from being delimiters.
781
782         Inline comments are put aside.
783         """
784         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
785         if not has_value:
786             return
787
788         if token.COLON == leaf.type and self.is_class_paren_empty:
789             del self.leaves[-2:]
790         if self.leaves and not preformatted:
791             # Note: at this point leaf.prefix should be empty except for
792             # imports, for which we only preserve newlines.
793             leaf.prefix += whitespace(
794                 leaf, complex_subscript=self.is_complex_subscript(leaf)
795             )
796         if self.inside_brackets or not preformatted:
797             self.bracket_tracker.mark(leaf)
798             self.maybe_remove_trailing_comma(leaf)
799         if not self.append_comment(leaf):
800             self.leaves.append(leaf)
801
802     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
803         """Like :func:`append()` but disallow invalid standalone comment structure.
804
805         Raises ValueError when any `leaf` is appended after a standalone comment
806         or when a standalone comment is not the first leaf on the line.
807         """
808         if self.bracket_tracker.depth == 0:
809             if self.is_comment:
810                 raise ValueError("cannot append to standalone comments")
811
812             if self.leaves and leaf.type == STANDALONE_COMMENT:
813                 raise ValueError(
814                     "cannot append standalone comments to a populated line"
815                 )
816
817         self.append(leaf, preformatted=preformatted)
818
819     @property
820     def is_comment(self) -> bool:
821         """Is this line a standalone comment?"""
822         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
823
824     @property
825     def is_decorator(self) -> bool:
826         """Is this line a decorator?"""
827         return bool(self) and self.leaves[0].type == token.AT
828
829     @property
830     def is_import(self) -> bool:
831         """Is this an import line?"""
832         return bool(self) and is_import(self.leaves[0])
833
834     @property
835     def is_class(self) -> bool:
836         """Is this line a class definition?"""
837         return (
838             bool(self)
839             and self.leaves[0].type == token.NAME
840             and self.leaves[0].value == "class"
841         )
842
843     @property
844     def is_trivial_class(self) -> bool:
845         """Is this line a class definition with a body consisting only of "..."?"""
846         return (
847             self.is_class
848             and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)]
849         )
850
851     @property
852     def is_def(self) -> bool:
853         """Is this a function definition? (Also returns True for async defs.)"""
854         try:
855             first_leaf = self.leaves[0]
856         except IndexError:
857             return False
858
859         try:
860             second_leaf: Optional[Leaf] = self.leaves[1]
861         except IndexError:
862             second_leaf = None
863         return (
864             (first_leaf.type == token.NAME and first_leaf.value == "def")
865             or (
866                 first_leaf.type == token.ASYNC
867                 and second_leaf is not None
868                 and second_leaf.type == token.NAME
869                 and second_leaf.value == "def"
870             )
871         )
872
873     @property
874     def is_flow_control(self) -> bool:
875         """Is this line a flow control statement?
876
877         Those are `return`, `raise`, `break`, and `continue`.
878         """
879         return (
880             bool(self)
881             and self.leaves[0].type == token.NAME
882             and self.leaves[0].value in FLOW_CONTROL
883         )
884
885     @property
886     def is_yield(self) -> bool:
887         """Is this line a yield statement?"""
888         return (
889             bool(self)
890             and self.leaves[0].type == token.NAME
891             and self.leaves[0].value == "yield"
892         )
893
894     @property
895     def is_class_paren_empty(self) -> bool:
896         """Is this a class with no base classes but using parentheses?
897
898         Those are unnecessary and should be removed.
899         """
900         return (
901             bool(self)
902             and len(self.leaves) == 4
903             and self.is_class
904             and self.leaves[2].type == token.LPAR
905             and self.leaves[2].value == "("
906             and self.leaves[3].type == token.RPAR
907             and self.leaves[3].value == ")"
908         )
909
910     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
911         """If so, needs to be split before emitting."""
912         for leaf in self.leaves:
913             if leaf.type == STANDALONE_COMMENT:
914                 if leaf.bracket_depth <= depth_limit:
915                     return True
916
917         return False
918
919     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
920         """Remove trailing comma if there is one and it's safe."""
921         if not (
922             self.leaves
923             and self.leaves[-1].type == token.COMMA
924             and closing.type in CLOSING_BRACKETS
925         ):
926             return False
927
928         if closing.type == token.RBRACE:
929             self.remove_trailing_comma()
930             return True
931
932         if closing.type == token.RSQB:
933             comma = self.leaves[-1]
934             if comma.parent and comma.parent.type == syms.listmaker:
935                 self.remove_trailing_comma()
936                 return True
937
938         # For parens let's check if it's safe to remove the comma.
939         # Imports are always safe.
940         if self.is_import:
941             self.remove_trailing_comma()
942             return True
943
944         # Otheriwsse, if the trailing one is the only one, we might mistakenly
945         # change a tuple into a different type by removing the comma.
946         depth = closing.bracket_depth + 1
947         commas = 0
948         opening = closing.opening_bracket
949         for _opening_index, leaf in enumerate(self.leaves):
950             if leaf is opening:
951                 break
952
953         else:
954             return False
955
956         for leaf in self.leaves[_opening_index + 1 :]:
957             if leaf is closing:
958                 break
959
960             bracket_depth = leaf.bracket_depth
961             if bracket_depth == depth and leaf.type == token.COMMA:
962                 commas += 1
963                 if leaf.parent and leaf.parent.type == syms.arglist:
964                     commas += 1
965                     break
966
967         if commas > 1:
968             self.remove_trailing_comma()
969             return True
970
971         return False
972
973     def append_comment(self, comment: Leaf) -> bool:
974         """Add an inline or standalone comment to the line."""
975         if (
976             comment.type == STANDALONE_COMMENT
977             and self.bracket_tracker.any_open_brackets()
978         ):
979             comment.prefix = ""
980             return False
981
982         if comment.type != token.COMMENT:
983             return False
984
985         after = len(self.leaves) - 1
986         if after == -1:
987             comment.type = STANDALONE_COMMENT
988             comment.prefix = ""
989             return False
990
991         else:
992             self.comments.append((after, comment))
993             return True
994
995     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
996         """Generate comments that should appear directly after `leaf`.
997
998         Provide a non-negative leaf `_index` to speed up the function.
999         """
1000         if _index == -1:
1001             for _index, _leaf in enumerate(self.leaves):
1002                 if leaf is _leaf:
1003                     break
1004
1005             else:
1006                 return
1007
1008         for index, comment_after in self.comments:
1009             if _index == index:
1010                 yield comment_after
1011
1012     def remove_trailing_comma(self) -> None:
1013         """Remove the trailing comma and moves the comments attached to it."""
1014         comma_index = len(self.leaves) - 1
1015         for i in range(len(self.comments)):
1016             comment_index, comment = self.comments[i]
1017             if comment_index == comma_index:
1018                 self.comments[i] = (comma_index - 1, comment)
1019         self.leaves.pop()
1020
1021     def is_complex_subscript(self, leaf: Leaf) -> bool:
1022         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1023         open_lsqb = (
1024             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1025         )
1026         if open_lsqb is None:
1027             return False
1028
1029         subscript_start = open_lsqb.next_sibling
1030         if (
1031             isinstance(subscript_start, Node)
1032             and subscript_start.type == syms.subscriptlist
1033         ):
1034             subscript_start = child_towards(subscript_start, leaf)
1035         return (
1036             subscript_start is not None
1037             and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
1038         )
1039
1040     def __str__(self) -> str:
1041         """Render the line."""
1042         if not self:
1043             return "\n"
1044
1045         indent = "    " * self.depth
1046         leaves = iter(self.leaves)
1047         first = next(leaves)
1048         res = f"{first.prefix}{indent}{first.value}"
1049         for leaf in leaves:
1050             res += str(leaf)
1051         for _, comment in self.comments:
1052             res += str(comment)
1053         return res + "\n"
1054
1055     def __bool__(self) -> bool:
1056         """Return True if the line has leaves or comments."""
1057         return bool(self.leaves or self.comments)
1058
1059
1060 class UnformattedLines(Line):
1061     """Just like :class:`Line` but stores lines which aren't reformatted."""
1062
1063     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1064         """Just add a new `leaf` to the end of the lines.
1065
1066         The `preformatted` argument is ignored.
1067
1068         Keeps track of indentation `depth`, which is useful when the user
1069         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1070         """
1071         try:
1072             list(generate_comments(leaf))
1073         except FormatOn as f_on:
1074             self.leaves.append(f_on.leaf_from_consumed(leaf))
1075             raise
1076
1077         self.leaves.append(leaf)
1078         if leaf.type == token.INDENT:
1079             self.depth += 1
1080         elif leaf.type == token.DEDENT:
1081             self.depth -= 1
1082
1083     def __str__(self) -> str:
1084         """Render unformatted lines from leaves which were added with `append()`.
1085
1086         `depth` is not used for indentation in this case.
1087         """
1088         if not self:
1089             return "\n"
1090
1091         res = ""
1092         for leaf in self.leaves:
1093             res += str(leaf)
1094         return res
1095
1096     def append_comment(self, comment: Leaf) -> bool:
1097         """Not implemented in this class. Raises `NotImplementedError`."""
1098         raise NotImplementedError("Unformatted lines don't store comments separately.")
1099
1100     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1101         """Does nothing and returns False."""
1102         return False
1103
1104     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1105         """Does nothing and returns False."""
1106         return False
1107
1108
1109 @dataclass
1110 class EmptyLineTracker:
1111     """Provides a stateful method that returns the number of potential extra
1112     empty lines needed before and after the currently processed line.
1113
1114     Note: this tracker works on lines that haven't been split yet.  It assumes
1115     the prefix of the first leaf consists of optional newlines.  Those newlines
1116     are consumed by `maybe_empty_lines()` and included in the computation.
1117     """
1118     is_pyi: bool = False
1119     previous_line: Optional[Line] = None
1120     previous_after: int = 0
1121     previous_defs: List[int] = Factory(list)
1122
1123     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1124         """Return the number of extra empty lines before and after the `current_line`.
1125
1126         This is for separating `def`, `async def` and `class` with extra empty
1127         lines (two on module-level), as well as providing an extra empty line
1128         after flow control keywords to make them more prominent.
1129         """
1130         if isinstance(current_line, UnformattedLines):
1131             return 0, 0
1132
1133         before, after = self._maybe_empty_lines(current_line)
1134         before -= self.previous_after
1135         self.previous_after = after
1136         self.previous_line = current_line
1137         return before, after
1138
1139     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1140         max_allowed = 1
1141         if current_line.depth == 0:
1142             max_allowed = 1 if self.is_pyi else 2
1143         if current_line.leaves:
1144             # Consume the first leaf's extra newlines.
1145             first_leaf = current_line.leaves[0]
1146             before = first_leaf.prefix.count("\n")
1147             before = min(before, max_allowed)
1148             first_leaf.prefix = ""
1149         else:
1150             before = 0
1151         depth = current_line.depth
1152         while self.previous_defs and self.previous_defs[-1] >= depth:
1153             self.previous_defs.pop()
1154             if self.is_pyi:
1155                 before = 0 if depth else 1
1156             else:
1157                 before = 1 if depth else 2
1158         is_decorator = current_line.is_decorator
1159         if is_decorator or current_line.is_def or current_line.is_class:
1160             if not is_decorator:
1161                 self.previous_defs.append(depth)
1162             if self.previous_line is None:
1163                 # Don't insert empty lines before the first line in the file.
1164                 return 0, 0
1165
1166             if self.previous_line.is_decorator:
1167                 return 0, 0
1168
1169             if (
1170                 self.previous_line.is_comment
1171                 and self.previous_line.depth == current_line.depth
1172                 and before == 0
1173             ):
1174                 return 0, 0
1175
1176             if self.is_pyi:
1177                 if self.previous_line.depth > current_line.depth:
1178                     newlines = 1
1179                 elif current_line.is_class or self.previous_line.is_class:
1180                     if (
1181                         current_line.is_trivial_class
1182                         and self.previous_line.is_trivial_class
1183                     ):
1184                         newlines = 0
1185                     else:
1186                         newlines = 1
1187                 else:
1188                     newlines = 0
1189             else:
1190                 newlines = 2
1191             if current_line.depth and newlines:
1192                 newlines -= 1
1193             return newlines, 0
1194
1195         if (
1196             self.previous_line
1197             and self.previous_line.is_import
1198             and not current_line.is_import
1199             and depth == self.previous_line.depth
1200         ):
1201             return (before or 1), 0
1202
1203         return before, 0
1204
1205
1206 @dataclass
1207 class LineGenerator(Visitor[Line]):
1208     """Generates reformatted Line objects.  Empty lines are not emitted.
1209
1210     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1211     in ways that will no longer stringify to valid Python code on the tree.
1212     """
1213     is_pyi: bool = False
1214     current_line: Line = Factory(Line)
1215     remove_u_prefix: bool = False
1216
1217     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1218         """Generate a line.
1219
1220         If the line is empty, only emit if it makes sense.
1221         If the line is too long, split it first and then generate.
1222
1223         If any lines were generated, set up a new current_line.
1224         """
1225         if not self.current_line:
1226             if self.current_line.__class__ == type:
1227                 self.current_line.depth += indent
1228             else:
1229                 self.current_line = type(depth=self.current_line.depth + indent)
1230             return  # Line is empty, don't emit. Creating a new one unnecessary.
1231
1232         complete_line = self.current_line
1233         self.current_line = type(depth=complete_line.depth + indent)
1234         yield complete_line
1235
1236     def visit(self, node: LN) -> Iterator[Line]:
1237         """Main method to visit `node` and its children.
1238
1239         Yields :class:`Line` objects.
1240         """
1241         if isinstance(self.current_line, UnformattedLines):
1242             # File contained `# fmt: off`
1243             yield from self.visit_unformatted(node)
1244
1245         else:
1246             yield from super().visit(node)
1247
1248     def visit_default(self, node: LN) -> Iterator[Line]:
1249         """Default `visit_*()` implementation. Recurses to children of `node`."""
1250         if isinstance(node, Leaf):
1251             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1252             try:
1253                 for comment in generate_comments(node):
1254                     if any_open_brackets:
1255                         # any comment within brackets is subject to splitting
1256                         self.current_line.append(comment)
1257                     elif comment.type == token.COMMENT:
1258                         # regular trailing comment
1259                         self.current_line.append(comment)
1260                         yield from self.line()
1261
1262                     else:
1263                         # regular standalone comment
1264                         yield from self.line()
1265
1266                         self.current_line.append(comment)
1267                         yield from self.line()
1268
1269             except FormatOff as f_off:
1270                 f_off.trim_prefix(node)
1271                 yield from self.line(type=UnformattedLines)
1272                 yield from self.visit(node)
1273
1274             except FormatOn as f_on:
1275                 # This only happens here if somebody says "fmt: on" multiple
1276                 # times in a row.
1277                 f_on.trim_prefix(node)
1278                 yield from self.visit_default(node)
1279
1280             else:
1281                 normalize_prefix(node, inside_brackets=any_open_brackets)
1282                 if node.type == token.STRING:
1283                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1284                     normalize_string_quotes(node)
1285                 if node.type not in WHITESPACE:
1286                     self.current_line.append(node)
1287         yield from super().visit_default(node)
1288
1289     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1290         """Increase indentation level, maybe yield a line."""
1291         # In blib2to3 INDENT never holds comments.
1292         yield from self.line(+1)
1293         yield from self.visit_default(node)
1294
1295     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1296         """Decrease indentation level, maybe yield a line."""
1297         # The current line might still wait for trailing comments.  At DEDENT time
1298         # there won't be any (they would be prefixes on the preceding NEWLINE).
1299         # Emit the line then.
1300         yield from self.line()
1301
1302         # While DEDENT has no value, its prefix may contain standalone comments
1303         # that belong to the current indentation level.  Get 'em.
1304         yield from self.visit_default(node)
1305
1306         # Finally, emit the dedent.
1307         yield from self.line(-1)
1308
1309     def visit_stmt(
1310         self, node: Node, keywords: Set[str], parens: Set[str]
1311     ) -> Iterator[Line]:
1312         """Visit a statement.
1313
1314         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1315         `def`, `with`, `class`, `assert` and assignments.
1316
1317         The relevant Python language `keywords` for a given statement will be
1318         NAME leaves within it. This methods puts those on a separate line.
1319
1320         `parens` holds a set of string leaf values immeditely after which
1321         invisible parens should be put.
1322         """
1323         normalize_invisible_parens(node, parens_after=parens)
1324         for child in node.children:
1325             if child.type == token.NAME and child.value in keywords:  # type: ignore
1326                 yield from self.line()
1327
1328             yield from self.visit(child)
1329
1330     def visit_suite(self, node: Node) -> Iterator[Line]:
1331         """Visit a suite."""
1332         if self.is_pyi and self.is_trivial_suite(node):
1333             yield from self.visit(node.children[2])
1334         else:
1335             yield from self.visit_default(node)
1336
1337     def is_trivial_suite(self, node: Node) -> bool:
1338         if len(node.children) != 4:
1339             return False
1340         if (
1341             not isinstance(node.children[0], Leaf)
1342             or node.children[0].type != token.NEWLINE
1343         ):
1344             return False
1345         if (
1346             not isinstance(node.children[1], Leaf)
1347             or node.children[1].type != token.INDENT
1348         ):
1349             return False
1350         if (
1351             not isinstance(node.children[3], Leaf)
1352             or node.children[3].type != token.DEDENT
1353         ):
1354             return False
1355         stmt = node.children[2]
1356         if not isinstance(stmt, Node):
1357             return False
1358         return self.is_trivial_body(stmt)
1359
1360     def is_trivial_body(self, stmt: Node) -> bool:
1361         if not isinstance(stmt, Node) or stmt.type != syms.simple_stmt:
1362             return False
1363         if len(stmt.children) != 2:
1364             return False
1365         child = stmt.children[0]
1366         return (
1367             child.type == syms.atom
1368             and len(child.children) == 3
1369             and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
1370         )
1371
1372     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1373         """Visit a statement without nested statements."""
1374         is_suite_like = node.parent and node.parent.type in STATEMENT
1375         if is_suite_like:
1376             if self.is_pyi and self.is_trivial_body(node):
1377                 yield from self.visit_default(node)
1378             else:
1379                 yield from self.line(+1)
1380                 yield from self.visit_default(node)
1381                 yield from self.line(-1)
1382
1383         else:
1384             if (
1385                 not self.is_pyi
1386                 or not node.parent
1387                 or not self.is_trivial_suite(node.parent)
1388             ):
1389                 yield from self.line()
1390             yield from self.visit_default(node)
1391
1392     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1393         """Visit `async def`, `async for`, `async with`."""
1394         yield from self.line()
1395
1396         children = iter(node.children)
1397         for child in children:
1398             yield from self.visit(child)
1399
1400             if child.type == token.ASYNC:
1401                 break
1402
1403         internal_stmt = next(children)
1404         for child in internal_stmt.children:
1405             yield from self.visit(child)
1406
1407     def visit_decorators(self, node: Node) -> Iterator[Line]:
1408         """Visit decorators."""
1409         for child in node.children:
1410             yield from self.line()
1411             yield from self.visit(child)
1412
1413     def visit_import_from(self, node: Node) -> Iterator[Line]:
1414         """Visit import_from and maybe put invisible parentheses.
1415
1416         This is separate from `visit_stmt` because import statements don't
1417         support arbitrary atoms and thus handling of parentheses is custom.
1418         """
1419         check_lpar = False
1420         for index, child in enumerate(node.children):
1421             if check_lpar:
1422                 if child.type == token.LPAR:
1423                     # make parentheses invisible
1424                     child.value = ""  # type: ignore
1425                     node.children[-1].value = ""  # type: ignore
1426                 else:
1427                     # insert invisible parentheses
1428                     node.insert_child(index, Leaf(token.LPAR, ""))
1429                     node.append_child(Leaf(token.RPAR, ""))
1430                 break
1431
1432             check_lpar = (
1433                 child.type == token.NAME and child.value == "import"  # type: ignore
1434             )
1435
1436         for child in node.children:
1437             yield from self.visit(child)
1438
1439     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1440         """Remove a semicolon and put the other statement on a separate line."""
1441         yield from self.line()
1442
1443     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1444         """End of file. Process outstanding comments and end with a newline."""
1445         yield from self.visit_default(leaf)
1446         yield from self.line()
1447
1448     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1449         """Used when file contained a `# fmt: off`."""
1450         if isinstance(node, Node):
1451             for child in node.children:
1452                 yield from self.visit(child)
1453
1454         else:
1455             try:
1456                 self.current_line.append(node)
1457             except FormatOn as f_on:
1458                 f_on.trim_prefix(node)
1459                 yield from self.line()
1460                 yield from self.visit(node)
1461
1462             if node.type == token.ENDMARKER:
1463                 # somebody decided not to put a final `# fmt: on`
1464                 yield from self.line()
1465
1466     def __attrs_post_init__(self) -> None:
1467         """You are in a twisty little maze of passages."""
1468         v = self.visit_stmt
1469         Ø: Set[str] = set()
1470         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1471         self.visit_if_stmt = partial(
1472             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1473         )
1474         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1475         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1476         self.visit_try_stmt = partial(
1477             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1478         )
1479         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1480         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1481         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1482         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1483         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1484         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1485         self.visit_async_funcdef = self.visit_async_stmt
1486         self.visit_decorated = self.visit_decorators
1487
1488
1489 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1490 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1491 OPENING_BRACKETS = set(BRACKET.keys())
1492 CLOSING_BRACKETS = set(BRACKET.values())
1493 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1494 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1495
1496
1497 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1498     """Return whitespace prefix if needed for the given `leaf`.
1499
1500     `complex_subscript` signals whether the given leaf is part of a subscription
1501     which has non-trivial arguments, like arithmetic expressions or function calls.
1502     """
1503     NO = ""
1504     SPACE = " "
1505     DOUBLESPACE = "  "
1506     t = leaf.type
1507     p = leaf.parent
1508     v = leaf.value
1509     if t in ALWAYS_NO_SPACE:
1510         return NO
1511
1512     if t == token.COMMENT:
1513         return DOUBLESPACE
1514
1515     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1516     if (
1517         t == token.COLON
1518         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1519     ):
1520         return NO
1521
1522     prev = leaf.prev_sibling
1523     if not prev:
1524         prevp = preceding_leaf(p)
1525         if not prevp or prevp.type in OPENING_BRACKETS:
1526             return NO
1527
1528         if t == token.COLON:
1529             if prevp.type == token.COLON:
1530                 return NO
1531
1532             elif prevp.type != token.COMMA and not complex_subscript:
1533                 return NO
1534
1535             return SPACE
1536
1537         if prevp.type == token.EQUAL:
1538             if prevp.parent:
1539                 if prevp.parent.type in {
1540                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1541                 }:
1542                     return NO
1543
1544                 elif prevp.parent.type == syms.typedargslist:
1545                     # A bit hacky: if the equal sign has whitespace, it means we
1546                     # previously found it's a typed argument.  So, we're using
1547                     # that, too.
1548                     return prevp.prefix
1549
1550         elif prevp.type in STARS:
1551             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1552                 return NO
1553
1554         elif prevp.type == token.COLON:
1555             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1556                 return SPACE if complex_subscript else NO
1557
1558         elif (
1559             prevp.parent
1560             and prevp.parent.type == syms.factor
1561             and prevp.type in MATH_OPERATORS
1562         ):
1563             return NO
1564
1565         elif (
1566             prevp.type == token.RIGHTSHIFT
1567             and prevp.parent
1568             and prevp.parent.type == syms.shift_expr
1569             and prevp.prev_sibling
1570             and prevp.prev_sibling.type == token.NAME
1571             and prevp.prev_sibling.value == "print"  # type: ignore
1572         ):
1573             # Python 2 print chevron
1574             return NO
1575
1576     elif prev.type in OPENING_BRACKETS:
1577         return NO
1578
1579     if p.type in {syms.parameters, syms.arglist}:
1580         # untyped function signatures or calls
1581         if not prev or prev.type != token.COMMA:
1582             return NO
1583
1584     elif p.type == syms.varargslist:
1585         # lambdas
1586         if prev and prev.type != token.COMMA:
1587             return NO
1588
1589     elif p.type == syms.typedargslist:
1590         # typed function signatures
1591         if not prev:
1592             return NO
1593
1594         if t == token.EQUAL:
1595             if prev.type != syms.tname:
1596                 return NO
1597
1598         elif prev.type == token.EQUAL:
1599             # A bit hacky: if the equal sign has whitespace, it means we
1600             # previously found it's a typed argument.  So, we're using that, too.
1601             return prev.prefix
1602
1603         elif prev.type != token.COMMA:
1604             return NO
1605
1606     elif p.type == syms.tname:
1607         # type names
1608         if not prev:
1609             prevp = preceding_leaf(p)
1610             if not prevp or prevp.type != token.COMMA:
1611                 return NO
1612
1613     elif p.type == syms.trailer:
1614         # attributes and calls
1615         if t == token.LPAR or t == token.RPAR:
1616             return NO
1617
1618         if not prev:
1619             if t == token.DOT:
1620                 prevp = preceding_leaf(p)
1621                 if not prevp or prevp.type != token.NUMBER:
1622                     return NO
1623
1624             elif t == token.LSQB:
1625                 return NO
1626
1627         elif prev.type != token.COMMA:
1628             return NO
1629
1630     elif p.type == syms.argument:
1631         # single argument
1632         if t == token.EQUAL:
1633             return NO
1634
1635         if not prev:
1636             prevp = preceding_leaf(p)
1637             if not prevp or prevp.type == token.LPAR:
1638                 return NO
1639
1640         elif prev.type in {token.EQUAL} | STARS:
1641             return NO
1642
1643     elif p.type == syms.decorator:
1644         # decorators
1645         return NO
1646
1647     elif p.type == syms.dotted_name:
1648         if prev:
1649             return NO
1650
1651         prevp = preceding_leaf(p)
1652         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1653             return NO
1654
1655     elif p.type == syms.classdef:
1656         if t == token.LPAR:
1657             return NO
1658
1659         if prev and prev.type == token.LPAR:
1660             return NO
1661
1662     elif p.type in {syms.subscript, syms.sliceop}:
1663         # indexing
1664         if not prev:
1665             assert p.parent is not None, "subscripts are always parented"
1666             if p.parent.type == syms.subscriptlist:
1667                 return SPACE
1668
1669             return NO
1670
1671         elif not complex_subscript:
1672             return NO
1673
1674     elif p.type == syms.atom:
1675         if prev and t == token.DOT:
1676             # dots, but not the first one.
1677             return NO
1678
1679     elif p.type == syms.dictsetmaker:
1680         # dict unpacking
1681         if prev and prev.type == token.DOUBLESTAR:
1682             return NO
1683
1684     elif p.type in {syms.factor, syms.star_expr}:
1685         # unary ops
1686         if not prev:
1687             prevp = preceding_leaf(p)
1688             if not prevp or prevp.type in OPENING_BRACKETS:
1689                 return NO
1690
1691             prevp_parent = prevp.parent
1692             assert prevp_parent is not None
1693             if (
1694                 prevp.type == token.COLON
1695                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1696             ):
1697                 return NO
1698
1699             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1700                 return NO
1701
1702         elif t == token.NAME or t == token.NUMBER:
1703             return NO
1704
1705     elif p.type == syms.import_from:
1706         if t == token.DOT:
1707             if prev and prev.type == token.DOT:
1708                 return NO
1709
1710         elif t == token.NAME:
1711             if v == "import":
1712                 return SPACE
1713
1714             if prev and prev.type == token.DOT:
1715                 return NO
1716
1717     elif p.type == syms.sliceop:
1718         return NO
1719
1720     return SPACE
1721
1722
1723 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1724     """Return the first leaf that precedes `node`, if any."""
1725     while node:
1726         res = node.prev_sibling
1727         if res:
1728             if isinstance(res, Leaf):
1729                 return res
1730
1731             try:
1732                 return list(res.leaves())[-1]
1733
1734             except IndexError:
1735                 return None
1736
1737         node = node.parent
1738     return None
1739
1740
1741 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1742     """Return the child of `ancestor` that contains `descendant`."""
1743     node: Optional[LN] = descendant
1744     while node and node.parent != ancestor:
1745         node = node.parent
1746     return node
1747
1748
1749 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1750     """Return the priority of the `leaf` delimiter, given a line break after it.
1751
1752     The delimiter priorities returned here are from those delimiters that would
1753     cause a line break after themselves.
1754
1755     Higher numbers are higher priority.
1756     """
1757     if leaf.type == token.COMMA:
1758         return COMMA_PRIORITY
1759
1760     return 0
1761
1762
1763 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1764     """Return the priority of the `leaf` delimiter, given a line before after it.
1765
1766     The delimiter priorities returned here are from those delimiters that would
1767     cause a line break before themselves.
1768
1769     Higher numbers are higher priority.
1770     """
1771     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1772         # * and ** might also be MATH_OPERATORS but in this case they are not.
1773         # Don't treat them as a delimiter.
1774         return 0
1775
1776     if (
1777         leaf.type in MATH_OPERATORS
1778         and leaf.parent
1779         and leaf.parent.type not in {syms.factor, syms.star_expr}
1780     ):
1781         return MATH_PRIORITIES[leaf.type]
1782
1783     if leaf.type in COMPARATORS:
1784         return COMPARATOR_PRIORITY
1785
1786     if (
1787         leaf.type == token.STRING
1788         and previous is not None
1789         and previous.type == token.STRING
1790     ):
1791         return STRING_PRIORITY
1792
1793     if (
1794         leaf.type == token.NAME
1795         and leaf.value == "for"
1796         and leaf.parent
1797         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1798     ):
1799         return COMPREHENSION_PRIORITY
1800
1801     if (
1802         leaf.type == token.NAME
1803         and leaf.value == "if"
1804         and leaf.parent
1805         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1806     ):
1807         return COMPREHENSION_PRIORITY
1808
1809     if (
1810         leaf.type == token.NAME
1811         and leaf.value in {"if", "else"}
1812         and leaf.parent
1813         and leaf.parent.type == syms.test
1814     ):
1815         return TERNARY_PRIORITY
1816
1817     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1818         return LOGIC_PRIORITY
1819
1820     return 0
1821
1822
1823 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1824     """Clean the prefix of the `leaf` and generate comments from it, if any.
1825
1826     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1827     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1828     move because it does away with modifying the grammar to include all the
1829     possible places in which comments can be placed.
1830
1831     The sad consequence for us though is that comments don't "belong" anywhere.
1832     This is why this function generates simple parentless Leaf objects for
1833     comments.  We simply don't know what the correct parent should be.
1834
1835     No matter though, we can live without this.  We really only need to
1836     differentiate between inline and standalone comments.  The latter don't
1837     share the line with any code.
1838
1839     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1840     are emitted with a fake STANDALONE_COMMENT token identifier.
1841     """
1842     p = leaf.prefix
1843     if not p:
1844         return
1845
1846     if "#" not in p:
1847         return
1848
1849     consumed = 0
1850     nlines = 0
1851     for index, line in enumerate(p.split("\n")):
1852         consumed += len(line) + 1  # adding the length of the split '\n'
1853         line = line.lstrip()
1854         if not line:
1855             nlines += 1
1856         if not line.startswith("#"):
1857             continue
1858
1859         if index == 0 and leaf.type != token.ENDMARKER:
1860             comment_type = token.COMMENT  # simple trailing comment
1861         else:
1862             comment_type = STANDALONE_COMMENT
1863         comment = make_comment(line)
1864         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1865
1866         if comment in {"# fmt: on", "# yapf: enable"}:
1867             raise FormatOn(consumed)
1868
1869         if comment in {"# fmt: off", "# yapf: disable"}:
1870             if comment_type == STANDALONE_COMMENT:
1871                 raise FormatOff(consumed)
1872
1873             prev = preceding_leaf(leaf)
1874             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1875                 raise FormatOff(consumed)
1876
1877         nlines = 0
1878
1879
1880 def make_comment(content: str) -> str:
1881     """Return a consistently formatted comment from the given `content` string.
1882
1883     All comments (except for "##", "#!", "#:") should have a single space between
1884     the hash sign and the content.
1885
1886     If `content` didn't start with a hash sign, one is provided.
1887     """
1888     content = content.rstrip()
1889     if not content:
1890         return "#"
1891
1892     if content[0] == "#":
1893         content = content[1:]
1894     if content and content[0] not in " !:#":
1895         content = " " + content
1896     return "#" + content
1897
1898
1899 def split_line(
1900     line: Line, line_length: int, inner: bool = False, py36: bool = False
1901 ) -> Iterator[Line]:
1902     """Split a `line` into potentially many lines.
1903
1904     They should fit in the allotted `line_length` but might not be able to.
1905     `inner` signifies that there were a pair of brackets somewhere around the
1906     current `line`, possibly transitively. This means we can fallback to splitting
1907     by delimiters if the LHS/RHS don't yield any results.
1908
1909     If `py36` is True, splitting may generate syntax that is only compatible
1910     with Python 3.6 and later.
1911     """
1912     if isinstance(line, UnformattedLines) or line.is_comment:
1913         yield line
1914         return
1915
1916     line_str = str(line).strip("\n")
1917     if is_line_short_enough(line, line_length=line_length, line_str=line_str):
1918         yield line
1919         return
1920
1921     split_funcs: List[SplitFunc]
1922     if line.is_def:
1923         split_funcs = [left_hand_split]
1924     elif line.is_import:
1925         split_funcs = [explode_split]
1926     else:
1927
1928         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1929             for omit in generate_trailers_to_omit(line, line_length):
1930                 lines = list(right_hand_split(line, py36, omit=omit))
1931                 if is_line_short_enough(lines[0], line_length=line_length):
1932                     yield from lines
1933                     return
1934
1935             # All splits failed, best effort split with no omits.
1936             yield from right_hand_split(line, py36)
1937
1938         if line.inside_brackets:
1939             split_funcs = [delimiter_split, standalone_comment_split, rhs]
1940         else:
1941             split_funcs = [rhs]
1942     for split_func in split_funcs:
1943         # We are accumulating lines in `result` because we might want to abort
1944         # mission and return the original line in the end, or attempt a different
1945         # split altogether.
1946         result: List[Line] = []
1947         try:
1948             for l in split_func(line, py36):
1949                 if str(l).strip("\n") == line_str:
1950                     raise CannotSplit("Split function returned an unchanged result")
1951
1952                 result.extend(
1953                     split_line(l, line_length=line_length, inner=True, py36=py36)
1954                 )
1955         except CannotSplit as cs:
1956             continue
1957
1958         else:
1959             yield from result
1960             break
1961
1962     else:
1963         yield line
1964
1965
1966 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1967     """Split line into many lines, starting with the first matching bracket pair.
1968
1969     Note: this usually looks weird, only use this for function definitions.
1970     Prefer RHS otherwise.  This is why this function is not symmetrical with
1971     :func:`right_hand_split` which also handles optional parentheses.
1972     """
1973     head = Line(depth=line.depth)
1974     body = Line(depth=line.depth + 1, inside_brackets=True)
1975     tail = Line(depth=line.depth)
1976     tail_leaves: List[Leaf] = []
1977     body_leaves: List[Leaf] = []
1978     head_leaves: List[Leaf] = []
1979     current_leaves = head_leaves
1980     matching_bracket = None
1981     for leaf in line.leaves:
1982         if (
1983             current_leaves is body_leaves
1984             and leaf.type in CLOSING_BRACKETS
1985             and leaf.opening_bracket is matching_bracket
1986         ):
1987             current_leaves = tail_leaves if body_leaves else head_leaves
1988         current_leaves.append(leaf)
1989         if current_leaves is head_leaves:
1990             if leaf.type in OPENING_BRACKETS:
1991                 matching_bracket = leaf
1992                 current_leaves = body_leaves
1993     # Since body is a new indent level, remove spurious leading whitespace.
1994     if body_leaves:
1995         normalize_prefix(body_leaves[0], inside_brackets=True)
1996     # Build the new lines.
1997     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1998         for leaf in leaves:
1999             result.append(leaf, preformatted=True)
2000             for comment_after in line.comments_after(leaf):
2001                 result.append(comment_after, preformatted=True)
2002     bracket_split_succeeded_or_raise(head, body, tail)
2003     for result in (head, body, tail):
2004         if result:
2005             yield result
2006
2007
2008 def right_hand_split(
2009     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2010 ) -> Iterator[Line]:
2011     """Split line into many lines, starting with the last matching bracket pair.
2012
2013     If the split was by optional parentheses, attempt splitting without them, too.
2014     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2015     this split.
2016     """
2017     head = Line(depth=line.depth)
2018     body = Line(depth=line.depth + 1, inside_brackets=True)
2019     tail = Line(depth=line.depth)
2020     tail_leaves: List[Leaf] = []
2021     body_leaves: List[Leaf] = []
2022     head_leaves: List[Leaf] = []
2023     current_leaves = tail_leaves
2024     opening_bracket = None
2025     closing_bracket = None
2026     for leaf in reversed(line.leaves):
2027         if current_leaves is body_leaves:
2028             if leaf is opening_bracket:
2029                 current_leaves = head_leaves if body_leaves else tail_leaves
2030         current_leaves.append(leaf)
2031         if current_leaves is tail_leaves:
2032             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2033                 opening_bracket = leaf.opening_bracket
2034                 closing_bracket = leaf
2035                 current_leaves = body_leaves
2036     tail_leaves.reverse()
2037     body_leaves.reverse()
2038     head_leaves.reverse()
2039     # Since body is a new indent level, remove spurious leading whitespace.
2040     if body_leaves:
2041         normalize_prefix(body_leaves[0], inside_brackets=True)
2042     elif not head_leaves:
2043         # No `head` and no `body` means the split failed. `tail` has all content.
2044         raise CannotSplit("No brackets found")
2045
2046     # Build the new lines.
2047     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2048         for leaf in leaves:
2049             result.append(leaf, preformatted=True)
2050             for comment_after in line.comments_after(leaf):
2051                 result.append(comment_after, preformatted=True)
2052     bracket_split_succeeded_or_raise(head, body, tail)
2053     assert opening_bracket and closing_bracket
2054     if (
2055         # the opening bracket is an optional paren
2056         opening_bracket.type == token.LPAR
2057         and not opening_bracket.value
2058         # the closing bracket is an optional paren
2059         and closing_bracket.type == token.RPAR
2060         and not closing_bracket.value
2061         # there are no delimiters or standalone comments in the body
2062         and not body.bracket_tracker.delimiters
2063         and not line.contains_standalone_comments(0)
2064         # and it's not an import (optional parens are the only thing we can split
2065         # on in this case; attempting a split without them is a waste of time)
2066         and not line.is_import
2067     ):
2068         omit = {id(closing_bracket), *omit}
2069         try:
2070             yield from right_hand_split(line, py36=py36, omit=omit)
2071             return
2072         except CannotSplit:
2073             pass
2074
2075     ensure_visible(opening_bracket)
2076     ensure_visible(closing_bracket)
2077     for result in (head, body, tail):
2078         if result:
2079             yield result
2080
2081
2082 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2083     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2084
2085     Do nothing otherwise.
2086
2087     A left- or right-hand split is based on a pair of brackets. Content before
2088     (and including) the opening bracket is left on one line, content inside the
2089     brackets is put on a separate line, and finally content starting with and
2090     following the closing bracket is put on a separate line.
2091
2092     Those are called `head`, `body`, and `tail`, respectively. If the split
2093     produced the same line (all content in `head`) or ended up with an empty `body`
2094     and the `tail` is just the closing bracket, then it's considered failed.
2095     """
2096     tail_len = len(str(tail).strip())
2097     if not body:
2098         if tail_len == 0:
2099             raise CannotSplit("Splitting brackets produced the same line")
2100
2101         elif tail_len < 3:
2102             raise CannotSplit(
2103                 f"Splitting brackets on an empty body to save "
2104                 f"{tail_len} characters is not worth it"
2105             )
2106
2107
2108 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2109     """Normalize prefix of the first leaf in every line returned by `split_func`.
2110
2111     This is a decorator over relevant split functions.
2112     """
2113
2114     @wraps(split_func)
2115     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2116         for l in split_func(line, py36):
2117             normalize_prefix(l.leaves[0], inside_brackets=True)
2118             yield l
2119
2120     return split_wrapper
2121
2122
2123 @dont_increase_indentation
2124 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2125     """Split according to delimiters of the highest priority.
2126
2127     If `py36` is True, the split will add trailing commas also in function
2128     signatures that contain `*` and `**`.
2129     """
2130     try:
2131         last_leaf = line.leaves[-1]
2132     except IndexError:
2133         raise CannotSplit("Line empty")
2134
2135     delimiters = line.bracket_tracker.delimiters
2136     try:
2137         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
2138             exclude={id(last_leaf)}
2139         )
2140     except ValueError:
2141         raise CannotSplit("No delimiters found")
2142
2143     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2144     lowest_depth = sys.maxsize
2145     trailing_comma_safe = True
2146
2147     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2148         """Append `leaf` to current line or to new line if appending impossible."""
2149         nonlocal current_line
2150         try:
2151             current_line.append_safe(leaf, preformatted=True)
2152         except ValueError as ve:
2153             yield current_line
2154
2155             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2156             current_line.append(leaf)
2157
2158     for index, leaf in enumerate(line.leaves):
2159         yield from append_to_line(leaf)
2160
2161         for comment_after in line.comments_after(leaf, index):
2162             yield from append_to_line(comment_after)
2163
2164         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2165         if (
2166             leaf.bracket_depth == lowest_depth
2167             and is_vararg(leaf, within=VARARGS_PARENTS)
2168         ):
2169             trailing_comma_safe = trailing_comma_safe and py36
2170         leaf_priority = delimiters.get(id(leaf))
2171         if leaf_priority == delimiter_priority:
2172             yield current_line
2173
2174             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2175     if current_line:
2176         if (
2177             trailing_comma_safe
2178             and delimiter_priority == COMMA_PRIORITY
2179             and current_line.leaves[-1].type != token.COMMA
2180             and current_line.leaves[-1].type != STANDALONE_COMMENT
2181         ):
2182             current_line.append(Leaf(token.COMMA, ","))
2183         yield current_line
2184
2185
2186 @dont_increase_indentation
2187 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2188     """Split standalone comments from the rest of the line."""
2189     if not line.contains_standalone_comments(0):
2190         raise CannotSplit("Line does not have any standalone comments")
2191
2192     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2193
2194     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2195         """Append `leaf` to current line or to new line if appending impossible."""
2196         nonlocal current_line
2197         try:
2198             current_line.append_safe(leaf, preformatted=True)
2199         except ValueError as ve:
2200             yield current_line
2201
2202             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2203             current_line.append(leaf)
2204
2205     for index, leaf in enumerate(line.leaves):
2206         yield from append_to_line(leaf)
2207
2208         for comment_after in line.comments_after(leaf, index):
2209             yield from append_to_line(comment_after)
2210
2211     if current_line:
2212         yield current_line
2213
2214
2215 def explode_split(
2216     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2217 ) -> Iterator[Line]:
2218     """Split by rightmost bracket and immediately split contents by a delimiter."""
2219     new_lines = list(right_hand_split(line, py36, omit))
2220     if len(new_lines) != 3:
2221         yield from new_lines
2222         return
2223
2224     yield new_lines[0]
2225
2226     try:
2227         yield from delimiter_split(new_lines[1], py36)
2228
2229     except CannotSplit:
2230         yield new_lines[1]
2231
2232     yield new_lines[2]
2233
2234
2235 def is_import(leaf: Leaf) -> bool:
2236     """Return True if the given leaf starts an import statement."""
2237     p = leaf.parent
2238     t = leaf.type
2239     v = leaf.value
2240     return bool(
2241         t == token.NAME
2242         and (
2243             (v == "import" and p and p.type == syms.import_name)
2244             or (v == "from" and p and p.type == syms.import_from)
2245         )
2246     )
2247
2248
2249 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2250     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2251     else.
2252
2253     Note: don't use backslashes for formatting or you'll lose your voting rights.
2254     """
2255     if not inside_brackets:
2256         spl = leaf.prefix.split("#")
2257         if "\\" not in spl[0]:
2258             nl_count = spl[-1].count("\n")
2259             if len(spl) > 1:
2260                 nl_count -= 1
2261             leaf.prefix = "\n" * nl_count
2262             return
2263
2264     leaf.prefix = ""
2265
2266
2267 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2268     """Make all string prefixes lowercase.
2269
2270     If remove_u_prefix is given, also removes any u prefix from the string.
2271
2272     Note: Mutates its argument.
2273     """
2274     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2275     assert match is not None, f"failed to match string {leaf.value!r}"
2276     orig_prefix = match.group(1)
2277     new_prefix = orig_prefix.lower()
2278     if remove_u_prefix:
2279         new_prefix = new_prefix.replace("u", "")
2280     leaf.value = f"{new_prefix}{match.group(2)}"
2281
2282
2283 def normalize_string_quotes(leaf: Leaf) -> None:
2284     """Prefer double quotes but only if it doesn't cause more escaping.
2285
2286     Adds or removes backslashes as appropriate. Doesn't parse and fix
2287     strings nested in f-strings (yet).
2288
2289     Note: Mutates its argument.
2290     """
2291     value = leaf.value.lstrip("furbFURB")
2292     if value[:3] == '"""':
2293         return
2294
2295     elif value[:3] == "'''":
2296         orig_quote = "'''"
2297         new_quote = '"""'
2298     elif value[0] == '"':
2299         orig_quote = '"'
2300         new_quote = "'"
2301     else:
2302         orig_quote = "'"
2303         new_quote = '"'
2304     first_quote_pos = leaf.value.find(orig_quote)
2305     if first_quote_pos == -1:
2306         return  # There's an internal error
2307
2308     prefix = leaf.value[:first_quote_pos]
2309     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2310     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2311     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2312     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2313     if "r" in prefix.casefold():
2314         if unescaped_new_quote.search(body):
2315             # There's at least one unescaped new_quote in this raw string
2316             # so converting is impossible
2317             return
2318
2319         # Do not introduce or remove backslashes in raw strings
2320         new_body = body
2321     else:
2322         # remove unnecessary quotes
2323         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2324         if body != new_body:
2325             # Consider the string without unnecessary quotes as the original
2326             body = new_body
2327             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2328         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2329         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2330     if new_quote == '"""' and new_body[-1] == '"':
2331         # edge case:
2332         new_body = new_body[:-1] + '\\"'
2333     orig_escape_count = body.count("\\")
2334     new_escape_count = new_body.count("\\")
2335     if new_escape_count > orig_escape_count:
2336         return  # Do not introduce more escaping
2337
2338     if new_escape_count == orig_escape_count and orig_quote == '"':
2339         return  # Prefer double quotes
2340
2341     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2342
2343
2344 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2345     """Make existing optional parentheses invisible or create new ones.
2346
2347     `parens_after` is a set of string leaf values immeditely after which parens
2348     should be put.
2349
2350     Standardizes on visible parentheses for single-element tuples, and keeps
2351     existing visible parentheses for other tuples and generator expressions.
2352     """
2353     check_lpar = False
2354     for child in list(node.children):
2355         if check_lpar:
2356             if child.type == syms.atom:
2357                 maybe_make_parens_invisible_in_atom(child)
2358             elif is_one_tuple(child):
2359                 # wrap child in visible parentheses
2360                 lpar = Leaf(token.LPAR, "(")
2361                 rpar = Leaf(token.RPAR, ")")
2362                 index = child.remove() or 0
2363                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2364             else:
2365                 # wrap child in invisible parentheses
2366                 lpar = Leaf(token.LPAR, "")
2367                 rpar = Leaf(token.RPAR, "")
2368                 index = child.remove() or 0
2369                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2370
2371         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2372
2373
2374 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2375     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2376     if (
2377         node.type != syms.atom
2378         or is_empty_tuple(node)
2379         or is_one_tuple(node)
2380         or is_yield(node)
2381         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2382     ):
2383         return False
2384
2385     first = node.children[0]
2386     last = node.children[-1]
2387     if first.type == token.LPAR and last.type == token.RPAR:
2388         # make parentheses invisible
2389         first.value = ""  # type: ignore
2390         last.value = ""  # type: ignore
2391         if len(node.children) > 1:
2392             maybe_make_parens_invisible_in_atom(node.children[1])
2393         return True
2394
2395     return False
2396
2397
2398 def is_empty_tuple(node: LN) -> bool:
2399     """Return True if `node` holds an empty tuple."""
2400     return (
2401         node.type == syms.atom
2402         and len(node.children) == 2
2403         and node.children[0].type == token.LPAR
2404         and node.children[1].type == token.RPAR
2405     )
2406
2407
2408 def is_one_tuple(node: LN) -> bool:
2409     """Return True if `node` holds a tuple with one element, with or without parens."""
2410     if node.type == syms.atom:
2411         if len(node.children) != 3:
2412             return False
2413
2414         lpar, gexp, rpar = node.children
2415         if not (
2416             lpar.type == token.LPAR
2417             and gexp.type == syms.testlist_gexp
2418             and rpar.type == token.RPAR
2419         ):
2420             return False
2421
2422         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2423
2424     return (
2425         node.type in IMPLICIT_TUPLE
2426         and len(node.children) == 2
2427         and node.children[1].type == token.COMMA
2428     )
2429
2430
2431 def is_yield(node: LN) -> bool:
2432     """Return True if `node` holds a `yield` or `yield from` expression."""
2433     if node.type == syms.yield_expr:
2434         return True
2435
2436     if node.type == token.NAME and node.value == "yield":  # type: ignore
2437         return True
2438
2439     if node.type != syms.atom:
2440         return False
2441
2442     if len(node.children) != 3:
2443         return False
2444
2445     lpar, expr, rpar = node.children
2446     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2447         return is_yield(expr)
2448
2449     return False
2450
2451
2452 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2453     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2454
2455     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2456     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2457     extended iterable unpacking (PEP 3132) and additional unpacking
2458     generalizations (PEP 448).
2459     """
2460     if leaf.type not in STARS or not leaf.parent:
2461         return False
2462
2463     p = leaf.parent
2464     if p.type == syms.star_expr:
2465         # Star expressions are also used as assignment targets in extended
2466         # iterable unpacking (PEP 3132).  See what its parent is instead.
2467         if not p.parent:
2468             return False
2469
2470         p = p.parent
2471
2472     return p.type in within
2473
2474
2475 def max_delimiter_priority_in_atom(node: LN) -> int:
2476     """Return maximum delimiter priority inside `node`.
2477
2478     This is specific to atoms with contents contained in a pair of parentheses.
2479     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2480     """
2481     if node.type != syms.atom:
2482         return 0
2483
2484     first = node.children[0]
2485     last = node.children[-1]
2486     if not (first.type == token.LPAR and last.type == token.RPAR):
2487         return 0
2488
2489     bt = BracketTracker()
2490     for c in node.children[1:-1]:
2491         if isinstance(c, Leaf):
2492             bt.mark(c)
2493         else:
2494             for leaf in c.leaves():
2495                 bt.mark(leaf)
2496     try:
2497         return bt.max_delimiter_priority()
2498
2499     except ValueError:
2500         return 0
2501
2502
2503 def ensure_visible(leaf: Leaf) -> None:
2504     """Make sure parentheses are visible.
2505
2506     They could be invisible as part of some statements (see
2507     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2508     """
2509     if leaf.type == token.LPAR:
2510         leaf.value = "("
2511     elif leaf.type == token.RPAR:
2512         leaf.value = ")"
2513
2514
2515 def is_python36(node: Node) -> bool:
2516     """Return True if the current file is using Python 3.6+ features.
2517
2518     Currently looking for:
2519     - f-strings; and
2520     - trailing commas after * or ** in function signatures and calls.
2521     """
2522     for n in node.pre_order():
2523         if n.type == token.STRING:
2524             value_head = n.value[:2]  # type: ignore
2525             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2526                 return True
2527
2528         elif (
2529             n.type in {syms.typedargslist, syms.arglist}
2530             and n.children
2531             and n.children[-1].type == token.COMMA
2532         ):
2533             for ch in n.children:
2534                 if ch.type in STARS:
2535                     return True
2536
2537                 if ch.type == syms.argument:
2538                     for argch in ch.children:
2539                         if argch.type in STARS:
2540                             return True
2541
2542     return False
2543
2544
2545 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2546     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2547
2548     Brackets can be omitted if the entire trailer up to and including
2549     a preceding closing bracket fits in one line.
2550
2551     Yielded sets are cumulative (contain results of previous yields, too).  First
2552     set is empty.
2553     """
2554
2555     omit: Set[LeafID] = set()
2556     yield omit
2557
2558     length = 4 * line.depth
2559     opening_bracket = None
2560     closing_bracket = None
2561     optional_brackets: Set[LeafID] = set()
2562     inner_brackets: Set[LeafID] = set()
2563     for index, leaf in enumerate_reversed(line.leaves):
2564         length += len(leaf.prefix) + len(leaf.value)
2565         if length > line_length:
2566             break
2567
2568         comment: Optional[Leaf]
2569         for comment in line.comments_after(leaf, index):
2570             if "\n" in comment.prefix:
2571                 break  # Oops, standalone comment!
2572
2573             length += len(comment.value)
2574         else:
2575             comment = None
2576         if comment is not None:
2577             break  # There was a standalone comment, we can't continue.
2578
2579         optional_brackets.discard(id(leaf))
2580         if opening_bracket:
2581             if leaf is opening_bracket:
2582                 opening_bracket = None
2583             elif leaf.type in CLOSING_BRACKETS:
2584                 inner_brackets.add(id(leaf))
2585         elif leaf.type in CLOSING_BRACKETS:
2586             if not leaf.value:
2587                 optional_brackets.add(id(opening_bracket))
2588                 continue
2589
2590             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2591                 # Empty brackets would fail a split so treat them as "inner"
2592                 # brackets (e.g. only add them to the `omit` set if another
2593                 # pair of brackets was good enough.
2594                 inner_brackets.add(id(leaf))
2595                 continue
2596
2597             opening_bracket = leaf.opening_bracket
2598             if closing_bracket:
2599                 omit.add(id(closing_bracket))
2600                 omit.update(inner_brackets)
2601                 inner_brackets.clear()
2602                 yield omit
2603             closing_bracket = leaf
2604
2605
2606 def get_future_imports(node: Node) -> Set[str]:
2607     """Return a set of __future__ imports in the file."""
2608     imports = set()
2609     for child in node.children:
2610         if child.type != syms.simple_stmt:
2611             break
2612         first_child = child.children[0]
2613         if isinstance(first_child, Leaf):
2614             # Continue looking if we see a docstring; otherwise stop.
2615             if (
2616                 len(child.children) == 2
2617                 and first_child.type == token.STRING
2618                 and child.children[1].type == token.NEWLINE
2619             ):
2620                 continue
2621             else:
2622                 break
2623         elif first_child.type == syms.import_from:
2624             module_name = first_child.children[1]
2625             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2626                 break
2627             for import_from_child in first_child.children[3:]:
2628                 if isinstance(import_from_child, Leaf):
2629                     if import_from_child.type == token.NAME:
2630                         imports.add(import_from_child.value)
2631                 else:
2632                     assert import_from_child.type == syms.import_as_names
2633                     for leaf in import_from_child.children:
2634                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2635                             imports.add(leaf.value)
2636         else:
2637             break
2638     return imports
2639
2640
2641 PYTHON_EXTENSIONS = {".py", ".pyi"}
2642 BLACKLISTED_DIRECTORIES = {
2643     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2644 }
2645
2646
2647 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2648     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2649     and have one of the PYTHON_EXTENSIONS.
2650     """
2651     for child in path.iterdir():
2652         if child.is_dir():
2653             if child.name in BLACKLISTED_DIRECTORIES:
2654                 continue
2655
2656             yield from gen_python_files_in_dir(child)
2657
2658         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2659             yield child
2660
2661
2662 @dataclass
2663 class Report:
2664     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2665     check: bool = False
2666     quiet: bool = False
2667     change_count: int = 0
2668     same_count: int = 0
2669     failure_count: int = 0
2670
2671     def done(self, src: Path, changed: Changed) -> None:
2672         """Increment the counter for successful reformatting. Write out a message."""
2673         if changed is Changed.YES:
2674             reformatted = "would reformat" if self.check else "reformatted"
2675             if not self.quiet:
2676                 out(f"{reformatted} {src}")
2677             self.change_count += 1
2678         else:
2679             if not self.quiet:
2680                 if changed is Changed.NO:
2681                     msg = f"{src} already well formatted, good job."
2682                 else:
2683                     msg = f"{src} wasn't modified on disk since last run."
2684                 out(msg, bold=False)
2685             self.same_count += 1
2686
2687     def failed(self, src: Path, message: str) -> None:
2688         """Increment the counter for failed reformatting. Write out a message."""
2689         err(f"error: cannot format {src}: {message}")
2690         self.failure_count += 1
2691
2692     @property
2693     def return_code(self) -> int:
2694         """Return the exit code that the app should use.
2695
2696         This considers the current state of changed files and failures:
2697         - if there were any failures, return 123;
2698         - if any files were changed and --check is being used, return 1;
2699         - otherwise return 0.
2700         """
2701         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2702         # 126 we have special returncodes reserved by the shell.
2703         if self.failure_count:
2704             return 123
2705
2706         elif self.change_count and self.check:
2707             return 1
2708
2709         return 0
2710
2711     def __str__(self) -> str:
2712         """Render a color report of the current state.
2713
2714         Use `click.unstyle` to remove colors.
2715         """
2716         if self.check:
2717             reformatted = "would be reformatted"
2718             unchanged = "would be left unchanged"
2719             failed = "would fail to reformat"
2720         else:
2721             reformatted = "reformatted"
2722             unchanged = "left unchanged"
2723             failed = "failed to reformat"
2724         report = []
2725         if self.change_count:
2726             s = "s" if self.change_count > 1 else ""
2727             report.append(
2728                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2729             )
2730         if self.same_count:
2731             s = "s" if self.same_count > 1 else ""
2732             report.append(f"{self.same_count} file{s} {unchanged}")
2733         if self.failure_count:
2734             s = "s" if self.failure_count > 1 else ""
2735             report.append(
2736                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2737             )
2738         return ", ".join(report) + "."
2739
2740
2741 def assert_equivalent(src: str, dst: str) -> None:
2742     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2743
2744     import ast
2745     import traceback
2746
2747     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2748         """Simple visitor generating strings to compare ASTs by content."""
2749         yield f"{'  ' * depth}{node.__class__.__name__}("
2750
2751         for field in sorted(node._fields):
2752             try:
2753                 value = getattr(node, field)
2754             except AttributeError:
2755                 continue
2756
2757             yield f"{'  ' * (depth+1)}{field}="
2758
2759             if isinstance(value, list):
2760                 for item in value:
2761                     if isinstance(item, ast.AST):
2762                         yield from _v(item, depth + 2)
2763
2764             elif isinstance(value, ast.AST):
2765                 yield from _v(value, depth + 2)
2766
2767             else:
2768                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2769
2770         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2771
2772     try:
2773         src_ast = ast.parse(src)
2774     except Exception as exc:
2775         major, minor = sys.version_info[:2]
2776         raise AssertionError(
2777             f"cannot use --safe with this file; failed to parse source file "
2778             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2779             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2780         )
2781
2782     try:
2783         dst_ast = ast.parse(dst)
2784     except Exception as exc:
2785         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2786         raise AssertionError(
2787             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2788             f"Please report a bug on https://github.com/ambv/black/issues.  "
2789             f"This invalid output might be helpful: {log}"
2790         ) from None
2791
2792     src_ast_str = "\n".join(_v(src_ast))
2793     dst_ast_str = "\n".join(_v(dst_ast))
2794     if src_ast_str != dst_ast_str:
2795         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2796         raise AssertionError(
2797             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2798             f"the source.  "
2799             f"Please report a bug on https://github.com/ambv/black/issues.  "
2800             f"This diff might be helpful: {log}"
2801         ) from None
2802
2803
2804 def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
2805     """Raise AssertionError if `dst` reformats differently the second time."""
2806     newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
2807     if dst != newdst:
2808         log = dump_to_file(
2809             diff(src, dst, "source", "first pass"),
2810             diff(dst, newdst, "first pass", "second pass"),
2811         )
2812         raise AssertionError(
2813             f"INTERNAL ERROR: Black produced different code on the second pass "
2814             f"of the formatter.  "
2815             f"Please report a bug on https://github.com/ambv/black/issues.  "
2816             f"This diff might be helpful: {log}"
2817         ) from None
2818
2819
2820 def dump_to_file(*output: str) -> str:
2821     """Dump `output` to a temporary file. Return path to the file."""
2822     import tempfile
2823
2824     with tempfile.NamedTemporaryFile(
2825         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2826     ) as f:
2827         for lines in output:
2828             f.write(lines)
2829             if lines and lines[-1] != "\n":
2830                 f.write("\n")
2831     return f.name
2832
2833
2834 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2835     """Return a unified diff string between strings `a` and `b`."""
2836     import difflib
2837
2838     a_lines = [line + "\n" for line in a.split("\n")]
2839     b_lines = [line + "\n" for line in b.split("\n")]
2840     return "".join(
2841         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2842     )
2843
2844
2845 def cancel(tasks: List[asyncio.Task]) -> None:
2846     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2847     err("Aborted!")
2848     for task in tasks:
2849         task.cancel()
2850
2851
2852 def shutdown(loop: BaseEventLoop) -> None:
2853     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2854     try:
2855         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2856         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2857         if not to_cancel:
2858             return
2859
2860         for task in to_cancel:
2861             task.cancel()
2862         loop.run_until_complete(
2863             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2864         )
2865     finally:
2866         # `concurrent.futures.Future` objects cannot be cancelled once they
2867         # are already running. There might be some when the `shutdown()` happened.
2868         # Silence their logger's spew about the event loop being closed.
2869         cf_logger = logging.getLogger("concurrent.futures")
2870         cf_logger.setLevel(logging.CRITICAL)
2871         loop.close()
2872
2873
2874 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2875     """Replace `regex` with `replacement` twice on `original`.
2876
2877     This is used by string normalization to perform replaces on
2878     overlapping matches.
2879     """
2880     return regex.sub(replacement, regex.sub(replacement, original))
2881
2882
2883 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
2884     """Like `reversed(enumerate(sequence))` if that were possible."""
2885     index = len(sequence) - 1
2886     for element in reversed(sequence):
2887         yield (index, element)
2888         index -= 1
2889
2890
2891 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
2892     """Return True if `line` is no longer than `line_length`.
2893
2894     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
2895     """
2896     if not line_str:
2897         line_str = str(line).strip("\n")
2898     return (
2899         len(line_str) <= line_length
2900         and "\n" not in line_str  # multiline strings
2901         and not line.contains_standalone_comments()
2902     )
2903
2904
2905 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2906
2907
2908 def get_cache_file(line_length: int) -> Path:
2909     return CACHE_DIR / f"cache.{line_length}.pickle"
2910
2911
2912 def read_cache(line_length: int) -> Cache:
2913     """Read the cache if it exists and is well formed.
2914
2915     If it is not well formed, the call to write_cache later should resolve the issue.
2916     """
2917     cache_file = get_cache_file(line_length)
2918     if not cache_file.exists():
2919         return {}
2920
2921     with cache_file.open("rb") as fobj:
2922         try:
2923             cache: Cache = pickle.load(fobj)
2924         except pickle.UnpicklingError:
2925             return {}
2926
2927     return cache
2928
2929
2930 def get_cache_info(path: Path) -> CacheInfo:
2931     """Return the information used to check if a file is already formatted or not."""
2932     stat = path.stat()
2933     return stat.st_mtime, stat.st_size
2934
2935
2936 def filter_cached(
2937     cache: Cache, sources: Iterable[Path]
2938 ) -> Tuple[List[Path], List[Path]]:
2939     """Split a list of paths into two.
2940
2941     The first list contains paths of files that modified on disk or are not in the
2942     cache. The other list contains paths to non-modified files.
2943     """
2944     todo, done = [], []
2945     for src in sources:
2946         src = src.resolve()
2947         if cache.get(src) != get_cache_info(src):
2948             todo.append(src)
2949         else:
2950             done.append(src)
2951     return todo, done
2952
2953
2954 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2955     """Update the cache file."""
2956     cache_file = get_cache_file(line_length)
2957     try:
2958         if not CACHE_DIR.exists():
2959             CACHE_DIR.mkdir(parents=True)
2960         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2961         with cache_file.open("wb") as fobj:
2962             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2963     except OSError:
2964         pass
2965
2966
2967 if __name__ == "__main__":
2968     main()