]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Only return exit code 1 when --check is used
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a3"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by `format_file` when the reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
48     """
49
50
51 @click.command()
52 @click.option(
53     '-l',
54     '--line-length',
55     type=int,
56     default=DEFAULT_LINE_LENGTH,
57     help='How many character per line to allow.',
58     show_default=True,
59 )
60 @click.option(
61     '--check',
62     is_flag=True,
63     help=(
64         "Don't write back the files, just return the status.  Return code 0 "
65         "means nothing would change.  Return code 1 means some files would be "
66         "reformatted.  Return code 123 means there was an internal error."
67     ),
68 )
69 @click.option(
70     '--fast/--safe',
71     is_flag=True,
72     help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 )
74 @click.version_option(version=__version__)
75 @click.argument(
76     'src',
77     nargs=-1,
78     type=click.Path(
79         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
80     ),
81 )
82 @click.pass_context
83 def main(
84     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 ) -> None:
86     """The uncompromising code formatter."""
87     sources: List[Path] = []
88     for s in src:
89         p = Path(s)
90         if p.is_dir():
91             sources.extend(gen_python_files_in_dir(p))
92         elif p.is_file():
93             # if a file was explicitly given, we don't care about its extension
94             sources.append(p)
95         elif s == '-':
96             sources.append(Path('-'))
97         else:
98             err(f'invalid path: {s}')
99     if len(sources) == 0:
100         ctx.exit(0)
101     elif len(sources) == 1:
102         p = sources[0]
103         report = Report(check=check)
104         try:
105             if not p.is_file() and str(p) == '-':
106                 changed = format_stdin_to_stdout(
107                     line_length=line_length, fast=fast, write_back=not check
108                 )
109             else:
110                 changed = format_file_in_place(
111                     p, line_length=line_length, fast=fast, write_back=not check
112                 )
113             report.done(p, changed)
114         except Exception as exc:
115             report.failed(p, str(exc))
116         ctx.exit(report.return_code)
117     else:
118         loop = asyncio.get_event_loop()
119         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
120         return_code = 1
121         try:
122             return_code = loop.run_until_complete(
123                 schedule_formatting(
124                     sources, line_length, not check, fast, loop, executor
125                 )
126             )
127         finally:
128             loop.close()
129             ctx.exit(return_code)
130
131
132 async def schedule_formatting(
133     sources: List[Path],
134     line_length: int,
135     write_back: bool,
136     fast: bool,
137     loop: BaseEventLoop,
138     executor: Executor,
139 ) -> int:
140     tasks = {
141         src: loop.run_in_executor(
142             executor, format_file_in_place, src, line_length, fast, write_back
143         )
144         for src in sources
145     }
146     await asyncio.wait(tasks.values())
147     cancelled = []
148     report = Report()
149     for src, task in tasks.items():
150         if not task.done():
151             report.failed(src, 'timed out, cancelling')
152             task.cancel()
153             cancelled.append(task)
154         elif task.exception():
155             report.failed(src, str(task.exception()))
156         else:
157             report.done(src, task.result())
158     if cancelled:
159         await asyncio.wait(cancelled, timeout=2)
160     out('All done! ✨ 🍰 ✨')
161     click.echo(str(report))
162     return report.return_code
163
164
165 def format_file_in_place(
166     src: Path, line_length: int, fast: bool, write_back: bool = False
167 ) -> bool:
168     """Format the file and rewrite if changed. Return True if changed."""
169     with tokenize.open(src) as src_buffer:
170         src_contents = src_buffer.read()
171     try:
172         contents = format_file_contents(
173             src_contents, line_length=line_length, fast=fast
174         )
175     except NothingChanged:
176         return False
177
178     if write_back:
179         with open(src, "w", encoding=src_buffer.encoding) as f:
180             f.write(contents)
181     return True
182
183
184 def format_stdin_to_stdout(
185     line_length: int, fast: bool, write_back: bool = False
186 ) -> bool:
187     """Format file on stdin and pipe output to stdout. Return True if changed."""
188     contents = sys.stdin.read()
189     try:
190         contents = format_file_contents(contents, line_length=line_length, fast=fast)
191         return True
192
193     except NothingChanged:
194         return False
195
196     finally:
197         if write_back:
198             sys.stdout.write(contents)
199
200
201 def format_file_contents(
202     src_contents: str, line_length: int, fast: bool
203 ) -> FileContent:
204     """Reformats a file and returns its contents and encoding."""
205     if src_contents.strip() == '':
206         raise NothingChanged
207
208     dst_contents = format_str(src_contents, line_length=line_length)
209     if src_contents == dst_contents:
210         raise NothingChanged
211
212     if not fast:
213         assert_equivalent(src_contents, dst_contents)
214         assert_stable(src_contents, dst_contents, line_length=line_length)
215     return dst_contents
216
217
218 def format_str(src_contents: str, line_length: int) -> FileContent:
219     """Reformats a string and returns new contents."""
220     src_node = lib2to3_parse(src_contents)
221     dst_contents = ""
222     lines = LineGenerator()
223     elt = EmptyLineTracker()
224     py36 = is_python36(src_node)
225     empty_line = Line()
226     after = 0
227     for current_line in lines.visit(src_node):
228         for _ in range(after):
229             dst_contents += str(empty_line)
230         before, after = elt.maybe_empty_lines(current_line)
231         for _ in range(before):
232             dst_contents += str(empty_line)
233         for line in split_line(current_line, line_length=line_length, py36=py36):
234             dst_contents += str(line)
235     return dst_contents
236
237
238 def lib2to3_parse(src_txt: str) -> Node:
239     """Given a string with source, return the lib2to3 Node."""
240     grammar = pygram.python_grammar_no_print_statement
241     drv = driver.Driver(grammar, pytree.convert)
242     if src_txt[-1] != '\n':
243         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
244         src_txt += nl
245     try:
246         result = drv.parse_string(src_txt, True)
247     except ParseError as pe:
248         lineno, column = pe.context[1]
249         lines = src_txt.splitlines()
250         try:
251             faulty_line = lines[lineno - 1]
252         except IndexError:
253             faulty_line = "<line number missing in source>"
254         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
255
256     if isinstance(result, Leaf):
257         result = Node(syms.file_input, [result])
258     return result
259
260
261 def lib2to3_unparse(node: Node) -> str:
262     """Given a lib2to3 node, return its string representation."""
263     code = str(node)
264     return code
265
266
267 T = TypeVar('T')
268
269
270 class Visitor(Generic[T]):
271     """Basic lib2to3 visitor that yields things on visiting."""
272
273     def visit(self, node: LN) -> Iterator[T]:
274         if node.type < 256:
275             name = token.tok_name[node.type]
276         else:
277             name = type_repr(node.type)
278         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
279
280     def visit_default(self, node: LN) -> Iterator[T]:
281         if isinstance(node, Node):
282             for child in node.children:
283                 yield from self.visit(child)
284
285
286 @dataclass
287 class DebugVisitor(Visitor[T]):
288     tree_depth: int = 0
289
290     def visit_default(self, node: LN) -> Iterator[T]:
291         indent = ' ' * (2 * self.tree_depth)
292         if isinstance(node, Node):
293             _type = type_repr(node.type)
294             out(f'{indent}{_type}', fg='yellow')
295             self.tree_depth += 1
296             for child in node.children:
297                 yield from self.visit(child)
298
299             self.tree_depth -= 1
300             out(f'{indent}/{_type}', fg='yellow', bold=False)
301         else:
302             _type = token.tok_name.get(node.type, str(node.type))
303             out(f'{indent}{_type}', fg='blue', nl=False)
304             if node.prefix:
305                 # We don't have to handle prefixes for `Node` objects since
306                 # that delegates to the first child anyway.
307                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
308             out(f' {node.value!r}', fg='blue', bold=False)
309
310
311 KEYWORDS = set(keyword.kwlist)
312 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
313 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
314 STATEMENT = {
315     syms.if_stmt,
316     syms.while_stmt,
317     syms.for_stmt,
318     syms.try_stmt,
319     syms.except_clause,
320     syms.with_stmt,
321     syms.funcdef,
322     syms.classdef,
323 }
324 STANDALONE_COMMENT = 153
325 LOGIC_OPERATORS = {'and', 'or'}
326 COMPARATORS = {
327     token.LESS,
328     token.GREATER,
329     token.EQEQUAL,
330     token.NOTEQUAL,
331     token.LESSEQUAL,
332     token.GREATEREQUAL,
333 }
334 MATH_OPERATORS = {
335     token.PLUS,
336     token.MINUS,
337     token.STAR,
338     token.SLASH,
339     token.VBAR,
340     token.AMPER,
341     token.PERCENT,
342     token.CIRCUMFLEX,
343     token.TILDE,
344     token.LEFTSHIFT,
345     token.RIGHTSHIFT,
346     token.DOUBLESTAR,
347     token.DOUBLESLASH,
348 }
349 COMPREHENSION_PRIORITY = 20
350 COMMA_PRIORITY = 10
351 LOGIC_PRIORITY = 5
352 STRING_PRIORITY = 4
353 COMPARATOR_PRIORITY = 3
354 MATH_PRIORITY = 1
355
356
357 @dataclass
358 class BracketTracker:
359     depth: int = 0
360     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
361     delimiters: Dict[LeafID, Priority] = Factory(dict)
362     previous: Optional[Leaf] = None
363
364     def mark(self, leaf: Leaf) -> None:
365         if leaf.type == token.COMMENT:
366             return
367
368         if leaf.type in CLOSING_BRACKETS:
369             self.depth -= 1
370             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
371             leaf.opening_bracket = opening_bracket
372         leaf.bracket_depth = self.depth
373         if self.depth == 0:
374             delim = is_delimiter(leaf)
375             if delim:
376                 self.delimiters[id(leaf)] = delim
377             elif self.previous is not None:
378                 if leaf.type == token.STRING and self.previous.type == token.STRING:
379                     self.delimiters[id(self.previous)] = STRING_PRIORITY
380                 elif (
381                     leaf.type == token.NAME
382                     and leaf.value == 'for'
383                     and leaf.parent
384                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
385                 ):
386                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
387                 elif (
388                     leaf.type == token.NAME
389                     and leaf.value == 'if'
390                     and leaf.parent
391                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
392                 ):
393                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
394                 elif (
395                     leaf.type == token.NAME
396                     and leaf.value in LOGIC_OPERATORS
397                     and leaf.parent
398                 ):
399                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
400         if leaf.type in OPENING_BRACKETS:
401             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
402             self.depth += 1
403         self.previous = leaf
404
405     def any_open_brackets(self) -> bool:
406         """Returns True if there is an yet unmatched open bracket on the line."""
407         return bool(self.bracket_match)
408
409     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
410         """Returns the highest priority of a delimiter found on the line.
411
412         Values are consistent with what `is_delimiter()` returns.
413         """
414         return max(v for k, v in self.delimiters.items() if k not in exclude)
415
416
417 @dataclass
418 class Line:
419     depth: int = 0
420     leaves: List[Leaf] = Factory(list)
421     comments: Dict[LeafID, Leaf] = Factory(dict)
422     bracket_tracker: BracketTracker = Factory(BracketTracker)
423     inside_brackets: bool = False
424     has_for: bool = False
425     _for_loop_variable: bool = False
426
427     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
428         has_value = leaf.value.strip()
429         if not has_value:
430             return
431
432         if self.leaves and not preformatted:
433             # Note: at this point leaf.prefix should be empty except for
434             # imports, for which we only preserve newlines.
435             leaf.prefix += whitespace(leaf)
436         if self.inside_brackets or not preformatted:
437             self.maybe_decrement_after_for_loop_variable(leaf)
438             self.bracket_tracker.mark(leaf)
439             self.maybe_remove_trailing_comma(leaf)
440             self.maybe_increment_for_loop_variable(leaf)
441             if self.maybe_adapt_standalone_comment(leaf):
442                 return
443
444         if not self.append_comment(leaf):
445             self.leaves.append(leaf)
446
447     @property
448     def is_comment(self) -> bool:
449         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
450
451     @property
452     def is_decorator(self) -> bool:
453         return bool(self) and self.leaves[0].type == token.AT
454
455     @property
456     def is_import(self) -> bool:
457         return bool(self) and is_import(self.leaves[0])
458
459     @property
460     def is_class(self) -> bool:
461         return (
462             bool(self)
463             and self.leaves[0].type == token.NAME
464             and self.leaves[0].value == 'class'
465         )
466
467     @property
468     def is_def(self) -> bool:
469         """Also returns True for async defs."""
470         try:
471             first_leaf = self.leaves[0]
472         except IndexError:
473             return False
474
475         try:
476             second_leaf: Optional[Leaf] = self.leaves[1]
477         except IndexError:
478             second_leaf = None
479         return (
480             (first_leaf.type == token.NAME and first_leaf.value == 'def')
481             or (
482                 first_leaf.type == token.ASYNC
483                 and second_leaf is not None
484                 and second_leaf.type == token.NAME
485                 and second_leaf.value == 'def'
486             )
487         )
488
489     @property
490     def is_flow_control(self) -> bool:
491         return (
492             bool(self)
493             and self.leaves[0].type == token.NAME
494             and self.leaves[0].value in FLOW_CONTROL
495         )
496
497     @property
498     def is_yield(self) -> bool:
499         return (
500             bool(self)
501             and self.leaves[0].type == token.NAME
502             and self.leaves[0].value == 'yield'
503         )
504
505     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
506         if not (
507             self.leaves
508             and self.leaves[-1].type == token.COMMA
509             and closing.type in CLOSING_BRACKETS
510         ):
511             return False
512
513         if closing.type == token.RBRACE:
514             self.leaves.pop()
515             return True
516
517         if closing.type == token.RSQB:
518             comma = self.leaves[-1]
519             if comma.parent and comma.parent.type == syms.listmaker:
520                 self.leaves.pop()
521                 return True
522
523         # For parens let's check if it's safe to remove the comma.  If the
524         # trailing one is the only one, we might mistakenly change a tuple
525         # into a different type by removing the comma.
526         depth = closing.bracket_depth + 1
527         commas = 0
528         opening = closing.opening_bracket
529         for _opening_index, leaf in enumerate(self.leaves):
530             if leaf is opening:
531                 break
532
533         else:
534             return False
535
536         for leaf in self.leaves[_opening_index + 1:]:
537             if leaf is closing:
538                 break
539
540             bracket_depth = leaf.bracket_depth
541             if bracket_depth == depth and leaf.type == token.COMMA:
542                 commas += 1
543                 if leaf.parent and leaf.parent.type == syms.arglist:
544                     commas += 1
545                     break
546
547         if commas > 1:
548             self.leaves.pop()
549             return True
550
551         return False
552
553     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
554         """In a for loop, or comprehension, the variables are often unpacks.
555
556         To avoid splitting on the comma in this situation, we will increase
557         the depth of tokens between `for` and `in`.
558         """
559         if leaf.type == token.NAME and leaf.value == 'for':
560             self.has_for = True
561             self.bracket_tracker.depth += 1
562             self._for_loop_variable = True
563             return True
564
565         return False
566
567     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
568         # See `maybe_increment_for_loop_variable` above for explanation.
569         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
570             self.bracket_tracker.depth -= 1
571             self._for_loop_variable = False
572             return True
573
574         return False
575
576     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
577         """Hack a standalone comment to act as a trailing comment for line splitting.
578
579         If this line has brackets and a standalone `comment`, we need to adapt
580         it to be able to still reformat the line.
581
582         This is not perfect, the line to which the standalone comment gets
583         appended will appear "too long" when splitting.
584         """
585         if not (
586             comment.type == STANDALONE_COMMENT
587             and self.bracket_tracker.any_open_brackets()
588         ):
589             return False
590
591         comment.type = token.COMMENT
592         comment.prefix = '\n' + '    ' * (self.depth + 1)
593         return self.append_comment(comment)
594
595     def append_comment(self, comment: Leaf) -> bool:
596         if comment.type != token.COMMENT:
597             return False
598
599         try:
600             after = id(self.last_non_delimiter())
601         except LookupError:
602             comment.type = STANDALONE_COMMENT
603             comment.prefix = ''
604             return False
605
606         else:
607             if after in self.comments:
608                 self.comments[after].value += str(comment)
609             else:
610                 self.comments[after] = comment
611             return True
612
613     def last_non_delimiter(self) -> Leaf:
614         for i in range(len(self.leaves)):
615             last = self.leaves[-i - 1]
616             if not is_delimiter(last):
617                 return last
618
619         raise LookupError("No non-delimiters found")
620
621     def __str__(self) -> str:
622         if not self:
623             return '\n'
624
625         indent = '    ' * self.depth
626         leaves = iter(self.leaves)
627         first = next(leaves)
628         res = f'{first.prefix}{indent}{first.value}'
629         for leaf in leaves:
630             res += str(leaf)
631         for comment in self.comments.values():
632             res += str(comment)
633         return res + '\n'
634
635     def __bool__(self) -> bool:
636         return bool(self.leaves or self.comments)
637
638
639 @dataclass
640 class EmptyLineTracker:
641     """Provides a stateful method that returns the number of potential extra
642     empty lines needed before and after the currently processed line.
643
644     Note: this tracker works on lines that haven't been split yet.  It assumes
645     the prefix of the first leaf consists of optional newlines.  Those newlines
646     are consumed by `maybe_empty_lines()` and included in the computation.
647     """
648     previous_line: Optional[Line] = None
649     previous_after: int = 0
650     previous_defs: List[int] = Factory(list)
651
652     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
653         """Returns the number of extra empty lines before and after the `current_line`.
654
655         This is for separating `def`, `async def` and `class` with extra empty lines
656         (two on module-level), as well as providing an extra empty line after flow
657         control keywords to make them more prominent.
658         """
659         before, after = self._maybe_empty_lines(current_line)
660         before -= self.previous_after
661         self.previous_after = after
662         self.previous_line = current_line
663         return before, after
664
665     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
666         max_allowed = 1
667         if current_line.is_comment and current_line.depth == 0:
668             max_allowed = 2
669         if current_line.leaves:
670             # Consume the first leaf's extra newlines.
671             first_leaf = current_line.leaves[0]
672             before = first_leaf.prefix.count('\n')
673             before = min(before, max(before, max_allowed))
674             first_leaf.prefix = ''
675         else:
676             before = 0
677         depth = current_line.depth
678         while self.previous_defs and self.previous_defs[-1] >= depth:
679             self.previous_defs.pop()
680             before = 1 if depth else 2
681         is_decorator = current_line.is_decorator
682         if is_decorator or current_line.is_def or current_line.is_class:
683             if not is_decorator:
684                 self.previous_defs.append(depth)
685             if self.previous_line is None:
686                 # Don't insert empty lines before the first line in the file.
687                 return 0, 0
688
689             if self.previous_line and self.previous_line.is_decorator:
690                 # Don't insert empty lines between decorators.
691                 return 0, 0
692
693             newlines = 2
694             if current_line.depth:
695                 newlines -= 1
696             return newlines, 0
697
698         if current_line.is_flow_control:
699             return before, 1
700
701         if (
702             self.previous_line
703             and self.previous_line.is_import
704             and not current_line.is_import
705             and depth == self.previous_line.depth
706         ):
707             return (before or 1), 0
708
709         if (
710             self.previous_line
711             and self.previous_line.is_yield
712             and (not current_line.is_yield or depth != self.previous_line.depth)
713         ):
714             return (before or 1), 0
715
716         return before, 0
717
718
719 @dataclass
720 class LineGenerator(Visitor[Line]):
721     """Generates reformatted Line objects.  Empty lines are not emitted.
722
723     Note: destroys the tree it's visiting by mutating prefixes of its leaves
724     in ways that will no longer stringify to valid Python code on the tree.
725     """
726     current_line: Line = Factory(Line)
727
728     def line(self, indent: int = 0) -> Iterator[Line]:
729         """Generate a line.
730
731         If the line is empty, only emit if it makes sense.
732         If the line is too long, split it first and then generate.
733
734         If any lines were generated, set up a new current_line.
735         """
736         if not self.current_line:
737             self.current_line.depth += indent
738             return  # Line is empty, don't emit. Creating a new one unnecessary.
739
740         complete_line = self.current_line
741         self.current_line = Line(depth=complete_line.depth + indent)
742         yield complete_line
743
744     def visit_default(self, node: LN) -> Iterator[Line]:
745         if isinstance(node, Leaf):
746             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
747             for comment in generate_comments(node):
748                 if any_open_brackets:
749                     # any comment within brackets is subject to splitting
750                     self.current_line.append(comment)
751                 elif comment.type == token.COMMENT:
752                     # regular trailing comment
753                     self.current_line.append(comment)
754                     yield from self.line()
755
756                 else:
757                     # regular standalone comment
758                     yield from self.line()
759
760                     self.current_line.append(comment)
761                     yield from self.line()
762
763             normalize_prefix(node, inside_brackets=any_open_brackets)
764             if node.type not in WHITESPACE:
765                 self.current_line.append(node)
766         yield from super().visit_default(node)
767
768     def visit_INDENT(self, node: Node) -> Iterator[Line]:
769         yield from self.line(+1)
770         yield from self.visit_default(node)
771
772     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
773         yield from self.line(-1)
774
775     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
776         """Visit a statement.
777
778         The relevant Python language keywords for this statement are NAME leaves
779         within it.
780         """
781         for child in node.children:
782             if child.type == token.NAME and child.value in keywords:  # type: ignore
783                 yield from self.line()
784
785             yield from self.visit(child)
786
787     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
788         """A statement without nested statements."""
789         is_suite_like = node.parent and node.parent.type in STATEMENT
790         if is_suite_like:
791             yield from self.line(+1)
792             yield from self.visit_default(node)
793             yield from self.line(-1)
794
795         else:
796             yield from self.line()
797             yield from self.visit_default(node)
798
799     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
800         yield from self.line()
801
802         children = iter(node.children)
803         for child in children:
804             yield from self.visit(child)
805
806             if child.type == token.ASYNC:
807                 break
808
809         internal_stmt = next(children)
810         for child in internal_stmt.children:
811             yield from self.visit(child)
812
813     def visit_decorators(self, node: Node) -> Iterator[Line]:
814         for child in node.children:
815             yield from self.line()
816             yield from self.visit(child)
817
818     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
819         yield from self.line()
820
821     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
822         yield from self.visit_default(leaf)
823         yield from self.line()
824
825     def __attrs_post_init__(self) -> None:
826         """You are in a twisty little maze of passages."""
827         v = self.visit_stmt
828         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
829         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
830         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
831         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
832         self.visit_except_clause = partial(v, keywords={'except'})
833         self.visit_funcdef = partial(v, keywords={'def'})
834         self.visit_with_stmt = partial(v, keywords={'with'})
835         self.visit_classdef = partial(v, keywords={'class'})
836         self.visit_async_funcdef = self.visit_async_stmt
837         self.visit_decorated = self.visit_decorators
838
839
840 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
841 OPENING_BRACKETS = set(BRACKET.keys())
842 CLOSING_BRACKETS = set(BRACKET.values())
843 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
844 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
845
846
847 def whitespace(leaf: Leaf) -> str:  # noqa C901
848     """Return whitespace prefix if needed for the given `leaf`."""
849     NO = ''
850     SPACE = ' '
851     DOUBLESPACE = '  '
852     t = leaf.type
853     p = leaf.parent
854     v = leaf.value
855     if t in ALWAYS_NO_SPACE:
856         return NO
857
858     if t == token.COMMENT:
859         return DOUBLESPACE
860
861     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
862     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
863         return NO
864
865     prev = leaf.prev_sibling
866     if not prev:
867         prevp = preceding_leaf(p)
868         if not prevp or prevp.type in OPENING_BRACKETS:
869             return NO
870
871         if t == token.COLON:
872             return SPACE if prevp.type == token.COMMA else NO
873
874         if prevp.type == token.EQUAL:
875             if prevp.parent and prevp.parent.type in {
876                 syms.arglist,
877                 syms.argument,
878                 syms.parameters,
879                 syms.typedargslist,
880                 syms.varargslist,
881             }:
882                 return NO
883
884         elif prevp.type == token.DOUBLESTAR:
885             if prevp.parent and prevp.parent.type in {
886                 syms.arglist,
887                 syms.argument,
888                 syms.dictsetmaker,
889                 syms.parameters,
890                 syms.typedargslist,
891                 syms.varargslist,
892             }:
893                 return NO
894
895         elif prevp.type == token.COLON:
896             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
897                 return NO
898
899         elif (
900             prevp.parent
901             and prevp.parent.type in {syms.factor, syms.star_expr}
902             and prevp.type in MATH_OPERATORS
903         ):
904             return NO
905
906     elif prev.type in OPENING_BRACKETS:
907         return NO
908
909     if p.type in {syms.parameters, syms.arglist}:
910         # untyped function signatures or calls
911         if t == token.RPAR:
912             return NO
913
914         if not prev or prev.type != token.COMMA:
915             return NO
916
917     if p.type == syms.varargslist:
918         # lambdas
919         if t == token.RPAR:
920             return NO
921
922         if prev and prev.type != token.COMMA:
923             return NO
924
925     elif p.type == syms.typedargslist:
926         # typed function signatures
927         if not prev:
928             return NO
929
930         if t == token.EQUAL:
931             if prev.type != syms.tname:
932                 return NO
933
934         elif prev.type == token.EQUAL:
935             # A bit hacky: if the equal sign has whitespace, it means we
936             # previously found it's a typed argument.  So, we're using that, too.
937             return prev.prefix
938
939         elif prev.type != token.COMMA:
940             return NO
941
942     elif p.type == syms.tname:
943         # type names
944         if not prev:
945             prevp = preceding_leaf(p)
946             if not prevp or prevp.type != token.COMMA:
947                 return NO
948
949     elif p.type == syms.trailer:
950         # attributes and calls
951         if t == token.LPAR or t == token.RPAR:
952             return NO
953
954         if not prev:
955             if t == token.DOT:
956                 prevp = preceding_leaf(p)
957                 if not prevp or prevp.type != token.NUMBER:
958                     return NO
959
960             elif t == token.LSQB:
961                 return NO
962
963         elif prev.type != token.COMMA:
964             return NO
965
966     elif p.type == syms.argument:
967         # single argument
968         if t == token.EQUAL:
969             return NO
970
971         if not prev:
972             prevp = preceding_leaf(p)
973             if not prevp or prevp.type == token.LPAR:
974                 return NO
975
976         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
977             return NO
978
979     elif p.type == syms.decorator:
980         # decorators
981         return NO
982
983     elif p.type == syms.dotted_name:
984         if prev:
985             return NO
986
987         prevp = preceding_leaf(p)
988         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
989             return NO
990
991     elif p.type == syms.classdef:
992         if t == token.LPAR:
993             return NO
994
995         if prev and prev.type == token.LPAR:
996             return NO
997
998     elif p.type == syms.subscript:
999         # indexing
1000         if not prev:
1001             assert p.parent is not None, "subscripts are always parented"
1002             if p.parent.type == syms.subscriptlist:
1003                 return SPACE
1004
1005             return NO
1006
1007         else:
1008             return NO
1009
1010     elif p.type == syms.atom:
1011         if prev and t == token.DOT:
1012             # dots, but not the first one.
1013             return NO
1014
1015     elif (
1016         p.type == syms.listmaker
1017         or p.type == syms.testlist_gexp
1018         or p.type == syms.subscriptlist
1019     ):
1020         # list interior, including unpacking
1021         if not prev:
1022             return NO
1023
1024     elif p.type == syms.dictsetmaker:
1025         # dict and set interior, including unpacking
1026         if not prev:
1027             return NO
1028
1029         if prev.type == token.DOUBLESTAR:
1030             return NO
1031
1032     elif p.type in {syms.factor, syms.star_expr}:
1033         # unary ops
1034         if not prev:
1035             prevp = preceding_leaf(p)
1036             if not prevp or prevp.type in OPENING_BRACKETS:
1037                 return NO
1038
1039             prevp_parent = prevp.parent
1040             assert prevp_parent is not None
1041             if prevp.type == token.COLON and prevp_parent.type in {
1042                 syms.subscript, syms.sliceop
1043             }:
1044                 return NO
1045
1046             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1047                 return NO
1048
1049         elif t == token.NAME or t == token.NUMBER:
1050             return NO
1051
1052     elif p.type == syms.import_from:
1053         if t == token.DOT:
1054             if prev and prev.type == token.DOT:
1055                 return NO
1056
1057         elif t == token.NAME:
1058             if v == 'import':
1059                 return SPACE
1060
1061             if prev and prev.type == token.DOT:
1062                 return NO
1063
1064     elif p.type == syms.sliceop:
1065         return NO
1066
1067     return SPACE
1068
1069
1070 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1071     """Returns the first leaf that precedes `node`, if any."""
1072     while node:
1073         res = node.prev_sibling
1074         if res:
1075             if isinstance(res, Leaf):
1076                 return res
1077
1078             try:
1079                 return list(res.leaves())[-1]
1080
1081             except IndexError:
1082                 return None
1083
1084         node = node.parent
1085     return None
1086
1087
1088 def is_delimiter(leaf: Leaf) -> int:
1089     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1090
1091     Higher numbers are higher priority.
1092     """
1093     if leaf.type == token.COMMA:
1094         return COMMA_PRIORITY
1095
1096     if leaf.type in COMPARATORS:
1097         return COMPARATOR_PRIORITY
1098
1099     if (
1100         leaf.type in MATH_OPERATORS
1101         and leaf.parent
1102         and leaf.parent.type not in {syms.factor, syms.star_expr}
1103     ):
1104         return MATH_PRIORITY
1105
1106     return 0
1107
1108
1109 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1110     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1111
1112     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1113     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1114     move because it does away with modifying the grammar to include all the
1115     possible places in which comments can be placed.
1116
1117     The sad consequence for us though is that comments don't "belong" anywhere.
1118     This is why this function generates simple parentless Leaf objects for
1119     comments.  We simply don't know what the correct parent should be.
1120
1121     No matter though, we can live without this.  We really only need to
1122     differentiate between inline and standalone comments.  The latter don't
1123     share the line with any code.
1124
1125     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1126     are emitted with a fake STANDALONE_COMMENT token identifier.
1127     """
1128     p = leaf.prefix
1129     if not p:
1130         return
1131
1132     if '#' not in p:
1133         return
1134
1135     nlines = 0
1136     for index, line in enumerate(p.split('\n')):
1137         line = line.lstrip()
1138         if not line:
1139             nlines += 1
1140         if not line.startswith('#'):
1141             continue
1142
1143         if index == 0 and leaf.type != token.ENDMARKER:
1144             comment_type = token.COMMENT  # simple trailing comment
1145         else:
1146             comment_type = STANDALONE_COMMENT
1147         yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
1148
1149         nlines = 0
1150
1151
1152 def make_comment(content: str) -> str:
1153     content = content.rstrip()
1154     if not content:
1155         return '#'
1156
1157     if content[0] == '#':
1158         content = content[1:]
1159     if content and content[0] not in {' ', '!', '#'}:
1160         content = ' ' + content
1161     return '#' + content
1162
1163
1164 def split_line(
1165     line: Line, line_length: int, inner: bool = False, py36: bool = False
1166 ) -> Iterator[Line]:
1167     """Splits a `line` into potentially many lines.
1168
1169     They should fit in the allotted `line_length` but might not be able to.
1170     `inner` signifies that there were a pair of brackets somewhere around the
1171     current `line`, possibly transitively. This means we can fallback to splitting
1172     by delimiters if the LHS/RHS don't yield any results.
1173
1174     If `py36` is True, splitting may generate syntax that is only compatible
1175     with Python 3.6 and later.
1176     """
1177     line_str = str(line).strip('\n')
1178     if len(line_str) <= line_length and '\n' not in line_str:
1179         yield line
1180         return
1181
1182     if line.is_def:
1183         split_funcs = [left_hand_split]
1184     elif line.inside_brackets:
1185         split_funcs = [delimiter_split]
1186         if '\n' not in line_str:
1187             # Only attempt RHS if we don't have multiline strings or comments
1188             # on this line.
1189             split_funcs.append(right_hand_split)
1190     else:
1191         split_funcs = [right_hand_split]
1192     for split_func in split_funcs:
1193         # We are accumulating lines in `result` because we might want to abort
1194         # mission and return the original line in the end, or attempt a different
1195         # split altogether.
1196         result: List[Line] = []
1197         try:
1198             for l in split_func(line, py36=py36):
1199                 if str(l).strip('\n') == line_str:
1200                     raise CannotSplit("Split function returned an unchanged result")
1201
1202                 result.extend(
1203                     split_line(l, line_length=line_length, inner=True, py36=py36)
1204                 )
1205         except CannotSplit as cs:
1206             continue
1207
1208         else:
1209             yield from result
1210             break
1211
1212     else:
1213         yield line
1214
1215
1216 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1217     """Split line into many lines, starting with the first matching bracket pair.
1218
1219     Note: this usually looks weird, only use this for function definitions.
1220     Prefer RHS otherwise.
1221     """
1222     head = Line(depth=line.depth)
1223     body = Line(depth=line.depth + 1, inside_brackets=True)
1224     tail = Line(depth=line.depth)
1225     tail_leaves: List[Leaf] = []
1226     body_leaves: List[Leaf] = []
1227     head_leaves: List[Leaf] = []
1228     current_leaves = head_leaves
1229     matching_bracket = None
1230     for leaf in line.leaves:
1231         if (
1232             current_leaves is body_leaves
1233             and leaf.type in CLOSING_BRACKETS
1234             and leaf.opening_bracket is matching_bracket
1235         ):
1236             current_leaves = tail_leaves if body_leaves else head_leaves
1237         current_leaves.append(leaf)
1238         if current_leaves is head_leaves:
1239             if leaf.type in OPENING_BRACKETS:
1240                 matching_bracket = leaf
1241                 current_leaves = body_leaves
1242     # Since body is a new indent level, remove spurious leading whitespace.
1243     if body_leaves:
1244         normalize_prefix(body_leaves[0], inside_brackets=True)
1245     # Build the new lines.
1246     for result, leaves in (
1247         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1248     ):
1249         for leaf in leaves:
1250             result.append(leaf, preformatted=True)
1251             comment_after = line.comments.get(id(leaf))
1252             if comment_after:
1253                 result.append(comment_after, preformatted=True)
1254     split_succeeded_or_raise(head, body, tail)
1255     for result in (head, body, tail):
1256         if result:
1257             yield result
1258
1259
1260 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1261     """Split line into many lines, starting with the last matching bracket pair."""
1262     head = Line(depth=line.depth)
1263     body = Line(depth=line.depth + 1, inside_brackets=True)
1264     tail = Line(depth=line.depth)
1265     tail_leaves: List[Leaf] = []
1266     body_leaves: List[Leaf] = []
1267     head_leaves: List[Leaf] = []
1268     current_leaves = tail_leaves
1269     opening_bracket = None
1270     for leaf in reversed(line.leaves):
1271         if current_leaves is body_leaves:
1272             if leaf is opening_bracket:
1273                 current_leaves = head_leaves if body_leaves else tail_leaves
1274         current_leaves.append(leaf)
1275         if current_leaves is tail_leaves:
1276             if leaf.type in CLOSING_BRACKETS:
1277                 opening_bracket = leaf.opening_bracket
1278                 current_leaves = body_leaves
1279     tail_leaves.reverse()
1280     body_leaves.reverse()
1281     head_leaves.reverse()
1282     # Since body is a new indent level, remove spurious leading whitespace.
1283     if body_leaves:
1284         normalize_prefix(body_leaves[0], inside_brackets=True)
1285     # Build the new lines.
1286     for result, leaves in (
1287         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1288     ):
1289         for leaf in leaves:
1290             result.append(leaf, preformatted=True)
1291             comment_after = line.comments.get(id(leaf))
1292             if comment_after:
1293                 result.append(comment_after, preformatted=True)
1294     split_succeeded_or_raise(head, body, tail)
1295     for result in (head, body, tail):
1296         if result:
1297             yield result
1298
1299
1300 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1301     tail_len = len(str(tail).strip())
1302     if not body:
1303         if tail_len == 0:
1304             raise CannotSplit("Splitting brackets produced the same line")
1305
1306         elif tail_len < 3:
1307             raise CannotSplit(
1308                 f"Splitting brackets on an empty body to save "
1309                 f"{tail_len} characters is not worth it"
1310             )
1311
1312
1313 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1314     """Split according to delimiters of the highest priority.
1315
1316     This kind of split doesn't increase indentation.
1317     If `py36` is True, the split will add trailing commas also in function
1318     signatures that contain * and **.
1319     """
1320     try:
1321         last_leaf = line.leaves[-1]
1322     except IndexError:
1323         raise CannotSplit("Line empty")
1324
1325     delimiters = line.bracket_tracker.delimiters
1326     try:
1327         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1328     except ValueError:
1329         raise CannotSplit("No delimiters found")
1330
1331     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1332     lowest_depth = sys.maxsize
1333     trailing_comma_safe = True
1334     for leaf in line.leaves:
1335         current_line.append(leaf, preformatted=True)
1336         comment_after = line.comments.get(id(leaf))
1337         if comment_after:
1338             current_line.append(comment_after, preformatted=True)
1339         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1340         if (
1341             leaf.bracket_depth == lowest_depth
1342             and leaf.type == token.STAR
1343             or leaf.type == token.DOUBLESTAR
1344         ):
1345             trailing_comma_safe = trailing_comma_safe and py36
1346         leaf_priority = delimiters.get(id(leaf))
1347         if leaf_priority == delimiter_priority:
1348             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1349             yield current_line
1350
1351             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1352     if current_line:
1353         if (
1354             delimiter_priority == COMMA_PRIORITY
1355             and current_line.leaves[-1].type != token.COMMA
1356             and trailing_comma_safe
1357         ):
1358             current_line.append(Leaf(token.COMMA, ','))
1359         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1360         yield current_line
1361
1362
1363 def is_import(leaf: Leaf) -> bool:
1364     """Returns True if the given leaf starts an import statement."""
1365     p = leaf.parent
1366     t = leaf.type
1367     v = leaf.value
1368     return bool(
1369         t == token.NAME
1370         and (
1371             (v == 'import' and p and p.type == syms.import_name)
1372             or (v == 'from' and p and p.type == syms.import_from)
1373         )
1374     )
1375
1376
1377 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1378     """Leave existing extra newlines if not `inside_brackets`.
1379
1380     Remove everything else.  Note: don't use backslashes for formatting or
1381     you'll lose your voting rights.
1382     """
1383     if not inside_brackets:
1384         spl = leaf.prefix.split('#')
1385         if '\\' not in spl[0]:
1386             nl_count = spl[-1].count('\n')
1387             if len(spl) > 1:
1388                 nl_count -= 1
1389             leaf.prefix = '\n' * nl_count
1390             return
1391
1392     leaf.prefix = ''
1393
1394
1395 def is_python36(node: Node) -> bool:
1396     """Returns True if the current file is using Python 3.6+ features.
1397
1398     Currently looking for:
1399     - f-strings; and
1400     - trailing commas after * or ** in function signatures.
1401     """
1402     for n in node.pre_order():
1403         if n.type == token.STRING:
1404             value_head = n.value[:2]  # type: ignore
1405             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1406                 return True
1407
1408         elif (
1409             n.type == syms.typedargslist
1410             and n.children
1411             and n.children[-1].type == token.COMMA
1412         ):
1413             for ch in n.children:
1414                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1415                     return True
1416
1417     return False
1418
1419
1420 PYTHON_EXTENSIONS = {'.py'}
1421 BLACKLISTED_DIRECTORIES = {
1422     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1423 }
1424
1425
1426 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1427     for child in path.iterdir():
1428         if child.is_dir():
1429             if child.name in BLACKLISTED_DIRECTORIES:
1430                 continue
1431
1432             yield from gen_python_files_in_dir(child)
1433
1434         elif child.suffix in PYTHON_EXTENSIONS:
1435             yield child
1436
1437
1438 @dataclass
1439 class Report:
1440     """Provides a reformatting counter."""
1441     check: bool = False
1442     change_count: int = 0
1443     same_count: int = 0
1444     failure_count: int = 0
1445
1446     def done(self, src: Path, changed: bool) -> None:
1447         """Increment the counter for successful reformatting. Write out a message."""
1448         if changed:
1449             reformatted = 'would reformat' if self.check else 'reformatted'
1450             out(f'{reformatted} {src}')
1451             self.change_count += 1
1452         else:
1453             out(f'{src} already well formatted, good job.', bold=False)
1454             self.same_count += 1
1455
1456     def failed(self, src: Path, message: str) -> None:
1457         """Increment the counter for failed reformatting. Write out a message."""
1458         err(f'error: cannot format {src}: {message}')
1459         self.failure_count += 1
1460
1461     @property
1462     def return_code(self) -> int:
1463         """Which return code should the app use considering the current state."""
1464         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1465         # 126 we have special returncodes reserved by the shell.
1466         if self.failure_count:
1467             return 123
1468
1469         elif self.change_count and self.check:
1470             return 1
1471
1472         return 0
1473
1474     def __str__(self) -> str:
1475         """A color report of the current state.
1476
1477         Use `click.unstyle` to remove colors.
1478         """
1479         if self.check:
1480             reformatted = "would be reformatted"
1481             unchanged = "would be left unchanged"
1482             failed = "would fail to reformat"
1483         else:
1484             reformatted = "reformatted"
1485             unchanged = "left unchanged"
1486             failed = "failed to reformat"
1487         report = []
1488         if self.change_count:
1489             s = 's' if self.change_count > 1 else ''
1490             report.append(
1491                 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1492             )
1493         if self.same_count:
1494             s = 's' if self.same_count > 1 else ''
1495             report.append(f'{self.same_count} file{s} {unchanged}')
1496         if self.failure_count:
1497             s = 's' if self.failure_count > 1 else ''
1498             report.append(
1499                 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1500             )
1501         return ', '.join(report) + '.'
1502
1503
1504 def assert_equivalent(src: str, dst: str) -> None:
1505     """Raises AssertionError if `src` and `dst` aren't equivalent.
1506
1507     This is a temporary sanity check until Black becomes stable.
1508     """
1509
1510     import ast
1511     import traceback
1512
1513     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1514         """Simple visitor generating strings to compare ASTs by content."""
1515         yield f"{'  ' * depth}{node.__class__.__name__}("
1516
1517         for field in sorted(node._fields):
1518             try:
1519                 value = getattr(node, field)
1520             except AttributeError:
1521                 continue
1522
1523             yield f"{'  ' * (depth+1)}{field}="
1524
1525             if isinstance(value, list):
1526                 for item in value:
1527                     if isinstance(item, ast.AST):
1528                         yield from _v(item, depth + 2)
1529
1530             elif isinstance(value, ast.AST):
1531                 yield from _v(value, depth + 2)
1532
1533             else:
1534                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1535
1536         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1537
1538     try:
1539         src_ast = ast.parse(src)
1540     except Exception as exc:
1541         raise AssertionError(f"cannot parse source: {exc}") from None
1542
1543     try:
1544         dst_ast = ast.parse(dst)
1545     except Exception as exc:
1546         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1547         raise AssertionError(
1548             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1549             f"Please report a bug on https://github.com/ambv/black/issues.  "
1550             f"This invalid output might be helpful: {log}"
1551         ) from None
1552
1553     src_ast_str = '\n'.join(_v(src_ast))
1554     dst_ast_str = '\n'.join(_v(dst_ast))
1555     if src_ast_str != dst_ast_str:
1556         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1557         raise AssertionError(
1558             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1559             f"the source.  "
1560             f"Please report a bug on https://github.com/ambv/black/issues.  "
1561             f"This diff might be helpful: {log}"
1562         ) from None
1563
1564
1565 def assert_stable(src: str, dst: str, line_length: int) -> None:
1566     """Raises AssertionError if `dst` reformats differently the second time.
1567
1568     This is a temporary sanity check until Black becomes stable.
1569     """
1570     newdst = format_str(dst, line_length=line_length)
1571     if dst != newdst:
1572         log = dump_to_file(
1573             diff(src, dst, 'source', 'first pass'),
1574             diff(dst, newdst, 'first pass', 'second pass'),
1575         )
1576         raise AssertionError(
1577             f"INTERNAL ERROR: Black produced different code on the second pass "
1578             f"of the formatter.  "
1579             f"Please report a bug on https://github.com/ambv/black/issues.  "
1580             f"This diff might be helpful: {log}"
1581         ) from None
1582
1583
1584 def dump_to_file(*output: str) -> str:
1585     """Dumps `output` to a temporary file. Returns path to the file."""
1586     import tempfile
1587
1588     with tempfile.NamedTemporaryFile(
1589         mode='w', prefix='blk_', suffix='.log', delete=False
1590     ) as f:
1591         for lines in output:
1592             f.write(lines)
1593             f.write('\n')
1594     return f.name
1595
1596
1597 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1598     """Returns a udiff string between strings `a` and `b`."""
1599     import difflib
1600
1601     a_lines = [line + '\n' for line in a.split('\n')]
1602     b_lines = [line + '\n' for line in b.split('\n')]
1603     return ''.join(
1604         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1605     )
1606
1607
1608 if __name__ == '__main__':
1609     main()