]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Remove debug print
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from enum import Enum
7 from functools import partial, wraps
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import re
14 import tokenize
15 import signal
16 import sys
17 from typing import (
18     Any,
19     Callable,
20     Collection,
21     Dict,
22     Generic,
23     Iterable,
24     Iterator,
25     List,
26     Optional,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from attr import dataclass, Factory
35 import click
36
37 # lib2to3 fork
38 from blib2to3.pytree import Node, Leaf, type_repr
39 from blib2to3 import pygram, pytree
40 from blib2to3.pgen2 import driver, token
41 from blib2to3.pgen2.parse import ParseError
42
43 __version__ = "18.4a1"
44 DEFAULT_LINE_LENGTH = 88
45 # types
46 syms = pygram.python_symbols
47 FileContent = str
48 Encoding = str
49 Depth = int
50 NodeType = int
51 LeafID = int
52 Priority = int
53 Index = int
54 LN = Union[Leaf, Node]
55 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
56 out = partial(click.secho, bold=True, err=True)
57 err = partial(click.secho, fg="red", err=True)
58
59
60 class NothingChanged(UserWarning):
61     """Raised by :func:`format_file` when reformatted code is the same as source."""
62
63
64 class CannotSplit(Exception):
65     """A readable split that fits the allotted line length is impossible.
66
67     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
68     :func:`delimiter_split`.
69     """
70
71
72 class FormatError(Exception):
73     """Base exception for `# fmt: on` and `# fmt: off` handling.
74
75     It holds the number of bytes of the prefix consumed before the format
76     control comment appeared.
77     """
78
79     def __init__(self, consumed: int) -> None:
80         super().__init__(consumed)
81         self.consumed = consumed
82
83     def trim_prefix(self, leaf: Leaf) -> None:
84         leaf.prefix = leaf.prefix[self.consumed:]
85
86     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
87         """Returns a new Leaf from the consumed part of the prefix."""
88         unformatted_prefix = leaf.prefix[:self.consumed]
89         return Leaf(token.NEWLINE, unformatted_prefix)
90
91
92 class FormatOn(FormatError):
93     """Found a comment like `# fmt: on` in the file."""
94
95
96 class FormatOff(FormatError):
97     """Found a comment like `# fmt: off` in the file."""
98
99
100 class WriteBack(Enum):
101     NO = 0
102     YES = 1
103     DIFF = 2
104
105
106 @click.command()
107 @click.option(
108     "-l",
109     "--line-length",
110     type=int,
111     default=DEFAULT_LINE_LENGTH,
112     help="How many character per line to allow.",
113     show_default=True,
114 )
115 @click.option(
116     "--check",
117     is_flag=True,
118     help=(
119         "Don't write the files back, just return the status.  Return code 0 "
120         "means nothing would change.  Return code 1 means some files would be "
121         "reformatted.  Return code 123 means there was an internal error."
122     ),
123 )
124 @click.option(
125     "--diff",
126     is_flag=True,
127     help="Don't write the files back, just output a diff for each file on stdout.",
128 )
129 @click.option(
130     "--fast/--safe",
131     is_flag=True,
132     help="If --fast given, skip temporary sanity checks. [default: --safe]",
133 )
134 @click.option(
135     "-q",
136     "--quiet",
137     is_flag=True,
138     help=(
139         "Don't emit non-error messages to stderr. Errors are still emitted, "
140         "silence those with 2>/dev/null."
141     ),
142 )
143 @click.version_option(version=__version__)
144 @click.argument(
145     "src",
146     nargs=-1,
147     type=click.Path(
148         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
149     ),
150 )
151 @click.pass_context
152 def main(
153     ctx: click.Context,
154     line_length: int,
155     check: bool,
156     diff: bool,
157     fast: bool,
158     quiet: bool,
159     src: List[str],
160 ) -> None:
161     """The uncompromising code formatter."""
162     sources: List[Path] = []
163     for s in src:
164         p = Path(s)
165         if p.is_dir():
166             sources.extend(gen_python_files_in_dir(p))
167         elif p.is_file():
168             # if a file was explicitly given, we don't care about its extension
169             sources.append(p)
170         elif s == "-":
171             sources.append(Path("-"))
172         else:
173             err(f"invalid path: {s}")
174     if check and diff:
175         exc = click.ClickException("Options --check and --diff are mutually exclusive")
176         exc.exit_code = 2
177         raise exc
178
179     if check:
180         write_back = WriteBack.NO
181     elif diff:
182         write_back = WriteBack.DIFF
183     else:
184         write_back = WriteBack.YES
185     if len(sources) == 0:
186         ctx.exit(0)
187     elif len(sources) == 1:
188         p = sources[0]
189         report = Report(check=check, quiet=quiet)
190         try:
191             if not p.is_file() and str(p) == "-":
192                 changed = format_stdin_to_stdout(
193                     line_length=line_length, fast=fast, write_back=write_back
194                 )
195             else:
196                 changed = format_file_in_place(
197                     p, line_length=line_length, fast=fast, write_back=write_back
198                 )
199             report.done(p, changed)
200         except Exception as exc:
201             report.failed(p, str(exc))
202         ctx.exit(report.return_code)
203     else:
204         loop = asyncio.get_event_loop()
205         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
206         return_code = 1
207         try:
208             return_code = loop.run_until_complete(
209                 schedule_formatting(
210                     sources, line_length, write_back, fast, quiet, loop, executor
211                 )
212             )
213         finally:
214             shutdown(loop)
215             ctx.exit(return_code)
216
217
218 async def schedule_formatting(
219     sources: List[Path],
220     line_length: int,
221     write_back: WriteBack,
222     fast: bool,
223     quiet: bool,
224     loop: BaseEventLoop,
225     executor: Executor,
226 ) -> int:
227     """Run formatting of `sources` in parallel using the provided `executor`.
228
229     (Use ProcessPoolExecutors for actual parallelism.)
230
231     `line_length`, `write_back`, and `fast` options are passed to
232     :func:`format_file_in_place`.
233     """
234     lock = None
235     if write_back == WriteBack.DIFF:
236         # For diff output, we need locks to ensure we don't interleave output
237         # from different processes.
238         manager = Manager()
239         lock = manager.Lock()
240     tasks = {
241         src: loop.run_in_executor(
242             executor, format_file_in_place, src, line_length, fast, write_back, lock
243         )
244         for src in sources
245     }
246     _task_values = list(tasks.values())
247     loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
248     loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
249     await asyncio.wait(tasks.values())
250     cancelled = []
251     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
252     for src, task in tasks.items():
253         if not task.done():
254             report.failed(src, "timed out, cancelling")
255             task.cancel()
256             cancelled.append(task)
257         elif task.cancelled():
258             cancelled.append(task)
259         elif task.exception():
260             report.failed(src, str(task.exception()))
261         else:
262             report.done(src, task.result())
263     if cancelled:
264         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
265     elif not quiet:
266         out("All done! ✨ 🍰 ✨")
267     if not quiet:
268         click.echo(str(report))
269     return report.return_code
270
271
272 def format_file_in_place(
273     src: Path,
274     line_length: int,
275     fast: bool,
276     write_back: WriteBack = WriteBack.NO,
277     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
278 ) -> bool:
279     """Format file under `src` path. Return True if changed.
280
281     If `write_back` is True, write reformatted code back to stdout.
282     `line_length` and `fast` options are passed to :func:`format_file_contents`.
283     """
284     with tokenize.open(src) as src_buffer:
285         src_contents = src_buffer.read()
286     try:
287         dst_contents = format_file_contents(
288             src_contents, line_length=line_length, fast=fast
289         )
290     except NothingChanged:
291         return False
292
293     if write_back == write_back.YES:
294         with open(src, "w", encoding=src_buffer.encoding) as f:
295             f.write(dst_contents)
296     elif write_back == write_back.DIFF:
297         src_name = f"{src.name}  (original)"
298         dst_name = f"{src.name}  (formatted)"
299         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
300         if lock:
301             lock.acquire()
302         try:
303             sys.stdout.write(diff_contents)
304         finally:
305             if lock:
306                 lock.release()
307     return True
308
309
310 def format_stdin_to_stdout(
311     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
312 ) -> bool:
313     """Format file on stdin. Return True if changed.
314
315     If `write_back` is True, write reformatted code back to stdout.
316     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
317     """
318     src = sys.stdin.read()
319     dst = src
320     try:
321         dst = format_file_contents(src, line_length=line_length, fast=fast)
322         return True
323
324     except NothingChanged:
325         return False
326
327     finally:
328         if write_back == WriteBack.YES:
329             sys.stdout.write(dst)
330         elif write_back == WriteBack.DIFF:
331             src_name = "<stdin>  (original)"
332             dst_name = "<stdin>  (formatted)"
333             sys.stdout.write(diff(src, dst, src_name, dst_name))
334
335
336 def format_file_contents(
337     src_contents: str, line_length: int, fast: bool
338 ) -> FileContent:
339     """Reformat contents a file and return new contents.
340
341     If `fast` is False, additionally confirm that the reformatted code is
342     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
343     `line_length` is passed to :func:`format_str`.
344     """
345     if src_contents.strip() == "":
346         raise NothingChanged
347
348     dst_contents = format_str(src_contents, line_length=line_length)
349     if src_contents == dst_contents:
350         raise NothingChanged
351
352     if not fast:
353         assert_equivalent(src_contents, dst_contents)
354         assert_stable(src_contents, dst_contents, line_length=line_length)
355     return dst_contents
356
357
358 def format_str(src_contents: str, line_length: int) -> FileContent:
359     """Reformat a string and return new contents.
360
361     `line_length` determines how many characters per line are allowed.
362     """
363     src_node = lib2to3_parse(src_contents)
364     dst_contents = ""
365     lines = LineGenerator()
366     elt = EmptyLineTracker()
367     py36 = is_python36(src_node)
368     empty_line = Line()
369     after = 0
370     for current_line in lines.visit(src_node):
371         for _ in range(after):
372             dst_contents += str(empty_line)
373         before, after = elt.maybe_empty_lines(current_line)
374         for _ in range(before):
375             dst_contents += str(empty_line)
376         for line in split_line(current_line, line_length=line_length, py36=py36):
377             dst_contents += str(line)
378     return dst_contents
379
380
381 GRAMMARS = [
382     pygram.python_grammar_no_print_statement_no_exec_statement,
383     pygram.python_grammar_no_print_statement,
384     pygram.python_grammar_no_exec_statement,
385     pygram.python_grammar,
386 ]
387
388
389 def lib2to3_parse(src_txt: str) -> Node:
390     """Given a string with source, return the lib2to3 Node."""
391     grammar = pygram.python_grammar_no_print_statement
392     if src_txt[-1] != "\n":
393         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
394         src_txt += nl
395     for grammar in GRAMMARS:
396         drv = driver.Driver(grammar, pytree.convert)
397         try:
398             result = drv.parse_string(src_txt, True)
399             break
400
401         except ParseError as pe:
402             lineno, column = pe.context[1]
403             lines = src_txt.splitlines()
404             try:
405                 faulty_line = lines[lineno - 1]
406             except IndexError:
407                 faulty_line = "<line number missing in source>"
408             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
409     else:
410         raise exc from None
411
412     if isinstance(result, Leaf):
413         result = Node(syms.file_input, [result])
414     return result
415
416
417 def lib2to3_unparse(node: Node) -> str:
418     """Given a lib2to3 node, return its string representation."""
419     code = str(node)
420     return code
421
422
423 T = TypeVar("T")
424
425
426 class Visitor(Generic[T]):
427     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
428
429     def visit(self, node: LN) -> Iterator[T]:
430         """Main method to visit `node` and its children.
431
432         It tries to find a `visit_*()` method for the given `node.type`, like
433         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
434         If no dedicated `visit_*()` method is found, chooses `visit_default()`
435         instead.
436
437         Then yields objects of type `T` from the selected visitor.
438         """
439         if node.type < 256:
440             name = token.tok_name[node.type]
441         else:
442             name = type_repr(node.type)
443         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
444
445     def visit_default(self, node: LN) -> Iterator[T]:
446         """Default `visit_*()` implementation. Recurses to children of `node`."""
447         if isinstance(node, Node):
448             for child in node.children:
449                 yield from self.visit(child)
450
451
452 @dataclass
453 class DebugVisitor(Visitor[T]):
454     tree_depth: int = 0
455
456     def visit_default(self, node: LN) -> Iterator[T]:
457         indent = " " * (2 * self.tree_depth)
458         if isinstance(node, Node):
459             _type = type_repr(node.type)
460             out(f"{indent}{_type}", fg="yellow")
461             self.tree_depth += 1
462             for child in node.children:
463                 yield from self.visit(child)
464
465             self.tree_depth -= 1
466             out(f"{indent}/{_type}", fg="yellow", bold=False)
467         else:
468             _type = token.tok_name.get(node.type, str(node.type))
469             out(f"{indent}{_type}", fg="blue", nl=False)
470             if node.prefix:
471                 # We don't have to handle prefixes for `Node` objects since
472                 # that delegates to the first child anyway.
473                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
474             out(f" {node.value!r}", fg="blue", bold=False)
475
476     @classmethod
477     def show(cls, code: str) -> None:
478         """Pretty-print the lib2to3 AST of a given string of `code`.
479
480         Convenience method for debugging.
481         """
482         v: DebugVisitor[None] = DebugVisitor()
483         list(v.visit(lib2to3_parse(code)))
484
485
486 KEYWORDS = set(keyword.kwlist)
487 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
488 FLOW_CONTROL = {"return", "raise", "break", "continue"}
489 STATEMENT = {
490     syms.if_stmt,
491     syms.while_stmt,
492     syms.for_stmt,
493     syms.try_stmt,
494     syms.except_clause,
495     syms.with_stmt,
496     syms.funcdef,
497     syms.classdef,
498 }
499 STANDALONE_COMMENT = 153
500 LOGIC_OPERATORS = {"and", "or"}
501 COMPARATORS = {
502     token.LESS,
503     token.GREATER,
504     token.EQEQUAL,
505     token.NOTEQUAL,
506     token.LESSEQUAL,
507     token.GREATEREQUAL,
508 }
509 MATH_OPERATORS = {
510     token.PLUS,
511     token.MINUS,
512     token.STAR,
513     token.SLASH,
514     token.VBAR,
515     token.AMPER,
516     token.PERCENT,
517     token.CIRCUMFLEX,
518     token.TILDE,
519     token.LEFTSHIFT,
520     token.RIGHTSHIFT,
521     token.DOUBLESTAR,
522     token.DOUBLESLASH,
523 }
524 VARARGS = {token.STAR, token.DOUBLESTAR}
525 COMPREHENSION_PRIORITY = 20
526 COMMA_PRIORITY = 10
527 LOGIC_PRIORITY = 5
528 STRING_PRIORITY = 4
529 COMPARATOR_PRIORITY = 3
530 MATH_PRIORITY = 1
531
532
533 @dataclass
534 class BracketTracker:
535     """Keeps track of brackets on a line."""
536
537     depth: int = 0
538     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
539     delimiters: Dict[LeafID, Priority] = Factory(dict)
540     previous: Optional[Leaf] = None
541
542     def mark(self, leaf: Leaf) -> None:
543         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
544
545         All leaves receive an int `bracket_depth` field that stores how deep
546         within brackets a given leaf is. 0 means there are no enclosing brackets
547         that started on this line.
548
549         If a leaf is itself a closing bracket, it receives an `opening_bracket`
550         field that it forms a pair with. This is a one-directional link to
551         avoid reference cycles.
552
553         If a leaf is a delimiter (a token on which Black can split the line if
554         needed) and it's on depth 0, its `id()` is stored in the tracker's
555         `delimiters` field.
556         """
557         if leaf.type == token.COMMENT:
558             return
559
560         if leaf.type in CLOSING_BRACKETS:
561             self.depth -= 1
562             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
563             leaf.opening_bracket = opening_bracket
564         leaf.bracket_depth = self.depth
565         if self.depth == 0:
566             delim = is_split_before_delimiter(leaf, self.previous)
567             if delim and self.previous is not None:
568                 self.delimiters[id(self.previous)] = delim
569             else:
570                 delim = is_split_after_delimiter(leaf, self.previous)
571                 if delim:
572                     self.delimiters[id(leaf)] = delim
573         if leaf.type in OPENING_BRACKETS:
574             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
575             self.depth += 1
576         self.previous = leaf
577
578     def any_open_brackets(self) -> bool:
579         """Return True if there is an yet unmatched open bracket on the line."""
580         return bool(self.bracket_match)
581
582     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
583         """Return the highest priority of a delimiter found on the line.
584
585         Values are consistent with what `is_delimiter()` returns.
586         Raises ValueError on no delimiters.
587         """
588         return max(v for k, v in self.delimiters.items() if k not in exclude)
589
590
591 @dataclass
592 class Line:
593     """Holds leaves and comments. Can be printed with `str(line)`."""
594
595     depth: int = 0
596     leaves: List[Leaf] = Factory(list)
597     comments: List[Tuple[Index, Leaf]] = Factory(list)
598     bracket_tracker: BracketTracker = Factory(BracketTracker)
599     inside_brackets: bool = False
600     has_for: bool = False
601     _for_loop_variable: bool = False
602
603     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
604         """Add a new `leaf` to the end of the line.
605
606         Unless `preformatted` is True, the `leaf` will receive a new consistent
607         whitespace prefix and metadata applied by :class:`BracketTracker`.
608         Trailing commas are maybe removed, unpacked for loop variables are
609         demoted from being delimiters.
610
611         Inline comments are put aside.
612         """
613         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
614         if not has_value:
615             return
616
617         if self.leaves and not preformatted:
618             # Note: at this point leaf.prefix should be empty except for
619             # imports, for which we only preserve newlines.
620             leaf.prefix += whitespace(leaf)
621         if self.inside_brackets or not preformatted:
622             self.maybe_decrement_after_for_loop_variable(leaf)
623             self.bracket_tracker.mark(leaf)
624             self.maybe_remove_trailing_comma(leaf)
625             self.maybe_increment_for_loop_variable(leaf)
626
627         if not self.append_comment(leaf):
628             self.leaves.append(leaf)
629
630     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
631         """Like :func:`append()` but disallow invalid standalone comment structure.
632
633         Raises ValueError when any `leaf` is appended after a standalone comment
634         or when a standalone comment is not the first leaf on the line.
635         """
636         if self.bracket_tracker.depth == 0:
637             if self.is_comment:
638                 raise ValueError("cannot append to standalone comments")
639
640             if self.leaves and leaf.type == STANDALONE_COMMENT:
641                 raise ValueError(
642                     "cannot append standalone comments to a populated line"
643                 )
644
645         self.append(leaf, preformatted=preformatted)
646
647     @property
648     def is_comment(self) -> bool:
649         """Is this line a standalone comment?"""
650         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
651
652     @property
653     def is_decorator(self) -> bool:
654         """Is this line a decorator?"""
655         return bool(self) and self.leaves[0].type == token.AT
656
657     @property
658     def is_import(self) -> bool:
659         """Is this an import line?"""
660         return bool(self) and is_import(self.leaves[0])
661
662     @property
663     def is_class(self) -> bool:
664         """Is this line a class definition?"""
665         return (
666             bool(self)
667             and self.leaves[0].type == token.NAME
668             and self.leaves[0].value == "class"
669         )
670
671     @property
672     def is_def(self) -> bool:
673         """Is this a function definition? (Also returns True for async defs.)"""
674         try:
675             first_leaf = self.leaves[0]
676         except IndexError:
677             return False
678
679         try:
680             second_leaf: Optional[Leaf] = self.leaves[1]
681         except IndexError:
682             second_leaf = None
683         return (
684             (first_leaf.type == token.NAME and first_leaf.value == "def")
685             or (
686                 first_leaf.type == token.ASYNC
687                 and second_leaf is not None
688                 and second_leaf.type == token.NAME
689                 and second_leaf.value == "def"
690             )
691         )
692
693     @property
694     def is_flow_control(self) -> bool:
695         """Is this line a flow control statement?
696
697         Those are `return`, `raise`, `break`, and `continue`.
698         """
699         return (
700             bool(self)
701             and self.leaves[0].type == token.NAME
702             and self.leaves[0].value in FLOW_CONTROL
703         )
704
705     @property
706     def is_yield(self) -> bool:
707         """Is this line a yield statement?"""
708         return (
709             bool(self)
710             and self.leaves[0].type == token.NAME
711             and self.leaves[0].value == "yield"
712         )
713
714     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
715         """If so, needs to be split before emitting."""
716         for leaf in self.leaves:
717             if leaf.type == STANDALONE_COMMENT:
718                 if leaf.bracket_depth <= depth_limit:
719                     return True
720
721         return False
722
723     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
724         """Remove trailing comma if there is one and it's safe."""
725         if not (
726             self.leaves
727             and self.leaves[-1].type == token.COMMA
728             and closing.type in CLOSING_BRACKETS
729         ):
730             return False
731
732         if closing.type == token.RBRACE:
733             self.remove_trailing_comma()
734             return True
735
736         if closing.type == token.RSQB:
737             comma = self.leaves[-1]
738             if comma.parent and comma.parent.type == syms.listmaker:
739                 self.remove_trailing_comma()
740                 return True
741
742         # For parens let's check if it's safe to remove the comma.  If the
743         # trailing one is the only one, we might mistakenly change a tuple
744         # into a different type by removing the comma.
745         depth = closing.bracket_depth + 1
746         commas = 0
747         opening = closing.opening_bracket
748         for _opening_index, leaf in enumerate(self.leaves):
749             if leaf is opening:
750                 break
751
752         else:
753             return False
754
755         for leaf in self.leaves[_opening_index + 1:]:
756             if leaf is closing:
757                 break
758
759             bracket_depth = leaf.bracket_depth
760             if bracket_depth == depth and leaf.type == token.COMMA:
761                 commas += 1
762                 if leaf.parent and leaf.parent.type == syms.arglist:
763                     commas += 1
764                     break
765
766         if commas > 1:
767             self.remove_trailing_comma()
768             return True
769
770         return False
771
772     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
773         """In a for loop, or comprehension, the variables are often unpacks.
774
775         To avoid splitting on the comma in this situation, increase the depth of
776         tokens between `for` and `in`.
777         """
778         if leaf.type == token.NAME and leaf.value == "for":
779             self.has_for = True
780             self.bracket_tracker.depth += 1
781             self._for_loop_variable = True
782             return True
783
784         return False
785
786     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
787         """See `maybe_increment_for_loop_variable` above for explanation."""
788         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
789             self.bracket_tracker.depth -= 1
790             self._for_loop_variable = False
791             return True
792
793         return False
794
795     def append_comment(self, comment: Leaf) -> bool:
796         """Add an inline or standalone comment to the line."""
797         if (
798             comment.type == STANDALONE_COMMENT
799             and self.bracket_tracker.any_open_brackets()
800         ):
801             comment.prefix = ""
802             return False
803
804         if comment.type != token.COMMENT:
805             return False
806
807         after = len(self.leaves) - 1
808         if after == -1:
809             comment.type = STANDALONE_COMMENT
810             comment.prefix = ""
811             return False
812
813         else:
814             self.comments.append((after, comment))
815             return True
816
817     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
818         """Generate comments that should appear directly after `leaf`."""
819         for _leaf_index, _leaf in enumerate(self.leaves):
820             if leaf is _leaf:
821                 break
822
823         else:
824             return
825
826         for index, comment_after in self.comments:
827             if _leaf_index == index:
828                 yield comment_after
829
830     def remove_trailing_comma(self) -> None:
831         """Remove the trailing comma and moves the comments attached to it."""
832         comma_index = len(self.leaves) - 1
833         for i in range(len(self.comments)):
834             comment_index, comment = self.comments[i]
835             if comment_index == comma_index:
836                 self.comments[i] = (comma_index - 1, comment)
837         self.leaves.pop()
838
839     def __str__(self) -> str:
840         """Render the line."""
841         if not self:
842             return "\n"
843
844         indent = "    " * self.depth
845         leaves = iter(self.leaves)
846         first = next(leaves)
847         res = f"{first.prefix}{indent}{first.value}"
848         for leaf in leaves:
849             res += str(leaf)
850         for _, comment in self.comments:
851             res += str(comment)
852         return res + "\n"
853
854     def __bool__(self) -> bool:
855         """Return True if the line has leaves or comments."""
856         return bool(self.leaves or self.comments)
857
858
859 class UnformattedLines(Line):
860     """Just like :class:`Line` but stores lines which aren't reformatted."""
861
862     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
863         """Just add a new `leaf` to the end of the lines.
864
865         The `preformatted` argument is ignored.
866
867         Keeps track of indentation `depth`, which is useful when the user
868         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
869         """
870         try:
871             list(generate_comments(leaf))
872         except FormatOn as f_on:
873             self.leaves.append(f_on.leaf_from_consumed(leaf))
874             raise
875
876         self.leaves.append(leaf)
877         if leaf.type == token.INDENT:
878             self.depth += 1
879         elif leaf.type == token.DEDENT:
880             self.depth -= 1
881
882     def __str__(self) -> str:
883         """Render unformatted lines from leaves which were added with `append()`.
884
885         `depth` is not used for indentation in this case.
886         """
887         if not self:
888             return "\n"
889
890         res = ""
891         for leaf in self.leaves:
892             res += str(leaf)
893         return res
894
895     def append_comment(self, comment: Leaf) -> bool:
896         """Not implemented in this class. Raises `NotImplementedError`."""
897         raise NotImplementedError("Unformatted lines don't store comments separately.")
898
899     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
900         """Does nothing and returns False."""
901         return False
902
903     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
904         """Does nothing and returns False."""
905         return False
906
907
908 @dataclass
909 class EmptyLineTracker:
910     """Provides a stateful method that returns the number of potential extra
911     empty lines needed before and after the currently processed line.
912
913     Note: this tracker works on lines that haven't been split yet.  It assumes
914     the prefix of the first leaf consists of optional newlines.  Those newlines
915     are consumed by `maybe_empty_lines()` and included in the computation.
916     """
917     previous_line: Optional[Line] = None
918     previous_after: int = 0
919     previous_defs: List[int] = Factory(list)
920
921     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
922         """Return the number of extra empty lines before and after the `current_line`.
923
924         This is for separating `def`, `async def` and `class` with extra empty
925         lines (two on module-level), as well as providing an extra empty line
926         after flow control keywords to make them more prominent.
927         """
928         if isinstance(current_line, UnformattedLines):
929             return 0, 0
930
931         before, after = self._maybe_empty_lines(current_line)
932         before -= self.previous_after
933         self.previous_after = after
934         self.previous_line = current_line
935         return before, after
936
937     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
938         max_allowed = 1
939         if current_line.depth == 0:
940             max_allowed = 2
941         if current_line.leaves:
942             # Consume the first leaf's extra newlines.
943             first_leaf = current_line.leaves[0]
944             before = first_leaf.prefix.count("\n")
945             before = min(before, max_allowed)
946             first_leaf.prefix = ""
947         else:
948             before = 0
949         depth = current_line.depth
950         while self.previous_defs and self.previous_defs[-1] >= depth:
951             self.previous_defs.pop()
952             before = 1 if depth else 2
953         is_decorator = current_line.is_decorator
954         if is_decorator or current_line.is_def or current_line.is_class:
955             if not is_decorator:
956                 self.previous_defs.append(depth)
957             if self.previous_line is None:
958                 # Don't insert empty lines before the first line in the file.
959                 return 0, 0
960
961             if self.previous_line and self.previous_line.is_decorator:
962                 # Don't insert empty lines between decorators.
963                 return 0, 0
964
965             newlines = 2
966             if current_line.depth:
967                 newlines -= 1
968             return newlines, 0
969
970         if current_line.is_flow_control:
971             return before, 1
972
973         if (
974             self.previous_line
975             and self.previous_line.is_import
976             and not current_line.is_import
977             and depth == self.previous_line.depth
978         ):
979             return (before or 1), 0
980
981         if (
982             self.previous_line
983             and self.previous_line.is_yield
984             and (not current_line.is_yield or depth != self.previous_line.depth)
985         ):
986             return (before or 1), 0
987
988         return before, 0
989
990
991 @dataclass
992 class LineGenerator(Visitor[Line]):
993     """Generates reformatted Line objects.  Empty lines are not emitted.
994
995     Note: destroys the tree it's visiting by mutating prefixes of its leaves
996     in ways that will no longer stringify to valid Python code on the tree.
997     """
998     current_line: Line = Factory(Line)
999
1000     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1001         """Generate a line.
1002
1003         If the line is empty, only emit if it makes sense.
1004         If the line is too long, split it first and then generate.
1005
1006         If any lines were generated, set up a new current_line.
1007         """
1008         if not self.current_line:
1009             if self.current_line.__class__ == type:
1010                 self.current_line.depth += indent
1011             else:
1012                 self.current_line = type(depth=self.current_line.depth + indent)
1013             return  # Line is empty, don't emit. Creating a new one unnecessary.
1014
1015         complete_line = self.current_line
1016         self.current_line = type(depth=complete_line.depth + indent)
1017         yield complete_line
1018
1019     def visit(self, node: LN) -> Iterator[Line]:
1020         """Main method to visit `node` and its children.
1021
1022         Yields :class:`Line` objects.
1023         """
1024         if isinstance(self.current_line, UnformattedLines):
1025             # File contained `# fmt: off`
1026             yield from self.visit_unformatted(node)
1027
1028         else:
1029             yield from super().visit(node)
1030
1031     def visit_default(self, node: LN) -> Iterator[Line]:
1032         """Default `visit_*()` implementation. Recurses to children of `node`."""
1033         if isinstance(node, Leaf):
1034             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1035             try:
1036                 for comment in generate_comments(node):
1037                     if any_open_brackets:
1038                         # any comment within brackets is subject to splitting
1039                         self.current_line.append(comment)
1040                     elif comment.type == token.COMMENT:
1041                         # regular trailing comment
1042                         self.current_line.append(comment)
1043                         yield from self.line()
1044
1045                     else:
1046                         # regular standalone comment
1047                         yield from self.line()
1048
1049                         self.current_line.append(comment)
1050                         yield from self.line()
1051
1052             except FormatOff as f_off:
1053                 f_off.trim_prefix(node)
1054                 yield from self.line(type=UnformattedLines)
1055                 yield from self.visit(node)
1056
1057             except FormatOn as f_on:
1058                 # This only happens here if somebody says "fmt: on" multiple
1059                 # times in a row.
1060                 f_on.trim_prefix(node)
1061                 yield from self.visit_default(node)
1062
1063             else:
1064                 normalize_prefix(node, inside_brackets=any_open_brackets)
1065                 if node.type == token.STRING:
1066                     normalize_string_quotes(node)
1067                 if node.type not in WHITESPACE:
1068                     self.current_line.append(node)
1069         yield from super().visit_default(node)
1070
1071     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1072         """Increase indentation level, maybe yield a line."""
1073         # In blib2to3 INDENT never holds comments.
1074         yield from self.line(+1)
1075         yield from self.visit_default(node)
1076
1077     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1078         """Decrease indentation level, maybe yield a line."""
1079         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1080         yield from self.line(-1)
1081
1082     def visit_stmt(
1083         self, node: Node, keywords: Set[str], parens: Set[str]
1084     ) -> Iterator[Line]:
1085         """Visit a statement.
1086
1087         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1088         `def`, `with`, `class`, and `assert`.
1089
1090         The relevant Python language `keywords` for a given statement will be
1091         NAME leaves within it. This methods puts those on a separate line.
1092
1093         `parens` holds pairs of nodes where invisible parentheses should be put.
1094         Keys hold nodes after which opening parentheses should be put, values
1095         hold nodes before which closing parentheses should be put.
1096         """
1097         normalize_invisible_parens(node, parens_after=parens)
1098         for child in node.children:
1099             if child.type == token.NAME and child.value in keywords:  # type: ignore
1100                 yield from self.line()
1101
1102             yield from self.visit(child)
1103
1104     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1105         """Visit a statement without nested statements."""
1106         is_suite_like = node.parent and node.parent.type in STATEMENT
1107         if is_suite_like:
1108             yield from self.line(+1)
1109             yield from self.visit_default(node)
1110             yield from self.line(-1)
1111
1112         else:
1113             yield from self.line()
1114             yield from self.visit_default(node)
1115
1116     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1117         """Visit `async def`, `async for`, `async with`."""
1118         yield from self.line()
1119
1120         children = iter(node.children)
1121         for child in children:
1122             yield from self.visit(child)
1123
1124             if child.type == token.ASYNC:
1125                 break
1126
1127         internal_stmt = next(children)
1128         for child in internal_stmt.children:
1129             yield from self.visit(child)
1130
1131     def visit_decorators(self, node: Node) -> Iterator[Line]:
1132         """Visit decorators."""
1133         for child in node.children:
1134             yield from self.line()
1135             yield from self.visit(child)
1136
1137     def visit_import_from(self, node: Node) -> Iterator[Line]:
1138         """Visit import_from and maybe put invisible parentheses.
1139
1140         This is separate from `visit_stmt` because import statements don't
1141         support arbitrary atoms and thus handling of parentheses is custom.
1142         """
1143         check_lpar = False
1144         for index, child in enumerate(node.children):
1145             if check_lpar:
1146                 if child.type == token.LPAR:
1147                     # make parentheses invisible
1148                     child.value = ""  # type: ignore
1149                     node.children[-1].value = ""  # type: ignore
1150                 else:
1151                     # insert invisible parentheses
1152                     node.insert_child(index, Leaf(token.LPAR, ""))
1153                     node.append_child(Leaf(token.RPAR, ""))
1154                 break
1155
1156             check_lpar = (
1157                 child.type == token.NAME and child.value == "import"  # type: ignore
1158             )
1159
1160         for child in node.children:
1161             yield from self.visit(child)
1162
1163     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1164         """Remove a semicolon and put the other statement on a separate line."""
1165         yield from self.line()
1166
1167     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1168         """End of file. Process outstanding comments and end with a newline."""
1169         yield from self.visit_default(leaf)
1170         yield from self.line()
1171
1172     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1173         """Used when file contained a `# fmt: off`."""
1174         if isinstance(node, Node):
1175             for child in node.children:
1176                 yield from self.visit(child)
1177
1178         else:
1179             try:
1180                 self.current_line.append(node)
1181             except FormatOn as f_on:
1182                 f_on.trim_prefix(node)
1183                 yield from self.line()
1184                 yield from self.visit(node)
1185
1186             if node.type == token.ENDMARKER:
1187                 # somebody decided not to put a final `# fmt: on`
1188                 yield from self.line()
1189
1190     def __attrs_post_init__(self) -> None:
1191         """You are in a twisty little maze of passages."""
1192         v = self.visit_stmt
1193         Ø: Set[str] = set()
1194         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1195         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1196         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1197         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1198         self.visit_try_stmt = partial(
1199             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1200         )
1201         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1202         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1203         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1204         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1205         self.visit_async_funcdef = self.visit_async_stmt
1206         self.visit_decorated = self.visit_decorators
1207
1208
1209 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1210 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1211 OPENING_BRACKETS = set(BRACKET.keys())
1212 CLOSING_BRACKETS = set(BRACKET.values())
1213 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1214 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1215
1216
1217 def whitespace(leaf: Leaf) -> str:  # noqa C901
1218     """Return whitespace prefix if needed for the given `leaf`."""
1219     NO = ""
1220     SPACE = " "
1221     DOUBLESPACE = "  "
1222     t = leaf.type
1223     p = leaf.parent
1224     v = leaf.value
1225     if t in ALWAYS_NO_SPACE:
1226         return NO
1227
1228     if t == token.COMMENT:
1229         return DOUBLESPACE
1230
1231     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1232     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1233         return NO
1234
1235     prev = leaf.prev_sibling
1236     if not prev:
1237         prevp = preceding_leaf(p)
1238         if not prevp or prevp.type in OPENING_BRACKETS:
1239             return NO
1240
1241         if t == token.COLON:
1242             return SPACE if prevp.type == token.COMMA else NO
1243
1244         if prevp.type == token.EQUAL:
1245             if prevp.parent:
1246                 if prevp.parent.type in {
1247                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1248                 }:
1249                     return NO
1250
1251                 elif prevp.parent.type == syms.typedargslist:
1252                     # A bit hacky: if the equal sign has whitespace, it means we
1253                     # previously found it's a typed argument.  So, we're using
1254                     # that, too.
1255                     return prevp.prefix
1256
1257         elif prevp.type == token.DOUBLESTAR:
1258             if (
1259                 prevp.parent
1260                 and prevp.parent.type in {
1261                     syms.arglist,
1262                     syms.argument,
1263                     syms.dictsetmaker,
1264                     syms.parameters,
1265                     syms.typedargslist,
1266                     syms.varargslist,
1267                 }
1268             ):
1269                 return NO
1270
1271         elif prevp.type == token.COLON:
1272             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1273                 return NO
1274
1275         elif (
1276             prevp.parent
1277             and prevp.parent.type in {syms.factor, syms.star_expr}
1278             and prevp.type in MATH_OPERATORS
1279         ):
1280             return NO
1281
1282         elif (
1283             prevp.type == token.RIGHTSHIFT
1284             and prevp.parent
1285             and prevp.parent.type == syms.shift_expr
1286             and prevp.prev_sibling
1287             and prevp.prev_sibling.type == token.NAME
1288             and prevp.prev_sibling.value == "print"  # type: ignore
1289         ):
1290             # Python 2 print chevron
1291             return NO
1292
1293     elif prev.type in OPENING_BRACKETS:
1294         return NO
1295
1296     if p.type in {syms.parameters, syms.arglist}:
1297         # untyped function signatures or calls
1298         if t == token.RPAR:
1299             return NO
1300
1301         if not prev or prev.type != token.COMMA:
1302             return NO
1303
1304     elif p.type == syms.varargslist:
1305         # lambdas
1306         if t == token.RPAR:
1307             return NO
1308
1309         if prev and prev.type != token.COMMA:
1310             return NO
1311
1312     elif p.type == syms.typedargslist:
1313         # typed function signatures
1314         if not prev:
1315             return NO
1316
1317         if t == token.EQUAL:
1318             if prev.type != syms.tname:
1319                 return NO
1320
1321         elif prev.type == token.EQUAL:
1322             # A bit hacky: if the equal sign has whitespace, it means we
1323             # previously found it's a typed argument.  So, we're using that, too.
1324             return prev.prefix
1325
1326         elif prev.type != token.COMMA:
1327             return NO
1328
1329     elif p.type == syms.tname:
1330         # type names
1331         if not prev:
1332             prevp = preceding_leaf(p)
1333             if not prevp or prevp.type != token.COMMA:
1334                 return NO
1335
1336     elif p.type == syms.trailer:
1337         # attributes and calls
1338         if t == token.LPAR or t == token.RPAR:
1339             return NO
1340
1341         if not prev:
1342             if t == token.DOT:
1343                 prevp = preceding_leaf(p)
1344                 if not prevp or prevp.type != token.NUMBER:
1345                     return NO
1346
1347             elif t == token.LSQB:
1348                 return NO
1349
1350         elif prev.type != token.COMMA:
1351             return NO
1352
1353     elif p.type == syms.argument:
1354         # single argument
1355         if t == token.EQUAL:
1356             return NO
1357
1358         if not prev:
1359             prevp = preceding_leaf(p)
1360             if not prevp or prevp.type == token.LPAR:
1361                 return NO
1362
1363         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1364             return NO
1365
1366     elif p.type == syms.decorator:
1367         # decorators
1368         return NO
1369
1370     elif p.type == syms.dotted_name:
1371         if prev:
1372             return NO
1373
1374         prevp = preceding_leaf(p)
1375         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1376             return NO
1377
1378     elif p.type == syms.classdef:
1379         if t == token.LPAR:
1380             return NO
1381
1382         if prev and prev.type == token.LPAR:
1383             return NO
1384
1385     elif p.type == syms.subscript:
1386         # indexing
1387         if not prev:
1388             assert p.parent is not None, "subscripts are always parented"
1389             if p.parent.type == syms.subscriptlist:
1390                 return SPACE
1391
1392             return NO
1393
1394         else:
1395             return NO
1396
1397     elif p.type == syms.atom:
1398         if prev and t == token.DOT:
1399             # dots, but not the first one.
1400             return NO
1401
1402     elif (
1403         p.type == syms.listmaker
1404         or p.type == syms.testlist_gexp
1405         or p.type == syms.subscriptlist
1406     ):
1407         # list interior, including unpacking
1408         if not prev:
1409             return NO
1410
1411     elif p.type == syms.dictsetmaker:
1412         # dict and set interior, including unpacking
1413         if not prev:
1414             return NO
1415
1416         if prev.type == token.DOUBLESTAR:
1417             return NO
1418
1419     elif p.type in {syms.factor, syms.star_expr}:
1420         # unary ops
1421         if not prev:
1422             prevp = preceding_leaf(p)
1423             if not prevp or prevp.type in OPENING_BRACKETS:
1424                 return NO
1425
1426             prevp_parent = prevp.parent
1427             assert prevp_parent is not None
1428             if (
1429                 prevp.type == token.COLON
1430                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1431             ):
1432                 return NO
1433
1434             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1435                 return NO
1436
1437         elif t == token.NAME or t == token.NUMBER:
1438             return NO
1439
1440     elif p.type == syms.import_from:
1441         if t == token.DOT:
1442             if prev and prev.type == token.DOT:
1443                 return NO
1444
1445         elif t == token.NAME:
1446             if v == "import":
1447                 return SPACE
1448
1449             if prev and prev.type == token.DOT:
1450                 return NO
1451
1452     elif p.type == syms.sliceop:
1453         return NO
1454
1455     return SPACE
1456
1457
1458 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1459     """Return the first leaf that precedes `node`, if any."""
1460     while node:
1461         res = node.prev_sibling
1462         if res:
1463             if isinstance(res, Leaf):
1464                 return res
1465
1466             try:
1467                 return list(res.leaves())[-1]
1468
1469             except IndexError:
1470                 return None
1471
1472         node = node.parent
1473     return None
1474
1475
1476 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1477     """Return the priority of the `leaf` delimiter, given a line break after it.
1478
1479     The delimiter priorities returned here are from those delimiters that would
1480     cause a line break after themselves.
1481
1482     Higher numbers are higher priority.
1483     """
1484     if leaf.type == token.COMMA:
1485         return COMMA_PRIORITY
1486
1487     return 0
1488
1489
1490 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1491     """Return the priority of the `leaf` delimiter, given a line before after it.
1492
1493     The delimiter priorities returned here are from those delimiters that would
1494     cause a line break before themselves.
1495
1496     Higher numbers are higher priority.
1497     """
1498     if (
1499         leaf.type in VARARGS
1500         and leaf.parent
1501         and leaf.parent.type in {syms.argument, syms.typedargslist}
1502     ):
1503         # * and ** might also be MATH_OPERATORS but in this case they are not.
1504         # Don't treat them as a delimiter.
1505         return 0
1506
1507     if (
1508         leaf.type in MATH_OPERATORS
1509         and leaf.parent
1510         and leaf.parent.type not in {syms.factor, syms.star_expr}
1511     ):
1512         return MATH_PRIORITY
1513
1514     if leaf.type in COMPARATORS:
1515         return COMPARATOR_PRIORITY
1516
1517     if (
1518         leaf.type == token.STRING
1519         and previous is not None
1520         and previous.type == token.STRING
1521     ):
1522         return STRING_PRIORITY
1523
1524     if (
1525         leaf.type == token.NAME
1526         and leaf.value == "for"
1527         and leaf.parent
1528         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1529     ):
1530         return COMPREHENSION_PRIORITY
1531
1532     if (
1533         leaf.type == token.NAME
1534         and leaf.value == "if"
1535         and leaf.parent
1536         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1537     ):
1538         return COMPREHENSION_PRIORITY
1539
1540     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1541         return LOGIC_PRIORITY
1542
1543     return 0
1544
1545
1546 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1547     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1548
1549     Higher numbers are higher priority.
1550     """
1551     return max(
1552         is_split_before_delimiter(leaf, previous),
1553         is_split_after_delimiter(leaf, previous),
1554     )
1555
1556
1557 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1558     """Clean the prefix of the `leaf` and generate comments from it, if any.
1559
1560     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1561     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1562     move because it does away with modifying the grammar to include all the
1563     possible places in which comments can be placed.
1564
1565     The sad consequence for us though is that comments don't "belong" anywhere.
1566     This is why this function generates simple parentless Leaf objects for
1567     comments.  We simply don't know what the correct parent should be.
1568
1569     No matter though, we can live without this.  We really only need to
1570     differentiate between inline and standalone comments.  The latter don't
1571     share the line with any code.
1572
1573     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1574     are emitted with a fake STANDALONE_COMMENT token identifier.
1575     """
1576     p = leaf.prefix
1577     if not p:
1578         return
1579
1580     if "#" not in p:
1581         return
1582
1583     consumed = 0
1584     nlines = 0
1585     for index, line in enumerate(p.split("\n")):
1586         consumed += len(line) + 1  # adding the length of the split '\n'
1587         line = line.lstrip()
1588         if not line:
1589             nlines += 1
1590         if not line.startswith("#"):
1591             continue
1592
1593         if index == 0 and leaf.type != token.ENDMARKER:
1594             comment_type = token.COMMENT  # simple trailing comment
1595         else:
1596             comment_type = STANDALONE_COMMENT
1597         comment = make_comment(line)
1598         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1599
1600         if comment in {"# fmt: on", "# yapf: enable"}:
1601             raise FormatOn(consumed)
1602
1603         if comment in {"# fmt: off", "# yapf: disable"}:
1604             if comment_type == STANDALONE_COMMENT:
1605                 raise FormatOff(consumed)
1606
1607             prev = preceding_leaf(leaf)
1608             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1609                 raise FormatOff(consumed)
1610
1611         nlines = 0
1612
1613
1614 def make_comment(content: str) -> str:
1615     """Return a consistently formatted comment from the given `content` string.
1616
1617     All comments (except for "##", "#!", "#:") should have a single space between
1618     the hash sign and the content.
1619
1620     If `content` didn't start with a hash sign, one is provided.
1621     """
1622     content = content.rstrip()
1623     if not content:
1624         return "#"
1625
1626     if content[0] == "#":
1627         content = content[1:]
1628     if content and content[0] not in " !:#":
1629         content = " " + content
1630     return "#" + content
1631
1632
1633 def split_line(
1634     line: Line, line_length: int, inner: bool = False, py36: bool = False
1635 ) -> Iterator[Line]:
1636     """Split a `line` into potentially many lines.
1637
1638     They should fit in the allotted `line_length` but might not be able to.
1639     `inner` signifies that there were a pair of brackets somewhere around the
1640     current `line`, possibly transitively. This means we can fallback to splitting
1641     by delimiters if the LHS/RHS don't yield any results.
1642
1643     If `py36` is True, splitting may generate syntax that is only compatible
1644     with Python 3.6 and later.
1645     """
1646     if isinstance(line, UnformattedLines) or line.is_comment:
1647         yield line
1648         return
1649
1650     line_str = str(line).strip("\n")
1651     if (
1652         len(line_str) <= line_length
1653         and "\n" not in line_str  # multiline strings
1654         and not line.contains_standalone_comments()
1655     ):
1656         yield line
1657         return
1658
1659     split_funcs: List[SplitFunc]
1660     if line.is_def:
1661         split_funcs = [left_hand_split]
1662     elif line.inside_brackets:
1663         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1664     else:
1665         split_funcs = [right_hand_split]
1666     for split_func in split_funcs:
1667         # We are accumulating lines in `result` because we might want to abort
1668         # mission and return the original line in the end, or attempt a different
1669         # split altogether.
1670         result: List[Line] = []
1671         try:
1672             for l in split_func(line, py36):
1673                 if str(l).strip("\n") == line_str:
1674                     raise CannotSplit("Split function returned an unchanged result")
1675
1676                 result.extend(
1677                     split_line(l, line_length=line_length, inner=True, py36=py36)
1678                 )
1679         except CannotSplit as cs:
1680             continue
1681
1682         else:
1683             yield from result
1684             break
1685
1686     else:
1687         yield line
1688
1689
1690 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1691     """Split line into many lines, starting with the first matching bracket pair.
1692
1693     Note: this usually looks weird, only use this for function definitions.
1694     Prefer RHS otherwise.
1695     """
1696     head = Line(depth=line.depth)
1697     body = Line(depth=line.depth + 1, inside_brackets=True)
1698     tail = Line(depth=line.depth)
1699     tail_leaves: List[Leaf] = []
1700     body_leaves: List[Leaf] = []
1701     head_leaves: List[Leaf] = []
1702     current_leaves = head_leaves
1703     matching_bracket = None
1704     for leaf in line.leaves:
1705         if (
1706             current_leaves is body_leaves
1707             and leaf.type in CLOSING_BRACKETS
1708             and leaf.opening_bracket is matching_bracket
1709         ):
1710             current_leaves = tail_leaves if body_leaves else head_leaves
1711         current_leaves.append(leaf)
1712         if current_leaves is head_leaves:
1713             if leaf.type in OPENING_BRACKETS:
1714                 matching_bracket = leaf
1715                 current_leaves = body_leaves
1716     # Since body is a new indent level, remove spurious leading whitespace.
1717     if body_leaves:
1718         normalize_prefix(body_leaves[0], inside_brackets=True)
1719     # Build the new lines.
1720     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1721         for leaf in leaves:
1722             result.append(leaf, preformatted=True)
1723             for comment_after in line.comments_after(leaf):
1724                 result.append(comment_after, preformatted=True)
1725     bracket_split_succeeded_or_raise(head, body, tail)
1726     for result in (head, body, tail):
1727         if result:
1728             yield result
1729
1730
1731 def right_hand_split(
1732     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1733 ) -> Iterator[Line]:
1734     """Split line into many lines, starting with the last matching bracket pair."""
1735     head = Line(depth=line.depth)
1736     body = Line(depth=line.depth + 1, inside_brackets=True)
1737     tail = Line(depth=line.depth)
1738     tail_leaves: List[Leaf] = []
1739     body_leaves: List[Leaf] = []
1740     head_leaves: List[Leaf] = []
1741     current_leaves = tail_leaves
1742     opening_bracket = None
1743     closing_bracket = None
1744     for leaf in reversed(line.leaves):
1745         if current_leaves is body_leaves:
1746             if leaf is opening_bracket:
1747                 current_leaves = head_leaves if body_leaves else tail_leaves
1748         current_leaves.append(leaf)
1749         if current_leaves is tail_leaves:
1750             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1751                 opening_bracket = leaf.opening_bracket
1752                 closing_bracket = leaf
1753                 current_leaves = body_leaves
1754     tail_leaves.reverse()
1755     body_leaves.reverse()
1756     head_leaves.reverse()
1757     # Since body is a new indent level, remove spurious leading whitespace.
1758     if body_leaves:
1759         normalize_prefix(body_leaves[0], inside_brackets=True)
1760     elif not head_leaves:
1761         # No `head` and no `body` means the split failed. `tail` has all content.
1762         raise CannotSplit("No brackets found")
1763
1764     # Build the new lines.
1765     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1766         for leaf in leaves:
1767             result.append(leaf, preformatted=True)
1768             for comment_after in line.comments_after(leaf):
1769                 result.append(comment_after, preformatted=True)
1770     bracket_split_succeeded_or_raise(head, body, tail)
1771     assert opening_bracket and closing_bracket
1772     if (
1773         opening_bracket.type == token.LPAR
1774         and not opening_bracket.value
1775         and closing_bracket.type == token.RPAR
1776         and not closing_bracket.value
1777     ):
1778         # These parens were optional. If there aren't any delimiters or standalone
1779         # comments in the body, they were unnecessary and another split without
1780         # them should be attempted.
1781         if not (
1782             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1783         ):
1784             omit = {id(closing_bracket), *omit}
1785             yield from right_hand_split(line, py36=py36, omit=omit)
1786             return
1787
1788     ensure_visible(opening_bracket)
1789     ensure_visible(closing_bracket)
1790     for result in (head, body, tail):
1791         if result:
1792             yield result
1793
1794
1795 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1796     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1797
1798     Do nothing otherwise.
1799
1800     A left- or right-hand split is based on a pair of brackets. Content before
1801     (and including) the opening bracket is left on one line, content inside the
1802     brackets is put on a separate line, and finally content starting with and
1803     following the closing bracket is put on a separate line.
1804
1805     Those are called `head`, `body`, and `tail`, respectively. If the split
1806     produced the same line (all content in `head`) or ended up with an empty `body`
1807     and the `tail` is just the closing bracket, then it's considered failed.
1808     """
1809     tail_len = len(str(tail).strip())
1810     if not body:
1811         if tail_len == 0:
1812             raise CannotSplit("Splitting brackets produced the same line")
1813
1814         elif tail_len < 3:
1815             raise CannotSplit(
1816                 f"Splitting brackets on an empty body to save "
1817                 f"{tail_len} characters is not worth it"
1818             )
1819
1820
1821 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1822     """Normalize prefix of the first leaf in every line returned by `split_func`.
1823
1824     This is a decorator over relevant split functions.
1825     """
1826
1827     @wraps(split_func)
1828     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1829         for l in split_func(line, py36):
1830             normalize_prefix(l.leaves[0], inside_brackets=True)
1831             yield l
1832
1833     return split_wrapper
1834
1835
1836 @dont_increase_indentation
1837 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1838     """Split according to delimiters of the highest priority.
1839
1840     If `py36` is True, the split will add trailing commas also in function
1841     signatures that contain `*` and `**`.
1842     """
1843     try:
1844         last_leaf = line.leaves[-1]
1845     except IndexError:
1846         raise CannotSplit("Line empty")
1847
1848     delimiters = line.bracket_tracker.delimiters
1849     try:
1850         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1851             exclude={id(last_leaf)}
1852         )
1853     except ValueError:
1854         raise CannotSplit("No delimiters found")
1855
1856     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1857     lowest_depth = sys.maxsize
1858     trailing_comma_safe = True
1859
1860     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1861         """Append `leaf` to current line or to new line if appending impossible."""
1862         nonlocal current_line
1863         try:
1864             current_line.append_safe(leaf, preformatted=True)
1865         except ValueError as ve:
1866             yield current_line
1867
1868             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1869             current_line.append(leaf)
1870
1871     for leaf in line.leaves:
1872         yield from append_to_line(leaf)
1873
1874         for comment_after in line.comments_after(leaf):
1875             yield from append_to_line(comment_after)
1876
1877         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1878         if (
1879             leaf.bracket_depth == lowest_depth
1880             and leaf.type == token.STAR
1881             or leaf.type == token.DOUBLESTAR
1882         ):
1883             trailing_comma_safe = trailing_comma_safe and py36
1884         leaf_priority = delimiters.get(id(leaf))
1885         if leaf_priority == delimiter_priority:
1886             yield current_line
1887
1888             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1889     if current_line:
1890         if (
1891             trailing_comma_safe
1892             and delimiter_priority == COMMA_PRIORITY
1893             and current_line.leaves[-1].type != token.COMMA
1894             and current_line.leaves[-1].type != STANDALONE_COMMENT
1895         ):
1896             current_line.append(Leaf(token.COMMA, ","))
1897         yield current_line
1898
1899
1900 @dont_increase_indentation
1901 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1902     """Split standalone comments from the rest of the line."""
1903     if not line.contains_standalone_comments(0):
1904         raise CannotSplit("Line does not have any standalone comments")
1905
1906     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1907
1908     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1909         """Append `leaf` to current line or to new line if appending impossible."""
1910         nonlocal current_line
1911         try:
1912             current_line.append_safe(leaf, preformatted=True)
1913         except ValueError as ve:
1914             yield current_line
1915
1916             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1917             current_line.append(leaf)
1918
1919     for leaf in line.leaves:
1920         yield from append_to_line(leaf)
1921
1922         for comment_after in line.comments_after(leaf):
1923             yield from append_to_line(comment_after)
1924
1925     if current_line:
1926         yield current_line
1927
1928
1929 def is_import(leaf: Leaf) -> bool:
1930     """Return True if the given leaf starts an import statement."""
1931     p = leaf.parent
1932     t = leaf.type
1933     v = leaf.value
1934     return bool(
1935         t == token.NAME
1936         and (
1937             (v == "import" and p and p.type == syms.import_name)
1938             or (v == "from" and p and p.type == syms.import_from)
1939         )
1940     )
1941
1942
1943 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1944     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1945     else.
1946
1947     Note: don't use backslashes for formatting or you'll lose your voting rights.
1948     """
1949     if not inside_brackets:
1950         spl = leaf.prefix.split("#")
1951         if "\\" not in spl[0]:
1952             nl_count = spl[-1].count("\n")
1953             if len(spl) > 1:
1954                 nl_count -= 1
1955             leaf.prefix = "\n" * nl_count
1956             return
1957
1958     leaf.prefix = ""
1959
1960
1961 def normalize_string_quotes(leaf: Leaf) -> None:
1962     """Prefer double quotes but only if it doesn't cause more escaping.
1963
1964     Adds or removes backslashes as appropriate. Doesn't parse and fix
1965     strings nested in f-strings (yet).
1966
1967     Note: Mutates its argument.
1968     """
1969     value = leaf.value.lstrip("furbFURB")
1970     if value[:3] == '"""':
1971         return
1972
1973     elif value[:3] == "'''":
1974         orig_quote = "'''"
1975         new_quote = '"""'
1976     elif value[0] == '"':
1977         orig_quote = '"'
1978         new_quote = "'"
1979     else:
1980         orig_quote = "'"
1981         new_quote = '"'
1982     first_quote_pos = leaf.value.find(orig_quote)
1983     if first_quote_pos == -1:
1984         return  # There's an internal error
1985
1986     prefix = leaf.value[:first_quote_pos]
1987     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
1988     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
1989     escaped_orig_quote = re.compile(rf"\\(\\\\)*{orig_quote}")
1990     if "r" in prefix.casefold():
1991         if unescaped_new_quote.search(body):
1992             # There's at least one unescaped new_quote in this raw string
1993             # so converting is impossible
1994             return
1995
1996         # Do not introduce or remove backslashes in raw strings
1997         new_body = body
1998     else:
1999         new_body = escaped_orig_quote.sub(rf"\1{orig_quote}", body)
2000         new_body = unescaped_new_quote.sub(rf"\1\\{new_quote}", new_body)
2001         # Add escapes again for consecutive occurences of new_quote (sub
2002         # doesn't match overlapping substrings).
2003         new_body = unescaped_new_quote.sub(rf"\1\\{new_quote}", new_body)
2004     if new_quote == '"""' and new_body[-1] == '"':
2005         # edge case:
2006         new_body = new_body[:-1] + '\\"'
2007     orig_escape_count = body.count("\\")
2008     new_escape_count = new_body.count("\\")
2009     if new_escape_count > orig_escape_count:
2010         return  # Do not introduce more escaping
2011
2012     if new_escape_count == orig_escape_count and orig_quote == '"':
2013         return  # Prefer double quotes
2014
2015     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2016
2017
2018 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2019     """Make existing optional parentheses invisible or create new ones.
2020
2021     Standardizes on visible parentheses for single-element tuples, and keeps
2022     existing visible parentheses for other tuples and generator expressions.
2023     """
2024     check_lpar = False
2025     for child in list(node.children):
2026         if check_lpar:
2027             if child.type == syms.atom:
2028                 if not (
2029                     is_empty_tuple(child)
2030                     or is_one_tuple(child)
2031                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2032                 ):
2033                     first = child.children[0]
2034                     last = child.children[-1]
2035                     if first.type == token.LPAR and last.type == token.RPAR:
2036                         # make parentheses invisible
2037                         first.value = ""  # type: ignore
2038                         last.value = ""  # type: ignore
2039             elif is_one_tuple(child):
2040                 # wrap child in visible parentheses
2041                 lpar = Leaf(token.LPAR, "(")
2042                 rpar = Leaf(token.RPAR, ")")
2043                 index = child.remove() or 0
2044                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2045             else:
2046                 # wrap child in invisible parentheses
2047                 lpar = Leaf(token.LPAR, "")
2048                 rpar = Leaf(token.RPAR, "")
2049                 index = child.remove() or 0
2050                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2051
2052         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2053
2054
2055 def is_empty_tuple(node: LN) -> bool:
2056     """Return True if `node` holds an empty tuple."""
2057     return (
2058         node.type == syms.atom
2059         and len(node.children) == 2
2060         and node.children[0].type == token.LPAR
2061         and node.children[1].type == token.RPAR
2062     )
2063
2064
2065 def is_one_tuple(node: LN) -> bool:
2066     """Return True if `node` holds a tuple with one element, with or without parens."""
2067     if node.type == syms.atom:
2068         if len(node.children) != 3:
2069             return False
2070
2071         lpar, gexp, rpar = node.children
2072         if not (
2073             lpar.type == token.LPAR
2074             and gexp.type == syms.testlist_gexp
2075             and rpar.type == token.RPAR
2076         ):
2077             return False
2078
2079         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2080
2081     return (
2082         node.type in IMPLICIT_TUPLE
2083         and len(node.children) == 2
2084         and node.children[1].type == token.COMMA
2085     )
2086
2087
2088 def max_delimiter_priority_in_atom(node: LN) -> int:
2089     if node.type != syms.atom:
2090         return 0
2091
2092     first = node.children[0]
2093     last = node.children[-1]
2094     if not (first.type == token.LPAR and last.type == token.RPAR):
2095         return 0
2096
2097     bt = BracketTracker()
2098     for c in node.children[1:-1]:
2099         if isinstance(c, Leaf):
2100             bt.mark(c)
2101         else:
2102             for leaf in c.leaves():
2103                 bt.mark(leaf)
2104     try:
2105         return bt.max_delimiter_priority()
2106
2107     except ValueError:
2108         return 0
2109
2110
2111 def ensure_visible(leaf: Leaf) -> None:
2112     """Make sure parentheses are visible.
2113
2114     They could be invisible as part of some statements (see
2115     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2116     """
2117     if leaf.type == token.LPAR:
2118         leaf.value = "("
2119     elif leaf.type == token.RPAR:
2120         leaf.value = ")"
2121
2122
2123 def is_python36(node: Node) -> bool:
2124     """Return True if the current file is using Python 3.6+ features.
2125
2126     Currently looking for:
2127     - f-strings; and
2128     - trailing commas after * or ** in function signatures.
2129     """
2130     for n in node.pre_order():
2131         if n.type == token.STRING:
2132             value_head = n.value[:2]  # type: ignore
2133             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2134                 return True
2135
2136         elif (
2137             n.type == syms.typedargslist
2138             and n.children
2139             and n.children[-1].type == token.COMMA
2140         ):
2141             for ch in n.children:
2142                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
2143                     return True
2144
2145     return False
2146
2147
2148 PYTHON_EXTENSIONS = {".py"}
2149 BLACKLISTED_DIRECTORIES = {
2150     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2151 }
2152
2153
2154 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2155     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2156     and have one of the PYTHON_EXTENSIONS.
2157     """
2158     for child in path.iterdir():
2159         if child.is_dir():
2160             if child.name in BLACKLISTED_DIRECTORIES:
2161                 continue
2162
2163             yield from gen_python_files_in_dir(child)
2164
2165         elif child.suffix in PYTHON_EXTENSIONS:
2166             yield child
2167
2168
2169 @dataclass
2170 class Report:
2171     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2172     check: bool = False
2173     quiet: bool = False
2174     change_count: int = 0
2175     same_count: int = 0
2176     failure_count: int = 0
2177
2178     def done(self, src: Path, changed: bool) -> None:
2179         """Increment the counter for successful reformatting. Write out a message."""
2180         if changed:
2181             reformatted = "would reformat" if self.check else "reformatted"
2182             if not self.quiet:
2183                 out(f"{reformatted} {src}")
2184             self.change_count += 1
2185         else:
2186             if not self.quiet:
2187                 out(f"{src} already well formatted, good job.", bold=False)
2188             self.same_count += 1
2189
2190     def failed(self, src: Path, message: str) -> None:
2191         """Increment the counter for failed reformatting. Write out a message."""
2192         err(f"error: cannot format {src}: {message}")
2193         self.failure_count += 1
2194
2195     @property
2196     def return_code(self) -> int:
2197         """Return the exit code that the app should use.
2198
2199         This considers the current state of changed files and failures:
2200         - if there were any failures, return 123;
2201         - if any files were changed and --check is being used, return 1;
2202         - otherwise return 0.
2203         """
2204         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2205         # 126 we have special returncodes reserved by the shell.
2206         if self.failure_count:
2207             return 123
2208
2209         elif self.change_count and self.check:
2210             return 1
2211
2212         return 0
2213
2214     def __str__(self) -> str:
2215         """Render a color report of the current state.
2216
2217         Use `click.unstyle` to remove colors.
2218         """
2219         if self.check:
2220             reformatted = "would be reformatted"
2221             unchanged = "would be left unchanged"
2222             failed = "would fail to reformat"
2223         else:
2224             reformatted = "reformatted"
2225             unchanged = "left unchanged"
2226             failed = "failed to reformat"
2227         report = []
2228         if self.change_count:
2229             s = "s" if self.change_count > 1 else ""
2230             report.append(
2231                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2232             )
2233         if self.same_count:
2234             s = "s" if self.same_count > 1 else ""
2235             report.append(f"{self.same_count} file{s} {unchanged}")
2236         if self.failure_count:
2237             s = "s" if self.failure_count > 1 else ""
2238             report.append(
2239                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2240             )
2241         return ", ".join(report) + "."
2242
2243
2244 def assert_equivalent(src: str, dst: str) -> None:
2245     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2246
2247     import ast
2248     import traceback
2249
2250     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2251         """Simple visitor generating strings to compare ASTs by content."""
2252         yield f"{'  ' * depth}{node.__class__.__name__}("
2253
2254         for field in sorted(node._fields):
2255             try:
2256                 value = getattr(node, field)
2257             except AttributeError:
2258                 continue
2259
2260             yield f"{'  ' * (depth+1)}{field}="
2261
2262             if isinstance(value, list):
2263                 for item in value:
2264                     if isinstance(item, ast.AST):
2265                         yield from _v(item, depth + 2)
2266
2267             elif isinstance(value, ast.AST):
2268                 yield from _v(value, depth + 2)
2269
2270             else:
2271                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2272
2273         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2274
2275     try:
2276         src_ast = ast.parse(src)
2277     except Exception as exc:
2278         major, minor = sys.version_info[:2]
2279         raise AssertionError(
2280             f"cannot use --safe with this file; failed to parse source file "
2281             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2282             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2283         )
2284
2285     try:
2286         dst_ast = ast.parse(dst)
2287     except Exception as exc:
2288         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2289         raise AssertionError(
2290             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2291             f"Please report a bug on https://github.com/ambv/black/issues.  "
2292             f"This invalid output might be helpful: {log}"
2293         ) from None
2294
2295     src_ast_str = "\n".join(_v(src_ast))
2296     dst_ast_str = "\n".join(_v(dst_ast))
2297     if src_ast_str != dst_ast_str:
2298         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2299         raise AssertionError(
2300             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2301             f"the source.  "
2302             f"Please report a bug on https://github.com/ambv/black/issues.  "
2303             f"This diff might be helpful: {log}"
2304         ) from None
2305
2306
2307 def assert_stable(src: str, dst: str, line_length: int) -> None:
2308     """Raise AssertionError if `dst` reformats differently the second time."""
2309     newdst = format_str(dst, line_length=line_length)
2310     if dst != newdst:
2311         log = dump_to_file(
2312             diff(src, dst, "source", "first pass"),
2313             diff(dst, newdst, "first pass", "second pass"),
2314         )
2315         raise AssertionError(
2316             f"INTERNAL ERROR: Black produced different code on the second pass "
2317             f"of the formatter.  "
2318             f"Please report a bug on https://github.com/ambv/black/issues.  "
2319             f"This diff might be helpful: {log}"
2320         ) from None
2321
2322
2323 def dump_to_file(*output: str) -> str:
2324     """Dump `output` to a temporary file. Return path to the file."""
2325     import tempfile
2326
2327     with tempfile.NamedTemporaryFile(
2328         mode="w", prefix="blk_", suffix=".log", delete=False
2329     ) as f:
2330         for lines in output:
2331             f.write(lines)
2332             if lines and lines[-1] != "\n":
2333                 f.write("\n")
2334     return f.name
2335
2336
2337 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2338     """Return a unified diff string between strings `a` and `b`."""
2339     import difflib
2340
2341     a_lines = [line + "\n" for line in a.split("\n")]
2342     b_lines = [line + "\n" for line in b.split("\n")]
2343     return "".join(
2344         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2345     )
2346
2347
2348 def cancel(tasks: List[asyncio.Task]) -> None:
2349     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2350     err("Aborted!")
2351     for task in tasks:
2352         task.cancel()
2353
2354
2355 def shutdown(loop: BaseEventLoop) -> None:
2356     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2357     try:
2358         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2359         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2360         if not to_cancel:
2361             return
2362
2363         for task in to_cancel:
2364             task.cancel()
2365         loop.run_until_complete(
2366             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2367         )
2368     finally:
2369         # `concurrent.futures.Future` objects cannot be cancelled once they
2370         # are already running. There might be some when the `shutdown()` happened.
2371         # Silence their logger's spew about the event loop being closed.
2372         cf_logger = logging.getLogger("concurrent.futures")
2373         cf_logger.setLevel(logging.CRITICAL)
2374         loop.close()
2375
2376
2377 if __name__ == "__main__":
2378     main()