]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Automatic parentheses management
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from enum import Enum
7 from functools import partial, wraps
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import re
14 import tokenize
15 import signal
16 import sys
17 from typing import (
18     Any,
19     Callable,
20     Collection,
21     Dict,
22     Generic,
23     Iterable,
24     Iterator,
25     List,
26     Optional,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from attr import dataclass, Factory
35 import click
36
37 # lib2to3 fork
38 from blib2to3.pytree import Node, Leaf, type_repr
39 from blib2to3 import pygram, pytree
40 from blib2to3.pgen2 import driver, token
41 from blib2to3.pgen2.parse import ParseError
42
43 __version__ = "18.4a0"
44 DEFAULT_LINE_LENGTH = 88
45 # types
46 syms = pygram.python_symbols
47 FileContent = str
48 Encoding = str
49 Depth = int
50 NodeType = int
51 LeafID = int
52 Priority = int
53 Index = int
54 LN = Union[Leaf, Node]
55 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
56 out = partial(click.secho, bold=True, err=True)
57 err = partial(click.secho, fg="red", err=True)
58
59
60 class NothingChanged(UserWarning):
61     """Raised by :func:`format_file` when reformatted code is the same as source."""
62
63
64 class CannotSplit(Exception):
65     """A readable split that fits the allotted line length is impossible.
66
67     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
68     :func:`delimiter_split`.
69     """
70
71
72 class FormatError(Exception):
73     """Base exception for `# fmt: on` and `# fmt: off` handling.
74
75     It holds the number of bytes of the prefix consumed before the format
76     control comment appeared.
77     """
78
79     def __init__(self, consumed: int) -> None:
80         super().__init__(consumed)
81         self.consumed = consumed
82
83     def trim_prefix(self, leaf: Leaf) -> None:
84         leaf.prefix = leaf.prefix[self.consumed:]
85
86     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
87         """Returns a new Leaf from the consumed part of the prefix."""
88         unformatted_prefix = leaf.prefix[:self.consumed]
89         return Leaf(token.NEWLINE, unformatted_prefix)
90
91
92 class FormatOn(FormatError):
93     """Found a comment like `# fmt: on` in the file."""
94
95
96 class FormatOff(FormatError):
97     """Found a comment like `# fmt: off` in the file."""
98
99
100 class WriteBack(Enum):
101     NO = 0
102     YES = 1
103     DIFF = 2
104
105
106 @click.command()
107 @click.option(
108     "-l",
109     "--line-length",
110     type=int,
111     default=DEFAULT_LINE_LENGTH,
112     help="How many character per line to allow.",
113     show_default=True,
114 )
115 @click.option(
116     "--check",
117     is_flag=True,
118     help=(
119         "Don't write the files back, just return the status.  Return code 0 "
120         "means nothing would change.  Return code 1 means some files would be "
121         "reformatted.  Return code 123 means there was an internal error."
122     ),
123 )
124 @click.option(
125     "--diff",
126     is_flag=True,
127     help="Don't write the files back, just output a diff for each file on stdout.",
128 )
129 @click.option(
130     "--fast/--safe",
131     is_flag=True,
132     help="If --fast given, skip temporary sanity checks. [default: --safe]",
133 )
134 @click.option(
135     "-q",
136     "--quiet",
137     is_flag=True,
138     help=(
139         "Don't emit non-error messages to stderr. Errors are still emitted, "
140         "silence those with 2>/dev/null."
141     ),
142 )
143 @click.version_option(version=__version__)
144 @click.argument(
145     "src",
146     nargs=-1,
147     type=click.Path(
148         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
149     ),
150 )
151 @click.pass_context
152 def main(
153     ctx: click.Context,
154     line_length: int,
155     check: bool,
156     diff: bool,
157     fast: bool,
158     quiet: bool,
159     src: List[str],
160 ) -> None:
161     """The uncompromising code formatter."""
162     sources: List[Path] = []
163     for s in src:
164         p = Path(s)
165         if p.is_dir():
166             sources.extend(gen_python_files_in_dir(p))
167         elif p.is_file():
168             # if a file was explicitly given, we don't care about its extension
169             sources.append(p)
170         elif s == "-":
171             sources.append(Path("-"))
172         else:
173             err(f"invalid path: {s}")
174     if check and diff:
175         exc = click.ClickException("Options --check and --diff are mutually exclusive")
176         exc.exit_code = 2
177         raise exc
178
179     if check:
180         write_back = WriteBack.NO
181     elif diff:
182         write_back = WriteBack.DIFF
183     else:
184         write_back = WriteBack.YES
185     if len(sources) == 0:
186         ctx.exit(0)
187     elif len(sources) == 1:
188         p = sources[0]
189         report = Report(check=check, quiet=quiet)
190         try:
191             if not p.is_file() and str(p) == "-":
192                 changed = format_stdin_to_stdout(
193                     line_length=line_length, fast=fast, write_back=write_back
194                 )
195             else:
196                 changed = format_file_in_place(
197                     p, line_length=line_length, fast=fast, write_back=write_back
198                 )
199             report.done(p, changed)
200         except Exception as exc:
201             report.failed(p, str(exc))
202         ctx.exit(report.return_code)
203     else:
204         loop = asyncio.get_event_loop()
205         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
206         return_code = 1
207         try:
208             return_code = loop.run_until_complete(
209                 schedule_formatting(
210                     sources, line_length, write_back, fast, quiet, loop, executor
211                 )
212             )
213         finally:
214             shutdown(loop)
215             ctx.exit(return_code)
216
217
218 async def schedule_formatting(
219     sources: List[Path],
220     line_length: int,
221     write_back: WriteBack,
222     fast: bool,
223     quiet: bool,
224     loop: BaseEventLoop,
225     executor: Executor,
226 ) -> int:
227     """Run formatting of `sources` in parallel using the provided `executor`.
228
229     (Use ProcessPoolExecutors for actual parallelism.)
230
231     `line_length`, `write_back`, and `fast` options are passed to
232     :func:`format_file_in_place`.
233     """
234     lock = None
235     if write_back == WriteBack.DIFF:
236         # For diff output, we need locks to ensure we don't interleave output
237         # from different processes.
238         manager = Manager()
239         lock = manager.Lock()
240     tasks = {
241         src: loop.run_in_executor(
242             executor, format_file_in_place, src, line_length, fast, write_back, lock
243         )
244         for src in sources
245     }
246     _task_values = list(tasks.values())
247     loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
248     loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
249     await asyncio.wait(tasks.values())
250     cancelled = []
251     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
252     for src, task in tasks.items():
253         if not task.done():
254             report.failed(src, "timed out, cancelling")
255             task.cancel()
256             cancelled.append(task)
257         elif task.cancelled():
258             cancelled.append(task)
259         elif task.exception():
260             report.failed(src, str(task.exception()))
261         else:
262             report.done(src, task.result())
263     if cancelled:
264         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
265     elif not quiet:
266         out("All done! ✨ 🍰 ✨")
267     if not quiet:
268         click.echo(str(report))
269     return report.return_code
270
271
272 def format_file_in_place(
273     src: Path,
274     line_length: int,
275     fast: bool,
276     write_back: WriteBack = WriteBack.NO,
277     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
278 ) -> bool:
279     """Format file under `src` path. Return True if changed.
280
281     If `write_back` is True, write reformatted code back to stdout.
282     `line_length` and `fast` options are passed to :func:`format_file_contents`.
283     """
284     with tokenize.open(src) as src_buffer:
285         src_contents = src_buffer.read()
286     try:
287         dst_contents = format_file_contents(
288             src_contents, line_length=line_length, fast=fast
289         )
290     except NothingChanged:
291         return False
292
293     if write_back == write_back.YES:
294         with open(src, "w", encoding=src_buffer.encoding) as f:
295             f.write(dst_contents)
296     elif write_back == write_back.DIFF:
297         src_name = f"{src.name}  (original)"
298         dst_name = f"{src.name}  (formatted)"
299         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
300         if lock:
301             lock.acquire()
302         try:
303             sys.stdout.write(diff_contents)
304         finally:
305             if lock:
306                 lock.release()
307     return True
308
309
310 def format_stdin_to_stdout(
311     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
312 ) -> bool:
313     """Format file on stdin. Return True if changed.
314
315     If `write_back` is True, write reformatted code back to stdout.
316     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
317     """
318     src = sys.stdin.read()
319     dst = src
320     try:
321         dst = format_file_contents(src, line_length=line_length, fast=fast)
322         return True
323
324     except NothingChanged:
325         return False
326
327     finally:
328         if write_back == WriteBack.YES:
329             sys.stdout.write(dst)
330         elif write_back == WriteBack.DIFF:
331             src_name = "<stdin>  (original)"
332             dst_name = "<stdin>  (formatted)"
333             sys.stdout.write(diff(src, dst, src_name, dst_name))
334
335
336 def format_file_contents(
337     src_contents: str, line_length: int, fast: bool
338 ) -> FileContent:
339     """Reformat contents a file and return new contents.
340
341     If `fast` is False, additionally confirm that the reformatted code is
342     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
343     `line_length` is passed to :func:`format_str`.
344     """
345     if src_contents.strip() == "":
346         raise NothingChanged
347
348     dst_contents = format_str(src_contents, line_length=line_length)
349     if src_contents == dst_contents:
350         raise NothingChanged
351
352     if not fast:
353         assert_equivalent(src_contents, dst_contents)
354         assert_stable(src_contents, dst_contents, line_length=line_length)
355     return dst_contents
356
357
358 def format_str(src_contents: str, line_length: int) -> FileContent:
359     """Reformat a string and return new contents.
360
361     `line_length` determines how many characters per line are allowed.
362     """
363     src_node = lib2to3_parse(src_contents)
364     dst_contents = ""
365     lines = LineGenerator()
366     elt = EmptyLineTracker()
367     py36 = is_python36(src_node)
368     empty_line = Line()
369     after = 0
370     for current_line in lines.visit(src_node):
371         for _ in range(after):
372             dst_contents += str(empty_line)
373         before, after = elt.maybe_empty_lines(current_line)
374         for _ in range(before):
375             dst_contents += str(empty_line)
376         for line in split_line(current_line, line_length=line_length, py36=py36):
377             dst_contents += str(line)
378     return dst_contents
379
380
381 GRAMMARS = [
382     pygram.python_grammar_no_print_statement_no_exec_statement,
383     pygram.python_grammar_no_print_statement,
384     pygram.python_grammar_no_exec_statement,
385     pygram.python_grammar,
386 ]
387
388
389 def lib2to3_parse(src_txt: str) -> Node:
390     """Given a string with source, return the lib2to3 Node."""
391     grammar = pygram.python_grammar_no_print_statement
392     if src_txt[-1] != "\n":
393         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
394         src_txt += nl
395     for grammar in GRAMMARS:
396         drv = driver.Driver(grammar, pytree.convert)
397         try:
398             result = drv.parse_string(src_txt, True)
399             break
400
401         except ParseError as pe:
402             lineno, column = pe.context[1]
403             lines = src_txt.splitlines()
404             try:
405                 faulty_line = lines[lineno - 1]
406             except IndexError:
407                 faulty_line = "<line number missing in source>"
408             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
409     else:
410         raise exc from None
411
412     if isinstance(result, Leaf):
413         result = Node(syms.file_input, [result])
414     return result
415
416
417 def lib2to3_unparse(node: Node) -> str:
418     """Given a lib2to3 node, return its string representation."""
419     code = str(node)
420     return code
421
422
423 T = TypeVar("T")
424
425
426 class Visitor(Generic[T]):
427     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
428
429     def visit(self, node: LN) -> Iterator[T]:
430         """Main method to visit `node` and its children.
431
432         It tries to find a `visit_*()` method for the given `node.type`, like
433         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
434         If no dedicated `visit_*()` method is found, chooses `visit_default()`
435         instead.
436
437         Then yields objects of type `T` from the selected visitor.
438         """
439         if node.type < 256:
440             name = token.tok_name[node.type]
441         else:
442             name = type_repr(node.type)
443         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
444
445     def visit_default(self, node: LN) -> Iterator[T]:
446         """Default `visit_*()` implementation. Recurses to children of `node`."""
447         if isinstance(node, Node):
448             for child in node.children:
449                 yield from self.visit(child)
450
451
452 @dataclass
453 class DebugVisitor(Visitor[T]):
454     tree_depth: int = 0
455
456     def visit_default(self, node: LN) -> Iterator[T]:
457         indent = " " * (2 * self.tree_depth)
458         if isinstance(node, Node):
459             _type = type_repr(node.type)
460             out(f"{indent}{_type}", fg="yellow")
461             self.tree_depth += 1
462             for child in node.children:
463                 yield from self.visit(child)
464
465             self.tree_depth -= 1
466             out(f"{indent}/{_type}", fg="yellow", bold=False)
467         else:
468             _type = token.tok_name.get(node.type, str(node.type))
469             out(f"{indent}{_type}", fg="blue", nl=False)
470             if node.prefix:
471                 # We don't have to handle prefixes for `Node` objects since
472                 # that delegates to the first child anyway.
473                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
474             out(f" {node.value!r}", fg="blue", bold=False)
475
476     @classmethod
477     def show(cls, code: str) -> None:
478         """Pretty-print the lib2to3 AST of a given string of `code`.
479
480         Convenience method for debugging.
481         """
482         v: DebugVisitor[None] = DebugVisitor()
483         list(v.visit(lib2to3_parse(code)))
484
485
486 KEYWORDS = set(keyword.kwlist)
487 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
488 FLOW_CONTROL = {"return", "raise", "break", "continue"}
489 STATEMENT = {
490     syms.if_stmt,
491     syms.while_stmt,
492     syms.for_stmt,
493     syms.try_stmt,
494     syms.except_clause,
495     syms.with_stmt,
496     syms.funcdef,
497     syms.classdef,
498 }
499 STANDALONE_COMMENT = 153
500 LOGIC_OPERATORS = {"and", "or"}
501 COMPARATORS = {
502     token.LESS,
503     token.GREATER,
504     token.EQEQUAL,
505     token.NOTEQUAL,
506     token.LESSEQUAL,
507     token.GREATEREQUAL,
508 }
509 MATH_OPERATORS = {
510     token.PLUS,
511     token.MINUS,
512     token.STAR,
513     token.SLASH,
514     token.VBAR,
515     token.AMPER,
516     token.PERCENT,
517     token.CIRCUMFLEX,
518     token.TILDE,
519     token.LEFTSHIFT,
520     token.RIGHTSHIFT,
521     token.DOUBLESTAR,
522     token.DOUBLESLASH,
523 }
524 VARARGS = {token.STAR, token.DOUBLESTAR}
525 COMPREHENSION_PRIORITY = 20
526 COMMA_PRIORITY = 10
527 LOGIC_PRIORITY = 5
528 STRING_PRIORITY = 4
529 COMPARATOR_PRIORITY = 3
530 MATH_PRIORITY = 1
531
532
533 @dataclass
534 class BracketTracker:
535     """Keeps track of brackets on a line."""
536
537     depth: int = 0
538     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
539     delimiters: Dict[LeafID, Priority] = Factory(dict)
540     previous: Optional[Leaf] = None
541
542     def mark(self, leaf: Leaf) -> None:
543         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
544
545         All leaves receive an int `bracket_depth` field that stores how deep
546         within brackets a given leaf is. 0 means there are no enclosing brackets
547         that started on this line.
548
549         If a leaf is itself a closing bracket, it receives an `opening_bracket`
550         field that it forms a pair with. This is a one-directional link to
551         avoid reference cycles.
552
553         If a leaf is a delimiter (a token on which Black can split the line if
554         needed) and it's on depth 0, its `id()` is stored in the tracker's
555         `delimiters` field.
556         """
557         if leaf.type == token.COMMENT:
558             return
559
560         if leaf.type in CLOSING_BRACKETS:
561             self.depth -= 1
562             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
563             leaf.opening_bracket = opening_bracket
564         leaf.bracket_depth = self.depth
565         if self.depth == 0:
566             delim = is_split_before_delimiter(leaf, self.previous)
567             if delim and self.previous is not None:
568                 self.delimiters[id(self.previous)] = delim
569             else:
570                 delim = is_split_after_delimiter(leaf, self.previous)
571                 if delim:
572                     self.delimiters[id(leaf)] = delim
573         if leaf.type in OPENING_BRACKETS:
574             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
575             self.depth += 1
576         self.previous = leaf
577
578     def any_open_brackets(self) -> bool:
579         """Return True if there is an yet unmatched open bracket on the line."""
580         return bool(self.bracket_match)
581
582     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
583         """Return the highest priority of a delimiter found on the line.
584
585         Values are consistent with what `is_delimiter()` returns.
586         Raises ValueError on no delimiters.
587         """
588         return max(v for k, v in self.delimiters.items() if k not in exclude)
589
590
591 @dataclass
592 class Line:
593     """Holds leaves and comments. Can be printed with `str(line)`."""
594
595     depth: int = 0
596     leaves: List[Leaf] = Factory(list)
597     comments: List[Tuple[Index, Leaf]] = Factory(list)
598     bracket_tracker: BracketTracker = Factory(BracketTracker)
599     inside_brackets: bool = False
600     has_for: bool = False
601     _for_loop_variable: bool = False
602
603     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
604         """Add a new `leaf` to the end of the line.
605
606         Unless `preformatted` is True, the `leaf` will receive a new consistent
607         whitespace prefix and metadata applied by :class:`BracketTracker`.
608         Trailing commas are maybe removed, unpacked for loop variables are
609         demoted from being delimiters.
610
611         Inline comments are put aside.
612         """
613         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
614         if not has_value:
615             return
616
617         if self.leaves and not preformatted:
618             # Note: at this point leaf.prefix should be empty except for
619             # imports, for which we only preserve newlines.
620             leaf.prefix += whitespace(leaf)
621         if self.inside_brackets or not preformatted:
622             self.maybe_decrement_after_for_loop_variable(leaf)
623             self.bracket_tracker.mark(leaf)
624             self.maybe_remove_trailing_comma(leaf)
625             self.maybe_increment_for_loop_variable(leaf)
626
627         if not self.append_comment(leaf):
628             self.leaves.append(leaf)
629
630     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
631         """Like :func:`append()` but disallow invalid standalone comment structure.
632
633         Raises ValueError when any `leaf` is appended after a standalone comment
634         or when a standalone comment is not the first leaf on the line.
635         """
636         if self.bracket_tracker.depth == 0:
637             if self.is_comment:
638                 raise ValueError("cannot append to standalone comments")
639
640             if self.leaves and leaf.type == STANDALONE_COMMENT:
641                 raise ValueError(
642                     "cannot append standalone comments to a populated line"
643                 )
644
645         self.append(leaf, preformatted=preformatted)
646
647     @property
648     def is_comment(self) -> bool:
649         """Is this line a standalone comment?"""
650         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
651
652     @property
653     def is_decorator(self) -> bool:
654         """Is this line a decorator?"""
655         return bool(self) and self.leaves[0].type == token.AT
656
657     @property
658     def is_import(self) -> bool:
659         """Is this an import line?"""
660         return bool(self) and is_import(self.leaves[0])
661
662     @property
663     def is_class(self) -> bool:
664         """Is this line a class definition?"""
665         return (
666             bool(self)
667             and self.leaves[0].type == token.NAME
668             and self.leaves[0].value == "class"
669         )
670
671     @property
672     def is_def(self) -> bool:
673         """Is this a function definition? (Also returns True for async defs.)"""
674         try:
675             first_leaf = self.leaves[0]
676         except IndexError:
677             return False
678
679         try:
680             second_leaf: Optional[Leaf] = self.leaves[1]
681         except IndexError:
682             second_leaf = None
683         return (
684             (first_leaf.type == token.NAME and first_leaf.value == "def")
685             or (
686                 first_leaf.type == token.ASYNC
687                 and second_leaf is not None
688                 and second_leaf.type == token.NAME
689                 and second_leaf.value == "def"
690             )
691         )
692
693     @property
694     def is_flow_control(self) -> bool:
695         """Is this line a flow control statement?
696
697         Those are `return`, `raise`, `break`, and `continue`.
698         """
699         return (
700             bool(self)
701             and self.leaves[0].type == token.NAME
702             and self.leaves[0].value in FLOW_CONTROL
703         )
704
705     @property
706     def is_yield(self) -> bool:
707         """Is this line a yield statement?"""
708         return (
709             bool(self)
710             and self.leaves[0].type == token.NAME
711             and self.leaves[0].value == "yield"
712         )
713
714     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
715         """If so, needs to be split before emitting."""
716         for leaf in self.leaves:
717             if leaf.type == STANDALONE_COMMENT:
718                 if leaf.bracket_depth <= depth_limit:
719                     return True
720
721         return False
722
723     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
724         """Remove trailing comma if there is one and it's safe."""
725         if not (
726             self.leaves
727             and self.leaves[-1].type == token.COMMA
728             and closing.type in CLOSING_BRACKETS
729         ):
730             return False
731
732         if closing.type == token.RBRACE:
733             self.remove_trailing_comma()
734             return True
735
736         if closing.type == token.RSQB:
737             comma = self.leaves[-1]
738             if comma.parent and comma.parent.type == syms.listmaker:
739                 self.remove_trailing_comma()
740                 return True
741
742         # For parens let's check if it's safe to remove the comma.  If the
743         # trailing one is the only one, we might mistakenly change a tuple
744         # into a different type by removing the comma.
745         depth = closing.bracket_depth + 1
746         commas = 0
747         opening = closing.opening_bracket
748         for _opening_index, leaf in enumerate(self.leaves):
749             if leaf is opening:
750                 break
751
752         else:
753             return False
754
755         for leaf in self.leaves[_opening_index + 1:]:
756             if leaf is closing:
757                 break
758
759             bracket_depth = leaf.bracket_depth
760             if bracket_depth == depth and leaf.type == token.COMMA:
761                 commas += 1
762                 if leaf.parent and leaf.parent.type == syms.arglist:
763                     commas += 1
764                     break
765
766         if commas > 1:
767             self.remove_trailing_comma()
768             return True
769
770         return False
771
772     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
773         """In a for loop, or comprehension, the variables are often unpacks.
774
775         To avoid splitting on the comma in this situation, increase the depth of
776         tokens between `for` and `in`.
777         """
778         if leaf.type == token.NAME and leaf.value == "for":
779             self.has_for = True
780             self.bracket_tracker.depth += 1
781             self._for_loop_variable = True
782             return True
783
784         return False
785
786     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
787         """See `maybe_increment_for_loop_variable` above for explanation."""
788         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
789             self.bracket_tracker.depth -= 1
790             self._for_loop_variable = False
791             return True
792
793         return False
794
795     def append_comment(self, comment: Leaf) -> bool:
796         """Add an inline or standalone comment to the line."""
797         if (
798             comment.type == STANDALONE_COMMENT
799             and self.bracket_tracker.any_open_brackets()
800         ):
801             comment.prefix = ""
802             return False
803
804         if comment.type != token.COMMENT:
805             return False
806
807         after = len(self.leaves) - 1
808         if after == -1:
809             comment.type = STANDALONE_COMMENT
810             comment.prefix = ""
811             return False
812
813         else:
814             self.comments.append((after, comment))
815             return True
816
817     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
818         """Generate comments that should appear directly after `leaf`."""
819         for _leaf_index, _leaf in enumerate(self.leaves):
820             if leaf is _leaf:
821                 break
822
823         else:
824             return
825
826         for index, comment_after in self.comments:
827             if _leaf_index == index:
828                 yield comment_after
829
830     def remove_trailing_comma(self) -> None:
831         """Remove the trailing comma and moves the comments attached to it."""
832         comma_index = len(self.leaves) - 1
833         for i in range(len(self.comments)):
834             comment_index, comment = self.comments[i]
835             if comment_index == comma_index:
836                 self.comments[i] = (comma_index - 1, comment)
837         self.leaves.pop()
838
839     def __str__(self) -> str:
840         """Render the line."""
841         if not self:
842             return "\n"
843
844         indent = "    " * self.depth
845         leaves = iter(self.leaves)
846         first = next(leaves)
847         res = f"{first.prefix}{indent}{first.value}"
848         for leaf in leaves:
849             res += str(leaf)
850         for _, comment in self.comments:
851             res += str(comment)
852         return res + "\n"
853
854     def __bool__(self) -> bool:
855         """Return True if the line has leaves or comments."""
856         return bool(self.leaves or self.comments)
857
858
859 class UnformattedLines(Line):
860     """Just like :class:`Line` but stores lines which aren't reformatted."""
861
862     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
863         """Just add a new `leaf` to the end of the lines.
864
865         The `preformatted` argument is ignored.
866
867         Keeps track of indentation `depth`, which is useful when the user
868         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
869         """
870         try:
871             list(generate_comments(leaf))
872         except FormatOn as f_on:
873             self.leaves.append(f_on.leaf_from_consumed(leaf))
874             raise
875
876         self.leaves.append(leaf)
877         if leaf.type == token.INDENT:
878             self.depth += 1
879         elif leaf.type == token.DEDENT:
880             self.depth -= 1
881
882     def __str__(self) -> str:
883         """Render unformatted lines from leaves which were added with `append()`.
884
885         `depth` is not used for indentation in this case.
886         """
887         if not self:
888             return "\n"
889
890         res = ""
891         for leaf in self.leaves:
892             res += str(leaf)
893         return res
894
895     def append_comment(self, comment: Leaf) -> bool:
896         """Not implemented in this class. Raises `NotImplementedError`."""
897         raise NotImplementedError("Unformatted lines don't store comments separately.")
898
899     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
900         """Does nothing and returns False."""
901         return False
902
903     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
904         """Does nothing and returns False."""
905         return False
906
907
908 @dataclass
909 class EmptyLineTracker:
910     """Provides a stateful method that returns the number of potential extra
911     empty lines needed before and after the currently processed line.
912
913     Note: this tracker works on lines that haven't been split yet.  It assumes
914     the prefix of the first leaf consists of optional newlines.  Those newlines
915     are consumed by `maybe_empty_lines()` and included in the computation.
916     """
917     previous_line: Optional[Line] = None
918     previous_after: int = 0
919     previous_defs: List[int] = Factory(list)
920
921     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
922         """Return the number of extra empty lines before and after the `current_line`.
923
924         This is for separating `def`, `async def` and `class` with extra empty
925         lines (two on module-level), as well as providing an extra empty line
926         after flow control keywords to make them more prominent.
927         """
928         if isinstance(current_line, UnformattedLines):
929             return 0, 0
930
931         before, after = self._maybe_empty_lines(current_line)
932         before -= self.previous_after
933         self.previous_after = after
934         self.previous_line = current_line
935         return before, after
936
937     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
938         max_allowed = 1
939         if current_line.depth == 0:
940             max_allowed = 2
941         if current_line.leaves:
942             # Consume the first leaf's extra newlines.
943             first_leaf = current_line.leaves[0]
944             before = first_leaf.prefix.count("\n")
945             before = min(before, max_allowed)
946             first_leaf.prefix = ""
947         else:
948             before = 0
949         depth = current_line.depth
950         while self.previous_defs and self.previous_defs[-1] >= depth:
951             self.previous_defs.pop()
952             before = 1 if depth else 2
953         is_decorator = current_line.is_decorator
954         if is_decorator or current_line.is_def or current_line.is_class:
955             if not is_decorator:
956                 self.previous_defs.append(depth)
957             if self.previous_line is None:
958                 # Don't insert empty lines before the first line in the file.
959                 return 0, 0
960
961             if self.previous_line and self.previous_line.is_decorator:
962                 # Don't insert empty lines between decorators.
963                 return 0, 0
964
965             newlines = 2
966             if current_line.depth:
967                 newlines -= 1
968             return newlines, 0
969
970         if current_line.is_flow_control:
971             return before, 1
972
973         if (
974             self.previous_line
975             and self.previous_line.is_import
976             and not current_line.is_import
977             and depth == self.previous_line.depth
978         ):
979             return (before or 1), 0
980
981         if (
982             self.previous_line
983             and self.previous_line.is_yield
984             and (not current_line.is_yield or depth != self.previous_line.depth)
985         ):
986             return (before or 1), 0
987
988         return before, 0
989
990
991 @dataclass
992 class LineGenerator(Visitor[Line]):
993     """Generates reformatted Line objects.  Empty lines are not emitted.
994
995     Note: destroys the tree it's visiting by mutating prefixes of its leaves
996     in ways that will no longer stringify to valid Python code on the tree.
997     """
998     current_line: Line = Factory(Line)
999
1000     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1001         """Generate a line.
1002
1003         If the line is empty, only emit if it makes sense.
1004         If the line is too long, split it first and then generate.
1005
1006         If any lines were generated, set up a new current_line.
1007         """
1008         if not self.current_line:
1009             if self.current_line.__class__ == type:
1010                 self.current_line.depth += indent
1011             else:
1012                 self.current_line = type(depth=self.current_line.depth + indent)
1013             return  # Line is empty, don't emit. Creating a new one unnecessary.
1014
1015         complete_line = self.current_line
1016         self.current_line = type(depth=complete_line.depth + indent)
1017         yield complete_line
1018
1019     def visit(self, node: LN) -> Iterator[Line]:
1020         """Main method to visit `node` and its children.
1021
1022         Yields :class:`Line` objects.
1023         """
1024         if isinstance(self.current_line, UnformattedLines):
1025             # File contained `# fmt: off`
1026             yield from self.visit_unformatted(node)
1027
1028         else:
1029             yield from super().visit(node)
1030
1031     def visit_default(self, node: LN) -> Iterator[Line]:
1032         """Default `visit_*()` implementation. Recurses to children of `node`."""
1033         if isinstance(node, Leaf):
1034             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1035             try:
1036                 for comment in generate_comments(node):
1037                     if any_open_brackets:
1038                         # any comment within brackets is subject to splitting
1039                         self.current_line.append(comment)
1040                     elif comment.type == token.COMMENT:
1041                         # regular trailing comment
1042                         self.current_line.append(comment)
1043                         yield from self.line()
1044
1045                     else:
1046                         # regular standalone comment
1047                         yield from self.line()
1048
1049                         self.current_line.append(comment)
1050                         yield from self.line()
1051
1052             except FormatOff as f_off:
1053                 f_off.trim_prefix(node)
1054                 yield from self.line(type=UnformattedLines)
1055                 yield from self.visit(node)
1056
1057             except FormatOn as f_on:
1058                 # This only happens here if somebody says "fmt: on" multiple
1059                 # times in a row.
1060                 f_on.trim_prefix(node)
1061                 yield from self.visit_default(node)
1062
1063             else:
1064                 normalize_prefix(node, inside_brackets=any_open_brackets)
1065                 if node.type == token.STRING:
1066                     normalize_string_quotes(node)
1067                 if node.type not in WHITESPACE:
1068                     self.current_line.append(node)
1069         yield from super().visit_default(node)
1070
1071     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1072         """Increase indentation level, maybe yield a line."""
1073         # In blib2to3 INDENT never holds comments.
1074         yield from self.line(+1)
1075         yield from self.visit_default(node)
1076
1077     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1078         """Decrease indentation level, maybe yield a line."""
1079         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1080         yield from self.line(-1)
1081
1082     def visit_stmt(
1083         self, node: Node, keywords: Set[str], parens: Set[str]
1084     ) -> Iterator[Line]:
1085         """Visit a statement.
1086
1087         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1088         `def`, `with`, `class`, and `assert`.
1089
1090         The relevant Python language `keywords` for a given statement will be
1091         NAME leaves within it. This methods puts those on a separate line.
1092
1093         `parens` holds pairs of nodes where invisible parentheses should be put.
1094         Keys hold nodes after which opening parentheses should be put, values
1095         hold nodes before which closing parentheses should be put.
1096         """
1097         normalize_invisible_parens(node, parens_after=parens)
1098         for child in node.children:
1099             if child.type == token.NAME and child.value in keywords:  # type: ignore
1100                 yield from self.line()
1101
1102             yield from self.visit(child)
1103
1104     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1105         """Visit a statement without nested statements."""
1106         is_suite_like = node.parent and node.parent.type in STATEMENT
1107         if is_suite_like:
1108             yield from self.line(+1)
1109             yield from self.visit_default(node)
1110             yield from self.line(-1)
1111
1112         else:
1113             yield from self.line()
1114             yield from self.visit_default(node)
1115
1116     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1117         """Visit `async def`, `async for`, `async with`."""
1118         yield from self.line()
1119
1120         children = iter(node.children)
1121         for child in children:
1122             yield from self.visit(child)
1123
1124             if child.type == token.ASYNC:
1125                 break
1126
1127         internal_stmt = next(children)
1128         for child in internal_stmt.children:
1129             yield from self.visit(child)
1130
1131     def visit_decorators(self, node: Node) -> Iterator[Line]:
1132         """Visit decorators."""
1133         for child in node.children:
1134             yield from self.line()
1135             yield from self.visit(child)
1136
1137     def visit_import_from(self, node: Node) -> Iterator[Line]:
1138         """Visit import_from and maybe put invisible parentheses.
1139
1140         This is separate from `visit_stmt` because import statements don't
1141         support arbitrary atoms and thus handling of parentheses is custom.
1142         """
1143         check_lpar = False
1144         for index, child in enumerate(node.children):
1145             if check_lpar:
1146                 if child.type == token.LPAR:
1147                     # make parentheses invisible
1148                     child.value = ""  # type: ignore
1149                     node.children[-1].value = ""  # type: ignore
1150                 else:
1151                     # insert invisible parentheses
1152                     node.insert_child(index, Leaf(token.LPAR, ""))
1153                     node.append_child(Leaf(token.RPAR, ""))
1154                 break
1155
1156             check_lpar = (
1157                 child.type == token.NAME and child.value == "import"  # type: ignore
1158             )
1159
1160         for child in node.children:
1161             yield from self.visit(child)
1162
1163     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1164         """Remove a semicolon and put the other statement on a separate line."""
1165         yield from self.line()
1166
1167     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1168         """End of file. Process outstanding comments and end with a newline."""
1169         yield from self.visit_default(leaf)
1170         yield from self.line()
1171
1172     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1173         """Used when file contained a `# fmt: off`."""
1174         if isinstance(node, Node):
1175             for child in node.children:
1176                 yield from self.visit(child)
1177
1178         else:
1179             try:
1180                 self.current_line.append(node)
1181             except FormatOn as f_on:
1182                 f_on.trim_prefix(node)
1183                 yield from self.line()
1184                 yield from self.visit(node)
1185
1186             if node.type == token.ENDMARKER:
1187                 # somebody decided not to put a final `# fmt: on`
1188                 yield from self.line()
1189
1190     def __attrs_post_init__(self) -> None:
1191         """You are in a twisty little maze of passages."""
1192         v = self.visit_stmt
1193         Ø: Set[str] = set()
1194         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1195         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1196         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1197         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1198         self.visit_try_stmt = partial(
1199             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1200         )
1201         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1202         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1203         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1204         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1205         self.visit_async_funcdef = self.visit_async_stmt
1206         self.visit_decorated = self.visit_decorators
1207
1208
1209 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1210 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1211 OPENING_BRACKETS = set(BRACKET.keys())
1212 CLOSING_BRACKETS = set(BRACKET.values())
1213 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1214 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1215
1216
1217 def whitespace(leaf: Leaf) -> str:  # noqa C901
1218     """Return whitespace prefix if needed for the given `leaf`."""
1219     NO = ""
1220     SPACE = " "
1221     DOUBLESPACE = "  "
1222     t = leaf.type
1223     p = leaf.parent
1224     v = leaf.value
1225     if t in ALWAYS_NO_SPACE:
1226         return NO
1227
1228     if t == token.COMMENT:
1229         return DOUBLESPACE
1230
1231     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1232     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1233         return NO
1234
1235     prev = leaf.prev_sibling
1236     if not prev:
1237         prevp = preceding_leaf(p)
1238         if not prevp or prevp.type in OPENING_BRACKETS:
1239             return NO
1240
1241         if t == token.COLON:
1242             return SPACE if prevp.type == token.COMMA else NO
1243
1244         if prevp.type == token.EQUAL:
1245             if prevp.parent:
1246                 if prevp.parent.type in {
1247                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1248                 }:
1249                     return NO
1250
1251                 elif prevp.parent.type == syms.typedargslist:
1252                     # A bit hacky: if the equal sign has whitespace, it means we
1253                     # previously found it's a typed argument.  So, we're using
1254                     # that, too.
1255                     return prevp.prefix
1256
1257         elif prevp.type == token.DOUBLESTAR:
1258             if (
1259                 prevp.parent
1260                 and prevp.parent.type in {
1261                     syms.arglist,
1262                     syms.argument,
1263                     syms.dictsetmaker,
1264                     syms.parameters,
1265                     syms.typedargslist,
1266                     syms.varargslist,
1267                 }
1268             ):
1269                 return NO
1270
1271         elif prevp.type == token.COLON:
1272             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1273                 return NO
1274
1275         elif (
1276             prevp.parent
1277             and prevp.parent.type in {syms.factor, syms.star_expr}
1278             and prevp.type in MATH_OPERATORS
1279         ):
1280             return NO
1281
1282         elif (
1283             prevp.type == token.RIGHTSHIFT
1284             and prevp.parent
1285             and prevp.parent.type == syms.shift_expr
1286             and prevp.prev_sibling
1287             and prevp.prev_sibling.type == token.NAME
1288             and prevp.prev_sibling.value == "print"  # type: ignore
1289         ):
1290             # Python 2 print chevron
1291             return NO
1292
1293     elif prev.type in OPENING_BRACKETS:
1294         return NO
1295
1296     if p.type in {syms.parameters, syms.arglist}:
1297         # untyped function signatures or calls
1298         if t == token.RPAR:
1299             return NO
1300
1301         if not prev or prev.type != token.COMMA:
1302             return NO
1303
1304     elif p.type == syms.varargslist:
1305         # lambdas
1306         if t == token.RPAR:
1307             return NO
1308
1309         if prev and prev.type != token.COMMA:
1310             return NO
1311
1312     elif p.type == syms.typedargslist:
1313         # typed function signatures
1314         if not prev:
1315             return NO
1316
1317         if t == token.EQUAL:
1318             if prev.type != syms.tname:
1319                 return NO
1320
1321         elif prev.type == token.EQUAL:
1322             # A bit hacky: if the equal sign has whitespace, it means we
1323             # previously found it's a typed argument.  So, we're using that, too.
1324             return prev.prefix
1325
1326         elif prev.type != token.COMMA:
1327             return NO
1328
1329     elif p.type == syms.tname:
1330         # type names
1331         if not prev:
1332             prevp = preceding_leaf(p)
1333             if not prevp or prevp.type != token.COMMA:
1334                 return NO
1335
1336     elif p.type == syms.trailer:
1337         # attributes and calls
1338         if t == token.LPAR or t == token.RPAR:
1339             return NO
1340
1341         if not prev:
1342             if t == token.DOT:
1343                 prevp = preceding_leaf(p)
1344                 if not prevp or prevp.type != token.NUMBER:
1345                     return NO
1346
1347             elif t == token.LSQB:
1348                 return NO
1349
1350         elif prev.type != token.COMMA:
1351             return NO
1352
1353     elif p.type == syms.argument:
1354         # single argument
1355         if t == token.EQUAL:
1356             return NO
1357
1358         if not prev:
1359             prevp = preceding_leaf(p)
1360             if not prevp or prevp.type == token.LPAR:
1361                 return NO
1362
1363         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1364             return NO
1365
1366     elif p.type == syms.decorator:
1367         # decorators
1368         return NO
1369
1370     elif p.type == syms.dotted_name:
1371         if prev:
1372             return NO
1373
1374         prevp = preceding_leaf(p)
1375         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1376             return NO
1377
1378     elif p.type == syms.classdef:
1379         if t == token.LPAR:
1380             return NO
1381
1382         if prev and prev.type == token.LPAR:
1383             return NO
1384
1385     elif p.type == syms.subscript:
1386         # indexing
1387         if not prev:
1388             assert p.parent is not None, "subscripts are always parented"
1389             if p.parent.type == syms.subscriptlist:
1390                 return SPACE
1391
1392             return NO
1393
1394         else:
1395             return NO
1396
1397     elif p.type == syms.atom:
1398         if prev and t == token.DOT:
1399             # dots, but not the first one.
1400             return NO
1401
1402     elif (
1403         p.type == syms.listmaker
1404         or p.type == syms.testlist_gexp
1405         or p.type == syms.subscriptlist
1406     ):
1407         # list interior, including unpacking
1408         if not prev:
1409             return NO
1410
1411     elif p.type == syms.dictsetmaker:
1412         # dict and set interior, including unpacking
1413         if not prev:
1414             return NO
1415
1416         if prev.type == token.DOUBLESTAR:
1417             return NO
1418
1419     elif p.type in {syms.factor, syms.star_expr}:
1420         # unary ops
1421         if not prev:
1422             prevp = preceding_leaf(p)
1423             if not prevp or prevp.type in OPENING_BRACKETS:
1424                 return NO
1425
1426             prevp_parent = prevp.parent
1427             assert prevp_parent is not None
1428             if (
1429                 prevp.type == token.COLON
1430                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1431             ):
1432                 return NO
1433
1434             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1435                 return NO
1436
1437         elif t == token.NAME or t == token.NUMBER:
1438             return NO
1439
1440     elif p.type == syms.import_from:
1441         if t == token.DOT:
1442             if prev and prev.type == token.DOT:
1443                 return NO
1444
1445         elif t == token.NAME:
1446             if v == "import":
1447                 return SPACE
1448
1449             if prev and prev.type == token.DOT:
1450                 return NO
1451
1452     elif p.type == syms.sliceop:
1453         return NO
1454
1455     return SPACE
1456
1457
1458 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1459     """Return the first leaf that precedes `node`, if any."""
1460     while node:
1461         res = node.prev_sibling
1462         if res:
1463             if isinstance(res, Leaf):
1464                 return res
1465
1466             try:
1467                 return list(res.leaves())[-1]
1468
1469             except IndexError:
1470                 return None
1471
1472         node = node.parent
1473     return None
1474
1475
1476 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1477     """Return the priority of the `leaf` delimiter, given a line break after it.
1478
1479     The delimiter priorities returned here are from those delimiters that would
1480     cause a line break after themselves.
1481
1482     Higher numbers are higher priority.
1483     """
1484     if leaf.type == token.COMMA:
1485         return COMMA_PRIORITY
1486
1487     return 0
1488
1489
1490 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1491     """Return the priority of the `leaf` delimiter, given a line before after it.
1492
1493     The delimiter priorities returned here are from those delimiters that would
1494     cause a line break before themselves.
1495
1496     Higher numbers are higher priority.
1497     """
1498     if (
1499         leaf.type in VARARGS
1500         and leaf.parent
1501         and leaf.parent.type in {syms.argument, syms.typedargslist}
1502     ):
1503         # * and ** might also be MATH_OPERATORS but in this case they are not.
1504         # Don't treat them as a delimiter.
1505         return 0
1506
1507     if (
1508         leaf.type in MATH_OPERATORS
1509         and leaf.parent
1510         and leaf.parent.type not in {syms.factor, syms.star_expr}
1511     ):
1512         return MATH_PRIORITY
1513
1514     if leaf.type in COMPARATORS:
1515         return COMPARATOR_PRIORITY
1516
1517     if (
1518         leaf.type == token.STRING
1519         and previous is not None
1520         and previous.type == token.STRING
1521     ):
1522         return STRING_PRIORITY
1523
1524     if (
1525         leaf.type == token.NAME
1526         and leaf.value == "for"
1527         and leaf.parent
1528         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1529     ):
1530         return COMPREHENSION_PRIORITY
1531
1532     if (
1533         leaf.type == token.NAME
1534         and leaf.value == "if"
1535         and leaf.parent
1536         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1537     ):
1538         return COMPREHENSION_PRIORITY
1539
1540     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1541         return LOGIC_PRIORITY
1542
1543     return 0
1544
1545
1546 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1547     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1548
1549     Higher numbers are higher priority.
1550     """
1551     return max(
1552         is_split_before_delimiter(leaf, previous),
1553         is_split_after_delimiter(leaf, previous),
1554     )
1555
1556
1557 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1558     """Clean the prefix of the `leaf` and generate comments from it, if any.
1559
1560     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1561     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1562     move because it does away with modifying the grammar to include all the
1563     possible places in which comments can be placed.
1564
1565     The sad consequence for us though is that comments don't "belong" anywhere.
1566     This is why this function generates simple parentless Leaf objects for
1567     comments.  We simply don't know what the correct parent should be.
1568
1569     No matter though, we can live without this.  We really only need to
1570     differentiate between inline and standalone comments.  The latter don't
1571     share the line with any code.
1572
1573     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1574     are emitted with a fake STANDALONE_COMMENT token identifier.
1575     """
1576     p = leaf.prefix
1577     if not p:
1578         return
1579
1580     if "#" not in p:
1581         return
1582
1583     consumed = 0
1584     nlines = 0
1585     for index, line in enumerate(p.split("\n")):
1586         consumed += len(line) + 1  # adding the length of the split '\n'
1587         line = line.lstrip()
1588         if not line:
1589             nlines += 1
1590         if not line.startswith("#"):
1591             continue
1592
1593         if index == 0 and leaf.type != token.ENDMARKER:
1594             comment_type = token.COMMENT  # simple trailing comment
1595         else:
1596             comment_type = STANDALONE_COMMENT
1597         comment = make_comment(line)
1598         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1599
1600         if comment in {"# fmt: on", "# yapf: enable"}:
1601             raise FormatOn(consumed)
1602
1603         if comment in {"# fmt: off", "# yapf: disable"}:
1604             if comment_type == STANDALONE_COMMENT:
1605                 raise FormatOff(consumed)
1606
1607             prev = preceding_leaf(leaf)
1608             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1609                 raise FormatOff(consumed)
1610
1611         nlines = 0
1612
1613
1614 def make_comment(content: str) -> str:
1615     """Return a consistently formatted comment from the given `content` string.
1616
1617     All comments (except for "##", "#!", "#:") should have a single space between
1618     the hash sign and the content.
1619
1620     If `content` didn't start with a hash sign, one is provided.
1621     """
1622     content = content.rstrip()
1623     if not content:
1624         return "#"
1625
1626     if content[0] == "#":
1627         content = content[1:]
1628     if content and content[0] not in " !:#":
1629         content = " " + content
1630     return "#" + content
1631
1632
1633 def split_line(
1634     line: Line, line_length: int, inner: bool = False, py36: bool = False
1635 ) -> Iterator[Line]:
1636     """Split a `line` into potentially many lines.
1637
1638     They should fit in the allotted `line_length` but might not be able to.
1639     `inner` signifies that there were a pair of brackets somewhere around the
1640     current `line`, possibly transitively. This means we can fallback to splitting
1641     by delimiters if the LHS/RHS don't yield any results.
1642
1643     If `py36` is True, splitting may generate syntax that is only compatible
1644     with Python 3.6 and later.
1645     """
1646     if isinstance(line, UnformattedLines) or line.is_comment:
1647         yield line
1648         return
1649
1650     line_str = str(line).strip("\n")
1651     if (
1652         len(line_str) <= line_length
1653         and "\n" not in line_str  # multiline strings
1654         and not line.contains_standalone_comments()
1655     ):
1656         yield line
1657         return
1658
1659     split_funcs: List[SplitFunc]
1660     if line.is_def:
1661         split_funcs = [left_hand_split]
1662     elif line.inside_brackets:
1663         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1664     else:
1665         split_funcs = [right_hand_split]
1666     for split_func in split_funcs:
1667         # We are accumulating lines in `result` because we might want to abort
1668         # mission and return the original line in the end, or attempt a different
1669         # split altogether.
1670         result: List[Line] = []
1671         try:
1672             for l in split_func(line, py36):
1673                 if str(l).strip("\n") == line_str:
1674                     raise CannotSplit("Split function returned an unchanged result")
1675
1676                 result.extend(
1677                     split_line(l, line_length=line_length, inner=True, py36=py36)
1678                 )
1679         except CannotSplit as cs:
1680             continue
1681
1682         else:
1683             yield from result
1684             break
1685
1686     else:
1687         yield line
1688
1689
1690 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1691     """Split line into many lines, starting with the first matching bracket pair.
1692
1693     Note: this usually looks weird, only use this for function definitions.
1694     Prefer RHS otherwise.
1695     """
1696     head = Line(depth=line.depth)
1697     body = Line(depth=line.depth + 1, inside_brackets=True)
1698     tail = Line(depth=line.depth)
1699     tail_leaves: List[Leaf] = []
1700     body_leaves: List[Leaf] = []
1701     head_leaves: List[Leaf] = []
1702     current_leaves = head_leaves
1703     matching_bracket = None
1704     for leaf in line.leaves:
1705         if (
1706             current_leaves is body_leaves
1707             and leaf.type in CLOSING_BRACKETS
1708             and leaf.opening_bracket is matching_bracket
1709         ):
1710             current_leaves = tail_leaves if body_leaves else head_leaves
1711         current_leaves.append(leaf)
1712         if current_leaves is head_leaves:
1713             if leaf.type in OPENING_BRACKETS:
1714                 matching_bracket = leaf
1715                 current_leaves = body_leaves
1716     # Since body is a new indent level, remove spurious leading whitespace.
1717     if body_leaves:
1718         normalize_prefix(body_leaves[0], inside_brackets=True)
1719     # Build the new lines.
1720     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1721         for leaf in leaves:
1722             result.append(leaf, preformatted=True)
1723             for comment_after in line.comments_after(leaf):
1724                 result.append(comment_after, preformatted=True)
1725     bracket_split_succeeded_or_raise(head, body, tail)
1726     for result in (head, body, tail):
1727         if result:
1728             yield result
1729
1730
1731 def right_hand_split(
1732     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1733 ) -> Iterator[Line]:
1734     """Split line into many lines, starting with the last matching bracket pair."""
1735     head = Line(depth=line.depth)
1736     body = Line(depth=line.depth + 1, inside_brackets=True)
1737     tail = Line(depth=line.depth)
1738     tail_leaves: List[Leaf] = []
1739     body_leaves: List[Leaf] = []
1740     head_leaves: List[Leaf] = []
1741     current_leaves = tail_leaves
1742     opening_bracket = None
1743     closing_bracket = None
1744     for leaf in reversed(line.leaves):
1745         if current_leaves is body_leaves:
1746             if leaf is opening_bracket:
1747                 current_leaves = head_leaves if body_leaves else tail_leaves
1748         current_leaves.append(leaf)
1749         if current_leaves is tail_leaves:
1750             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1751                 opening_bracket = leaf.opening_bracket
1752                 closing_bracket = leaf
1753                 current_leaves = body_leaves
1754     tail_leaves.reverse()
1755     body_leaves.reverse()
1756     head_leaves.reverse()
1757     # Since body is a new indent level, remove spurious leading whitespace.
1758     if body_leaves:
1759         normalize_prefix(body_leaves[0], inside_brackets=True)
1760     elif not head_leaves:
1761         # No `head` and no `body` means the split failed. `tail` has all content.
1762         raise CannotSplit("No brackets found")
1763
1764     # Build the new lines.
1765     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1766         for leaf in leaves:
1767             result.append(leaf, preformatted=True)
1768             for comment_after in line.comments_after(leaf):
1769                 result.append(comment_after, preformatted=True)
1770     bracket_split_succeeded_or_raise(head, body, tail)
1771     assert opening_bracket and closing_bracket
1772     if (
1773         opening_bracket.type == token.LPAR
1774         and not opening_bracket.value
1775         and closing_bracket.type == token.RPAR
1776         and not closing_bracket.value
1777     ):
1778         # These parens were optional. If there aren't any delimiters or standalone
1779         # comments in the body, they were unnecessary and another split without
1780         # them should be attempted.
1781         if not (
1782             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1783         ):
1784             omit = {id(closing_bracket), *omit}
1785             yield from right_hand_split(line, py36=py36, omit=omit)
1786             return
1787
1788     ensure_visible(opening_bracket)
1789     ensure_visible(closing_bracket)
1790     for result in (head, body, tail):
1791         if result:
1792             yield result
1793
1794
1795 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1796     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1797
1798     Do nothing otherwise.
1799
1800     A left- or right-hand split is based on a pair of brackets. Content before
1801     (and including) the opening bracket is left on one line, content inside the
1802     brackets is put on a separate line, and finally content starting with and
1803     following the closing bracket is put on a separate line.
1804
1805     Those are called `head`, `body`, and `tail`, respectively. If the split
1806     produced the same line (all content in `head`) or ended up with an empty `body`
1807     and the `tail` is just the closing bracket, then it's considered failed.
1808     """
1809     tail_len = len(str(tail).strip())
1810     if not body:
1811         if tail_len == 0:
1812             raise CannotSplit("Splitting brackets produced the same line")
1813
1814         elif tail_len < 3:
1815             raise CannotSplit(
1816                 f"Splitting brackets on an empty body to save "
1817                 f"{tail_len} characters is not worth it"
1818             )
1819
1820
1821 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1822     """Normalize prefix of the first leaf in every line returned by `split_func`.
1823
1824     This is a decorator over relevant split functions.
1825     """
1826
1827     @wraps(split_func)
1828     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1829         for l in split_func(line, py36):
1830             normalize_prefix(l.leaves[0], inside_brackets=True)
1831             yield l
1832
1833     return split_wrapper
1834
1835
1836 @dont_increase_indentation
1837 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1838     """Split according to delimiters of the highest priority.
1839
1840     If `py36` is True, the split will add trailing commas also in function
1841     signatures that contain `*` and `**`.
1842     """
1843     try:
1844         last_leaf = line.leaves[-1]
1845     except IndexError:
1846         raise CannotSplit("Line empty")
1847
1848     delimiters = line.bracket_tracker.delimiters
1849     try:
1850         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1851             exclude={id(last_leaf)}
1852         )
1853     except ValueError:
1854         raise CannotSplit("No delimiters found")
1855
1856     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1857     lowest_depth = sys.maxsize
1858     trailing_comma_safe = True
1859
1860     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1861         """Append `leaf` to current line or to new line if appending impossible."""
1862         nonlocal current_line
1863         try:
1864             current_line.append_safe(leaf, preformatted=True)
1865         except ValueError as ve:
1866             yield current_line
1867
1868             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1869             current_line.append(leaf)
1870
1871     for leaf in line.leaves:
1872         yield from append_to_line(leaf)
1873
1874         for comment_after in line.comments_after(leaf):
1875             yield from append_to_line(comment_after)
1876
1877         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1878         if (
1879             leaf.bracket_depth == lowest_depth
1880             and leaf.type == token.STAR
1881             or leaf.type == token.DOUBLESTAR
1882         ):
1883             trailing_comma_safe = trailing_comma_safe and py36
1884         leaf_priority = delimiters.get(id(leaf))
1885         if leaf_priority == delimiter_priority:
1886             yield current_line
1887
1888             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1889     if current_line:
1890         if (
1891             trailing_comma_safe
1892             and delimiter_priority == COMMA_PRIORITY
1893             and current_line.leaves[-1].type != token.COMMA
1894             and current_line.leaves[-1].type != STANDALONE_COMMENT
1895         ):
1896             current_line.append(Leaf(token.COMMA, ","))
1897         yield current_line
1898
1899
1900 @dont_increase_indentation
1901 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1902     """Split standalone comments from the rest of the line."""
1903     if not line.contains_standalone_comments(0):
1904         raise CannotSplit("Line does not have any standalone comments")
1905
1906     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1907
1908     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1909         """Append `leaf` to current line or to new line if appending impossible."""
1910         nonlocal current_line
1911         try:
1912             current_line.append_safe(leaf, preformatted=True)
1913         except ValueError as ve:
1914             yield current_line
1915
1916             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1917             current_line.append(leaf)
1918
1919     for leaf in line.leaves:
1920         yield from append_to_line(leaf)
1921
1922         for comment_after in line.comments_after(leaf):
1923             yield from append_to_line(comment_after)
1924
1925     if current_line:
1926         yield current_line
1927
1928
1929 def is_import(leaf: Leaf) -> bool:
1930     """Return True if the given leaf starts an import statement."""
1931     p = leaf.parent
1932     t = leaf.type
1933     v = leaf.value
1934     return bool(
1935         t == token.NAME
1936         and (
1937             (v == "import" and p and p.type == syms.import_name)
1938             or (v == "from" and p and p.type == syms.import_from)
1939         )
1940     )
1941
1942
1943 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1944     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1945     else.
1946
1947     Note: don't use backslashes for formatting or you'll lose your voting rights.
1948     """
1949     if not inside_brackets:
1950         spl = leaf.prefix.split("#")
1951         if "\\" not in spl[0]:
1952             nl_count = spl[-1].count("\n")
1953             if len(spl) > 1:
1954                 nl_count -= 1
1955             leaf.prefix = "\n" * nl_count
1956             return
1957
1958     leaf.prefix = ""
1959
1960
1961 def normalize_string_quotes(leaf: Leaf) -> None:
1962     """Prefer double quotes but only if it doesn't cause more escaping.
1963
1964     Adds or removes backslashes as appropriate. Doesn't parse and fix
1965     strings nested in f-strings (yet).
1966
1967     Note: Mutates its argument.
1968     """
1969     value = leaf.value.lstrip("furbFURB")
1970     if value[:3] == '"""':
1971         return
1972
1973     elif value[:3] == "'''":
1974         orig_quote = "'''"
1975         new_quote = '"""'
1976     elif value[0] == '"':
1977         orig_quote = '"'
1978         new_quote = "'"
1979     else:
1980         orig_quote = "'"
1981         new_quote = '"'
1982     first_quote_pos = leaf.value.find(orig_quote)
1983     if first_quote_pos == -1:
1984         return  # There's an internal error
1985
1986     prefix = leaf.value[:first_quote_pos]
1987     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
1988     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
1989     escaped_orig_quote = re.compile(rf"\\(\\\\)*{orig_quote}")
1990     if "r" in prefix.casefold():
1991         if unescaped_new_quote.search(body):
1992             # There's at least one unescaped new_quote in this raw string
1993             # so converting is impossible
1994             return
1995
1996         # Do not introduce or remove backslashes in raw strings
1997         new_body = body
1998     else:
1999         new_body = escaped_orig_quote.sub(rf"\1{orig_quote}", body)
2000         new_body = unescaped_new_quote.sub(rf"\1\\{new_quote}", new_body)
2001     if new_quote == '"""' and new_body[-1] == '"':
2002         # edge case:
2003         new_body = new_body[:-1] + '\\"'
2004     orig_escape_count = body.count("\\")
2005     new_escape_count = new_body.count("\\")
2006     if new_escape_count > orig_escape_count:
2007         return  # Do not introduce more escaping
2008
2009     if new_escape_count == orig_escape_count and orig_quote == '"':
2010         return  # Prefer double quotes
2011
2012     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2013
2014
2015 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2016     """Make existing optional parentheses invisible or create new ones.
2017
2018     Standardizes on visible parentheses for single-element tuples, and keeps
2019     existing visible parentheses for other tuples and generator expressions.
2020     """
2021     check_lpar = False
2022     for child in list(node.children):
2023         if check_lpar:
2024             if child.type == syms.atom:
2025                 if not (
2026                     is_empty_tuple(child)
2027                     or is_one_tuple(child)
2028                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2029                 ):
2030                     first = child.children[0]
2031                     last = child.children[-1]
2032                     if first.type == token.LPAR and last.type == token.RPAR:
2033                         # make parentheses invisible
2034                         first.value = ""  # type: ignore
2035                         last.value = ""  # type: ignore
2036             elif is_one_tuple(child):
2037                 # wrap child in visible parentheses
2038                 lpar = Leaf(token.LPAR, "(")
2039                 rpar = Leaf(token.RPAR, ")")
2040                 index = child.remove() or 0
2041                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2042             else:
2043                 # wrap child in invisible parentheses
2044                 lpar = Leaf(token.LPAR, "")
2045                 rpar = Leaf(token.RPAR, "")
2046                 index = child.remove() or 0
2047                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2048
2049         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2050
2051
2052 def is_empty_tuple(node: LN) -> bool:
2053     """Return True if `node` holds an empty tuple."""
2054     return (
2055         node.type == syms.atom
2056         and len(node.children) == 2
2057         and node.children[0].type == token.LPAR
2058         and node.children[1].type == token.RPAR
2059     )
2060
2061
2062 def is_one_tuple(node: LN) -> bool:
2063     """Return True if `node` holds a tuple with one element, with or without parens."""
2064     if node.type == syms.atom:
2065         if len(node.children) != 3:
2066             return False
2067
2068         lpar, gexp, rpar = node.children
2069         if not (
2070             lpar.type == token.LPAR
2071             and gexp.type == syms.testlist_gexp
2072             and rpar.type == token.RPAR
2073         ):
2074             return False
2075
2076         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2077
2078     return (
2079         node.type in IMPLICIT_TUPLE
2080         and len(node.children) == 2
2081         and node.children[1].type == token.COMMA
2082     )
2083
2084
2085 def max_delimiter_priority_in_atom(node: LN) -> int:
2086     if node.type != syms.atom:
2087         return 0
2088
2089     first = node.children[0]
2090     last = node.children[-1]
2091     if first.type == token.LPAR and last.type == token.RPAR:
2092         bt = BracketTracker()
2093         for c in node.children[1:-1]:
2094             if isinstance(c, Leaf):
2095                 bt.mark(c)
2096             else:
2097                 for leaf in c.leaves():
2098                     bt.mark(leaf)
2099     try:
2100         return bt.max_delimiter_priority()
2101
2102     except ValueError:
2103         return 0
2104
2105
2106 def ensure_visible(leaf: Leaf) -> None:
2107     """Make sure parentheses are visible.
2108
2109     They could be invisible as part of some statements (see
2110     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2111     """
2112     if leaf.type == token.LPAR:
2113         leaf.value = "("
2114     elif leaf.type == token.RPAR:
2115         leaf.value = ")"
2116
2117
2118 def is_python36(node: Node) -> bool:
2119     """Return True if the current file is using Python 3.6+ features.
2120
2121     Currently looking for:
2122     - f-strings; and
2123     - trailing commas after * or ** in function signatures.
2124     """
2125     for n in node.pre_order():
2126         if n.type == token.STRING:
2127             value_head = n.value[:2]  # type: ignore
2128             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2129                 return True
2130
2131         elif (
2132             n.type == syms.typedargslist
2133             and n.children
2134             and n.children[-1].type == token.COMMA
2135         ):
2136             for ch in n.children:
2137                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
2138                     return True
2139
2140     return False
2141
2142
2143 PYTHON_EXTENSIONS = {".py"}
2144 BLACKLISTED_DIRECTORIES = {
2145     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2146 }
2147
2148
2149 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2150     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2151     and have one of the PYTHON_EXTENSIONS.
2152     """
2153     for child in path.iterdir():
2154         if child.is_dir():
2155             if child.name in BLACKLISTED_DIRECTORIES:
2156                 continue
2157
2158             yield from gen_python_files_in_dir(child)
2159
2160         elif child.suffix in PYTHON_EXTENSIONS:
2161             yield child
2162
2163
2164 @dataclass
2165 class Report:
2166     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2167     check: bool = False
2168     quiet: bool = False
2169     change_count: int = 0
2170     same_count: int = 0
2171     failure_count: int = 0
2172
2173     def done(self, src: Path, changed: bool) -> None:
2174         """Increment the counter for successful reformatting. Write out a message."""
2175         if changed:
2176             reformatted = "would reformat" if self.check else "reformatted"
2177             if not self.quiet:
2178                 out(f"{reformatted} {src}")
2179             self.change_count += 1
2180         else:
2181             if not self.quiet:
2182                 out(f"{src} already well formatted, good job.", bold=False)
2183             self.same_count += 1
2184
2185     def failed(self, src: Path, message: str) -> None:
2186         """Increment the counter for failed reformatting. Write out a message."""
2187         err(f"error: cannot format {src}: {message}")
2188         self.failure_count += 1
2189
2190     @property
2191     def return_code(self) -> int:
2192         """Return the exit code that the app should use.
2193
2194         This considers the current state of changed files and failures:
2195         - if there were any failures, return 123;
2196         - if any files were changed and --check is being used, return 1;
2197         - otherwise return 0.
2198         """
2199         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2200         # 126 we have special returncodes reserved by the shell.
2201         if self.failure_count:
2202             return 123
2203
2204         elif self.change_count and self.check:
2205             return 1
2206
2207         return 0
2208
2209     def __str__(self) -> str:
2210         """Render a color report of the current state.
2211
2212         Use `click.unstyle` to remove colors.
2213         """
2214         if self.check:
2215             reformatted = "would be reformatted"
2216             unchanged = "would be left unchanged"
2217             failed = "would fail to reformat"
2218         else:
2219             reformatted = "reformatted"
2220             unchanged = "left unchanged"
2221             failed = "failed to reformat"
2222         report = []
2223         if self.change_count:
2224             s = "s" if self.change_count > 1 else ""
2225             report.append(
2226                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2227             )
2228         if self.same_count:
2229             s = "s" if self.same_count > 1 else ""
2230             report.append(f"{self.same_count} file{s} {unchanged}")
2231         if self.failure_count:
2232             s = "s" if self.failure_count > 1 else ""
2233             report.append(
2234                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2235             )
2236         return ", ".join(report) + "."
2237
2238
2239 def assert_equivalent(src: str, dst: str) -> None:
2240     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2241
2242     import ast
2243     import traceback
2244
2245     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2246         """Simple visitor generating strings to compare ASTs by content."""
2247         yield f"{'  ' * depth}{node.__class__.__name__}("
2248
2249         for field in sorted(node._fields):
2250             try:
2251                 value = getattr(node, field)
2252             except AttributeError:
2253                 continue
2254
2255             yield f"{'  ' * (depth+1)}{field}="
2256
2257             if isinstance(value, list):
2258                 for item in value:
2259                     if isinstance(item, ast.AST):
2260                         yield from _v(item, depth + 2)
2261
2262             elif isinstance(value, ast.AST):
2263                 yield from _v(value, depth + 2)
2264
2265             else:
2266                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2267
2268         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2269
2270     try:
2271         src_ast = ast.parse(src)
2272     except Exception as exc:
2273         major, minor = sys.version_info[:2]
2274         raise AssertionError(
2275             f"cannot use --safe with this file; failed to parse source file "
2276             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2277             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2278         )
2279
2280     try:
2281         dst_ast = ast.parse(dst)
2282     except Exception as exc:
2283         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2284         raise AssertionError(
2285             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2286             f"Please report a bug on https://github.com/ambv/black/issues.  "
2287             f"This invalid output might be helpful: {log}"
2288         ) from None
2289
2290     src_ast_str = "\n".join(_v(src_ast))
2291     dst_ast_str = "\n".join(_v(dst_ast))
2292     if src_ast_str != dst_ast_str:
2293         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2294         raise AssertionError(
2295             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2296             f"the source.  "
2297             f"Please report a bug on https://github.com/ambv/black/issues.  "
2298             f"This diff might be helpful: {log}"
2299         ) from None
2300
2301
2302 def assert_stable(src: str, dst: str, line_length: int) -> None:
2303     """Raise AssertionError if `dst` reformats differently the second time."""
2304     newdst = format_str(dst, line_length=line_length)
2305     if dst != newdst:
2306         log = dump_to_file(
2307             diff(src, dst, "source", "first pass"),
2308             diff(dst, newdst, "first pass", "second pass"),
2309         )
2310         raise AssertionError(
2311             f"INTERNAL ERROR: Black produced different code on the second pass "
2312             f"of the formatter.  "
2313             f"Please report a bug on https://github.com/ambv/black/issues.  "
2314             f"This diff might be helpful: {log}"
2315         ) from None
2316
2317
2318 def dump_to_file(*output: str) -> str:
2319     """Dump `output` to a temporary file. Return path to the file."""
2320     import tempfile
2321
2322     with tempfile.NamedTemporaryFile(
2323         mode="w", prefix="blk_", suffix=".log", delete=False
2324     ) as f:
2325         for lines in output:
2326             f.write(lines)
2327             if lines and lines[-1] != "\n":
2328                 f.write("\n")
2329     return f.name
2330
2331
2332 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2333     """Return a unified diff string between strings `a` and `b`."""
2334     import difflib
2335
2336     a_lines = [line + "\n" for line in a.split("\n")]
2337     b_lines = [line + "\n" for line in b.split("\n")]
2338     return "".join(
2339         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2340     )
2341
2342
2343 def cancel(tasks: List[asyncio.Task]) -> None:
2344     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2345     err("Aborted!")
2346     for task in tasks:
2347         task.cancel()
2348
2349
2350 def shutdown(loop: BaseEventLoop) -> None:
2351     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2352     try:
2353         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2354         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2355         if not to_cancel:
2356             return
2357
2358         for task in to_cancel:
2359             task.cancel()
2360         loop.run_until_complete(
2361             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2362         )
2363     finally:
2364         # `concurrent.futures.Future` objects cannot be cancelled once they
2365         # are already running. There might be some when the `shutdown()` happened.
2366         # Silence their logger's spew about the event loop being closed.
2367         cf_logger = logging.getLogger("concurrent.futures")
2368         cf_logger.setLevel(logging.CRITICAL)
2369         loop.close()
2370
2371
2372 if __name__ == "__main__":
2373     main()