]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Automatic detection of deprecated Python 2 forms of print and exec
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a3"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by `format_file` when the reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
48     """
49
50
51 @click.command()
52 @click.option(
53     '-l',
54     '--line-length',
55     type=int,
56     default=DEFAULT_LINE_LENGTH,
57     help='How many character per line to allow.',
58     show_default=True,
59 )
60 @click.option(
61     '--check',
62     is_flag=True,
63     help=(
64         "Don't write back the files, just return the status.  Return code 0 "
65         "means nothing would change.  Return code 1 means some files would be "
66         "reformatted.  Return code 123 means there was an internal error."
67     ),
68 )
69 @click.option(
70     '--fast/--safe',
71     is_flag=True,
72     help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 )
74 @click.version_option(version=__version__)
75 @click.argument(
76     'src',
77     nargs=-1,
78     type=click.Path(
79         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
80     ),
81 )
82 @click.pass_context
83 def main(
84     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 ) -> None:
86     """The uncompromising code formatter."""
87     sources: List[Path] = []
88     for s in src:
89         p = Path(s)
90         if p.is_dir():
91             sources.extend(gen_python_files_in_dir(p))
92         elif p.is_file():
93             # if a file was explicitly given, we don't care about its extension
94             sources.append(p)
95         elif s == '-':
96             sources.append(Path('-'))
97         else:
98             err(f'invalid path: {s}')
99     if len(sources) == 0:
100         ctx.exit(0)
101     elif len(sources) == 1:
102         p = sources[0]
103         report = Report(check=check)
104         try:
105             if not p.is_file() and str(p) == '-':
106                 changed = format_stdin_to_stdout(
107                     line_length=line_length, fast=fast, write_back=not check
108                 )
109             else:
110                 changed = format_file_in_place(
111                     p, line_length=line_length, fast=fast, write_back=not check
112                 )
113             report.done(p, changed)
114         except Exception as exc:
115             report.failed(p, str(exc))
116         ctx.exit(report.return_code)
117     else:
118         loop = asyncio.get_event_loop()
119         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
120         return_code = 1
121         try:
122             return_code = loop.run_until_complete(
123                 schedule_formatting(
124                     sources, line_length, not check, fast, loop, executor
125                 )
126             )
127         finally:
128             loop.close()
129             ctx.exit(return_code)
130
131
132 async def schedule_formatting(
133     sources: List[Path],
134     line_length: int,
135     write_back: bool,
136     fast: bool,
137     loop: BaseEventLoop,
138     executor: Executor,
139 ) -> int:
140     tasks = {
141         src: loop.run_in_executor(
142             executor, format_file_in_place, src, line_length, fast, write_back
143         )
144         for src in sources
145     }
146     await asyncio.wait(tasks.values())
147     cancelled = []
148     report = Report()
149     for src, task in tasks.items():
150         if not task.done():
151             report.failed(src, 'timed out, cancelling')
152             task.cancel()
153             cancelled.append(task)
154         elif task.exception():
155             report.failed(src, str(task.exception()))
156         else:
157             report.done(src, task.result())
158     if cancelled:
159         await asyncio.wait(cancelled, timeout=2)
160     out('All done! ✨ 🍰 ✨')
161     click.echo(str(report))
162     return report.return_code
163
164
165 def format_file_in_place(
166     src: Path, line_length: int, fast: bool, write_back: bool = False
167 ) -> bool:
168     """Format the file and rewrite if changed. Return True if changed."""
169     with tokenize.open(src) as src_buffer:
170         src_contents = src_buffer.read()
171     try:
172         contents = format_file_contents(
173             src_contents, line_length=line_length, fast=fast
174         )
175     except NothingChanged:
176         return False
177
178     if write_back:
179         with open(src, "w", encoding=src_buffer.encoding) as f:
180             f.write(contents)
181     return True
182
183
184 def format_stdin_to_stdout(
185     line_length: int, fast: bool, write_back: bool = False
186 ) -> bool:
187     """Format file on stdin and pipe output to stdout. Return True if changed."""
188     contents = sys.stdin.read()
189     try:
190         contents = format_file_contents(contents, line_length=line_length, fast=fast)
191         return True
192
193     except NothingChanged:
194         return False
195
196     finally:
197         if write_back:
198             sys.stdout.write(contents)
199
200
201 def format_file_contents(
202     src_contents: str, line_length: int, fast: bool
203 ) -> FileContent:
204     """Reformats a file and returns its contents and encoding."""
205     if src_contents.strip() == '':
206         raise NothingChanged
207
208     dst_contents = format_str(src_contents, line_length=line_length)
209     if src_contents == dst_contents:
210         raise NothingChanged
211
212     if not fast:
213         assert_equivalent(src_contents, dst_contents)
214         assert_stable(src_contents, dst_contents, line_length=line_length)
215     return dst_contents
216
217
218 def format_str(src_contents: str, line_length: int) -> FileContent:
219     """Reformats a string and returns new contents."""
220     src_node = lib2to3_parse(src_contents)
221     dst_contents = ""
222     lines = LineGenerator()
223     elt = EmptyLineTracker()
224     py36 = is_python36(src_node)
225     empty_line = Line()
226     after = 0
227     for current_line in lines.visit(src_node):
228         for _ in range(after):
229             dst_contents += str(empty_line)
230         before, after = elt.maybe_empty_lines(current_line)
231         for _ in range(before):
232             dst_contents += str(empty_line)
233         for line in split_line(current_line, line_length=line_length, py36=py36):
234             dst_contents += str(line)
235     return dst_contents
236
237
238 GRAMMARS = [
239     pygram.python_grammar_no_print_statement_no_exec_statement,
240     pygram.python_grammar_no_print_statement,
241     pygram.python_grammar_no_exec_statement,
242     pygram.python_grammar,
243 ]
244
245
246 def lib2to3_parse(src_txt: str) -> Node:
247     """Given a string with source, return the lib2to3 Node."""
248     grammar = pygram.python_grammar_no_print_statement
249     if src_txt[-1] != '\n':
250         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
251         src_txt += nl
252     for grammar in GRAMMARS:
253         drv = driver.Driver(grammar, pytree.convert)
254         try:
255             result = drv.parse_string(src_txt, True)
256             break
257
258         except ParseError as pe:
259             lineno, column = pe.context[1]
260             lines = src_txt.splitlines()
261             try:
262                 faulty_line = lines[lineno - 1]
263             except IndexError:
264                 faulty_line = "<line number missing in source>"
265             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
266     else:
267         raise exc from None
268
269     if isinstance(result, Leaf):
270         result = Node(syms.file_input, [result])
271     return result
272
273
274 def lib2to3_unparse(node: Node) -> str:
275     """Given a lib2to3 node, return its string representation."""
276     code = str(node)
277     return code
278
279
280 T = TypeVar('T')
281
282
283 class Visitor(Generic[T]):
284     """Basic lib2to3 visitor that yields things on visiting."""
285
286     def visit(self, node: LN) -> Iterator[T]:
287         if node.type < 256:
288             name = token.tok_name[node.type]
289         else:
290             name = type_repr(node.type)
291         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
292
293     def visit_default(self, node: LN) -> Iterator[T]:
294         if isinstance(node, Node):
295             for child in node.children:
296                 yield from self.visit(child)
297
298
299 @dataclass
300 class DebugVisitor(Visitor[T]):
301     tree_depth: int = 0
302
303     def visit_default(self, node: LN) -> Iterator[T]:
304         indent = ' ' * (2 * self.tree_depth)
305         if isinstance(node, Node):
306             _type = type_repr(node.type)
307             out(f'{indent}{_type}', fg='yellow')
308             self.tree_depth += 1
309             for child in node.children:
310                 yield from self.visit(child)
311
312             self.tree_depth -= 1
313             out(f'{indent}/{_type}', fg='yellow', bold=False)
314         else:
315             _type = token.tok_name.get(node.type, str(node.type))
316             out(f'{indent}{_type}', fg='blue', nl=False)
317             if node.prefix:
318                 # We don't have to handle prefixes for `Node` objects since
319                 # that delegates to the first child anyway.
320                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
321             out(f' {node.value!r}', fg='blue', bold=False)
322
323
324 KEYWORDS = set(keyword.kwlist)
325 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
326 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
327 STATEMENT = {
328     syms.if_stmt,
329     syms.while_stmt,
330     syms.for_stmt,
331     syms.try_stmt,
332     syms.except_clause,
333     syms.with_stmt,
334     syms.funcdef,
335     syms.classdef,
336 }
337 STANDALONE_COMMENT = 153
338 LOGIC_OPERATORS = {'and', 'or'}
339 COMPARATORS = {
340     token.LESS,
341     token.GREATER,
342     token.EQEQUAL,
343     token.NOTEQUAL,
344     token.LESSEQUAL,
345     token.GREATEREQUAL,
346 }
347 MATH_OPERATORS = {
348     token.PLUS,
349     token.MINUS,
350     token.STAR,
351     token.SLASH,
352     token.VBAR,
353     token.AMPER,
354     token.PERCENT,
355     token.CIRCUMFLEX,
356     token.TILDE,
357     token.LEFTSHIFT,
358     token.RIGHTSHIFT,
359     token.DOUBLESTAR,
360     token.DOUBLESLASH,
361 }
362 COMPREHENSION_PRIORITY = 20
363 COMMA_PRIORITY = 10
364 LOGIC_PRIORITY = 5
365 STRING_PRIORITY = 4
366 COMPARATOR_PRIORITY = 3
367 MATH_PRIORITY = 1
368
369
370 @dataclass
371 class BracketTracker:
372     depth: int = 0
373     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
374     delimiters: Dict[LeafID, Priority] = Factory(dict)
375     previous: Optional[Leaf] = None
376
377     def mark(self, leaf: Leaf) -> None:
378         if leaf.type == token.COMMENT:
379             return
380
381         if leaf.type in CLOSING_BRACKETS:
382             self.depth -= 1
383             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
384             leaf.opening_bracket = opening_bracket
385         leaf.bracket_depth = self.depth
386         if self.depth == 0:
387             delim = is_delimiter(leaf)
388             if delim:
389                 self.delimiters[id(leaf)] = delim
390             elif self.previous is not None:
391                 if leaf.type == token.STRING and self.previous.type == token.STRING:
392                     self.delimiters[id(self.previous)] = STRING_PRIORITY
393                 elif (
394                     leaf.type == token.NAME
395                     and leaf.value == 'for'
396                     and leaf.parent
397                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
398                 ):
399                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
400                 elif (
401                     leaf.type == token.NAME
402                     and leaf.value == 'if'
403                     and leaf.parent
404                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
405                 ):
406                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
407                 elif (
408                     leaf.type == token.NAME
409                     and leaf.value in LOGIC_OPERATORS
410                     and leaf.parent
411                 ):
412                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
413         if leaf.type in OPENING_BRACKETS:
414             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
415             self.depth += 1
416         self.previous = leaf
417
418     def any_open_brackets(self) -> bool:
419         """Returns True if there is an yet unmatched open bracket on the line."""
420         return bool(self.bracket_match)
421
422     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
423         """Returns the highest priority of a delimiter found on the line.
424
425         Values are consistent with what `is_delimiter()` returns.
426         """
427         return max(v for k, v in self.delimiters.items() if k not in exclude)
428
429
430 @dataclass
431 class Line:
432     depth: int = 0
433     leaves: List[Leaf] = Factory(list)
434     comments: Dict[LeafID, Leaf] = Factory(dict)
435     bracket_tracker: BracketTracker = Factory(BracketTracker)
436     inside_brackets: bool = False
437     has_for: bool = False
438     _for_loop_variable: bool = False
439
440     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
441         has_value = leaf.value.strip()
442         if not has_value:
443             return
444
445         if self.leaves and not preformatted:
446             # Note: at this point leaf.prefix should be empty except for
447             # imports, for which we only preserve newlines.
448             leaf.prefix += whitespace(leaf)
449         if self.inside_brackets or not preformatted:
450             self.maybe_decrement_after_for_loop_variable(leaf)
451             self.bracket_tracker.mark(leaf)
452             self.maybe_remove_trailing_comma(leaf)
453             self.maybe_increment_for_loop_variable(leaf)
454             if self.maybe_adapt_standalone_comment(leaf):
455                 return
456
457         if not self.append_comment(leaf):
458             self.leaves.append(leaf)
459
460     @property
461     def is_comment(self) -> bool:
462         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
463
464     @property
465     def is_decorator(self) -> bool:
466         return bool(self) and self.leaves[0].type == token.AT
467
468     @property
469     def is_import(self) -> bool:
470         return bool(self) and is_import(self.leaves[0])
471
472     @property
473     def is_class(self) -> bool:
474         return (
475             bool(self)
476             and self.leaves[0].type == token.NAME
477             and self.leaves[0].value == 'class'
478         )
479
480     @property
481     def is_def(self) -> bool:
482         """Also returns True for async defs."""
483         try:
484             first_leaf = self.leaves[0]
485         except IndexError:
486             return False
487
488         try:
489             second_leaf: Optional[Leaf] = self.leaves[1]
490         except IndexError:
491             second_leaf = None
492         return (
493             (first_leaf.type == token.NAME and first_leaf.value == 'def')
494             or (
495                 first_leaf.type == token.ASYNC
496                 and second_leaf is not None
497                 and second_leaf.type == token.NAME
498                 and second_leaf.value == 'def'
499             )
500         )
501
502     @property
503     def is_flow_control(self) -> bool:
504         return (
505             bool(self)
506             and self.leaves[0].type == token.NAME
507             and self.leaves[0].value in FLOW_CONTROL
508         )
509
510     @property
511     def is_yield(self) -> bool:
512         return (
513             bool(self)
514             and self.leaves[0].type == token.NAME
515             and self.leaves[0].value == 'yield'
516         )
517
518     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
519         if not (
520             self.leaves
521             and self.leaves[-1].type == token.COMMA
522             and closing.type in CLOSING_BRACKETS
523         ):
524             return False
525
526         if closing.type == token.RBRACE:
527             self.leaves.pop()
528             return True
529
530         if closing.type == token.RSQB:
531             comma = self.leaves[-1]
532             if comma.parent and comma.parent.type == syms.listmaker:
533                 self.leaves.pop()
534                 return True
535
536         # For parens let's check if it's safe to remove the comma.  If the
537         # trailing one is the only one, we might mistakenly change a tuple
538         # into a different type by removing the comma.
539         depth = closing.bracket_depth + 1
540         commas = 0
541         opening = closing.opening_bracket
542         for _opening_index, leaf in enumerate(self.leaves):
543             if leaf is opening:
544                 break
545
546         else:
547             return False
548
549         for leaf in self.leaves[_opening_index + 1:]:
550             if leaf is closing:
551                 break
552
553             bracket_depth = leaf.bracket_depth
554             if bracket_depth == depth and leaf.type == token.COMMA:
555                 commas += 1
556                 if leaf.parent and leaf.parent.type == syms.arglist:
557                     commas += 1
558                     break
559
560         if commas > 1:
561             self.leaves.pop()
562             return True
563
564         return False
565
566     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
567         """In a for loop, or comprehension, the variables are often unpacks.
568
569         To avoid splitting on the comma in this situation, we will increase
570         the depth of tokens between `for` and `in`.
571         """
572         if leaf.type == token.NAME and leaf.value == 'for':
573             self.has_for = True
574             self.bracket_tracker.depth += 1
575             self._for_loop_variable = True
576             return True
577
578         return False
579
580     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
581         # See `maybe_increment_for_loop_variable` above for explanation.
582         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
583             self.bracket_tracker.depth -= 1
584             self._for_loop_variable = False
585             return True
586
587         return False
588
589     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
590         """Hack a standalone comment to act as a trailing comment for line splitting.
591
592         If this line has brackets and a standalone `comment`, we need to adapt
593         it to be able to still reformat the line.
594
595         This is not perfect, the line to which the standalone comment gets
596         appended will appear "too long" when splitting.
597         """
598         if not (
599             comment.type == STANDALONE_COMMENT
600             and self.bracket_tracker.any_open_brackets()
601         ):
602             return False
603
604         comment.type = token.COMMENT
605         comment.prefix = '\n' + '    ' * (self.depth + 1)
606         return self.append_comment(comment)
607
608     def append_comment(self, comment: Leaf) -> bool:
609         if comment.type != token.COMMENT:
610             return False
611
612         try:
613             after = id(self.last_non_delimiter())
614         except LookupError:
615             comment.type = STANDALONE_COMMENT
616             comment.prefix = ''
617             return False
618
619         else:
620             if after in self.comments:
621                 self.comments[after].value += str(comment)
622             else:
623                 self.comments[after] = comment
624             return True
625
626     def last_non_delimiter(self) -> Leaf:
627         for i in range(len(self.leaves)):
628             last = self.leaves[-i - 1]
629             if not is_delimiter(last):
630                 return last
631
632         raise LookupError("No non-delimiters found")
633
634     def __str__(self) -> str:
635         if not self:
636             return '\n'
637
638         indent = '    ' * self.depth
639         leaves = iter(self.leaves)
640         first = next(leaves)
641         res = f'{first.prefix}{indent}{first.value}'
642         for leaf in leaves:
643             res += str(leaf)
644         for comment in self.comments.values():
645             res += str(comment)
646         return res + '\n'
647
648     def __bool__(self) -> bool:
649         return bool(self.leaves or self.comments)
650
651
652 @dataclass
653 class EmptyLineTracker:
654     """Provides a stateful method that returns the number of potential extra
655     empty lines needed before and after the currently processed line.
656
657     Note: this tracker works on lines that haven't been split yet.  It assumes
658     the prefix of the first leaf consists of optional newlines.  Those newlines
659     are consumed by `maybe_empty_lines()` and included in the computation.
660     """
661     previous_line: Optional[Line] = None
662     previous_after: int = 0
663     previous_defs: List[int] = Factory(list)
664
665     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
666         """Returns the number of extra empty lines before and after the `current_line`.
667
668         This is for separating `def`, `async def` and `class` with extra empty lines
669         (two on module-level), as well as providing an extra empty line after flow
670         control keywords to make them more prominent.
671         """
672         before, after = self._maybe_empty_lines(current_line)
673         before -= self.previous_after
674         self.previous_after = after
675         self.previous_line = current_line
676         return before, after
677
678     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
679         max_allowed = 1
680         if current_line.is_comment and current_line.depth == 0:
681             max_allowed = 2
682         if current_line.leaves:
683             # Consume the first leaf's extra newlines.
684             first_leaf = current_line.leaves[0]
685             before = first_leaf.prefix.count('\n')
686             before = min(before, max(before, max_allowed))
687             first_leaf.prefix = ''
688         else:
689             before = 0
690         depth = current_line.depth
691         while self.previous_defs and self.previous_defs[-1] >= depth:
692             self.previous_defs.pop()
693             before = 1 if depth else 2
694         is_decorator = current_line.is_decorator
695         if is_decorator or current_line.is_def or current_line.is_class:
696             if not is_decorator:
697                 self.previous_defs.append(depth)
698             if self.previous_line is None:
699                 # Don't insert empty lines before the first line in the file.
700                 return 0, 0
701
702             if self.previous_line and self.previous_line.is_decorator:
703                 # Don't insert empty lines between decorators.
704                 return 0, 0
705
706             newlines = 2
707             if current_line.depth:
708                 newlines -= 1
709             return newlines, 0
710
711         if current_line.is_flow_control:
712             return before, 1
713
714         if (
715             self.previous_line
716             and self.previous_line.is_import
717             and not current_line.is_import
718             and depth == self.previous_line.depth
719         ):
720             return (before or 1), 0
721
722         if (
723             self.previous_line
724             and self.previous_line.is_yield
725             and (not current_line.is_yield or depth != self.previous_line.depth)
726         ):
727             return (before or 1), 0
728
729         return before, 0
730
731
732 @dataclass
733 class LineGenerator(Visitor[Line]):
734     """Generates reformatted Line objects.  Empty lines are not emitted.
735
736     Note: destroys the tree it's visiting by mutating prefixes of its leaves
737     in ways that will no longer stringify to valid Python code on the tree.
738     """
739     current_line: Line = Factory(Line)
740
741     def line(self, indent: int = 0) -> Iterator[Line]:
742         """Generate a line.
743
744         If the line is empty, only emit if it makes sense.
745         If the line is too long, split it first and then generate.
746
747         If any lines were generated, set up a new current_line.
748         """
749         if not self.current_line:
750             self.current_line.depth += indent
751             return  # Line is empty, don't emit. Creating a new one unnecessary.
752
753         complete_line = self.current_line
754         self.current_line = Line(depth=complete_line.depth + indent)
755         yield complete_line
756
757     def visit_default(self, node: LN) -> Iterator[Line]:
758         if isinstance(node, Leaf):
759             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
760             for comment in generate_comments(node):
761                 if any_open_brackets:
762                     # any comment within brackets is subject to splitting
763                     self.current_line.append(comment)
764                 elif comment.type == token.COMMENT:
765                     # regular trailing comment
766                     self.current_line.append(comment)
767                     yield from self.line()
768
769                 else:
770                     # regular standalone comment
771                     yield from self.line()
772
773                     self.current_line.append(comment)
774                     yield from self.line()
775
776             normalize_prefix(node, inside_brackets=any_open_brackets)
777             if node.type not in WHITESPACE:
778                 self.current_line.append(node)
779         yield from super().visit_default(node)
780
781     def visit_INDENT(self, node: Node) -> Iterator[Line]:
782         yield from self.line(+1)
783         yield from self.visit_default(node)
784
785     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
786         yield from self.line(-1)
787
788     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
789         """Visit a statement.
790
791         The relevant Python language keywords for this statement are NAME leaves
792         within it.
793         """
794         for child in node.children:
795             if child.type == token.NAME and child.value in keywords:  # type: ignore
796                 yield from self.line()
797
798             yield from self.visit(child)
799
800     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
801         """A statement without nested statements."""
802         is_suite_like = node.parent and node.parent.type in STATEMENT
803         if is_suite_like:
804             yield from self.line(+1)
805             yield from self.visit_default(node)
806             yield from self.line(-1)
807
808         else:
809             yield from self.line()
810             yield from self.visit_default(node)
811
812     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
813         yield from self.line()
814
815         children = iter(node.children)
816         for child in children:
817             yield from self.visit(child)
818
819             if child.type == token.ASYNC:
820                 break
821
822         internal_stmt = next(children)
823         for child in internal_stmt.children:
824             yield from self.visit(child)
825
826     def visit_decorators(self, node: Node) -> Iterator[Line]:
827         for child in node.children:
828             yield from self.line()
829             yield from self.visit(child)
830
831     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
832         yield from self.line()
833
834     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
835         yield from self.visit_default(leaf)
836         yield from self.line()
837
838     def __attrs_post_init__(self) -> None:
839         """You are in a twisty little maze of passages."""
840         v = self.visit_stmt
841         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
842         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
843         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
844         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
845         self.visit_except_clause = partial(v, keywords={'except'})
846         self.visit_funcdef = partial(v, keywords={'def'})
847         self.visit_with_stmt = partial(v, keywords={'with'})
848         self.visit_classdef = partial(v, keywords={'class'})
849         self.visit_async_funcdef = self.visit_async_stmt
850         self.visit_decorated = self.visit_decorators
851
852
853 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
854 OPENING_BRACKETS = set(BRACKET.keys())
855 CLOSING_BRACKETS = set(BRACKET.values())
856 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
857 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
858
859
860 def whitespace(leaf: Leaf) -> str:  # noqa C901
861     """Return whitespace prefix if needed for the given `leaf`."""
862     NO = ''
863     SPACE = ' '
864     DOUBLESPACE = '  '
865     t = leaf.type
866     p = leaf.parent
867     v = leaf.value
868     if t in ALWAYS_NO_SPACE:
869         return NO
870
871     if t == token.COMMENT:
872         return DOUBLESPACE
873
874     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
875     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
876         return NO
877
878     prev = leaf.prev_sibling
879     if not prev:
880         prevp = preceding_leaf(p)
881         if not prevp or prevp.type in OPENING_BRACKETS:
882             return NO
883
884         if t == token.COLON:
885             return SPACE if prevp.type == token.COMMA else NO
886
887         if prevp.type == token.EQUAL:
888             if prevp.parent and prevp.parent.type in {
889                 syms.arglist,
890                 syms.argument,
891                 syms.parameters,
892                 syms.typedargslist,
893                 syms.varargslist,
894             }:
895                 return NO
896
897         elif prevp.type == token.DOUBLESTAR:
898             if prevp.parent and prevp.parent.type in {
899                 syms.arglist,
900                 syms.argument,
901                 syms.dictsetmaker,
902                 syms.parameters,
903                 syms.typedargslist,
904                 syms.varargslist,
905             }:
906                 return NO
907
908         elif prevp.type == token.COLON:
909             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
910                 return NO
911
912         elif (
913             prevp.parent
914             and prevp.parent.type in {syms.factor, syms.star_expr}
915             and prevp.type in MATH_OPERATORS
916         ):
917             return NO
918
919         elif (
920             prevp.type == token.RIGHTSHIFT
921             and prevp.parent
922             and prevp.parent.type == syms.shift_expr
923             and prevp.prev_sibling
924             and prevp.prev_sibling.type == token.NAME
925             and prevp.prev_sibling.value == 'print'
926         ):
927             # Python 2 print chevron
928             return NO
929
930     elif prev.type in OPENING_BRACKETS:
931         return NO
932
933     if p.type in {syms.parameters, syms.arglist}:
934         # untyped function signatures or calls
935         if t == token.RPAR:
936             return NO
937
938         if not prev or prev.type != token.COMMA:
939             return NO
940
941     if p.type == syms.varargslist:
942         # lambdas
943         if t == token.RPAR:
944             return NO
945
946         if prev and prev.type != token.COMMA:
947             return NO
948
949     elif p.type == syms.typedargslist:
950         # typed function signatures
951         if not prev:
952             return NO
953
954         if t == token.EQUAL:
955             if prev.type != syms.tname:
956                 return NO
957
958         elif prev.type == token.EQUAL:
959             # A bit hacky: if the equal sign has whitespace, it means we
960             # previously found it's a typed argument.  So, we're using that, too.
961             return prev.prefix
962
963         elif prev.type != token.COMMA:
964             return NO
965
966     elif p.type == syms.tname:
967         # type names
968         if not prev:
969             prevp = preceding_leaf(p)
970             if not prevp or prevp.type != token.COMMA:
971                 return NO
972
973     elif p.type == syms.trailer:
974         # attributes and calls
975         if t == token.LPAR or t == token.RPAR:
976             return NO
977
978         if not prev:
979             if t == token.DOT:
980                 prevp = preceding_leaf(p)
981                 if not prevp or prevp.type != token.NUMBER:
982                     return NO
983
984             elif t == token.LSQB:
985                 return NO
986
987         elif prev.type != token.COMMA:
988             return NO
989
990     elif p.type == syms.argument:
991         # single argument
992         if t == token.EQUAL:
993             return NO
994
995         if not prev:
996             prevp = preceding_leaf(p)
997             if not prevp or prevp.type == token.LPAR:
998                 return NO
999
1000         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1001             return NO
1002
1003     elif p.type == syms.decorator:
1004         # decorators
1005         return NO
1006
1007     elif p.type == syms.dotted_name:
1008         if prev:
1009             return NO
1010
1011         prevp = preceding_leaf(p)
1012         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1013             return NO
1014
1015     elif p.type == syms.classdef:
1016         if t == token.LPAR:
1017             return NO
1018
1019         if prev and prev.type == token.LPAR:
1020             return NO
1021
1022     elif p.type == syms.subscript:
1023         # indexing
1024         if not prev:
1025             assert p.parent is not None, "subscripts are always parented"
1026             if p.parent.type == syms.subscriptlist:
1027                 return SPACE
1028
1029             return NO
1030
1031         else:
1032             return NO
1033
1034     elif p.type == syms.atom:
1035         if prev and t == token.DOT:
1036             # dots, but not the first one.
1037             return NO
1038
1039     elif (
1040         p.type == syms.listmaker
1041         or p.type == syms.testlist_gexp
1042         or p.type == syms.subscriptlist
1043     ):
1044         # list interior, including unpacking
1045         if not prev:
1046             return NO
1047
1048     elif p.type == syms.dictsetmaker:
1049         # dict and set interior, including unpacking
1050         if not prev:
1051             return NO
1052
1053         if prev.type == token.DOUBLESTAR:
1054             return NO
1055
1056     elif p.type in {syms.factor, syms.star_expr}:
1057         # unary ops
1058         if not prev:
1059             prevp = preceding_leaf(p)
1060             if not prevp or prevp.type in OPENING_BRACKETS:
1061                 return NO
1062
1063             prevp_parent = prevp.parent
1064             assert prevp_parent is not None
1065             if prevp.type == token.COLON and prevp_parent.type in {
1066                 syms.subscript, syms.sliceop
1067             }:
1068                 return NO
1069
1070             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1071                 return NO
1072
1073         elif t == token.NAME or t == token.NUMBER:
1074             return NO
1075
1076     elif p.type == syms.import_from:
1077         if t == token.DOT:
1078             if prev and prev.type == token.DOT:
1079                 return NO
1080
1081         elif t == token.NAME:
1082             if v == 'import':
1083                 return SPACE
1084
1085             if prev and prev.type == token.DOT:
1086                 return NO
1087
1088     elif p.type == syms.sliceop:
1089         return NO
1090
1091     return SPACE
1092
1093
1094 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1095     """Returns the first leaf that precedes `node`, if any."""
1096     while node:
1097         res = node.prev_sibling
1098         if res:
1099             if isinstance(res, Leaf):
1100                 return res
1101
1102             try:
1103                 return list(res.leaves())[-1]
1104
1105             except IndexError:
1106                 return None
1107
1108         node = node.parent
1109     return None
1110
1111
1112 def is_delimiter(leaf: Leaf) -> int:
1113     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1114
1115     Higher numbers are higher priority.
1116     """
1117     if leaf.type == token.COMMA:
1118         return COMMA_PRIORITY
1119
1120     if leaf.type in COMPARATORS:
1121         return COMPARATOR_PRIORITY
1122
1123     if (
1124         leaf.type in MATH_OPERATORS
1125         and leaf.parent
1126         and leaf.parent.type not in {syms.factor, syms.star_expr}
1127     ):
1128         return MATH_PRIORITY
1129
1130     return 0
1131
1132
1133 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1134     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1135
1136     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1137     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1138     move because it does away with modifying the grammar to include all the
1139     possible places in which comments can be placed.
1140
1141     The sad consequence for us though is that comments don't "belong" anywhere.
1142     This is why this function generates simple parentless Leaf objects for
1143     comments.  We simply don't know what the correct parent should be.
1144
1145     No matter though, we can live without this.  We really only need to
1146     differentiate between inline and standalone comments.  The latter don't
1147     share the line with any code.
1148
1149     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1150     are emitted with a fake STANDALONE_COMMENT token identifier.
1151     """
1152     p = leaf.prefix
1153     if not p:
1154         return
1155
1156     if '#' not in p:
1157         return
1158
1159     nlines = 0
1160     for index, line in enumerate(p.split('\n')):
1161         line = line.lstrip()
1162         if not line:
1163             nlines += 1
1164         if not line.startswith('#'):
1165             continue
1166
1167         if index == 0 and leaf.type != token.ENDMARKER:
1168             comment_type = token.COMMENT  # simple trailing comment
1169         else:
1170             comment_type = STANDALONE_COMMENT
1171         yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
1172
1173         nlines = 0
1174
1175
1176 def make_comment(content: str) -> str:
1177     content = content.rstrip()
1178     if not content:
1179         return '#'
1180
1181     if content[0] == '#':
1182         content = content[1:]
1183     if content and content[0] not in {' ', '!', '#'}:
1184         content = ' ' + content
1185     return '#' + content
1186
1187
1188 def split_line(
1189     line: Line, line_length: int, inner: bool = False, py36: bool = False
1190 ) -> Iterator[Line]:
1191     """Splits a `line` into potentially many lines.
1192
1193     They should fit in the allotted `line_length` but might not be able to.
1194     `inner` signifies that there were a pair of brackets somewhere around the
1195     current `line`, possibly transitively. This means we can fallback to splitting
1196     by delimiters if the LHS/RHS don't yield any results.
1197
1198     If `py36` is True, splitting may generate syntax that is only compatible
1199     with Python 3.6 and later.
1200     """
1201     line_str = str(line).strip('\n')
1202     if len(line_str) <= line_length and '\n' not in line_str:
1203         yield line
1204         return
1205
1206     if line.is_def:
1207         split_funcs = [left_hand_split]
1208     elif line.inside_brackets:
1209         split_funcs = [delimiter_split]
1210         if '\n' not in line_str:
1211             # Only attempt RHS if we don't have multiline strings or comments
1212             # on this line.
1213             split_funcs.append(right_hand_split)
1214     else:
1215         split_funcs = [right_hand_split]
1216     for split_func in split_funcs:
1217         # We are accumulating lines in `result` because we might want to abort
1218         # mission and return the original line in the end, or attempt a different
1219         # split altogether.
1220         result: List[Line] = []
1221         try:
1222             for l in split_func(line, py36=py36):
1223                 if str(l).strip('\n') == line_str:
1224                     raise CannotSplit("Split function returned an unchanged result")
1225
1226                 result.extend(
1227                     split_line(l, line_length=line_length, inner=True, py36=py36)
1228                 )
1229         except CannotSplit as cs:
1230             continue
1231
1232         else:
1233             yield from result
1234             break
1235
1236     else:
1237         yield line
1238
1239
1240 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1241     """Split line into many lines, starting with the first matching bracket pair.
1242
1243     Note: this usually looks weird, only use this for function definitions.
1244     Prefer RHS otherwise.
1245     """
1246     head = Line(depth=line.depth)
1247     body = Line(depth=line.depth + 1, inside_brackets=True)
1248     tail = Line(depth=line.depth)
1249     tail_leaves: List[Leaf] = []
1250     body_leaves: List[Leaf] = []
1251     head_leaves: List[Leaf] = []
1252     current_leaves = head_leaves
1253     matching_bracket = None
1254     for leaf in line.leaves:
1255         if (
1256             current_leaves is body_leaves
1257             and leaf.type in CLOSING_BRACKETS
1258             and leaf.opening_bracket is matching_bracket
1259         ):
1260             current_leaves = tail_leaves if body_leaves else head_leaves
1261         current_leaves.append(leaf)
1262         if current_leaves is head_leaves:
1263             if leaf.type in OPENING_BRACKETS:
1264                 matching_bracket = leaf
1265                 current_leaves = body_leaves
1266     # Since body is a new indent level, remove spurious leading whitespace.
1267     if body_leaves:
1268         normalize_prefix(body_leaves[0], inside_brackets=True)
1269     # Build the new lines.
1270     for result, leaves in (
1271         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1272     ):
1273         for leaf in leaves:
1274             result.append(leaf, preformatted=True)
1275             comment_after = line.comments.get(id(leaf))
1276             if comment_after:
1277                 result.append(comment_after, preformatted=True)
1278     split_succeeded_or_raise(head, body, tail)
1279     for result in (head, body, tail):
1280         if result:
1281             yield result
1282
1283
1284 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1285     """Split line into many lines, starting with the last matching bracket pair."""
1286     head = Line(depth=line.depth)
1287     body = Line(depth=line.depth + 1, inside_brackets=True)
1288     tail = Line(depth=line.depth)
1289     tail_leaves: List[Leaf] = []
1290     body_leaves: List[Leaf] = []
1291     head_leaves: List[Leaf] = []
1292     current_leaves = tail_leaves
1293     opening_bracket = None
1294     for leaf in reversed(line.leaves):
1295         if current_leaves is body_leaves:
1296             if leaf is opening_bracket:
1297                 current_leaves = head_leaves if body_leaves else tail_leaves
1298         current_leaves.append(leaf)
1299         if current_leaves is tail_leaves:
1300             if leaf.type in CLOSING_BRACKETS:
1301                 opening_bracket = leaf.opening_bracket
1302                 current_leaves = body_leaves
1303     tail_leaves.reverse()
1304     body_leaves.reverse()
1305     head_leaves.reverse()
1306     # Since body is a new indent level, remove spurious leading whitespace.
1307     if body_leaves:
1308         normalize_prefix(body_leaves[0], inside_brackets=True)
1309     # Build the new lines.
1310     for result, leaves in (
1311         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1312     ):
1313         for leaf in leaves:
1314             result.append(leaf, preformatted=True)
1315             comment_after = line.comments.get(id(leaf))
1316             if comment_after:
1317                 result.append(comment_after, preformatted=True)
1318     split_succeeded_or_raise(head, body, tail)
1319     for result in (head, body, tail):
1320         if result:
1321             yield result
1322
1323
1324 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1325     tail_len = len(str(tail).strip())
1326     if not body:
1327         if tail_len == 0:
1328             raise CannotSplit("Splitting brackets produced the same line")
1329
1330         elif tail_len < 3:
1331             raise CannotSplit(
1332                 f"Splitting brackets on an empty body to save "
1333                 f"{tail_len} characters is not worth it"
1334             )
1335
1336
1337 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1338     """Split according to delimiters of the highest priority.
1339
1340     This kind of split doesn't increase indentation.
1341     If `py36` is True, the split will add trailing commas also in function
1342     signatures that contain * and **.
1343     """
1344     try:
1345         last_leaf = line.leaves[-1]
1346     except IndexError:
1347         raise CannotSplit("Line empty")
1348
1349     delimiters = line.bracket_tracker.delimiters
1350     try:
1351         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1352     except ValueError:
1353         raise CannotSplit("No delimiters found")
1354
1355     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1356     lowest_depth = sys.maxsize
1357     trailing_comma_safe = True
1358     for leaf in line.leaves:
1359         current_line.append(leaf, preformatted=True)
1360         comment_after = line.comments.get(id(leaf))
1361         if comment_after:
1362             current_line.append(comment_after, preformatted=True)
1363         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1364         if (
1365             leaf.bracket_depth == lowest_depth
1366             and leaf.type == token.STAR
1367             or leaf.type == token.DOUBLESTAR
1368         ):
1369             trailing_comma_safe = trailing_comma_safe and py36
1370         leaf_priority = delimiters.get(id(leaf))
1371         if leaf_priority == delimiter_priority:
1372             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1373             yield current_line
1374
1375             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1376     if current_line:
1377         if (
1378             delimiter_priority == COMMA_PRIORITY
1379             and current_line.leaves[-1].type != token.COMMA
1380             and trailing_comma_safe
1381         ):
1382             current_line.append(Leaf(token.COMMA, ','))
1383         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1384         yield current_line
1385
1386
1387 def is_import(leaf: Leaf) -> bool:
1388     """Returns True if the given leaf starts an import statement."""
1389     p = leaf.parent
1390     t = leaf.type
1391     v = leaf.value
1392     return bool(
1393         t == token.NAME
1394         and (
1395             (v == 'import' and p and p.type == syms.import_name)
1396             or (v == 'from' and p and p.type == syms.import_from)
1397         )
1398     )
1399
1400
1401 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1402     """Leave existing extra newlines if not `inside_brackets`.
1403
1404     Remove everything else.  Note: don't use backslashes for formatting or
1405     you'll lose your voting rights.
1406     """
1407     if not inside_brackets:
1408         spl = leaf.prefix.split('#')
1409         if '\\' not in spl[0]:
1410             nl_count = spl[-1].count('\n')
1411             if len(spl) > 1:
1412                 nl_count -= 1
1413             leaf.prefix = '\n' * nl_count
1414             return
1415
1416     leaf.prefix = ''
1417
1418
1419 def is_python36(node: Node) -> bool:
1420     """Returns True if the current file is using Python 3.6+ features.
1421
1422     Currently looking for:
1423     - f-strings; and
1424     - trailing commas after * or ** in function signatures.
1425     """
1426     for n in node.pre_order():
1427         if n.type == token.STRING:
1428             value_head = n.value[:2]  # type: ignore
1429             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1430                 return True
1431
1432         elif (
1433             n.type == syms.typedargslist
1434             and n.children
1435             and n.children[-1].type == token.COMMA
1436         ):
1437             for ch in n.children:
1438                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1439                     return True
1440
1441     return False
1442
1443
1444 PYTHON_EXTENSIONS = {'.py'}
1445 BLACKLISTED_DIRECTORIES = {
1446     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1447 }
1448
1449
1450 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1451     for child in path.iterdir():
1452         if child.is_dir():
1453             if child.name in BLACKLISTED_DIRECTORIES:
1454                 continue
1455
1456             yield from gen_python_files_in_dir(child)
1457
1458         elif child.suffix in PYTHON_EXTENSIONS:
1459             yield child
1460
1461
1462 @dataclass
1463 class Report:
1464     """Provides a reformatting counter."""
1465     check: bool = False
1466     change_count: int = 0
1467     same_count: int = 0
1468     failure_count: int = 0
1469
1470     def done(self, src: Path, changed: bool) -> None:
1471         """Increment the counter for successful reformatting. Write out a message."""
1472         if changed:
1473             reformatted = 'would reformat' if self.check else 'reformatted'
1474             out(f'{reformatted} {src}')
1475             self.change_count += 1
1476         else:
1477             out(f'{src} already well formatted, good job.', bold=False)
1478             self.same_count += 1
1479
1480     def failed(self, src: Path, message: str) -> None:
1481         """Increment the counter for failed reformatting. Write out a message."""
1482         err(f'error: cannot format {src}: {message}')
1483         self.failure_count += 1
1484
1485     @property
1486     def return_code(self) -> int:
1487         """Which return code should the app use considering the current state."""
1488         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1489         # 126 we have special returncodes reserved by the shell.
1490         if self.failure_count:
1491             return 123
1492
1493         elif self.change_count and self.check:
1494             return 1
1495
1496         return 0
1497
1498     def __str__(self) -> str:
1499         """A color report of the current state.
1500
1501         Use `click.unstyle` to remove colors.
1502         """
1503         if self.check:
1504             reformatted = "would be reformatted"
1505             unchanged = "would be left unchanged"
1506             failed = "would fail to reformat"
1507         else:
1508             reformatted = "reformatted"
1509             unchanged = "left unchanged"
1510             failed = "failed to reformat"
1511         report = []
1512         if self.change_count:
1513             s = 's' if self.change_count > 1 else ''
1514             report.append(
1515                 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1516             )
1517         if self.same_count:
1518             s = 's' if self.same_count > 1 else ''
1519             report.append(f'{self.same_count} file{s} {unchanged}')
1520         if self.failure_count:
1521             s = 's' if self.failure_count > 1 else ''
1522             report.append(
1523                 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1524             )
1525         return ', '.join(report) + '.'
1526
1527
1528 def assert_equivalent(src: str, dst: str) -> None:
1529     """Raises AssertionError if `src` and `dst` aren't equivalent.
1530
1531     This is a temporary sanity check until Black becomes stable.
1532     """
1533
1534     import ast
1535     import traceback
1536
1537     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1538         """Simple visitor generating strings to compare ASTs by content."""
1539         yield f"{'  ' * depth}{node.__class__.__name__}("
1540
1541         for field in sorted(node._fields):
1542             try:
1543                 value = getattr(node, field)
1544             except AttributeError:
1545                 continue
1546
1547             yield f"{'  ' * (depth+1)}{field}="
1548
1549             if isinstance(value, list):
1550                 for item in value:
1551                     if isinstance(item, ast.AST):
1552                         yield from _v(item, depth + 2)
1553
1554             elif isinstance(value, ast.AST):
1555                 yield from _v(value, depth + 2)
1556
1557             else:
1558                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1559
1560         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1561
1562     try:
1563         src_ast = ast.parse(src)
1564     except Exception as exc:
1565         major, minor = sys.version_info[:2]
1566         raise AssertionError(
1567             f"cannot use --safe with this file; failed to parse source file "
1568             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1569             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1570         )
1571
1572     try:
1573         dst_ast = ast.parse(dst)
1574     except Exception as exc:
1575         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1576         raise AssertionError(
1577             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1578             f"Please report a bug on https://github.com/ambv/black/issues.  "
1579             f"This invalid output might be helpful: {log}"
1580         ) from None
1581
1582     src_ast_str = '\n'.join(_v(src_ast))
1583     dst_ast_str = '\n'.join(_v(dst_ast))
1584     if src_ast_str != dst_ast_str:
1585         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1586         raise AssertionError(
1587             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1588             f"the source.  "
1589             f"Please report a bug on https://github.com/ambv/black/issues.  "
1590             f"This diff might be helpful: {log}"
1591         ) from None
1592
1593
1594 def assert_stable(src: str, dst: str, line_length: int) -> None:
1595     """Raises AssertionError if `dst` reformats differently the second time.
1596
1597     This is a temporary sanity check until Black becomes stable.
1598     """
1599     newdst = format_str(dst, line_length=line_length)
1600     if dst != newdst:
1601         log = dump_to_file(
1602             diff(src, dst, 'source', 'first pass'),
1603             diff(dst, newdst, 'first pass', 'second pass'),
1604         )
1605         raise AssertionError(
1606             f"INTERNAL ERROR: Black produced different code on the second pass "
1607             f"of the formatter.  "
1608             f"Please report a bug on https://github.com/ambv/black/issues.  "
1609             f"This diff might be helpful: {log}"
1610         ) from None
1611
1612
1613 def dump_to_file(*output: str) -> str:
1614     """Dumps `output` to a temporary file. Returns path to the file."""
1615     import tempfile
1616
1617     with tempfile.NamedTemporaryFile(
1618         mode='w', prefix='blk_', suffix='.log', delete=False
1619     ) as f:
1620         for lines in output:
1621             f.write(lines)
1622             f.write('\n')
1623     return f.name
1624
1625
1626 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1627     """Returns a udiff string between strings `a` and `b`."""
1628     import difflib
1629
1630     a_lines = [line + '\n' for line in a.split('\n')]
1631     b_lines = [line + '\n' for line in b.split('\n')]
1632     return ''.join(
1633         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1634     )
1635
1636
1637 if __name__ == '__main__':
1638     main()