]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

18.4a2
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from enum import Enum
7 from functools import partial, wraps
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import re
14 import tokenize
15 import signal
16 import sys
17 from typing import (
18     Any,
19     Callable,
20     Collection,
21     Dict,
22     Generic,
23     Iterable,
24     Iterator,
25     List,
26     Optional,
27     Pattern,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33 )
34
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a2"
45 DEFAULT_LINE_LENGTH = 88
46 # types
47 syms = pygram.python_symbols
48 FileContent = str
49 Encoding = str
50 Depth = int
51 NodeType = int
52 LeafID = int
53 Priority = int
54 Index = int
55 LN = Union[Leaf, Node]
56 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
57 out = partial(click.secho, bold=True, err=True)
58 err = partial(click.secho, fg="red", err=True)
59
60
61 class NothingChanged(UserWarning):
62     """Raised by :func:`format_file` when reformatted code is the same as source."""
63
64
65 class CannotSplit(Exception):
66     """A readable split that fits the allotted line length is impossible.
67
68     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
69     :func:`delimiter_split`.
70     """
71
72
73 class FormatError(Exception):
74     """Base exception for `# fmt: on` and `# fmt: off` handling.
75
76     It holds the number of bytes of the prefix consumed before the format
77     control comment appeared.
78     """
79
80     def __init__(self, consumed: int) -> None:
81         super().__init__(consumed)
82         self.consumed = consumed
83
84     def trim_prefix(self, leaf: Leaf) -> None:
85         leaf.prefix = leaf.prefix[self.consumed:]
86
87     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
88         """Returns a new Leaf from the consumed part of the prefix."""
89         unformatted_prefix = leaf.prefix[:self.consumed]
90         return Leaf(token.NEWLINE, unformatted_prefix)
91
92
93 class FormatOn(FormatError):
94     """Found a comment like `# fmt: on` in the file."""
95
96
97 class FormatOff(FormatError):
98     """Found a comment like `# fmt: off` in the file."""
99
100
101 class WriteBack(Enum):
102     NO = 0
103     YES = 1
104     DIFF = 2
105
106
107 @click.command()
108 @click.option(
109     "-l",
110     "--line-length",
111     type=int,
112     default=DEFAULT_LINE_LENGTH,
113     help="How many character per line to allow.",
114     show_default=True,
115 )
116 @click.option(
117     "--check",
118     is_flag=True,
119     help=(
120         "Don't write the files back, just return the status.  Return code 0 "
121         "means nothing would change.  Return code 1 means some files would be "
122         "reformatted.  Return code 123 means there was an internal error."
123     ),
124 )
125 @click.option(
126     "--diff",
127     is_flag=True,
128     help="Don't write the files back, just output a diff for each file on stdout.",
129 )
130 @click.option(
131     "--fast/--safe",
132     is_flag=True,
133     help="If --fast given, skip temporary sanity checks. [default: --safe]",
134 )
135 @click.option(
136     "-q",
137     "--quiet",
138     is_flag=True,
139     help=(
140         "Don't emit non-error messages to stderr. Errors are still emitted, "
141         "silence those with 2>/dev/null."
142     ),
143 )
144 @click.version_option(version=__version__)
145 @click.argument(
146     "src",
147     nargs=-1,
148     type=click.Path(
149         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
150     ),
151 )
152 @click.pass_context
153 def main(
154     ctx: click.Context,
155     line_length: int,
156     check: bool,
157     diff: bool,
158     fast: bool,
159     quiet: bool,
160     src: List[str],
161 ) -> None:
162     """The uncompromising code formatter."""
163     sources: List[Path] = []
164     for s in src:
165         p = Path(s)
166         if p.is_dir():
167             sources.extend(gen_python_files_in_dir(p))
168         elif p.is_file():
169             # if a file was explicitly given, we don't care about its extension
170             sources.append(p)
171         elif s == "-":
172             sources.append(Path("-"))
173         else:
174             err(f"invalid path: {s}")
175     if check and diff:
176         exc = click.ClickException("Options --check and --diff are mutually exclusive")
177         exc.exit_code = 2
178         raise exc
179
180     if check:
181         write_back = WriteBack.NO
182     elif diff:
183         write_back = WriteBack.DIFF
184     else:
185         write_back = WriteBack.YES
186     if len(sources) == 0:
187         ctx.exit(0)
188     elif len(sources) == 1:
189         p = sources[0]
190         report = Report(check=check, quiet=quiet)
191         try:
192             if not p.is_file() and str(p) == "-":
193                 changed = format_stdin_to_stdout(
194                     line_length=line_length, fast=fast, write_back=write_back
195                 )
196             else:
197                 changed = format_file_in_place(
198                     p, line_length=line_length, fast=fast, write_back=write_back
199                 )
200             report.done(p, changed)
201         except Exception as exc:
202             report.failed(p, str(exc))
203         ctx.exit(report.return_code)
204     else:
205         loop = asyncio.get_event_loop()
206         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
207         return_code = 1
208         try:
209             return_code = loop.run_until_complete(
210                 schedule_formatting(
211                     sources, line_length, write_back, fast, quiet, loop, executor
212                 )
213             )
214         finally:
215             shutdown(loop)
216             ctx.exit(return_code)
217
218
219 async def schedule_formatting(
220     sources: List[Path],
221     line_length: int,
222     write_back: WriteBack,
223     fast: bool,
224     quiet: bool,
225     loop: BaseEventLoop,
226     executor: Executor,
227 ) -> int:
228     """Run formatting of `sources` in parallel using the provided `executor`.
229
230     (Use ProcessPoolExecutors for actual parallelism.)
231
232     `line_length`, `write_back`, and `fast` options are passed to
233     :func:`format_file_in_place`.
234     """
235     lock = None
236     if write_back == WriteBack.DIFF:
237         # For diff output, we need locks to ensure we don't interleave output
238         # from different processes.
239         manager = Manager()
240         lock = manager.Lock()
241     tasks = {
242         src: loop.run_in_executor(
243             executor, format_file_in_place, src, line_length, fast, write_back, lock
244         )
245         for src in sources
246     }
247     _task_values = list(tasks.values())
248     loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
249     loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
250     await asyncio.wait(tasks.values())
251     cancelled = []
252     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
253     for src, task in tasks.items():
254         if not task.done():
255             report.failed(src, "timed out, cancelling")
256             task.cancel()
257             cancelled.append(task)
258         elif task.cancelled():
259             cancelled.append(task)
260         elif task.exception():
261             report.failed(src, str(task.exception()))
262         else:
263             report.done(src, task.result())
264     if cancelled:
265         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
266     elif not quiet:
267         out("All done! ✨ 🍰 ✨")
268     if not quiet:
269         click.echo(str(report))
270     return report.return_code
271
272
273 def format_file_in_place(
274     src: Path,
275     line_length: int,
276     fast: bool,
277     write_back: WriteBack = WriteBack.NO,
278     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
279 ) -> bool:
280     """Format file under `src` path. Return True if changed.
281
282     If `write_back` is True, write reformatted code back to stdout.
283     `line_length` and `fast` options are passed to :func:`format_file_contents`.
284     """
285     with tokenize.open(src) as src_buffer:
286         src_contents = src_buffer.read()
287     try:
288         dst_contents = format_file_contents(
289             src_contents, line_length=line_length, fast=fast
290         )
291     except NothingChanged:
292         return False
293
294     if write_back == write_back.YES:
295         with open(src, "w", encoding=src_buffer.encoding) as f:
296             f.write(dst_contents)
297     elif write_back == write_back.DIFF:
298         src_name = f"{src.name}  (original)"
299         dst_name = f"{src.name}  (formatted)"
300         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
301         if lock:
302             lock.acquire()
303         try:
304             sys.stdout.write(diff_contents)
305         finally:
306             if lock:
307                 lock.release()
308     return True
309
310
311 def format_stdin_to_stdout(
312     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
313 ) -> bool:
314     """Format file on stdin. Return True if changed.
315
316     If `write_back` is True, write reformatted code back to stdout.
317     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
318     """
319     src = sys.stdin.read()
320     dst = src
321     try:
322         dst = format_file_contents(src, line_length=line_length, fast=fast)
323         return True
324
325     except NothingChanged:
326         return False
327
328     finally:
329         if write_back == WriteBack.YES:
330             sys.stdout.write(dst)
331         elif write_back == WriteBack.DIFF:
332             src_name = "<stdin>  (original)"
333             dst_name = "<stdin>  (formatted)"
334             sys.stdout.write(diff(src, dst, src_name, dst_name))
335
336
337 def format_file_contents(
338     src_contents: str, line_length: int, fast: bool
339 ) -> FileContent:
340     """Reformat contents a file and return new contents.
341
342     If `fast` is False, additionally confirm that the reformatted code is
343     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
344     `line_length` is passed to :func:`format_str`.
345     """
346     if src_contents.strip() == "":
347         raise NothingChanged
348
349     dst_contents = format_str(src_contents, line_length=line_length)
350     if src_contents == dst_contents:
351         raise NothingChanged
352
353     if not fast:
354         assert_equivalent(src_contents, dst_contents)
355         assert_stable(src_contents, dst_contents, line_length=line_length)
356     return dst_contents
357
358
359 def format_str(src_contents: str, line_length: int) -> FileContent:
360     """Reformat a string and return new contents.
361
362     `line_length` determines how many characters per line are allowed.
363     """
364     src_node = lib2to3_parse(src_contents)
365     dst_contents = ""
366     lines = LineGenerator()
367     elt = EmptyLineTracker()
368     py36 = is_python36(src_node)
369     empty_line = Line()
370     after = 0
371     for current_line in lines.visit(src_node):
372         for _ in range(after):
373             dst_contents += str(empty_line)
374         before, after = elt.maybe_empty_lines(current_line)
375         for _ in range(before):
376             dst_contents += str(empty_line)
377         for line in split_line(current_line, line_length=line_length, py36=py36):
378             dst_contents += str(line)
379     return dst_contents
380
381
382 GRAMMARS = [
383     pygram.python_grammar_no_print_statement_no_exec_statement,
384     pygram.python_grammar_no_print_statement,
385     pygram.python_grammar_no_exec_statement,
386     pygram.python_grammar,
387 ]
388
389
390 def lib2to3_parse(src_txt: str) -> Node:
391     """Given a string with source, return the lib2to3 Node."""
392     grammar = pygram.python_grammar_no_print_statement
393     if src_txt[-1] != "\n":
394         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
395         src_txt += nl
396     for grammar in GRAMMARS:
397         drv = driver.Driver(grammar, pytree.convert)
398         try:
399             result = drv.parse_string(src_txt, True)
400             break
401
402         except ParseError as pe:
403             lineno, column = pe.context[1]
404             lines = src_txt.splitlines()
405             try:
406                 faulty_line = lines[lineno - 1]
407             except IndexError:
408                 faulty_line = "<line number missing in source>"
409             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
410     else:
411         raise exc from None
412
413     if isinstance(result, Leaf):
414         result = Node(syms.file_input, [result])
415     return result
416
417
418 def lib2to3_unparse(node: Node) -> str:
419     """Given a lib2to3 node, return its string representation."""
420     code = str(node)
421     return code
422
423
424 T = TypeVar("T")
425
426
427 class Visitor(Generic[T]):
428     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
429
430     def visit(self, node: LN) -> Iterator[T]:
431         """Main method to visit `node` and its children.
432
433         It tries to find a `visit_*()` method for the given `node.type`, like
434         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
435         If no dedicated `visit_*()` method is found, chooses `visit_default()`
436         instead.
437
438         Then yields objects of type `T` from the selected visitor.
439         """
440         if node.type < 256:
441             name = token.tok_name[node.type]
442         else:
443             name = type_repr(node.type)
444         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
445
446     def visit_default(self, node: LN) -> Iterator[T]:
447         """Default `visit_*()` implementation. Recurses to children of `node`."""
448         if isinstance(node, Node):
449             for child in node.children:
450                 yield from self.visit(child)
451
452
453 @dataclass
454 class DebugVisitor(Visitor[T]):
455     tree_depth: int = 0
456
457     def visit_default(self, node: LN) -> Iterator[T]:
458         indent = " " * (2 * self.tree_depth)
459         if isinstance(node, Node):
460             _type = type_repr(node.type)
461             out(f"{indent}{_type}", fg="yellow")
462             self.tree_depth += 1
463             for child in node.children:
464                 yield from self.visit(child)
465
466             self.tree_depth -= 1
467             out(f"{indent}/{_type}", fg="yellow", bold=False)
468         else:
469             _type = token.tok_name.get(node.type, str(node.type))
470             out(f"{indent}{_type}", fg="blue", nl=False)
471             if node.prefix:
472                 # We don't have to handle prefixes for `Node` objects since
473                 # that delegates to the first child anyway.
474                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
475             out(f" {node.value!r}", fg="blue", bold=False)
476
477     @classmethod
478     def show(cls, code: str) -> None:
479         """Pretty-print the lib2to3 AST of a given string of `code`.
480
481         Convenience method for debugging.
482         """
483         v: DebugVisitor[None] = DebugVisitor()
484         list(v.visit(lib2to3_parse(code)))
485
486
487 KEYWORDS = set(keyword.kwlist)
488 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
489 FLOW_CONTROL = {"return", "raise", "break", "continue"}
490 STATEMENT = {
491     syms.if_stmt,
492     syms.while_stmt,
493     syms.for_stmt,
494     syms.try_stmt,
495     syms.except_clause,
496     syms.with_stmt,
497     syms.funcdef,
498     syms.classdef,
499 }
500 STANDALONE_COMMENT = 153
501 LOGIC_OPERATORS = {"and", "or"}
502 COMPARATORS = {
503     token.LESS,
504     token.GREATER,
505     token.EQEQUAL,
506     token.NOTEQUAL,
507     token.LESSEQUAL,
508     token.GREATEREQUAL,
509 }
510 MATH_OPERATORS = {
511     token.PLUS,
512     token.MINUS,
513     token.STAR,
514     token.SLASH,
515     token.VBAR,
516     token.AMPER,
517     token.PERCENT,
518     token.CIRCUMFLEX,
519     token.TILDE,
520     token.LEFTSHIFT,
521     token.RIGHTSHIFT,
522     token.DOUBLESTAR,
523     token.DOUBLESLASH,
524 }
525 VARARGS = {token.STAR, token.DOUBLESTAR}
526 COMPREHENSION_PRIORITY = 20
527 COMMA_PRIORITY = 10
528 LOGIC_PRIORITY = 5
529 STRING_PRIORITY = 4
530 COMPARATOR_PRIORITY = 3
531 MATH_PRIORITY = 1
532
533
534 @dataclass
535 class BracketTracker:
536     """Keeps track of brackets on a line."""
537
538     depth: int = 0
539     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
540     delimiters: Dict[LeafID, Priority] = Factory(dict)
541     previous: Optional[Leaf] = None
542
543     def mark(self, leaf: Leaf) -> None:
544         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
545
546         All leaves receive an int `bracket_depth` field that stores how deep
547         within brackets a given leaf is. 0 means there are no enclosing brackets
548         that started on this line.
549
550         If a leaf is itself a closing bracket, it receives an `opening_bracket`
551         field that it forms a pair with. This is a one-directional link to
552         avoid reference cycles.
553
554         If a leaf is a delimiter (a token on which Black can split the line if
555         needed) and it's on depth 0, its `id()` is stored in the tracker's
556         `delimiters` field.
557         """
558         if leaf.type == token.COMMENT:
559             return
560
561         if leaf.type in CLOSING_BRACKETS:
562             self.depth -= 1
563             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
564             leaf.opening_bracket = opening_bracket
565         leaf.bracket_depth = self.depth
566         if self.depth == 0:
567             delim = is_split_before_delimiter(leaf, self.previous)
568             if delim and self.previous is not None:
569                 self.delimiters[id(self.previous)] = delim
570             else:
571                 delim = is_split_after_delimiter(leaf, self.previous)
572                 if delim:
573                     self.delimiters[id(leaf)] = delim
574         if leaf.type in OPENING_BRACKETS:
575             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
576             self.depth += 1
577         self.previous = leaf
578
579     def any_open_brackets(self) -> bool:
580         """Return True if there is an yet unmatched open bracket on the line."""
581         return bool(self.bracket_match)
582
583     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
584         """Return the highest priority of a delimiter found on the line.
585
586         Values are consistent with what `is_delimiter()` returns.
587         Raises ValueError on no delimiters.
588         """
589         return max(v for k, v in self.delimiters.items() if k not in exclude)
590
591
592 @dataclass
593 class Line:
594     """Holds leaves and comments. Can be printed with `str(line)`."""
595
596     depth: int = 0
597     leaves: List[Leaf] = Factory(list)
598     comments: List[Tuple[Index, Leaf]] = Factory(list)
599     bracket_tracker: BracketTracker = Factory(BracketTracker)
600     inside_brackets: bool = False
601     has_for: bool = False
602     _for_loop_variable: bool = False
603
604     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
605         """Add a new `leaf` to the end of the line.
606
607         Unless `preformatted` is True, the `leaf` will receive a new consistent
608         whitespace prefix and metadata applied by :class:`BracketTracker`.
609         Trailing commas are maybe removed, unpacked for loop variables are
610         demoted from being delimiters.
611
612         Inline comments are put aside.
613         """
614         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
615         if not has_value:
616             return
617
618         if self.leaves and not preformatted:
619             # Note: at this point leaf.prefix should be empty except for
620             # imports, for which we only preserve newlines.
621             leaf.prefix += whitespace(leaf)
622         if self.inside_brackets or not preformatted:
623             self.maybe_decrement_after_for_loop_variable(leaf)
624             self.bracket_tracker.mark(leaf)
625             self.maybe_remove_trailing_comma(leaf)
626             self.maybe_increment_for_loop_variable(leaf)
627
628         if not self.append_comment(leaf):
629             self.leaves.append(leaf)
630
631     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
632         """Like :func:`append()` but disallow invalid standalone comment structure.
633
634         Raises ValueError when any `leaf` is appended after a standalone comment
635         or when a standalone comment is not the first leaf on the line.
636         """
637         if self.bracket_tracker.depth == 0:
638             if self.is_comment:
639                 raise ValueError("cannot append to standalone comments")
640
641             if self.leaves and leaf.type == STANDALONE_COMMENT:
642                 raise ValueError(
643                     "cannot append standalone comments to a populated line"
644                 )
645
646         self.append(leaf, preformatted=preformatted)
647
648     @property
649     def is_comment(self) -> bool:
650         """Is this line a standalone comment?"""
651         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
652
653     @property
654     def is_decorator(self) -> bool:
655         """Is this line a decorator?"""
656         return bool(self) and self.leaves[0].type == token.AT
657
658     @property
659     def is_import(self) -> bool:
660         """Is this an import line?"""
661         return bool(self) and is_import(self.leaves[0])
662
663     @property
664     def is_class(self) -> bool:
665         """Is this line a class definition?"""
666         return (
667             bool(self)
668             and self.leaves[0].type == token.NAME
669             and self.leaves[0].value == "class"
670         )
671
672     @property
673     def is_def(self) -> bool:
674         """Is this a function definition? (Also returns True for async defs.)"""
675         try:
676             first_leaf = self.leaves[0]
677         except IndexError:
678             return False
679
680         try:
681             second_leaf: Optional[Leaf] = self.leaves[1]
682         except IndexError:
683             second_leaf = None
684         return (
685             (first_leaf.type == token.NAME and first_leaf.value == "def")
686             or (
687                 first_leaf.type == token.ASYNC
688                 and second_leaf is not None
689                 and second_leaf.type == token.NAME
690                 and second_leaf.value == "def"
691             )
692         )
693
694     @property
695     def is_flow_control(self) -> bool:
696         """Is this line a flow control statement?
697
698         Those are `return`, `raise`, `break`, and `continue`.
699         """
700         return (
701             bool(self)
702             and self.leaves[0].type == token.NAME
703             and self.leaves[0].value in FLOW_CONTROL
704         )
705
706     @property
707     def is_yield(self) -> bool:
708         """Is this line a yield statement?"""
709         return (
710             bool(self)
711             and self.leaves[0].type == token.NAME
712             and self.leaves[0].value == "yield"
713         )
714
715     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
716         """If so, needs to be split before emitting."""
717         for leaf in self.leaves:
718             if leaf.type == STANDALONE_COMMENT:
719                 if leaf.bracket_depth <= depth_limit:
720                     return True
721
722         return False
723
724     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
725         """Remove trailing comma if there is one and it's safe."""
726         if not (
727             self.leaves
728             and self.leaves[-1].type == token.COMMA
729             and closing.type in CLOSING_BRACKETS
730         ):
731             return False
732
733         if closing.type == token.RBRACE:
734             self.remove_trailing_comma()
735             return True
736
737         if closing.type == token.RSQB:
738             comma = self.leaves[-1]
739             if comma.parent and comma.parent.type == syms.listmaker:
740                 self.remove_trailing_comma()
741                 return True
742
743         # For parens let's check if it's safe to remove the comma.  If the
744         # trailing one is the only one, we might mistakenly change a tuple
745         # into a different type by removing the comma.
746         depth = closing.bracket_depth + 1
747         commas = 0
748         opening = closing.opening_bracket
749         for _opening_index, leaf in enumerate(self.leaves):
750             if leaf is opening:
751                 break
752
753         else:
754             return False
755
756         for leaf in self.leaves[_opening_index + 1:]:
757             if leaf is closing:
758                 break
759
760             bracket_depth = leaf.bracket_depth
761             if bracket_depth == depth and leaf.type == token.COMMA:
762                 commas += 1
763                 if leaf.parent and leaf.parent.type == syms.arglist:
764                     commas += 1
765                     break
766
767         if commas > 1:
768             self.remove_trailing_comma()
769             return True
770
771         return False
772
773     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
774         """In a for loop, or comprehension, the variables are often unpacks.
775
776         To avoid splitting on the comma in this situation, increase the depth of
777         tokens between `for` and `in`.
778         """
779         if leaf.type == token.NAME and leaf.value == "for":
780             self.has_for = True
781             self.bracket_tracker.depth += 1
782             self._for_loop_variable = True
783             return True
784
785         return False
786
787     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
788         """See `maybe_increment_for_loop_variable` above for explanation."""
789         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
790             self.bracket_tracker.depth -= 1
791             self._for_loop_variable = False
792             return True
793
794         return False
795
796     def append_comment(self, comment: Leaf) -> bool:
797         """Add an inline or standalone comment to the line."""
798         if (
799             comment.type == STANDALONE_COMMENT
800             and self.bracket_tracker.any_open_brackets()
801         ):
802             comment.prefix = ""
803             return False
804
805         if comment.type != token.COMMENT:
806             return False
807
808         after = len(self.leaves) - 1
809         if after == -1:
810             comment.type = STANDALONE_COMMENT
811             comment.prefix = ""
812             return False
813
814         else:
815             self.comments.append((after, comment))
816             return True
817
818     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
819         """Generate comments that should appear directly after `leaf`."""
820         for _leaf_index, _leaf in enumerate(self.leaves):
821             if leaf is _leaf:
822                 break
823
824         else:
825             return
826
827         for index, comment_after in self.comments:
828             if _leaf_index == index:
829                 yield comment_after
830
831     def remove_trailing_comma(self) -> None:
832         """Remove the trailing comma and moves the comments attached to it."""
833         comma_index = len(self.leaves) - 1
834         for i in range(len(self.comments)):
835             comment_index, comment = self.comments[i]
836             if comment_index == comma_index:
837                 self.comments[i] = (comma_index - 1, comment)
838         self.leaves.pop()
839
840     def __str__(self) -> str:
841         """Render the line."""
842         if not self:
843             return "\n"
844
845         indent = "    " * self.depth
846         leaves = iter(self.leaves)
847         first = next(leaves)
848         res = f"{first.prefix}{indent}{first.value}"
849         for leaf in leaves:
850             res += str(leaf)
851         for _, comment in self.comments:
852             res += str(comment)
853         return res + "\n"
854
855     def __bool__(self) -> bool:
856         """Return True if the line has leaves or comments."""
857         return bool(self.leaves or self.comments)
858
859
860 class UnformattedLines(Line):
861     """Just like :class:`Line` but stores lines which aren't reformatted."""
862
863     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
864         """Just add a new `leaf` to the end of the lines.
865
866         The `preformatted` argument is ignored.
867
868         Keeps track of indentation `depth`, which is useful when the user
869         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
870         """
871         try:
872             list(generate_comments(leaf))
873         except FormatOn as f_on:
874             self.leaves.append(f_on.leaf_from_consumed(leaf))
875             raise
876
877         self.leaves.append(leaf)
878         if leaf.type == token.INDENT:
879             self.depth += 1
880         elif leaf.type == token.DEDENT:
881             self.depth -= 1
882
883     def __str__(self) -> str:
884         """Render unformatted lines from leaves which were added with `append()`.
885
886         `depth` is not used for indentation in this case.
887         """
888         if not self:
889             return "\n"
890
891         res = ""
892         for leaf in self.leaves:
893             res += str(leaf)
894         return res
895
896     def append_comment(self, comment: Leaf) -> bool:
897         """Not implemented in this class. Raises `NotImplementedError`."""
898         raise NotImplementedError("Unformatted lines don't store comments separately.")
899
900     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
901         """Does nothing and returns False."""
902         return False
903
904     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
905         """Does nothing and returns False."""
906         return False
907
908
909 @dataclass
910 class EmptyLineTracker:
911     """Provides a stateful method that returns the number of potential extra
912     empty lines needed before and after the currently processed line.
913
914     Note: this tracker works on lines that haven't been split yet.  It assumes
915     the prefix of the first leaf consists of optional newlines.  Those newlines
916     are consumed by `maybe_empty_lines()` and included in the computation.
917     """
918     previous_line: Optional[Line] = None
919     previous_after: int = 0
920     previous_defs: List[int] = Factory(list)
921
922     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
923         """Return the number of extra empty lines before and after the `current_line`.
924
925         This is for separating `def`, `async def` and `class` with extra empty
926         lines (two on module-level), as well as providing an extra empty line
927         after flow control keywords to make them more prominent.
928         """
929         if isinstance(current_line, UnformattedLines):
930             return 0, 0
931
932         before, after = self._maybe_empty_lines(current_line)
933         before -= self.previous_after
934         self.previous_after = after
935         self.previous_line = current_line
936         return before, after
937
938     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
939         max_allowed = 1
940         if current_line.depth == 0:
941             max_allowed = 2
942         if current_line.leaves:
943             # Consume the first leaf's extra newlines.
944             first_leaf = current_line.leaves[0]
945             before = first_leaf.prefix.count("\n")
946             before = min(before, max_allowed)
947             first_leaf.prefix = ""
948         else:
949             before = 0
950         depth = current_line.depth
951         while self.previous_defs and self.previous_defs[-1] >= depth:
952             self.previous_defs.pop()
953             before = 1 if depth else 2
954         is_decorator = current_line.is_decorator
955         if is_decorator or current_line.is_def or current_line.is_class:
956             if not is_decorator:
957                 self.previous_defs.append(depth)
958             if self.previous_line is None:
959                 # Don't insert empty lines before the first line in the file.
960                 return 0, 0
961
962             if self.previous_line and self.previous_line.is_decorator:
963                 # Don't insert empty lines between decorators.
964                 return 0, 0
965
966             newlines = 2
967             if current_line.depth:
968                 newlines -= 1
969             return newlines, 0
970
971         if current_line.is_flow_control:
972             return before, 1
973
974         if (
975             self.previous_line
976             and self.previous_line.is_import
977             and not current_line.is_import
978             and depth == self.previous_line.depth
979         ):
980             return (before or 1), 0
981
982         if (
983             self.previous_line
984             and self.previous_line.is_yield
985             and (not current_line.is_yield or depth != self.previous_line.depth)
986         ):
987             return (before or 1), 0
988
989         return before, 0
990
991
992 @dataclass
993 class LineGenerator(Visitor[Line]):
994     """Generates reformatted Line objects.  Empty lines are not emitted.
995
996     Note: destroys the tree it's visiting by mutating prefixes of its leaves
997     in ways that will no longer stringify to valid Python code on the tree.
998     """
999     current_line: Line = Factory(Line)
1000
1001     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1002         """Generate a line.
1003
1004         If the line is empty, only emit if it makes sense.
1005         If the line is too long, split it first and then generate.
1006
1007         If any lines were generated, set up a new current_line.
1008         """
1009         if not self.current_line:
1010             if self.current_line.__class__ == type:
1011                 self.current_line.depth += indent
1012             else:
1013                 self.current_line = type(depth=self.current_line.depth + indent)
1014             return  # Line is empty, don't emit. Creating a new one unnecessary.
1015
1016         complete_line = self.current_line
1017         self.current_line = type(depth=complete_line.depth + indent)
1018         yield complete_line
1019
1020     def visit(self, node: LN) -> Iterator[Line]:
1021         """Main method to visit `node` and its children.
1022
1023         Yields :class:`Line` objects.
1024         """
1025         if isinstance(self.current_line, UnformattedLines):
1026             # File contained `# fmt: off`
1027             yield from self.visit_unformatted(node)
1028
1029         else:
1030             yield from super().visit(node)
1031
1032     def visit_default(self, node: LN) -> Iterator[Line]:
1033         """Default `visit_*()` implementation. Recurses to children of `node`."""
1034         if isinstance(node, Leaf):
1035             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1036             try:
1037                 for comment in generate_comments(node):
1038                     if any_open_brackets:
1039                         # any comment within brackets is subject to splitting
1040                         self.current_line.append(comment)
1041                     elif comment.type == token.COMMENT:
1042                         # regular trailing comment
1043                         self.current_line.append(comment)
1044                         yield from self.line()
1045
1046                     else:
1047                         # regular standalone comment
1048                         yield from self.line()
1049
1050                         self.current_line.append(comment)
1051                         yield from self.line()
1052
1053             except FormatOff as f_off:
1054                 f_off.trim_prefix(node)
1055                 yield from self.line(type=UnformattedLines)
1056                 yield from self.visit(node)
1057
1058             except FormatOn as f_on:
1059                 # This only happens here if somebody says "fmt: on" multiple
1060                 # times in a row.
1061                 f_on.trim_prefix(node)
1062                 yield from self.visit_default(node)
1063
1064             else:
1065                 normalize_prefix(node, inside_brackets=any_open_brackets)
1066                 if node.type == token.STRING:
1067                     normalize_string_quotes(node)
1068                 if node.type not in WHITESPACE:
1069                     self.current_line.append(node)
1070         yield from super().visit_default(node)
1071
1072     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1073         """Increase indentation level, maybe yield a line."""
1074         # In blib2to3 INDENT never holds comments.
1075         yield from self.line(+1)
1076         yield from self.visit_default(node)
1077
1078     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1079         """Decrease indentation level, maybe yield a line."""
1080         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1081         yield from self.line(-1)
1082
1083     def visit_stmt(
1084         self, node: Node, keywords: Set[str], parens: Set[str]
1085     ) -> Iterator[Line]:
1086         """Visit a statement.
1087
1088         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1089         `def`, `with`, `class`, and `assert`.
1090
1091         The relevant Python language `keywords` for a given statement will be
1092         NAME leaves within it. This methods puts those on a separate line.
1093
1094         `parens` holds pairs of nodes where invisible parentheses should be put.
1095         Keys hold nodes after which opening parentheses should be put, values
1096         hold nodes before which closing parentheses should be put.
1097         """
1098         normalize_invisible_parens(node, parens_after=parens)
1099         for child in node.children:
1100             if child.type == token.NAME and child.value in keywords:  # type: ignore
1101                 yield from self.line()
1102
1103             yield from self.visit(child)
1104
1105     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1106         """Visit a statement without nested statements."""
1107         is_suite_like = node.parent and node.parent.type in STATEMENT
1108         if is_suite_like:
1109             yield from self.line(+1)
1110             yield from self.visit_default(node)
1111             yield from self.line(-1)
1112
1113         else:
1114             yield from self.line()
1115             yield from self.visit_default(node)
1116
1117     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1118         """Visit `async def`, `async for`, `async with`."""
1119         yield from self.line()
1120
1121         children = iter(node.children)
1122         for child in children:
1123             yield from self.visit(child)
1124
1125             if child.type == token.ASYNC:
1126                 break
1127
1128         internal_stmt = next(children)
1129         for child in internal_stmt.children:
1130             yield from self.visit(child)
1131
1132     def visit_decorators(self, node: Node) -> Iterator[Line]:
1133         """Visit decorators."""
1134         for child in node.children:
1135             yield from self.line()
1136             yield from self.visit(child)
1137
1138     def visit_import_from(self, node: Node) -> Iterator[Line]:
1139         """Visit import_from and maybe put invisible parentheses.
1140
1141         This is separate from `visit_stmt` because import statements don't
1142         support arbitrary atoms and thus handling of parentheses is custom.
1143         """
1144         check_lpar = False
1145         for index, child in enumerate(node.children):
1146             if check_lpar:
1147                 if child.type == token.LPAR:
1148                     # make parentheses invisible
1149                     child.value = ""  # type: ignore
1150                     node.children[-1].value = ""  # type: ignore
1151                 else:
1152                     # insert invisible parentheses
1153                     node.insert_child(index, Leaf(token.LPAR, ""))
1154                     node.append_child(Leaf(token.RPAR, ""))
1155                 break
1156
1157             check_lpar = (
1158                 child.type == token.NAME and child.value == "import"  # type: ignore
1159             )
1160
1161         for child in node.children:
1162             yield from self.visit(child)
1163
1164     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1165         """Remove a semicolon and put the other statement on a separate line."""
1166         yield from self.line()
1167
1168     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1169         """End of file. Process outstanding comments and end with a newline."""
1170         yield from self.visit_default(leaf)
1171         yield from self.line()
1172
1173     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1174         """Used when file contained a `# fmt: off`."""
1175         if isinstance(node, Node):
1176             for child in node.children:
1177                 yield from self.visit(child)
1178
1179         else:
1180             try:
1181                 self.current_line.append(node)
1182             except FormatOn as f_on:
1183                 f_on.trim_prefix(node)
1184                 yield from self.line()
1185                 yield from self.visit(node)
1186
1187             if node.type == token.ENDMARKER:
1188                 # somebody decided not to put a final `# fmt: on`
1189                 yield from self.line()
1190
1191     def __attrs_post_init__(self) -> None:
1192         """You are in a twisty little maze of passages."""
1193         v = self.visit_stmt
1194         Ø: Set[str] = set()
1195         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1196         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1197         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1198         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1199         self.visit_try_stmt = partial(
1200             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1201         )
1202         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1203         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1204         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1205         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1206         self.visit_async_funcdef = self.visit_async_stmt
1207         self.visit_decorated = self.visit_decorators
1208
1209
1210 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1211 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1212 OPENING_BRACKETS = set(BRACKET.keys())
1213 CLOSING_BRACKETS = set(BRACKET.values())
1214 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1215 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1216
1217
1218 def whitespace(leaf: Leaf) -> str:  # noqa C901
1219     """Return whitespace prefix if needed for the given `leaf`."""
1220     NO = ""
1221     SPACE = " "
1222     DOUBLESPACE = "  "
1223     t = leaf.type
1224     p = leaf.parent
1225     v = leaf.value
1226     if t in ALWAYS_NO_SPACE:
1227         return NO
1228
1229     if t == token.COMMENT:
1230         return DOUBLESPACE
1231
1232     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1233     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1234         return NO
1235
1236     prev = leaf.prev_sibling
1237     if not prev:
1238         prevp = preceding_leaf(p)
1239         if not prevp or prevp.type in OPENING_BRACKETS:
1240             return NO
1241
1242         if t == token.COLON:
1243             return SPACE if prevp.type == token.COMMA else NO
1244
1245         if prevp.type == token.EQUAL:
1246             if prevp.parent:
1247                 if prevp.parent.type in {
1248                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1249                 }:
1250                     return NO
1251
1252                 elif prevp.parent.type == syms.typedargslist:
1253                     # A bit hacky: if the equal sign has whitespace, it means we
1254                     # previously found it's a typed argument.  So, we're using
1255                     # that, too.
1256                     return prevp.prefix
1257
1258         elif prevp.type == token.DOUBLESTAR:
1259             if (
1260                 prevp.parent
1261                 and prevp.parent.type in {
1262                     syms.arglist,
1263                     syms.argument,
1264                     syms.dictsetmaker,
1265                     syms.parameters,
1266                     syms.typedargslist,
1267                     syms.varargslist,
1268                 }
1269             ):
1270                 return NO
1271
1272         elif prevp.type == token.COLON:
1273             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1274                 return NO
1275
1276         elif (
1277             prevp.parent
1278             and prevp.parent.type in {syms.factor, syms.star_expr}
1279             and prevp.type in MATH_OPERATORS
1280         ):
1281             return NO
1282
1283         elif (
1284             prevp.type == token.RIGHTSHIFT
1285             and prevp.parent
1286             and prevp.parent.type == syms.shift_expr
1287             and prevp.prev_sibling
1288             and prevp.prev_sibling.type == token.NAME
1289             and prevp.prev_sibling.value == "print"  # type: ignore
1290         ):
1291             # Python 2 print chevron
1292             return NO
1293
1294     elif prev.type in OPENING_BRACKETS:
1295         return NO
1296
1297     if p.type in {syms.parameters, syms.arglist}:
1298         # untyped function signatures or calls
1299         if t == token.RPAR:
1300             return NO
1301
1302         if not prev or prev.type != token.COMMA:
1303             return NO
1304
1305     elif p.type == syms.varargslist:
1306         # lambdas
1307         if t == token.RPAR:
1308             return NO
1309
1310         if prev and prev.type != token.COMMA:
1311             return NO
1312
1313     elif p.type == syms.typedargslist:
1314         # typed function signatures
1315         if not prev:
1316             return NO
1317
1318         if t == token.EQUAL:
1319             if prev.type != syms.tname:
1320                 return NO
1321
1322         elif prev.type == token.EQUAL:
1323             # A bit hacky: if the equal sign has whitespace, it means we
1324             # previously found it's a typed argument.  So, we're using that, too.
1325             return prev.prefix
1326
1327         elif prev.type != token.COMMA:
1328             return NO
1329
1330     elif p.type == syms.tname:
1331         # type names
1332         if not prev:
1333             prevp = preceding_leaf(p)
1334             if not prevp or prevp.type != token.COMMA:
1335                 return NO
1336
1337     elif p.type == syms.trailer:
1338         # attributes and calls
1339         if t == token.LPAR or t == token.RPAR:
1340             return NO
1341
1342         if not prev:
1343             if t == token.DOT:
1344                 prevp = preceding_leaf(p)
1345                 if not prevp or prevp.type != token.NUMBER:
1346                     return NO
1347
1348             elif t == token.LSQB:
1349                 return NO
1350
1351         elif prev.type != token.COMMA:
1352             return NO
1353
1354     elif p.type == syms.argument:
1355         # single argument
1356         if t == token.EQUAL:
1357             return NO
1358
1359         if not prev:
1360             prevp = preceding_leaf(p)
1361             if not prevp or prevp.type == token.LPAR:
1362                 return NO
1363
1364         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1365             return NO
1366
1367     elif p.type == syms.decorator:
1368         # decorators
1369         return NO
1370
1371     elif p.type == syms.dotted_name:
1372         if prev:
1373             return NO
1374
1375         prevp = preceding_leaf(p)
1376         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1377             return NO
1378
1379     elif p.type == syms.classdef:
1380         if t == token.LPAR:
1381             return NO
1382
1383         if prev and prev.type == token.LPAR:
1384             return NO
1385
1386     elif p.type == syms.subscript:
1387         # indexing
1388         if not prev:
1389             assert p.parent is not None, "subscripts are always parented"
1390             if p.parent.type == syms.subscriptlist:
1391                 return SPACE
1392
1393             return NO
1394
1395         else:
1396             return NO
1397
1398     elif p.type == syms.atom:
1399         if prev and t == token.DOT:
1400             # dots, but not the first one.
1401             return NO
1402
1403     elif (
1404         p.type == syms.listmaker
1405         or p.type == syms.testlist_gexp
1406         or p.type == syms.subscriptlist
1407     ):
1408         # list interior, including unpacking
1409         if not prev:
1410             return NO
1411
1412     elif p.type == syms.dictsetmaker:
1413         # dict and set interior, including unpacking
1414         if not prev:
1415             return NO
1416
1417         if prev.type == token.DOUBLESTAR:
1418             return NO
1419
1420     elif p.type in {syms.factor, syms.star_expr}:
1421         # unary ops
1422         if not prev:
1423             prevp = preceding_leaf(p)
1424             if not prevp or prevp.type in OPENING_BRACKETS:
1425                 return NO
1426
1427             prevp_parent = prevp.parent
1428             assert prevp_parent is not None
1429             if (
1430                 prevp.type == token.COLON
1431                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1432             ):
1433                 return NO
1434
1435             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1436                 return NO
1437
1438         elif t == token.NAME or t == token.NUMBER:
1439             return NO
1440
1441     elif p.type == syms.import_from:
1442         if t == token.DOT:
1443             if prev and prev.type == token.DOT:
1444                 return NO
1445
1446         elif t == token.NAME:
1447             if v == "import":
1448                 return SPACE
1449
1450             if prev and prev.type == token.DOT:
1451                 return NO
1452
1453     elif p.type == syms.sliceop:
1454         return NO
1455
1456     return SPACE
1457
1458
1459 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1460     """Return the first leaf that precedes `node`, if any."""
1461     while node:
1462         res = node.prev_sibling
1463         if res:
1464             if isinstance(res, Leaf):
1465                 return res
1466
1467             try:
1468                 return list(res.leaves())[-1]
1469
1470             except IndexError:
1471                 return None
1472
1473         node = node.parent
1474     return None
1475
1476
1477 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1478     """Return the priority of the `leaf` delimiter, given a line break after it.
1479
1480     The delimiter priorities returned here are from those delimiters that would
1481     cause a line break after themselves.
1482
1483     Higher numbers are higher priority.
1484     """
1485     if leaf.type == token.COMMA:
1486         return COMMA_PRIORITY
1487
1488     return 0
1489
1490
1491 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1492     """Return the priority of the `leaf` delimiter, given a line before after it.
1493
1494     The delimiter priorities returned here are from those delimiters that would
1495     cause a line break before themselves.
1496
1497     Higher numbers are higher priority.
1498     """
1499     if (
1500         leaf.type in VARARGS
1501         and leaf.parent
1502         and leaf.parent.type in {syms.argument, syms.typedargslist, syms.dictsetmaker}
1503     ):
1504         # * and ** might also be MATH_OPERATORS but in this case they are not.
1505         # Don't treat them as a delimiter.
1506         return 0
1507
1508     if (
1509         leaf.type in MATH_OPERATORS
1510         and leaf.parent
1511         and leaf.parent.type not in {syms.factor, syms.star_expr}
1512     ):
1513         return MATH_PRIORITY
1514
1515     if leaf.type in COMPARATORS:
1516         return COMPARATOR_PRIORITY
1517
1518     if (
1519         leaf.type == token.STRING
1520         and previous is not None
1521         and previous.type == token.STRING
1522     ):
1523         return STRING_PRIORITY
1524
1525     if (
1526         leaf.type == token.NAME
1527         and leaf.value == "for"
1528         and leaf.parent
1529         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1530     ):
1531         return COMPREHENSION_PRIORITY
1532
1533     if (
1534         leaf.type == token.NAME
1535         and leaf.value == "if"
1536         and leaf.parent
1537         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1538     ):
1539         return COMPREHENSION_PRIORITY
1540
1541     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1542         return LOGIC_PRIORITY
1543
1544     return 0
1545
1546
1547 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1548     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1549
1550     Higher numbers are higher priority.
1551     """
1552     return max(
1553         is_split_before_delimiter(leaf, previous),
1554         is_split_after_delimiter(leaf, previous),
1555     )
1556
1557
1558 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1559     """Clean the prefix of the `leaf` and generate comments from it, if any.
1560
1561     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1562     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1563     move because it does away with modifying the grammar to include all the
1564     possible places in which comments can be placed.
1565
1566     The sad consequence for us though is that comments don't "belong" anywhere.
1567     This is why this function generates simple parentless Leaf objects for
1568     comments.  We simply don't know what the correct parent should be.
1569
1570     No matter though, we can live without this.  We really only need to
1571     differentiate between inline and standalone comments.  The latter don't
1572     share the line with any code.
1573
1574     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1575     are emitted with a fake STANDALONE_COMMENT token identifier.
1576     """
1577     p = leaf.prefix
1578     if not p:
1579         return
1580
1581     if "#" not in p:
1582         return
1583
1584     consumed = 0
1585     nlines = 0
1586     for index, line in enumerate(p.split("\n")):
1587         consumed += len(line) + 1  # adding the length of the split '\n'
1588         line = line.lstrip()
1589         if not line:
1590             nlines += 1
1591         if not line.startswith("#"):
1592             continue
1593
1594         if index == 0 and leaf.type != token.ENDMARKER:
1595             comment_type = token.COMMENT  # simple trailing comment
1596         else:
1597             comment_type = STANDALONE_COMMENT
1598         comment = make_comment(line)
1599         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1600
1601         if comment in {"# fmt: on", "# yapf: enable"}:
1602             raise FormatOn(consumed)
1603
1604         if comment in {"# fmt: off", "# yapf: disable"}:
1605             if comment_type == STANDALONE_COMMENT:
1606                 raise FormatOff(consumed)
1607
1608             prev = preceding_leaf(leaf)
1609             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1610                 raise FormatOff(consumed)
1611
1612         nlines = 0
1613
1614
1615 def make_comment(content: str) -> str:
1616     """Return a consistently formatted comment from the given `content` string.
1617
1618     All comments (except for "##", "#!", "#:") should have a single space between
1619     the hash sign and the content.
1620
1621     If `content` didn't start with a hash sign, one is provided.
1622     """
1623     content = content.rstrip()
1624     if not content:
1625         return "#"
1626
1627     if content[0] == "#":
1628         content = content[1:]
1629     if content and content[0] not in " !:#":
1630         content = " " + content
1631     return "#" + content
1632
1633
1634 def split_line(
1635     line: Line, line_length: int, inner: bool = False, py36: bool = False
1636 ) -> Iterator[Line]:
1637     """Split a `line` into potentially many lines.
1638
1639     They should fit in the allotted `line_length` but might not be able to.
1640     `inner` signifies that there were a pair of brackets somewhere around the
1641     current `line`, possibly transitively. This means we can fallback to splitting
1642     by delimiters if the LHS/RHS don't yield any results.
1643
1644     If `py36` is True, splitting may generate syntax that is only compatible
1645     with Python 3.6 and later.
1646     """
1647     if isinstance(line, UnformattedLines) or line.is_comment:
1648         yield line
1649         return
1650
1651     line_str = str(line).strip("\n")
1652     if (
1653         len(line_str) <= line_length
1654         and "\n" not in line_str  # multiline strings
1655         and not line.contains_standalone_comments()
1656     ):
1657         yield line
1658         return
1659
1660     split_funcs: List[SplitFunc]
1661     if line.is_def:
1662         split_funcs = [left_hand_split]
1663     elif line.inside_brackets:
1664         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1665     else:
1666         split_funcs = [right_hand_split]
1667     for split_func in split_funcs:
1668         # We are accumulating lines in `result` because we might want to abort
1669         # mission and return the original line in the end, or attempt a different
1670         # split altogether.
1671         result: List[Line] = []
1672         try:
1673             for l in split_func(line, py36):
1674                 if str(l).strip("\n") == line_str:
1675                     raise CannotSplit("Split function returned an unchanged result")
1676
1677                 result.extend(
1678                     split_line(l, line_length=line_length, inner=True, py36=py36)
1679                 )
1680         except CannotSplit as cs:
1681             continue
1682
1683         else:
1684             yield from result
1685             break
1686
1687     else:
1688         yield line
1689
1690
1691 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1692     """Split line into many lines, starting with the first matching bracket pair.
1693
1694     Note: this usually looks weird, only use this for function definitions.
1695     Prefer RHS otherwise.
1696     """
1697     head = Line(depth=line.depth)
1698     body = Line(depth=line.depth + 1, inside_brackets=True)
1699     tail = Line(depth=line.depth)
1700     tail_leaves: List[Leaf] = []
1701     body_leaves: List[Leaf] = []
1702     head_leaves: List[Leaf] = []
1703     current_leaves = head_leaves
1704     matching_bracket = None
1705     for leaf in line.leaves:
1706         if (
1707             current_leaves is body_leaves
1708             and leaf.type in CLOSING_BRACKETS
1709             and leaf.opening_bracket is matching_bracket
1710         ):
1711             current_leaves = tail_leaves if body_leaves else head_leaves
1712         current_leaves.append(leaf)
1713         if current_leaves is head_leaves:
1714             if leaf.type in OPENING_BRACKETS:
1715                 matching_bracket = leaf
1716                 current_leaves = body_leaves
1717     # Since body is a new indent level, remove spurious leading whitespace.
1718     if body_leaves:
1719         normalize_prefix(body_leaves[0], inside_brackets=True)
1720     # Build the new lines.
1721     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1722         for leaf in leaves:
1723             result.append(leaf, preformatted=True)
1724             for comment_after in line.comments_after(leaf):
1725                 result.append(comment_after, preformatted=True)
1726     bracket_split_succeeded_or_raise(head, body, tail)
1727     for result in (head, body, tail):
1728         if result:
1729             yield result
1730
1731
1732 def right_hand_split(
1733     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1734 ) -> Iterator[Line]:
1735     """Split line into many lines, starting with the last matching bracket pair."""
1736     head = Line(depth=line.depth)
1737     body = Line(depth=line.depth + 1, inside_brackets=True)
1738     tail = Line(depth=line.depth)
1739     tail_leaves: List[Leaf] = []
1740     body_leaves: List[Leaf] = []
1741     head_leaves: List[Leaf] = []
1742     current_leaves = tail_leaves
1743     opening_bracket = None
1744     closing_bracket = None
1745     for leaf in reversed(line.leaves):
1746         if current_leaves is body_leaves:
1747             if leaf is opening_bracket:
1748                 current_leaves = head_leaves if body_leaves else tail_leaves
1749         current_leaves.append(leaf)
1750         if current_leaves is tail_leaves:
1751             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1752                 opening_bracket = leaf.opening_bracket
1753                 closing_bracket = leaf
1754                 current_leaves = body_leaves
1755     tail_leaves.reverse()
1756     body_leaves.reverse()
1757     head_leaves.reverse()
1758     # Since body is a new indent level, remove spurious leading whitespace.
1759     if body_leaves:
1760         normalize_prefix(body_leaves[0], inside_brackets=True)
1761     elif not head_leaves:
1762         # No `head` and no `body` means the split failed. `tail` has all content.
1763         raise CannotSplit("No brackets found")
1764
1765     # Build the new lines.
1766     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1767         for leaf in leaves:
1768             result.append(leaf, preformatted=True)
1769             for comment_after in line.comments_after(leaf):
1770                 result.append(comment_after, preformatted=True)
1771     bracket_split_succeeded_or_raise(head, body, tail)
1772     assert opening_bracket and closing_bracket
1773     if (
1774         opening_bracket.type == token.LPAR
1775         and not opening_bracket.value
1776         and closing_bracket.type == token.RPAR
1777         and not closing_bracket.value
1778     ):
1779         # These parens were optional. If there aren't any delimiters or standalone
1780         # comments in the body, they were unnecessary and another split without
1781         # them should be attempted.
1782         if not (
1783             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1784         ):
1785             omit = {id(closing_bracket), *omit}
1786             yield from right_hand_split(line, py36=py36, omit=omit)
1787             return
1788
1789     ensure_visible(opening_bracket)
1790     ensure_visible(closing_bracket)
1791     for result in (head, body, tail):
1792         if result:
1793             yield result
1794
1795
1796 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1797     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1798
1799     Do nothing otherwise.
1800
1801     A left- or right-hand split is based on a pair of brackets. Content before
1802     (and including) the opening bracket is left on one line, content inside the
1803     brackets is put on a separate line, and finally content starting with and
1804     following the closing bracket is put on a separate line.
1805
1806     Those are called `head`, `body`, and `tail`, respectively. If the split
1807     produced the same line (all content in `head`) or ended up with an empty `body`
1808     and the `tail` is just the closing bracket, then it's considered failed.
1809     """
1810     tail_len = len(str(tail).strip())
1811     if not body:
1812         if tail_len == 0:
1813             raise CannotSplit("Splitting brackets produced the same line")
1814
1815         elif tail_len < 3:
1816             raise CannotSplit(
1817                 f"Splitting brackets on an empty body to save "
1818                 f"{tail_len} characters is not worth it"
1819             )
1820
1821
1822 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1823     """Normalize prefix of the first leaf in every line returned by `split_func`.
1824
1825     This is a decorator over relevant split functions.
1826     """
1827
1828     @wraps(split_func)
1829     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1830         for l in split_func(line, py36):
1831             normalize_prefix(l.leaves[0], inside_brackets=True)
1832             yield l
1833
1834     return split_wrapper
1835
1836
1837 @dont_increase_indentation
1838 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1839     """Split according to delimiters of the highest priority.
1840
1841     If `py36` is True, the split will add trailing commas also in function
1842     signatures that contain `*` and `**`.
1843     """
1844     try:
1845         last_leaf = line.leaves[-1]
1846     except IndexError:
1847         raise CannotSplit("Line empty")
1848
1849     delimiters = line.bracket_tracker.delimiters
1850     try:
1851         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1852             exclude={id(last_leaf)}
1853         )
1854     except ValueError:
1855         raise CannotSplit("No delimiters found")
1856
1857     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1858     lowest_depth = sys.maxsize
1859     trailing_comma_safe = True
1860
1861     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1862         """Append `leaf` to current line or to new line if appending impossible."""
1863         nonlocal current_line
1864         try:
1865             current_line.append_safe(leaf, preformatted=True)
1866         except ValueError as ve:
1867             yield current_line
1868
1869             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1870             current_line.append(leaf)
1871
1872     for leaf in line.leaves:
1873         yield from append_to_line(leaf)
1874
1875         for comment_after in line.comments_after(leaf):
1876             yield from append_to_line(comment_after)
1877
1878         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1879         if (
1880             leaf.bracket_depth == lowest_depth
1881             and leaf.type == token.STAR
1882             or leaf.type == token.DOUBLESTAR
1883         ):
1884             trailing_comma_safe = trailing_comma_safe and py36
1885         leaf_priority = delimiters.get(id(leaf))
1886         if leaf_priority == delimiter_priority:
1887             yield current_line
1888
1889             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1890     if current_line:
1891         if (
1892             trailing_comma_safe
1893             and delimiter_priority == COMMA_PRIORITY
1894             and current_line.leaves[-1].type != token.COMMA
1895             and current_line.leaves[-1].type != STANDALONE_COMMENT
1896         ):
1897             current_line.append(Leaf(token.COMMA, ","))
1898         yield current_line
1899
1900
1901 @dont_increase_indentation
1902 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1903     """Split standalone comments from the rest of the line."""
1904     if not line.contains_standalone_comments(0):
1905         raise CannotSplit("Line does not have any standalone comments")
1906
1907     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1908
1909     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1910         """Append `leaf` to current line or to new line if appending impossible."""
1911         nonlocal current_line
1912         try:
1913             current_line.append_safe(leaf, preformatted=True)
1914         except ValueError as ve:
1915             yield current_line
1916
1917             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1918             current_line.append(leaf)
1919
1920     for leaf in line.leaves:
1921         yield from append_to_line(leaf)
1922
1923         for comment_after in line.comments_after(leaf):
1924             yield from append_to_line(comment_after)
1925
1926     if current_line:
1927         yield current_line
1928
1929
1930 def is_import(leaf: Leaf) -> bool:
1931     """Return True if the given leaf starts an import statement."""
1932     p = leaf.parent
1933     t = leaf.type
1934     v = leaf.value
1935     return bool(
1936         t == token.NAME
1937         and (
1938             (v == "import" and p and p.type == syms.import_name)
1939             or (v == "from" and p and p.type == syms.import_from)
1940         )
1941     )
1942
1943
1944 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1945     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1946     else.
1947
1948     Note: don't use backslashes for formatting or you'll lose your voting rights.
1949     """
1950     if not inside_brackets:
1951         spl = leaf.prefix.split("#")
1952         if "\\" not in spl[0]:
1953             nl_count = spl[-1].count("\n")
1954             if len(spl) > 1:
1955                 nl_count -= 1
1956             leaf.prefix = "\n" * nl_count
1957             return
1958
1959     leaf.prefix = ""
1960
1961
1962 def normalize_string_quotes(leaf: Leaf) -> None:
1963     """Prefer double quotes but only if it doesn't cause more escaping.
1964
1965     Adds or removes backslashes as appropriate. Doesn't parse and fix
1966     strings nested in f-strings (yet).
1967
1968     Note: Mutates its argument.
1969     """
1970     value = leaf.value.lstrip("furbFURB")
1971     if value[:3] == '"""':
1972         return
1973
1974     elif value[:3] == "'''":
1975         orig_quote = "'''"
1976         new_quote = '"""'
1977     elif value[0] == '"':
1978         orig_quote = '"'
1979         new_quote = "'"
1980     else:
1981         orig_quote = "'"
1982         new_quote = '"'
1983     first_quote_pos = leaf.value.find(orig_quote)
1984     if first_quote_pos == -1:
1985         return  # There's an internal error
1986
1987     prefix = leaf.value[:first_quote_pos]
1988     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
1989     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
1990     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
1991     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
1992     if "r" in prefix.casefold():
1993         if unescaped_new_quote.search(body):
1994             # There's at least one unescaped new_quote in this raw string
1995             # so converting is impossible
1996             return
1997
1998         # Do not introduce or remove backslashes in raw strings
1999         new_body = body
2000     else:
2001         # remove unnecessary quotes
2002         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2003         if body != new_body:
2004             # Consider the string without unnecessary quotes as the original
2005             body = new_body
2006             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2007         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2008         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2009     if new_quote == '"""' and new_body[-1] == '"':
2010         # edge case:
2011         new_body = new_body[:-1] + '\\"'
2012     orig_escape_count = body.count("\\")
2013     new_escape_count = new_body.count("\\")
2014     if new_escape_count > orig_escape_count:
2015         return  # Do not introduce more escaping
2016
2017     if new_escape_count == orig_escape_count and orig_quote == '"':
2018         return  # Prefer double quotes
2019
2020     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2021
2022
2023 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2024     """Make existing optional parentheses invisible or create new ones.
2025
2026     Standardizes on visible parentheses for single-element tuples, and keeps
2027     existing visible parentheses for other tuples and generator expressions.
2028     """
2029     check_lpar = False
2030     for child in list(node.children):
2031         if check_lpar:
2032             if child.type == syms.atom:
2033                 if not (
2034                     is_empty_tuple(child)
2035                     or is_one_tuple(child)
2036                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2037                 ):
2038                     first = child.children[0]
2039                     last = child.children[-1]
2040                     if first.type == token.LPAR and last.type == token.RPAR:
2041                         # make parentheses invisible
2042                         first.value = ""  # type: ignore
2043                         last.value = ""  # type: ignore
2044             elif is_one_tuple(child):
2045                 # wrap child in visible parentheses
2046                 lpar = Leaf(token.LPAR, "(")
2047                 rpar = Leaf(token.RPAR, ")")
2048                 index = child.remove() or 0
2049                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2050             else:
2051                 # wrap child in invisible parentheses
2052                 lpar = Leaf(token.LPAR, "")
2053                 rpar = Leaf(token.RPAR, "")
2054                 index = child.remove() or 0
2055                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2056
2057         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2058
2059
2060 def is_empty_tuple(node: LN) -> bool:
2061     """Return True if `node` holds an empty tuple."""
2062     return (
2063         node.type == syms.atom
2064         and len(node.children) == 2
2065         and node.children[0].type == token.LPAR
2066         and node.children[1].type == token.RPAR
2067     )
2068
2069
2070 def is_one_tuple(node: LN) -> bool:
2071     """Return True if `node` holds a tuple with one element, with or without parens."""
2072     if node.type == syms.atom:
2073         if len(node.children) != 3:
2074             return False
2075
2076         lpar, gexp, rpar = node.children
2077         if not (
2078             lpar.type == token.LPAR
2079             and gexp.type == syms.testlist_gexp
2080             and rpar.type == token.RPAR
2081         ):
2082             return False
2083
2084         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2085
2086     return (
2087         node.type in IMPLICIT_TUPLE
2088         and len(node.children) == 2
2089         and node.children[1].type == token.COMMA
2090     )
2091
2092
2093 def max_delimiter_priority_in_atom(node: LN) -> int:
2094     if node.type != syms.atom:
2095         return 0
2096
2097     first = node.children[0]
2098     last = node.children[-1]
2099     if not (first.type == token.LPAR and last.type == token.RPAR):
2100         return 0
2101
2102     bt = BracketTracker()
2103     for c in node.children[1:-1]:
2104         if isinstance(c, Leaf):
2105             bt.mark(c)
2106         else:
2107             for leaf in c.leaves():
2108                 bt.mark(leaf)
2109     try:
2110         return bt.max_delimiter_priority()
2111
2112     except ValueError:
2113         return 0
2114
2115
2116 def ensure_visible(leaf: Leaf) -> None:
2117     """Make sure parentheses are visible.
2118
2119     They could be invisible as part of some statements (see
2120     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2121     """
2122     if leaf.type == token.LPAR:
2123         leaf.value = "("
2124     elif leaf.type == token.RPAR:
2125         leaf.value = ")"
2126
2127
2128 def is_python36(node: Node) -> bool:
2129     """Return True if the current file is using Python 3.6+ features.
2130
2131     Currently looking for:
2132     - f-strings; and
2133     - trailing commas after * or ** in function signatures.
2134     """
2135     for n in node.pre_order():
2136         if n.type == token.STRING:
2137             value_head = n.value[:2]  # type: ignore
2138             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2139                 return True
2140
2141         elif (
2142             n.type == syms.typedargslist
2143             and n.children
2144             and n.children[-1].type == token.COMMA
2145         ):
2146             for ch in n.children:
2147                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
2148                     return True
2149
2150     return False
2151
2152
2153 PYTHON_EXTENSIONS = {".py"}
2154 BLACKLISTED_DIRECTORIES = {
2155     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2156 }
2157
2158
2159 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2160     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2161     and have one of the PYTHON_EXTENSIONS.
2162     """
2163     for child in path.iterdir():
2164         if child.is_dir():
2165             if child.name in BLACKLISTED_DIRECTORIES:
2166                 continue
2167
2168             yield from gen_python_files_in_dir(child)
2169
2170         elif child.suffix in PYTHON_EXTENSIONS:
2171             yield child
2172
2173
2174 @dataclass
2175 class Report:
2176     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2177     check: bool = False
2178     quiet: bool = False
2179     change_count: int = 0
2180     same_count: int = 0
2181     failure_count: int = 0
2182
2183     def done(self, src: Path, changed: bool) -> None:
2184         """Increment the counter for successful reformatting. Write out a message."""
2185         if changed:
2186             reformatted = "would reformat" if self.check else "reformatted"
2187             if not self.quiet:
2188                 out(f"{reformatted} {src}")
2189             self.change_count += 1
2190         else:
2191             if not self.quiet:
2192                 out(f"{src} already well formatted, good job.", bold=False)
2193             self.same_count += 1
2194
2195     def failed(self, src: Path, message: str) -> None:
2196         """Increment the counter for failed reformatting. Write out a message."""
2197         err(f"error: cannot format {src}: {message}")
2198         self.failure_count += 1
2199
2200     @property
2201     def return_code(self) -> int:
2202         """Return the exit code that the app should use.
2203
2204         This considers the current state of changed files and failures:
2205         - if there were any failures, return 123;
2206         - if any files were changed and --check is being used, return 1;
2207         - otherwise return 0.
2208         """
2209         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2210         # 126 we have special returncodes reserved by the shell.
2211         if self.failure_count:
2212             return 123
2213
2214         elif self.change_count and self.check:
2215             return 1
2216
2217         return 0
2218
2219     def __str__(self) -> str:
2220         """Render a color report of the current state.
2221
2222         Use `click.unstyle` to remove colors.
2223         """
2224         if self.check:
2225             reformatted = "would be reformatted"
2226             unchanged = "would be left unchanged"
2227             failed = "would fail to reformat"
2228         else:
2229             reformatted = "reformatted"
2230             unchanged = "left unchanged"
2231             failed = "failed to reformat"
2232         report = []
2233         if self.change_count:
2234             s = "s" if self.change_count > 1 else ""
2235             report.append(
2236                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2237             )
2238         if self.same_count:
2239             s = "s" if self.same_count > 1 else ""
2240             report.append(f"{self.same_count} file{s} {unchanged}")
2241         if self.failure_count:
2242             s = "s" if self.failure_count > 1 else ""
2243             report.append(
2244                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2245             )
2246         return ", ".join(report) + "."
2247
2248
2249 def assert_equivalent(src: str, dst: str) -> None:
2250     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2251
2252     import ast
2253     import traceback
2254
2255     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2256         """Simple visitor generating strings to compare ASTs by content."""
2257         yield f"{'  ' * depth}{node.__class__.__name__}("
2258
2259         for field in sorted(node._fields):
2260             try:
2261                 value = getattr(node, field)
2262             except AttributeError:
2263                 continue
2264
2265             yield f"{'  ' * (depth+1)}{field}="
2266
2267             if isinstance(value, list):
2268                 for item in value:
2269                     if isinstance(item, ast.AST):
2270                         yield from _v(item, depth + 2)
2271
2272             elif isinstance(value, ast.AST):
2273                 yield from _v(value, depth + 2)
2274
2275             else:
2276                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2277
2278         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2279
2280     try:
2281         src_ast = ast.parse(src)
2282     except Exception as exc:
2283         major, minor = sys.version_info[:2]
2284         raise AssertionError(
2285             f"cannot use --safe with this file; failed to parse source file "
2286             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2287             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2288         )
2289
2290     try:
2291         dst_ast = ast.parse(dst)
2292     except Exception as exc:
2293         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2294         raise AssertionError(
2295             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2296             f"Please report a bug on https://github.com/ambv/black/issues.  "
2297             f"This invalid output might be helpful: {log}"
2298         ) from None
2299
2300     src_ast_str = "\n".join(_v(src_ast))
2301     dst_ast_str = "\n".join(_v(dst_ast))
2302     if src_ast_str != dst_ast_str:
2303         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2304         raise AssertionError(
2305             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2306             f"the source.  "
2307             f"Please report a bug on https://github.com/ambv/black/issues.  "
2308             f"This diff might be helpful: {log}"
2309         ) from None
2310
2311
2312 def assert_stable(src: str, dst: str, line_length: int) -> None:
2313     """Raise AssertionError if `dst` reformats differently the second time."""
2314     newdst = format_str(dst, line_length=line_length)
2315     if dst != newdst:
2316         log = dump_to_file(
2317             diff(src, dst, "source", "first pass"),
2318             diff(dst, newdst, "first pass", "second pass"),
2319         )
2320         raise AssertionError(
2321             f"INTERNAL ERROR: Black produced different code on the second pass "
2322             f"of the formatter.  "
2323             f"Please report a bug on https://github.com/ambv/black/issues.  "
2324             f"This diff might be helpful: {log}"
2325         ) from None
2326
2327
2328 def dump_to_file(*output: str) -> str:
2329     """Dump `output` to a temporary file. Return path to the file."""
2330     import tempfile
2331
2332     with tempfile.NamedTemporaryFile(
2333         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2334     ) as f:
2335         for lines in output:
2336             f.write(lines)
2337             if lines and lines[-1] != "\n":
2338                 f.write("\n")
2339     return f.name
2340
2341
2342 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2343     """Return a unified diff string between strings `a` and `b`."""
2344     import difflib
2345
2346     a_lines = [line + "\n" for line in a.split("\n")]
2347     b_lines = [line + "\n" for line in b.split("\n")]
2348     return "".join(
2349         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2350     )
2351
2352
2353 def cancel(tasks: List[asyncio.Task]) -> None:
2354     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2355     err("Aborted!")
2356     for task in tasks:
2357         task.cancel()
2358
2359
2360 def shutdown(loop: BaseEventLoop) -> None:
2361     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2362     try:
2363         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2364         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2365         if not to_cancel:
2366             return
2367
2368         for task in to_cancel:
2369             task.cancel()
2370         loop.run_until_complete(
2371             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2372         )
2373     finally:
2374         # `concurrent.futures.Future` objects cannot be cancelled once they
2375         # are already running. There might be some when the `shutdown()` happened.
2376         # Silence their logger's spew about the event loop being closed.
2377         cf_logger = logging.getLogger("concurrent.futures")
2378         cf_logger.setLevel(logging.CRITICAL)
2379         loop.close()
2380
2381
2382 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2383     """Replace `regex` with `replacement` twice on `original`.
2384
2385     This is used by string normalization to perform replaces on
2386     overlapping matches.
2387     """
2388     return regex.sub(replacement, regex.sub(replacement, original))
2389
2390
2391 if __name__ == "__main__":
2392     main()