]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

54044a5f564630289d7236e95fa0548d0049342b
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(
78         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
79     ),
80 )
81 @click.pass_context
82 def main(
83     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
84 ) -> None:
85     """The uncompromising code formatter."""
86     sources: List[Path] = []
87     for s in src:
88         p = Path(s)
89         if p.is_dir():
90             sources.extend(gen_python_files_in_dir(p))
91         elif p.is_file():
92             # if a file was explicitly given, we don't care about its extension
93             sources.append(p)
94         elif s == '-':
95             sources.append(Path('-'))
96         else:
97             err(f'invalid path: {s}')
98     if len(sources) == 0:
99         ctx.exit(0)
100     elif len(sources) == 1:
101         p = sources[0]
102         report = Report()
103         try:
104             if not p.is_file() and str(p) == '-':
105                 changed = format_stdin_to_stdout(
106                     line_length=line_length, fast=fast, write_back=not check
107                 )
108             else:
109                 changed = format_file_in_place(
110                     p, line_length=line_length, fast=fast, write_back=not check
111                 )
112             report.done(p, changed)
113         except Exception as exc:
114             report.failed(p, str(exc))
115         ctx.exit(report.return_code)
116     else:
117         loop = asyncio.get_event_loop()
118         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
119         return_code = 1
120         try:
121             return_code = loop.run_until_complete(
122                 schedule_formatting(
123                     sources, line_length, not check, fast, loop, executor
124                 )
125             )
126         finally:
127             loop.close()
128             ctx.exit(return_code)
129
130
131 async def schedule_formatting(
132     sources: List[Path],
133     line_length: int,
134     write_back: bool,
135     fast: bool,
136     loop: BaseEventLoop,
137     executor: Executor,
138 ) -> int:
139     tasks = {
140         src: loop.run_in_executor(
141             executor, format_file_in_place, src, line_length, fast, write_back
142         )
143         for src in sources
144     }
145     await asyncio.wait(tasks.values())
146     cancelled = []
147     report = Report()
148     for src, task in tasks.items():
149         if not task.done():
150             report.failed(src, 'timed out, cancelling')
151             task.cancel()
152             cancelled.append(task)
153         elif task.exception():
154             report.failed(src, str(task.exception()))
155         else:
156             report.done(src, task.result())
157     if cancelled:
158         await asyncio.wait(cancelled, timeout=2)
159     out('All done! ✨ 🍰 ✨')
160     click.echo(str(report))
161     return report.return_code
162
163
164 def format_file_in_place(
165     src: Path, line_length: int, fast: bool, write_back: bool = False
166 ) -> bool:
167     """Format the file and rewrite if changed. Return True if changed."""
168     with tokenize.open(src) as src_buffer:
169         src_contents = src_buffer.read()
170     try:
171         contents = format_file_contents(
172             src_contents, line_length=line_length, fast=fast
173         )
174     except NothingChanged:
175         return False
176
177     if write_back:
178         with open(src, "w", encoding=src_buffer.encoding) as f:
179             f.write(contents)
180     return True
181
182
183 def format_stdin_to_stdout(
184     line_length: int, fast: bool, write_back: bool = False
185 ) -> bool:
186     """Format file on stdin and pipe output to stdout. Return True if changed."""
187     contents = sys.stdin.read()
188     try:
189         contents = format_file_contents(contents, line_length=line_length, fast=fast)
190         return True
191
192     except NothingChanged:
193         return False
194
195     finally:
196         if write_back:
197             sys.stdout.write(contents)
198
199
200 def format_file_contents(
201     src_contents: str, line_length: int, fast: bool
202 ) -> FileContent:
203     """Reformats a file and returns its contents and encoding."""
204     if src_contents.strip() == '':
205         raise NothingChanged
206
207     dst_contents = format_str(src_contents, line_length=line_length)
208     if src_contents == dst_contents:
209         raise NothingChanged
210
211     if not fast:
212         assert_equivalent(src_contents, dst_contents)
213         assert_stable(src_contents, dst_contents, line_length=line_length)
214     return dst_contents
215
216
217 def format_str(src_contents: str, line_length: int) -> FileContent:
218     """Reformats a string and returns new contents."""
219     src_node = lib2to3_parse(src_contents)
220     dst_contents = ""
221     comments: List[Line] = []
222     lines = LineGenerator()
223     elt = EmptyLineTracker()
224     py36 = is_python36(src_node)
225     empty_line = Line()
226     after = 0
227     for current_line in lines.visit(src_node):
228         for _ in range(after):
229             dst_contents += str(empty_line)
230         before, after = elt.maybe_empty_lines(current_line)
231         for _ in range(before):
232             dst_contents += str(empty_line)
233         if not current_line.is_comment:
234             for comment in comments:
235                 dst_contents += str(comment)
236             comments = []
237             for line in split_line(current_line, line_length=line_length, py36=py36):
238                 dst_contents += str(line)
239         else:
240             comments.append(current_line)
241     if comments:
242         if elt.previous_defs:
243             # Separate postscriptum comments from the last module-level def.
244             dst_contents += str(empty_line)
245             dst_contents += str(empty_line)
246         for comment in comments:
247             dst_contents += str(comment)
248     return dst_contents
249
250
251 def lib2to3_parse(src_txt: str) -> Node:
252     """Given a string with source, return the lib2to3 Node."""
253     grammar = pygram.python_grammar_no_print_statement
254     drv = driver.Driver(grammar, pytree.convert)
255     if src_txt[-1] != '\n':
256         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
257         src_txt += nl
258     try:
259         result = drv.parse_string(src_txt, True)
260     except ParseError as pe:
261         lineno, column = pe.context[1]
262         lines = src_txt.splitlines()
263         try:
264             faulty_line = lines[lineno - 1]
265         except IndexError:
266             faulty_line = "<line number missing in source>"
267         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
268
269     if isinstance(result, Leaf):
270         result = Node(syms.file_input, [result])
271     return result
272
273
274 def lib2to3_unparse(node: Node) -> str:
275     """Given a lib2to3 node, return its string representation."""
276     code = str(node)
277     return code
278
279
280 T = TypeVar('T')
281
282
283 class Visitor(Generic[T]):
284     """Basic lib2to3 visitor that yields things on visiting."""
285
286     def visit(self, node: LN) -> Iterator[T]:
287         if node.type < 256:
288             name = token.tok_name[node.type]
289         else:
290             name = type_repr(node.type)
291         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
292
293     def visit_default(self, node: LN) -> Iterator[T]:
294         if isinstance(node, Node):
295             for child in node.children:
296                 yield from self.visit(child)
297
298
299 @dataclass
300 class DebugVisitor(Visitor[T]):
301     tree_depth: int = 0
302
303     def visit_default(self, node: LN) -> Iterator[T]:
304         indent = ' ' * (2 * self.tree_depth)
305         if isinstance(node, Node):
306             _type = type_repr(node.type)
307             out(f'{indent}{_type}', fg='yellow')
308             self.tree_depth += 1
309             for child in node.children:
310                 yield from self.visit(child)
311
312             self.tree_depth -= 1
313             out(f'{indent}/{_type}', fg='yellow', bold=False)
314         else:
315             _type = token.tok_name.get(node.type, str(node.type))
316             out(f'{indent}{_type}', fg='blue', nl=False)
317             if node.prefix:
318                 # We don't have to handle prefixes for `Node` objects since
319                 # that delegates to the first child anyway.
320                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
321             out(f' {node.value!r}', fg='blue', bold=False)
322
323
324 KEYWORDS = set(keyword.kwlist)
325 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
326 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
327 STATEMENT = {
328     syms.if_stmt,
329     syms.while_stmt,
330     syms.for_stmt,
331     syms.try_stmt,
332     syms.except_clause,
333     syms.with_stmt,
334     syms.funcdef,
335     syms.classdef,
336 }
337 STANDALONE_COMMENT = 153
338 LOGIC_OPERATORS = {'and', 'or'}
339 COMPARATORS = {
340     token.LESS,
341     token.GREATER,
342     token.EQEQUAL,
343     token.NOTEQUAL,
344     token.LESSEQUAL,
345     token.GREATEREQUAL,
346 }
347 MATH_OPERATORS = {
348     token.PLUS,
349     token.MINUS,
350     token.STAR,
351     token.SLASH,
352     token.VBAR,
353     token.AMPER,
354     token.PERCENT,
355     token.CIRCUMFLEX,
356     token.LEFTSHIFT,
357     token.RIGHTSHIFT,
358     token.DOUBLESTAR,
359     token.DOUBLESLASH,
360 }
361 COMPREHENSION_PRIORITY = 20
362 COMMA_PRIORITY = 10
363 LOGIC_PRIORITY = 5
364 STRING_PRIORITY = 4
365 COMPARATOR_PRIORITY = 3
366 MATH_PRIORITY = 1
367
368
369 @dataclass
370 class BracketTracker:
371     depth: int = 0
372     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
373     delimiters: Dict[LeafID, Priority] = Factory(dict)
374     previous: Optional[Leaf] = None
375
376     def mark(self, leaf: Leaf) -> None:
377         if leaf.type == token.COMMENT:
378             return
379
380         if leaf.type in CLOSING_BRACKETS:
381             self.depth -= 1
382             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
383             leaf.opening_bracket = opening_bracket
384         leaf.bracket_depth = self.depth
385         if self.depth == 0:
386             delim = is_delimiter(leaf)
387             if delim:
388                 self.delimiters[id(leaf)] = delim
389             elif self.previous is not None:
390                 if leaf.type == token.STRING and self.previous.type == token.STRING:
391                     self.delimiters[id(self.previous)] = STRING_PRIORITY
392                 elif (
393                     leaf.type == token.NAME
394                     and leaf.value == 'for'
395                     and leaf.parent
396                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
397                 ):
398                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
399                 elif (
400                     leaf.type == token.NAME
401                     and leaf.value == 'if'
402                     and leaf.parent
403                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
404                 ):
405                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
406                 elif (
407                     leaf.type == token.NAME
408                     and leaf.value in LOGIC_OPERATORS
409                     and leaf.parent
410                 ):
411                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
412         if leaf.type in OPENING_BRACKETS:
413             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
414             self.depth += 1
415         self.previous = leaf
416
417     def any_open_brackets(self) -> bool:
418         """Returns True if there is an yet unmatched open bracket on the line."""
419         return bool(self.bracket_match)
420
421     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
422         """Returns the highest priority of a delimiter found on the line.
423
424         Values are consistent with what `is_delimiter()` returns.
425         """
426         return max(v for k, v in self.delimiters.items() if k not in exclude)
427
428
429 @dataclass
430 class Line:
431     depth: int = 0
432     leaves: List[Leaf] = Factory(list)
433     comments: Dict[LeafID, Leaf] = Factory(dict)
434     bracket_tracker: BracketTracker = Factory(BracketTracker)
435     inside_brackets: bool = False
436     has_for: bool = False
437     _for_loop_variable: bool = False
438
439     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
440         has_value = leaf.value.strip()
441         if not has_value:
442             return
443
444         if self.leaves and not preformatted:
445             # Note: at this point leaf.prefix should be empty except for
446             # imports, for which we only preserve newlines.
447             leaf.prefix += whitespace(leaf)
448         if self.inside_brackets or not preformatted:
449             self.maybe_decrement_after_for_loop_variable(leaf)
450             self.bracket_tracker.mark(leaf)
451             self.maybe_remove_trailing_comma(leaf)
452             self.maybe_increment_for_loop_variable(leaf)
453             if self.maybe_adapt_standalone_comment(leaf):
454                 return
455
456         if not self.append_comment(leaf):
457             self.leaves.append(leaf)
458
459     @property
460     def is_comment(self) -> bool:
461         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
462
463     @property
464     def is_decorator(self) -> bool:
465         return bool(self) and self.leaves[0].type == token.AT
466
467     @property
468     def is_import(self) -> bool:
469         return bool(self) and is_import(self.leaves[0])
470
471     @property
472     def is_class(self) -> bool:
473         return (
474             bool(self)
475             and self.leaves[0].type == token.NAME
476             and self.leaves[0].value == 'class'
477         )
478
479     @property
480     def is_def(self) -> bool:
481         """Also returns True for async defs."""
482         try:
483             first_leaf = self.leaves[0]
484         except IndexError:
485             return False
486
487         try:
488             second_leaf: Optional[Leaf] = self.leaves[1]
489         except IndexError:
490             second_leaf = None
491         return (
492             (first_leaf.type == token.NAME and first_leaf.value == 'def')
493             or (
494                 first_leaf.type == token.ASYNC
495                 and second_leaf is not None
496                 and second_leaf.type == token.NAME
497                 and second_leaf.value == 'def'
498             )
499         )
500
501     @property
502     def is_flow_control(self) -> bool:
503         return (
504             bool(self)
505             and self.leaves[0].type == token.NAME
506             and self.leaves[0].value in FLOW_CONTROL
507         )
508
509     @property
510     def is_yield(self) -> bool:
511         return (
512             bool(self)
513             and self.leaves[0].type == token.NAME
514             and self.leaves[0].value == 'yield'
515         )
516
517     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
518         if not (
519             self.leaves
520             and self.leaves[-1].type == token.COMMA
521             and closing.type in CLOSING_BRACKETS
522         ):
523             return False
524
525         if closing.type == token.RSQB or closing.type == token.RBRACE:
526             self.leaves.pop()
527             return True
528
529         # For parens let's check if it's safe to remove the comma.  If the
530         # trailing one is the only one, we might mistakenly change a tuple
531         # into a different type by removing the comma.
532         depth = closing.bracket_depth + 1
533         commas = 0
534         opening = closing.opening_bracket
535         for _opening_index, leaf in enumerate(self.leaves):
536             if leaf is opening:
537                 break
538
539         else:
540             return False
541
542         for leaf in self.leaves[_opening_index + 1:]:
543             if leaf is closing:
544                 break
545
546             bracket_depth = leaf.bracket_depth
547             if bracket_depth == depth and leaf.type == token.COMMA:
548                 commas += 1
549                 if leaf.parent and leaf.parent.type == syms.arglist:
550                     commas += 1
551                     break
552
553         if commas > 1:
554             self.leaves.pop()
555             return True
556
557         return False
558
559     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
560         """In a for loop, or comprehension, the variables are often unpacks.
561
562         To avoid splitting on the comma in this situation, we will increase
563         the depth of tokens between `for` and `in`.
564         """
565         if leaf.type == token.NAME and leaf.value == 'for':
566             self.has_for = True
567             self.bracket_tracker.depth += 1
568             self._for_loop_variable = True
569             return True
570
571         return False
572
573     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
574         # See `maybe_increment_for_loop_variable` above for explanation.
575         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
576             self.bracket_tracker.depth -= 1
577             self._for_loop_variable = False
578             return True
579
580         return False
581
582     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
583         """Hack a standalone comment to act as a trailing comment for line splitting.
584
585         If this line has brackets and a standalone `comment`, we need to adapt
586         it to be able to still reformat the line.
587
588         This is not perfect, the line to which the standalone comment gets
589         appended will appear "too long" when splitting.
590         """
591         if not (
592             comment.type == STANDALONE_COMMENT
593             and self.bracket_tracker.any_open_brackets()
594         ):
595             return False
596
597         comment.type = token.COMMENT
598         comment.prefix = '\n' + '    ' * (self.depth + 1)
599         return self.append_comment(comment)
600
601     def append_comment(self, comment: Leaf) -> bool:
602         if comment.type != token.COMMENT:
603             return False
604
605         try:
606             after = id(self.last_non_delimiter())
607         except LookupError:
608             comment.type = STANDALONE_COMMENT
609             comment.prefix = ''
610             return False
611
612         else:
613             if after in self.comments:
614                 self.comments[after].value += str(comment)
615             else:
616                 self.comments[after] = comment
617             return True
618
619     def last_non_delimiter(self) -> Leaf:
620         for i in range(len(self.leaves)):
621             last = self.leaves[-i - 1]
622             if not is_delimiter(last):
623                 return last
624
625         raise LookupError("No non-delimiters found")
626
627     def __str__(self) -> str:
628         if not self:
629             return '\n'
630
631         indent = '    ' * self.depth
632         leaves = iter(self.leaves)
633         first = next(leaves)
634         res = f'{first.prefix}{indent}{first.value}'
635         for leaf in leaves:
636             res += str(leaf)
637         for comment in self.comments.values():
638             res += str(comment)
639         return res + '\n'
640
641     def __bool__(self) -> bool:
642         return bool(self.leaves or self.comments)
643
644
645 @dataclass
646 class EmptyLineTracker:
647     """Provides a stateful method that returns the number of potential extra
648     empty lines needed before and after the currently processed line.
649
650     Note: this tracker works on lines that haven't been split yet.  It assumes
651     the prefix of the first leaf consists of optional newlines.  Those newlines
652     are consumed by `maybe_empty_lines()` and included in the computation.
653     """
654     previous_line: Optional[Line] = None
655     previous_after: int = 0
656     previous_defs: List[int] = Factory(list)
657
658     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
659         """Returns the number of extra empty lines before and after the `current_line`.
660
661         This is for separating `def`, `async def` and `class` with extra empty lines
662         (two on module-level), as well as providing an extra empty line after flow
663         control keywords to make them more prominent.
664         """
665         if current_line.is_comment:
666             # Don't count standalone comments towards previous empty lines.
667             return 0, 0
668
669         before, after = self._maybe_empty_lines(current_line)
670         before -= self.previous_after
671         self.previous_after = after
672         self.previous_line = current_line
673         return before, after
674
675     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
676         if current_line.leaves:
677             # Consume the first leaf's extra newlines.
678             first_leaf = current_line.leaves[0]
679             before = int('\n' in first_leaf.prefix)
680             first_leaf.prefix = ''
681         else:
682             before = 0
683         depth = current_line.depth
684         while self.previous_defs and self.previous_defs[-1] >= depth:
685             self.previous_defs.pop()
686             before = 1 if depth else 2
687         is_decorator = current_line.is_decorator
688         if is_decorator or current_line.is_def or current_line.is_class:
689             if not is_decorator:
690                 self.previous_defs.append(depth)
691             if self.previous_line is None:
692                 # Don't insert empty lines before the first line in the file.
693                 return 0, 0
694
695             if self.previous_line and self.previous_line.is_decorator:
696                 # Don't insert empty lines between decorators.
697                 return 0, 0
698
699             newlines = 2
700             if current_line.depth:
701                 newlines -= 1
702             return newlines, 0
703
704         if current_line.is_flow_control:
705             return before, 1
706
707         if (
708             self.previous_line
709             and self.previous_line.is_import
710             and not current_line.is_import
711             and depth == self.previous_line.depth
712         ):
713             return (before or 1), 0
714
715         if (
716             self.previous_line
717             and self.previous_line.is_yield
718             and (not current_line.is_yield or depth != self.previous_line.depth)
719         ):
720             return (before or 1), 0
721
722         return before, 0
723
724
725 @dataclass
726 class LineGenerator(Visitor[Line]):
727     """Generates reformatted Line objects.  Empty lines are not emitted.
728
729     Note: destroys the tree it's visiting by mutating prefixes of its leaves
730     in ways that will no longer stringify to valid Python code on the tree.
731     """
732     current_line: Line = Factory(Line)
733     standalone_comments: List[Leaf] = Factory(list)
734
735     def line(self, indent: int = 0) -> Iterator[Line]:
736         """Generate a line.
737
738         If the line is empty, only emit if it makes sense.
739         If the line is too long, split it first and then generate.
740
741         If any lines were generated, set up a new current_line.
742         """
743         if not self.current_line:
744             self.current_line.depth += indent
745             return  # Line is empty, don't emit. Creating a new one unnecessary.
746
747         complete_line = self.current_line
748         self.current_line = Line(depth=complete_line.depth + indent)
749         yield complete_line
750
751     def visit_default(self, node: LN) -> Iterator[Line]:
752         if isinstance(node, Leaf):
753             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
754             for comment in generate_comments(node):
755                 if any_open_brackets:
756                     # any comment within brackets is subject to splitting
757                     self.current_line.append(comment)
758                 elif comment.type == token.COMMENT:
759                     # regular trailing comment
760                     self.current_line.append(comment)
761                     yield from self.line()
762
763                 else:
764                     # regular standalone comment, to be processed later (see
765                     # docstring in `generate_comments()`
766                     self.standalone_comments.append(comment)
767             normalize_prefix(node, inside_brackets=any_open_brackets)
768             if node.type not in WHITESPACE:
769                 for comment in self.standalone_comments:
770                     yield from self.line()
771
772                     self.current_line.append(comment)
773                     yield from self.line()
774
775                 self.standalone_comments = []
776                 self.current_line.append(node)
777         yield from super().visit_default(node)
778
779     def visit_suite(self, node: Node) -> Iterator[Line]:
780         """Body of a statement after a colon."""
781         children = iter(node.children)
782         # Process newline before indenting.  It might contain an inline
783         # comment that should go right after the colon.
784         newline = next(children)
785         yield from self.visit(newline)
786         yield from self.line(+1)
787
788         for child in children:
789             yield from self.visit(child)
790
791         yield from self.line(-1)
792
793     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
794         """Visit a statement.
795
796         The relevant Python language keywords for this statement are NAME leaves
797         within it.
798         """
799         for child in node.children:
800             if child.type == token.NAME and child.value in keywords:  # type: ignore
801                 yield from self.line()
802
803             yield from self.visit(child)
804
805     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
806         """A statement without nested statements."""
807         is_suite_like = node.parent and node.parent.type in STATEMENT
808         if is_suite_like:
809             yield from self.line(+1)
810             yield from self.visit_default(node)
811             yield from self.line(-1)
812
813         else:
814             yield from self.line()
815             yield from self.visit_default(node)
816
817     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
818         yield from self.line()
819
820         children = iter(node.children)
821         for child in children:
822             yield from self.visit(child)
823
824             if child.type == token.ASYNC:
825                 break
826
827         internal_stmt = next(children)
828         for child in internal_stmt.children:
829             yield from self.visit(child)
830
831     def visit_decorators(self, node: Node) -> Iterator[Line]:
832         for child in node.children:
833             yield from self.line()
834             yield from self.visit(child)
835
836     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
837         yield from self.line()
838
839     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
840         yield from self.visit_default(leaf)
841         yield from self.line()
842
843     def __attrs_post_init__(self) -> None:
844         """You are in a twisty little maze of passages."""
845         v = self.visit_stmt
846         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
847         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
848         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
849         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
850         self.visit_except_clause = partial(v, keywords={'except'})
851         self.visit_funcdef = partial(v, keywords={'def'})
852         self.visit_with_stmt = partial(v, keywords={'with'})
853         self.visit_classdef = partial(v, keywords={'class'})
854         self.visit_async_funcdef = self.visit_async_stmt
855         self.visit_decorated = self.visit_decorators
856
857
858 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
859 OPENING_BRACKETS = set(BRACKET.keys())
860 CLOSING_BRACKETS = set(BRACKET.values())
861 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
862 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
863
864
865 def whitespace(leaf: Leaf) -> str:  # noqa C901
866     """Return whitespace prefix if needed for the given `leaf`."""
867     NO = ''
868     SPACE = ' '
869     DOUBLESPACE = '  '
870     t = leaf.type
871     p = leaf.parent
872     v = leaf.value
873     if t in ALWAYS_NO_SPACE:
874         return NO
875
876     if t == token.COMMENT:
877         return DOUBLESPACE
878
879     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
880     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
881         return NO
882
883     prev = leaf.prev_sibling
884     if not prev:
885         prevp = preceding_leaf(p)
886         if not prevp or prevp.type in OPENING_BRACKETS:
887             return NO
888
889         if t == token.COLON:
890             return SPACE if prevp.type == token.COMMA else NO
891
892         if prevp.type == token.EQUAL:
893             if prevp.parent and prevp.parent.type in {
894                 syms.typedargslist,
895                 syms.varargslist,
896                 syms.parameters,
897                 syms.arglist,
898                 syms.argument,
899             }:
900                 return NO
901
902         elif prevp.type == token.DOUBLESTAR:
903             if prevp.parent and prevp.parent.type in {
904                 syms.typedargslist,
905                 syms.varargslist,
906                 syms.parameters,
907                 syms.arglist,
908                 syms.dictsetmaker,
909             }:
910                 return NO
911
912         elif prevp.type == token.COLON:
913             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
914                 return NO
915
916         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
917             return NO
918
919     elif prev.type in OPENING_BRACKETS:
920         return NO
921
922     if p.type in {syms.parameters, syms.arglist}:
923         # untyped function signatures or calls
924         if t == token.RPAR:
925             return NO
926
927         if not prev or prev.type != token.COMMA:
928             return NO
929
930     if p.type == syms.varargslist:
931         # lambdas
932         if t == token.RPAR:
933             return NO
934
935         if prev and prev.type != token.COMMA:
936             return NO
937
938     elif p.type == syms.typedargslist:
939         # typed function signatures
940         if not prev:
941             return NO
942
943         if t == token.EQUAL:
944             if prev.type != syms.tname:
945                 return NO
946
947         elif prev.type == token.EQUAL:
948             # A bit hacky: if the equal sign has whitespace, it means we
949             # previously found it's a typed argument.  So, we're using that, too.
950             return prev.prefix
951
952         elif prev.type != token.COMMA:
953             return NO
954
955     elif p.type == syms.tname:
956         # type names
957         if not prev:
958             prevp = preceding_leaf(p)
959             if not prevp or prevp.type != token.COMMA:
960                 return NO
961
962     elif p.type == syms.trailer:
963         # attributes and calls
964         if t == token.LPAR or t == token.RPAR:
965             return NO
966
967         if not prev:
968             if t == token.DOT:
969                 prevp = preceding_leaf(p)
970                 if not prevp or prevp.type != token.NUMBER:
971                     return NO
972
973             elif t == token.LSQB:
974                 return NO
975
976         elif prev.type != token.COMMA:
977             return NO
978
979     elif p.type == syms.argument:
980         # single argument
981         if t == token.EQUAL:
982             return NO
983
984         if not prev:
985             prevp = preceding_leaf(p)
986             if not prevp or prevp.type == token.LPAR:
987                 return NO
988
989         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
990             return NO
991
992     elif p.type == syms.decorator:
993         # decorators
994         return NO
995
996     elif p.type == syms.dotted_name:
997         if prev:
998             return NO
999
1000         prevp = preceding_leaf(p)
1001         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1002             return NO
1003
1004     elif p.type == syms.classdef:
1005         if t == token.LPAR:
1006             return NO
1007
1008         if prev and prev.type == token.LPAR:
1009             return NO
1010
1011     elif p.type == syms.subscript:
1012         # indexing
1013         if not prev:
1014             assert p.parent is not None, "subscripts are always parented"
1015             if p.parent.type == syms.subscriptlist:
1016                 return SPACE
1017
1018             return NO
1019
1020         else:
1021             return NO
1022
1023     elif p.type == syms.atom:
1024         if prev and t == token.DOT:
1025             # dots, but not the first one.
1026             return NO
1027
1028     elif (
1029         p.type == syms.listmaker
1030         or p.type == syms.testlist_gexp
1031         or p.type == syms.subscriptlist
1032     ):
1033         # list interior, including unpacking
1034         if not prev:
1035             return NO
1036
1037     elif p.type == syms.dictsetmaker:
1038         # dict and set interior, including unpacking
1039         if not prev:
1040             return NO
1041
1042         if prev.type == token.DOUBLESTAR:
1043             return NO
1044
1045     elif p.type in {syms.factor, syms.star_expr}:
1046         # unary ops
1047         if not prev:
1048             prevp = preceding_leaf(p)
1049             if not prevp or prevp.type in OPENING_BRACKETS:
1050                 return NO
1051
1052             prevp_parent = prevp.parent
1053             assert prevp_parent is not None
1054             if prevp.type == token.COLON and prevp_parent.type in {
1055                 syms.subscript, syms.sliceop
1056             }:
1057                 return NO
1058
1059             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1060                 return NO
1061
1062         elif t == token.NAME or t == token.NUMBER:
1063             return NO
1064
1065     elif p.type == syms.import_from:
1066         if t == token.DOT:
1067             if prev and prev.type == token.DOT:
1068                 return NO
1069
1070         elif t == token.NAME:
1071             if v == 'import':
1072                 return SPACE
1073
1074             if prev and prev.type == token.DOT:
1075                 return NO
1076
1077     elif p.type == syms.sliceop:
1078         return NO
1079
1080     return SPACE
1081
1082
1083 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1084     """Returns the first leaf that precedes `node`, if any."""
1085     while node:
1086         res = node.prev_sibling
1087         if res:
1088             if isinstance(res, Leaf):
1089                 return res
1090
1091             try:
1092                 return list(res.leaves())[-1]
1093
1094             except IndexError:
1095                 return None
1096
1097         node = node.parent
1098     return None
1099
1100
1101 def is_delimiter(leaf: Leaf) -> int:
1102     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1103
1104     Higher numbers are higher priority.
1105     """
1106     if leaf.type == token.COMMA:
1107         return COMMA_PRIORITY
1108
1109     if leaf.type in COMPARATORS:
1110         return COMPARATOR_PRIORITY
1111
1112     if (
1113         leaf.type in MATH_OPERATORS
1114         and leaf.parent
1115         and leaf.parent.type not in {syms.factor, syms.star_expr}
1116     ):
1117         return MATH_PRIORITY
1118
1119     return 0
1120
1121
1122 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1123     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1124
1125     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1126     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1127     move because it does away with modifying the grammar to include all the
1128     possible places in which comments can be placed.
1129
1130     The sad consequence for us though is that comments don't "belong" anywhere.
1131     This is why this function generates simple parentless Leaf objects for
1132     comments.  We simply don't know what the correct parent should be.
1133
1134     No matter though, we can live without this.  We really only need to
1135     differentiate between inline and standalone comments.  The latter don't
1136     share the line with any code.
1137
1138     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1139     are emitted with a fake STANDALONE_COMMENT token identifier.
1140     """
1141     if not leaf.prefix:
1142         return
1143
1144     if '#' not in leaf.prefix:
1145         return
1146
1147     before_comment, content = leaf.prefix.split('#', 1)
1148     content = content.rstrip()
1149     if content and (content[0] not in {' ', '!', '#'}):
1150         content = ' ' + content
1151     is_standalone_comment = (
1152         '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER
1153     )
1154     if not is_standalone_comment:
1155         # simple trailing comment
1156         yield Leaf(token.COMMENT, value='#' + content)
1157         return
1158
1159     for line in ('#' + content).split('\n'):
1160         line = line.lstrip()
1161         if not line.startswith('#'):
1162             continue
1163
1164         yield Leaf(STANDALONE_COMMENT, line)
1165
1166
1167 def split_line(
1168     line: Line, line_length: int, inner: bool = False, py36: bool = False
1169 ) -> Iterator[Line]:
1170     """Splits a `line` into potentially many lines.
1171
1172     They should fit in the allotted `line_length` but might not be able to.
1173     `inner` signifies that there were a pair of brackets somewhere around the
1174     current `line`, possibly transitively. This means we can fallback to splitting
1175     by delimiters if the LHS/RHS don't yield any results.
1176
1177     If `py36` is True, splitting may generate syntax that is only compatible
1178     with Python 3.6 and later.
1179     """
1180     line_str = str(line).strip('\n')
1181     if len(line_str) <= line_length and '\n' not in line_str:
1182         yield line
1183         return
1184
1185     if line.is_def:
1186         split_funcs = [left_hand_split]
1187     elif line.inside_brackets:
1188         split_funcs = [delimiter_split]
1189         if '\n' not in line_str:
1190             # Only attempt RHS if we don't have multiline strings or comments
1191             # on this line.
1192             split_funcs.append(right_hand_split)
1193     else:
1194         split_funcs = [right_hand_split]
1195     for split_func in split_funcs:
1196         # We are accumulating lines in `result` because we might want to abort
1197         # mission and return the original line in the end, or attempt a different
1198         # split altogether.
1199         result: List[Line] = []
1200         try:
1201             for l in split_func(line, py36=py36):
1202                 if str(l).strip('\n') == line_str:
1203                     raise CannotSplit("Split function returned an unchanged result")
1204
1205                 result.extend(
1206                     split_line(l, line_length=line_length, inner=True, py36=py36)
1207                 )
1208         except CannotSplit as cs:
1209             continue
1210
1211         else:
1212             yield from result
1213             break
1214
1215     else:
1216         yield line
1217
1218
1219 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1220     """Split line into many lines, starting with the first matching bracket pair.
1221
1222     Note: this usually looks weird, only use this for function definitions.
1223     Prefer RHS otherwise.
1224     """
1225     head = Line(depth=line.depth)
1226     body = Line(depth=line.depth + 1, inside_brackets=True)
1227     tail = Line(depth=line.depth)
1228     tail_leaves: List[Leaf] = []
1229     body_leaves: List[Leaf] = []
1230     head_leaves: List[Leaf] = []
1231     current_leaves = head_leaves
1232     matching_bracket = None
1233     for leaf in line.leaves:
1234         if (
1235             current_leaves is body_leaves
1236             and leaf.type in CLOSING_BRACKETS
1237             and leaf.opening_bracket is matching_bracket
1238         ):
1239             current_leaves = tail_leaves if body_leaves else head_leaves
1240         current_leaves.append(leaf)
1241         if current_leaves is head_leaves:
1242             if leaf.type in OPENING_BRACKETS:
1243                 matching_bracket = leaf
1244                 current_leaves = body_leaves
1245     # Since body is a new indent level, remove spurious leading whitespace.
1246     if body_leaves:
1247         normalize_prefix(body_leaves[0], inside_brackets=True)
1248     # Build the new lines.
1249     for result, leaves in (
1250         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1251     ):
1252         for leaf in leaves:
1253             result.append(leaf, preformatted=True)
1254             comment_after = line.comments.get(id(leaf))
1255             if comment_after:
1256                 result.append(comment_after, preformatted=True)
1257     split_succeeded_or_raise(head, body, tail)
1258     for result in (head, body, tail):
1259         if result:
1260             yield result
1261
1262
1263 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1264     """Split line into many lines, starting with the last matching bracket pair."""
1265     head = Line(depth=line.depth)
1266     body = Line(depth=line.depth + 1, inside_brackets=True)
1267     tail = Line(depth=line.depth)
1268     tail_leaves: List[Leaf] = []
1269     body_leaves: List[Leaf] = []
1270     head_leaves: List[Leaf] = []
1271     current_leaves = tail_leaves
1272     opening_bracket = None
1273     for leaf in reversed(line.leaves):
1274         if current_leaves is body_leaves:
1275             if leaf is opening_bracket:
1276                 current_leaves = head_leaves if body_leaves else tail_leaves
1277         current_leaves.append(leaf)
1278         if current_leaves is tail_leaves:
1279             if leaf.type in CLOSING_BRACKETS:
1280                 opening_bracket = leaf.opening_bracket
1281                 current_leaves = body_leaves
1282     tail_leaves.reverse()
1283     body_leaves.reverse()
1284     head_leaves.reverse()
1285     # Since body is a new indent level, remove spurious leading whitespace.
1286     if body_leaves:
1287         normalize_prefix(body_leaves[0], inside_brackets=True)
1288     # Build the new lines.
1289     for result, leaves in (
1290         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1291     ):
1292         for leaf in leaves:
1293             result.append(leaf, preformatted=True)
1294             comment_after = line.comments.get(id(leaf))
1295             if comment_after:
1296                 result.append(comment_after, preformatted=True)
1297     split_succeeded_or_raise(head, body, tail)
1298     for result in (head, body, tail):
1299         if result:
1300             yield result
1301
1302
1303 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1304     tail_len = len(str(tail).strip())
1305     if not body:
1306         if tail_len == 0:
1307             raise CannotSplit("Splitting brackets produced the same line")
1308
1309         elif tail_len < 3:
1310             raise CannotSplit(
1311                 f"Splitting brackets on an empty body to save "
1312                 f"{tail_len} characters is not worth it"
1313             )
1314
1315
1316 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1317     """Split according to delimiters of the highest priority.
1318
1319     This kind of split doesn't increase indentation.
1320     If `py36` is True, the split will add trailing commas also in function
1321     signatures that contain * and **.
1322     """
1323     try:
1324         last_leaf = line.leaves[-1]
1325     except IndexError:
1326         raise CannotSplit("Line empty")
1327
1328     delimiters = line.bracket_tracker.delimiters
1329     try:
1330         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1331     except ValueError:
1332         raise CannotSplit("No delimiters found")
1333
1334     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1335     lowest_depth = sys.maxsize
1336     trailing_comma_safe = True
1337     for leaf in line.leaves:
1338         current_line.append(leaf, preformatted=True)
1339         comment_after = line.comments.get(id(leaf))
1340         if comment_after:
1341             current_line.append(comment_after, preformatted=True)
1342         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1343         if (
1344             leaf.bracket_depth == lowest_depth
1345             and leaf.type == token.STAR
1346             or leaf.type == token.DOUBLESTAR
1347         ):
1348             trailing_comma_safe = trailing_comma_safe and py36
1349         leaf_priority = delimiters.get(id(leaf))
1350         if leaf_priority == delimiter_priority:
1351             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1352             yield current_line
1353
1354             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1355     if current_line:
1356         if (
1357             delimiter_priority == COMMA_PRIORITY
1358             and current_line.leaves[-1].type != token.COMMA
1359             and trailing_comma_safe
1360         ):
1361             current_line.append(Leaf(token.COMMA, ','))
1362         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1363         yield current_line
1364
1365
1366 def is_import(leaf: Leaf) -> bool:
1367     """Returns True if the given leaf starts an import statement."""
1368     p = leaf.parent
1369     t = leaf.type
1370     v = leaf.value
1371     return bool(
1372         t == token.NAME
1373         and (
1374             (v == 'import' and p and p.type == syms.import_name)
1375             or (v == 'from' and p and p.type == syms.import_from)
1376         )
1377     )
1378
1379
1380 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1381     """Leave existing extra newlines if not `inside_brackets`.
1382
1383     Remove everything else.  Note: don't use backslashes for formatting or
1384     you'll lose your voting rights.
1385     """
1386     if not inside_brackets:
1387         spl = leaf.prefix.split('#', 1)
1388         if '\\' not in spl[0]:
1389             nl_count = spl[0].count('\n')
1390             leaf.prefix = '\n' * nl_count
1391             return
1392
1393     leaf.prefix = ''
1394
1395
1396 def is_python36(node: Node) -> bool:
1397     """Returns True if the current file is using Python 3.6+ features.
1398
1399     Currently looking for:
1400     - f-strings; and
1401     - trailing commas after * or ** in function signatures.
1402     """
1403     for n in node.pre_order():
1404         if n.type == token.STRING:
1405             value_head = n.value[:2]  # type: ignore
1406             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1407                 return True
1408
1409         elif (
1410             n.type == syms.typedargslist
1411             and n.children
1412             and n.children[-1].type == token.COMMA
1413         ):
1414             for ch in n.children:
1415                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1416                     return True
1417
1418     return False
1419
1420
1421 PYTHON_EXTENSIONS = {'.py'}
1422 BLACKLISTED_DIRECTORIES = {
1423     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1424 }
1425
1426
1427 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1428     for child in path.iterdir():
1429         if child.is_dir():
1430             if child.name in BLACKLISTED_DIRECTORIES:
1431                 continue
1432
1433             yield from gen_python_files_in_dir(child)
1434
1435         elif child.suffix in PYTHON_EXTENSIONS:
1436             yield child
1437
1438
1439 @dataclass
1440 class Report:
1441     """Provides a reformatting counter."""
1442     change_count: int = 0
1443     same_count: int = 0
1444     failure_count: int = 0
1445
1446     def done(self, src: Path, changed: bool) -> None:
1447         """Increment the counter for successful reformatting. Write out a message."""
1448         if changed:
1449             out(f'reformatted {src}')
1450             self.change_count += 1
1451         else:
1452             out(f'{src} already well formatted, good job.', bold=False)
1453             self.same_count += 1
1454
1455     def failed(self, src: Path, message: str) -> None:
1456         """Increment the counter for failed reformatting. Write out a message."""
1457         err(f'error: cannot format {src}: {message}')
1458         self.failure_count += 1
1459
1460     @property
1461     def return_code(self) -> int:
1462         """Which return code should the app use considering the current state."""
1463         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1464         # 126 we have special returncodes reserved by the shell.
1465         if self.failure_count:
1466             return 123
1467
1468         elif self.change_count:
1469             return 1
1470
1471         return 0
1472
1473     def __str__(self) -> str:
1474         """A color report of the current state.
1475
1476         Use `click.unstyle` to remove colors.
1477         """
1478         report = []
1479         if self.change_count:
1480             s = 's' if self.change_count > 1 else ''
1481             report.append(
1482                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1483             )
1484         if self.same_count:
1485             s = 's' if self.same_count > 1 else ''
1486             report.append(f'{self.same_count} file{s} left unchanged')
1487         if self.failure_count:
1488             s = 's' if self.failure_count > 1 else ''
1489             report.append(
1490                 click.style(
1491                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1492                 )
1493             )
1494         return ', '.join(report) + '.'
1495
1496
1497 def assert_equivalent(src: str, dst: str) -> None:
1498     """Raises AssertionError if `src` and `dst` aren't equivalent.
1499
1500     This is a temporary sanity check until Black becomes stable.
1501     """
1502
1503     import ast
1504     import traceback
1505
1506     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1507         """Simple visitor generating strings to compare ASTs by content."""
1508         yield f"{'  ' * depth}{node.__class__.__name__}("
1509
1510         for field in sorted(node._fields):
1511             try:
1512                 value = getattr(node, field)
1513             except AttributeError:
1514                 continue
1515
1516             yield f"{'  ' * (depth+1)}{field}="
1517
1518             if isinstance(value, list):
1519                 for item in value:
1520                     if isinstance(item, ast.AST):
1521                         yield from _v(item, depth + 2)
1522
1523             elif isinstance(value, ast.AST):
1524                 yield from _v(value, depth + 2)
1525
1526             else:
1527                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1528
1529         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1530
1531     try:
1532         src_ast = ast.parse(src)
1533     except Exception as exc:
1534         raise AssertionError(f"cannot parse source: {exc}") from None
1535
1536     try:
1537         dst_ast = ast.parse(dst)
1538     except Exception as exc:
1539         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1540         raise AssertionError(
1541             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1542             f"Please report a bug on https://github.com/ambv/black/issues.  "
1543             f"This invalid output might be helpful: {log}"
1544         ) from None
1545
1546     src_ast_str = '\n'.join(_v(src_ast))
1547     dst_ast_str = '\n'.join(_v(dst_ast))
1548     if src_ast_str != dst_ast_str:
1549         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1550         raise AssertionError(
1551             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1552             f"the source.  "
1553             f"Please report a bug on https://github.com/ambv/black/issues.  "
1554             f"This diff might be helpful: {log}"
1555         ) from None
1556
1557
1558 def assert_stable(src: str, dst: str, line_length: int) -> None:
1559     """Raises AssertionError if `dst` reformats differently the second time.
1560
1561     This is a temporary sanity check until Black becomes stable.
1562     """
1563     newdst = format_str(dst, line_length=line_length)
1564     if dst != newdst:
1565         log = dump_to_file(
1566             diff(src, dst, 'source', 'first pass'),
1567             diff(dst, newdst, 'first pass', 'second pass'),
1568         )
1569         raise AssertionError(
1570             f"INTERNAL ERROR: Black produced different code on the second pass "
1571             f"of the formatter.  "
1572             f"Please report a bug on https://github.com/ambv/black/issues.  "
1573             f"This diff might be helpful: {log}"
1574         ) from None
1575
1576
1577 def dump_to_file(*output: str) -> str:
1578     """Dumps `output` to a temporary file. Returns path to the file."""
1579     import tempfile
1580
1581     with tempfile.NamedTemporaryFile(
1582         mode='w', prefix='blk_', suffix='.log', delete=False
1583     ) as f:
1584         for lines in output:
1585             f.write(lines)
1586             f.write('\n')
1587     return f.name
1588
1589
1590 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1591     """Returns a udiff string between strings `a` and `b`."""
1592     import difflib
1593
1594     a_lines = [line + '\n' for line in a.split('\n')]
1595     b_lines = [line + '\n' for line in b.split('\n')]
1596     return ''.join(
1597         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1598     )
1599
1600
1601 if __name__ == '__main__':
1602     main()