]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

203fbfa9055be595a358e94bd1c16b9a3911f23f
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(
78         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
79     ),
80 )
81 @click.pass_context
82 def main(
83     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
84 ) -> None:
85     """The uncompromising code formatter."""
86     sources: List[Path] = []
87     for s in src:
88         p = Path(s)
89         if p.is_dir():
90             sources.extend(gen_python_files_in_dir(p))
91         elif p.is_file():
92             # if a file was explicitly given, we don't care about its extension
93             sources.append(p)
94         elif s == '-':
95             sources.append(Path('-'))
96         else:
97             err(f'invalid path: {s}')
98     if len(sources) == 0:
99         ctx.exit(0)
100     elif len(sources) == 1:
101         p = sources[0]
102         report = Report()
103         try:
104             if not p.is_file() and str(p) == '-':
105                 changed = format_stdin_to_stdout(line_length=line_length, fast=fast)
106             else:
107                 changed = format_file_in_place(
108                     p, line_length=line_length, fast=fast, write_back=not check
109                 )
110             report.done(p, changed)
111         except Exception as exc:
112             report.failed(p, str(exc))
113         ctx.exit(report.return_code)
114     else:
115         loop = asyncio.get_event_loop()
116         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
117         return_code = 1
118         try:
119             return_code = loop.run_until_complete(
120                 schedule_formatting(
121                     sources, line_length, not check, fast, loop, executor
122                 )
123             )
124         finally:
125             loop.close()
126             ctx.exit(return_code)
127
128
129 async def schedule_formatting(
130     sources: List[Path],
131     line_length: int,
132     write_back: bool,
133     fast: bool,
134     loop: BaseEventLoop,
135     executor: Executor,
136 ) -> int:
137     tasks = {
138         src: loop.run_in_executor(
139             executor, format_file_in_place, src, line_length, fast, write_back
140         )
141         for src in sources
142     }
143     await asyncio.wait(tasks.values())
144     cancelled = []
145     report = Report()
146     for src, task in tasks.items():
147         if not task.done():
148             report.failed(src, 'timed out, cancelling')
149             task.cancel()
150             cancelled.append(task)
151         elif task.exception():
152             report.failed(src, str(task.exception()))
153         else:
154             report.done(src, task.result())
155     if cancelled:
156         await asyncio.wait(cancelled, timeout=2)
157     out('All done! ✨ 🍰 ✨')
158     click.echo(str(report))
159     return report.return_code
160
161
162 def format_file_in_place(
163     src: Path, line_length: int, fast: bool, write_back: bool = False
164 ) -> bool:
165     """Format the file and rewrite if changed. Return True if changed."""
166     with tokenize.open(src) as src_buffer:
167         src_contents = src_buffer.read()
168     try:
169         contents = format_file_contents(
170             src_contents, line_length=line_length, fast=fast
171         )
172     except NothingChanged:
173         return False
174
175     if write_back:
176         with open(src, "w", encoding=src_buffer.encoding) as f:
177             f.write(contents)
178     return True
179
180
181 def format_stdin_to_stdout(line_length: int, fast: bool) -> bool:
182     """Format file on stdin and pipe output to stdout. Return True if changed."""
183     contents = sys.stdin.read()
184     try:
185         contents = format_file_contents(contents, line_length=line_length, fast=fast)
186         return True
187
188     except NothingChanged:
189         return False
190
191     finally:
192         sys.stdout.write(contents)
193
194
195 def format_file_contents(
196     src_contents: str, line_length: int, fast: bool
197 ) -> FileContent:
198     """Reformats a file and returns its contents and encoding."""
199     if src_contents.strip() == '':
200         raise NothingChanged
201
202     dst_contents = format_str(src_contents, line_length=line_length)
203     if src_contents == dst_contents:
204         raise NothingChanged
205
206     if not fast:
207         assert_equivalent(src_contents, dst_contents)
208         assert_stable(src_contents, dst_contents, line_length=line_length)
209     return dst_contents
210
211
212 def format_str(src_contents: str, line_length: int) -> FileContent:
213     """Reformats a string and returns new contents."""
214     src_node = lib2to3_parse(src_contents)
215     dst_contents = ""
216     comments: List[Line] = []
217     lines = LineGenerator()
218     elt = EmptyLineTracker()
219     py36 = is_python36(src_node)
220     empty_line = Line()
221     after = 0
222     for current_line in lines.visit(src_node):
223         for _ in range(after):
224             dst_contents += str(empty_line)
225         before, after = elt.maybe_empty_lines(current_line)
226         for _ in range(before):
227             dst_contents += str(empty_line)
228         if not current_line.is_comment:
229             for comment in comments:
230                 dst_contents += str(comment)
231             comments = []
232             for line in split_line(current_line, line_length=line_length, py36=py36):
233                 dst_contents += str(line)
234         else:
235             comments.append(current_line)
236     if comments:
237         if elt.previous_defs:
238             # Separate postscriptum comments from the last module-level def.
239             dst_contents += str(empty_line)
240             dst_contents += str(empty_line)
241         for comment in comments:
242             dst_contents += str(comment)
243     return dst_contents
244
245
246 def lib2to3_parse(src_txt: str) -> Node:
247     """Given a string with source, return the lib2to3 Node."""
248     grammar = pygram.python_grammar_no_print_statement
249     drv = driver.Driver(grammar, pytree.convert)
250     if src_txt[-1] != '\n':
251         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
252         src_txt += nl
253     try:
254         result = drv.parse_string(src_txt, True)
255     except ParseError as pe:
256         lineno, column = pe.context[1]
257         lines = src_txt.splitlines()
258         try:
259             faulty_line = lines[lineno - 1]
260         except IndexError:
261             faulty_line = "<line number missing in source>"
262         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
263
264     if isinstance(result, Leaf):
265         result = Node(syms.file_input, [result])
266     return result
267
268
269 def lib2to3_unparse(node: Node) -> str:
270     """Given a lib2to3 node, return its string representation."""
271     code = str(node)
272     return code
273
274
275 T = TypeVar('T')
276
277
278 class Visitor(Generic[T]):
279     """Basic lib2to3 visitor that yields things on visiting."""
280
281     def visit(self, node: LN) -> Iterator[T]:
282         if node.type < 256:
283             name = token.tok_name[node.type]
284         else:
285             name = type_repr(node.type)
286         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
287
288     def visit_default(self, node: LN) -> Iterator[T]:
289         if isinstance(node, Node):
290             for child in node.children:
291                 yield from self.visit(child)
292
293
294 @dataclass
295 class DebugVisitor(Visitor[T]):
296     tree_depth: int = 0
297
298     def visit_default(self, node: LN) -> Iterator[T]:
299         indent = ' ' * (2 * self.tree_depth)
300         if isinstance(node, Node):
301             _type = type_repr(node.type)
302             out(f'{indent}{_type}', fg='yellow')
303             self.tree_depth += 1
304             for child in node.children:
305                 yield from self.visit(child)
306
307             self.tree_depth -= 1
308             out(f'{indent}/{_type}', fg='yellow', bold=False)
309         else:
310             _type = token.tok_name.get(node.type, str(node.type))
311             out(f'{indent}{_type}', fg='blue', nl=False)
312             if node.prefix:
313                 # We don't have to handle prefixes for `Node` objects since
314                 # that delegates to the first child anyway.
315                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
316             out(f' {node.value!r}', fg='blue', bold=False)
317
318
319 KEYWORDS = set(keyword.kwlist)
320 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
321 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
322 STATEMENT = {
323     syms.if_stmt,
324     syms.while_stmt,
325     syms.for_stmt,
326     syms.try_stmt,
327     syms.except_clause,
328     syms.with_stmt,
329     syms.funcdef,
330     syms.classdef,
331 }
332 STANDALONE_COMMENT = 153
333 LOGIC_OPERATORS = {'and', 'or'}
334 COMPARATORS = {
335     token.LESS,
336     token.GREATER,
337     token.EQEQUAL,
338     token.NOTEQUAL,
339     token.LESSEQUAL,
340     token.GREATEREQUAL,
341 }
342 MATH_OPERATORS = {
343     token.PLUS,
344     token.MINUS,
345     token.STAR,
346     token.SLASH,
347     token.VBAR,
348     token.AMPER,
349     token.PERCENT,
350     token.CIRCUMFLEX,
351     token.LEFTSHIFT,
352     token.RIGHTSHIFT,
353     token.DOUBLESTAR,
354     token.DOUBLESLASH,
355 }
356 COMPREHENSION_PRIORITY = 20
357 COMMA_PRIORITY = 10
358 LOGIC_PRIORITY = 5
359 STRING_PRIORITY = 4
360 COMPARATOR_PRIORITY = 3
361 MATH_PRIORITY = 1
362
363
364 @dataclass
365 class BracketTracker:
366     depth: int = 0
367     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
368     delimiters: Dict[LeafID, Priority] = Factory(dict)
369     previous: Optional[Leaf] = None
370
371     def mark(self, leaf: Leaf) -> None:
372         if leaf.type == token.COMMENT:
373             return
374
375         if leaf.type in CLOSING_BRACKETS:
376             self.depth -= 1
377             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
378             leaf.opening_bracket = opening_bracket
379         leaf.bracket_depth = self.depth
380         if self.depth == 0:
381             delim = is_delimiter(leaf)
382             if delim:
383                 self.delimiters[id(leaf)] = delim
384             elif self.previous is not None:
385                 if leaf.type == token.STRING and self.previous.type == token.STRING:
386                     self.delimiters[id(self.previous)] = STRING_PRIORITY
387                 elif (
388                     leaf.type == token.NAME
389                     and leaf.value == 'for'
390                     and leaf.parent
391                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
392                 ):
393                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
394                 elif (
395                     leaf.type == token.NAME
396                     and leaf.value == 'if'
397                     and leaf.parent
398                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
399                 ):
400                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
401                 elif (
402                     leaf.type == token.NAME
403                     and leaf.value in LOGIC_OPERATORS
404                     and leaf.parent
405                 ):
406                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
407         if leaf.type in OPENING_BRACKETS:
408             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
409             self.depth += 1
410         self.previous = leaf
411
412     def any_open_brackets(self) -> bool:
413         """Returns True if there is an yet unmatched open bracket on the line."""
414         return bool(self.bracket_match)
415
416     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
417         """Returns the highest priority of a delimiter found on the line.
418
419         Values are consistent with what `is_delimiter()` returns.
420         """
421         return max(v for k, v in self.delimiters.items() if k not in exclude)
422
423
424 @dataclass
425 class Line:
426     depth: int = 0
427     leaves: List[Leaf] = Factory(list)
428     comments: Dict[LeafID, Leaf] = Factory(dict)
429     bracket_tracker: BracketTracker = Factory(BracketTracker)
430     inside_brackets: bool = False
431     has_for: bool = False
432     _for_loop_variable: bool = False
433
434     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
435         has_value = leaf.value.strip()
436         if not has_value:
437             return
438
439         if self.leaves and not preformatted:
440             # Note: at this point leaf.prefix should be empty except for
441             # imports, for which we only preserve newlines.
442             leaf.prefix += whitespace(leaf)
443         if self.inside_brackets or not preformatted:
444             self.maybe_decrement_after_for_loop_variable(leaf)
445             self.bracket_tracker.mark(leaf)
446             self.maybe_remove_trailing_comma(leaf)
447             self.maybe_increment_for_loop_variable(leaf)
448             if self.maybe_adapt_standalone_comment(leaf):
449                 return
450
451         if not self.append_comment(leaf):
452             self.leaves.append(leaf)
453
454     @property
455     def is_comment(self) -> bool:
456         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
457
458     @property
459     def is_decorator(self) -> bool:
460         return bool(self) and self.leaves[0].type == token.AT
461
462     @property
463     def is_import(self) -> bool:
464         return bool(self) and is_import(self.leaves[0])
465
466     @property
467     def is_class(self) -> bool:
468         return (
469             bool(self)
470             and self.leaves[0].type == token.NAME
471             and self.leaves[0].value == 'class'
472         )
473
474     @property
475     def is_def(self) -> bool:
476         """Also returns True for async defs."""
477         try:
478             first_leaf = self.leaves[0]
479         except IndexError:
480             return False
481
482         try:
483             second_leaf: Optional[Leaf] = self.leaves[1]
484         except IndexError:
485             second_leaf = None
486         return (
487             (first_leaf.type == token.NAME and first_leaf.value == 'def')
488             or (
489                 first_leaf.type == token.ASYNC
490                 and second_leaf is not None
491                 and second_leaf.type == token.NAME
492                 and second_leaf.value == 'def'
493             )
494         )
495
496     @property
497     def is_flow_control(self) -> bool:
498         return (
499             bool(self)
500             and self.leaves[0].type == token.NAME
501             and self.leaves[0].value in FLOW_CONTROL
502         )
503
504     @property
505     def is_yield(self) -> bool:
506         return (
507             bool(self)
508             and self.leaves[0].type == token.NAME
509             and self.leaves[0].value == 'yield'
510         )
511
512     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
513         if not (
514             self.leaves
515             and self.leaves[-1].type == token.COMMA
516             and closing.type in CLOSING_BRACKETS
517         ):
518             return False
519
520         if closing.type == token.RSQB or closing.type == token.RBRACE:
521             self.leaves.pop()
522             return True
523
524         # For parens let's check if it's safe to remove the comma.  If the
525         # trailing one is the only one, we might mistakenly change a tuple
526         # into a different type by removing the comma.
527         depth = closing.bracket_depth + 1
528         commas = 0
529         opening = closing.opening_bracket
530         for _opening_index, leaf in enumerate(self.leaves):
531             if leaf is opening:
532                 break
533
534         else:
535             return False
536
537         for leaf in self.leaves[_opening_index + 1:]:
538             if leaf is closing:
539                 break
540
541             bracket_depth = leaf.bracket_depth
542             if bracket_depth == depth and leaf.type == token.COMMA:
543                 commas += 1
544                 if leaf.parent and leaf.parent.type == syms.arglist:
545                     commas += 1
546                     break
547
548         if commas > 1:
549             self.leaves.pop()
550             return True
551
552         return False
553
554     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
555         """In a for loop, or comprehension, the variables are often unpacks.
556
557         To avoid splitting on the comma in this situation, we will increase
558         the depth of tokens between `for` and `in`.
559         """
560         if leaf.type == token.NAME and leaf.value == 'for':
561             self.has_for = True
562             self.bracket_tracker.depth += 1
563             self._for_loop_variable = True
564             return True
565
566         return False
567
568     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
569         # See `maybe_increment_for_loop_variable` above for explanation.
570         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
571             self.bracket_tracker.depth -= 1
572             self._for_loop_variable = False
573             return True
574
575         return False
576
577     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
578         """Hack a standalone comment to act as a trailing comment for line splitting.
579
580         If this line has brackets and a standalone `comment`, we need to adapt
581         it to be able to still reformat the line.
582
583         This is not perfect, the line to which the standalone comment gets
584         appended will appear "too long" when splitting.
585         """
586         if not (
587             comment.type == STANDALONE_COMMENT
588             and self.bracket_tracker.any_open_brackets()
589         ):
590             return False
591
592         comment.type = token.COMMENT
593         comment.prefix = '\n' + '    ' * (self.depth + 1)
594         return self.append_comment(comment)
595
596     def append_comment(self, comment: Leaf) -> bool:
597         if comment.type != token.COMMENT:
598             return False
599
600         try:
601             after = id(self.last_non_delimiter())
602         except LookupError:
603             comment.type = STANDALONE_COMMENT
604             comment.prefix = ''
605             return False
606
607         else:
608             if after in self.comments:
609                 self.comments[after].value += str(comment)
610             else:
611                 self.comments[after] = comment
612             return True
613
614     def last_non_delimiter(self) -> Leaf:
615         for i in range(len(self.leaves)):
616             last = self.leaves[-i - 1]
617             if not is_delimiter(last):
618                 return last
619
620         raise LookupError("No non-delimiters found")
621
622     def __str__(self) -> str:
623         if not self:
624             return '\n'
625
626         indent = '    ' * self.depth
627         leaves = iter(self.leaves)
628         first = next(leaves)
629         res = f'{first.prefix}{indent}{first.value}'
630         for leaf in leaves:
631             res += str(leaf)
632         for comment in self.comments.values():
633             res += str(comment)
634         return res + '\n'
635
636     def __bool__(self) -> bool:
637         return bool(self.leaves or self.comments)
638
639
640 @dataclass
641 class EmptyLineTracker:
642     """Provides a stateful method that returns the number of potential extra
643     empty lines needed before and after the currently processed line.
644
645     Note: this tracker works on lines that haven't been split yet.  It assumes
646     the prefix of the first leaf consists of optional newlines.  Those newlines
647     are consumed by `maybe_empty_lines()` and included in the computation.
648     """
649     previous_line: Optional[Line] = None
650     previous_after: int = 0
651     previous_defs: List[int] = Factory(list)
652
653     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
654         """Returns the number of extra empty lines before and after the `current_line`.
655
656         This is for separating `def`, `async def` and `class` with extra empty lines
657         (two on module-level), as well as providing an extra empty line after flow
658         control keywords to make them more prominent.
659         """
660         if current_line.is_comment:
661             # Don't count standalone comments towards previous empty lines.
662             return 0, 0
663
664         before, after = self._maybe_empty_lines(current_line)
665         before -= self.previous_after
666         self.previous_after = after
667         self.previous_line = current_line
668         return before, after
669
670     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
671         if current_line.leaves:
672             # Consume the first leaf's extra newlines.
673             first_leaf = current_line.leaves[0]
674             before = int('\n' in first_leaf.prefix)
675             first_leaf.prefix = ''
676         else:
677             before = 0
678         depth = current_line.depth
679         while self.previous_defs and self.previous_defs[-1] >= depth:
680             self.previous_defs.pop()
681             before = 1 if depth else 2
682         is_decorator = current_line.is_decorator
683         if is_decorator or current_line.is_def or current_line.is_class:
684             if not is_decorator:
685                 self.previous_defs.append(depth)
686             if self.previous_line is None:
687                 # Don't insert empty lines before the first line in the file.
688                 return 0, 0
689
690             if self.previous_line and self.previous_line.is_decorator:
691                 # Don't insert empty lines between decorators.
692                 return 0, 0
693
694             newlines = 2
695             if current_line.depth:
696                 newlines -= 1
697             return newlines, 0
698
699         if current_line.is_flow_control:
700             return before, 1
701
702         if (
703             self.previous_line
704             and self.previous_line.is_import
705             and not current_line.is_import
706             and depth == self.previous_line.depth
707         ):
708             return (before or 1), 0
709
710         if (
711             self.previous_line
712             and self.previous_line.is_yield
713             and (not current_line.is_yield or depth != self.previous_line.depth)
714         ):
715             return (before or 1), 0
716
717         return before, 0
718
719
720 @dataclass
721 class LineGenerator(Visitor[Line]):
722     """Generates reformatted Line objects.  Empty lines are not emitted.
723
724     Note: destroys the tree it's visiting by mutating prefixes of its leaves
725     in ways that will no longer stringify to valid Python code on the tree.
726     """
727     current_line: Line = Factory(Line)
728     standalone_comments: List[Leaf] = Factory(list)
729
730     def line(self, indent: int = 0) -> Iterator[Line]:
731         """Generate a line.
732
733         If the line is empty, only emit if it makes sense.
734         If the line is too long, split it first and then generate.
735
736         If any lines were generated, set up a new current_line.
737         """
738         if not self.current_line:
739             self.current_line.depth += indent
740             return  # Line is empty, don't emit. Creating a new one unnecessary.
741
742         complete_line = self.current_line
743         self.current_line = Line(depth=complete_line.depth + indent)
744         yield complete_line
745
746     def visit_default(self, node: LN) -> Iterator[Line]:
747         if isinstance(node, Leaf):
748             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
749             for comment in generate_comments(node):
750                 if any_open_brackets:
751                     # any comment within brackets is subject to splitting
752                     self.current_line.append(comment)
753                 elif comment.type == token.COMMENT:
754                     # regular trailing comment
755                     self.current_line.append(comment)
756                     yield from self.line()
757
758                 else:
759                     # regular standalone comment, to be processed later (see
760                     # docstring in `generate_comments()`
761                     self.standalone_comments.append(comment)
762             normalize_prefix(node, inside_brackets=any_open_brackets)
763             if node.type not in WHITESPACE:
764                 for comment in self.standalone_comments:
765                     yield from self.line()
766
767                     self.current_line.append(comment)
768                     yield from self.line()
769
770                 self.standalone_comments = []
771                 self.current_line.append(node)
772         yield from super().visit_default(node)
773
774     def visit_suite(self, node: Node) -> Iterator[Line]:
775         """Body of a statement after a colon."""
776         children = iter(node.children)
777         # Process newline before indenting.  It might contain an inline
778         # comment that should go right after the colon.
779         newline = next(children)
780         yield from self.visit(newline)
781         yield from self.line(+1)
782
783         for child in children:
784             yield from self.visit(child)
785
786         yield from self.line(-1)
787
788     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
789         """Visit a statement.
790
791         The relevant Python language keywords for this statement are NAME leaves
792         within it.
793         """
794         for child in node.children:
795             if child.type == token.NAME and child.value in keywords:  # type: ignore
796                 yield from self.line()
797
798             yield from self.visit(child)
799
800     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
801         """A statement without nested statements."""
802         is_suite_like = node.parent and node.parent.type in STATEMENT
803         if is_suite_like:
804             yield from self.line(+1)
805             yield from self.visit_default(node)
806             yield from self.line(-1)
807
808         else:
809             yield from self.line()
810             yield from self.visit_default(node)
811
812     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
813         yield from self.line()
814
815         children = iter(node.children)
816         for child in children:
817             yield from self.visit(child)
818
819             if child.type == token.ASYNC:
820                 break
821
822         internal_stmt = next(children)
823         for child in internal_stmt.children:
824             yield from self.visit(child)
825
826     def visit_decorators(self, node: Node) -> Iterator[Line]:
827         for child in node.children:
828             yield from self.line()
829             yield from self.visit(child)
830
831     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
832         yield from self.line()
833
834     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
835         yield from self.visit_default(leaf)
836         yield from self.line()
837
838     def __attrs_post_init__(self) -> None:
839         """You are in a twisty little maze of passages."""
840         v = self.visit_stmt
841         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
842         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
843         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
844         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
845         self.visit_except_clause = partial(v, keywords={'except'})
846         self.visit_funcdef = partial(v, keywords={'def'})
847         self.visit_with_stmt = partial(v, keywords={'with'})
848         self.visit_classdef = partial(v, keywords={'class'})
849         self.visit_async_funcdef = self.visit_async_stmt
850         self.visit_decorated = self.visit_decorators
851
852
853 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
854 OPENING_BRACKETS = set(BRACKET.keys())
855 CLOSING_BRACKETS = set(BRACKET.values())
856 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
857 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
858
859
860 def whitespace(leaf: Leaf) -> str:  # noqa C901
861     """Return whitespace prefix if needed for the given `leaf`."""
862     NO = ''
863     SPACE = ' '
864     DOUBLESPACE = '  '
865     t = leaf.type
866     p = leaf.parent
867     v = leaf.value
868     if t in ALWAYS_NO_SPACE:
869         return NO
870
871     if t == token.COMMENT:
872         return DOUBLESPACE
873
874     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
875     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
876         return NO
877
878     prev = leaf.prev_sibling
879     if not prev:
880         prevp = preceding_leaf(p)
881         if not prevp or prevp.type in OPENING_BRACKETS:
882             return NO
883
884         if t == token.COLON:
885             return SPACE if prevp.type == token.COMMA else NO
886
887         if prevp.type == token.EQUAL:
888             if prevp.parent and prevp.parent.type in {
889                 syms.typedargslist,
890                 syms.varargslist,
891                 syms.parameters,
892                 syms.arglist,
893                 syms.argument,
894             }:
895                 return NO
896
897         elif prevp.type == token.DOUBLESTAR:
898             if prevp.parent and prevp.parent.type in {
899                 syms.typedargslist,
900                 syms.varargslist,
901                 syms.parameters,
902                 syms.arglist,
903                 syms.dictsetmaker,
904             }:
905                 return NO
906
907         elif prevp.type == token.COLON:
908             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
909                 return NO
910
911         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
912             return NO
913
914     elif prev.type in OPENING_BRACKETS:
915         return NO
916
917     if p.type in {syms.parameters, syms.arglist}:
918         # untyped function signatures or calls
919         if t == token.RPAR:
920             return NO
921
922         if not prev or prev.type != token.COMMA:
923             return NO
924
925     if p.type == syms.varargslist:
926         # lambdas
927         if t == token.RPAR:
928             return NO
929
930         if prev and prev.type != token.COMMA:
931             return NO
932
933     elif p.type == syms.typedargslist:
934         # typed function signatures
935         if not prev:
936             return NO
937
938         if t == token.EQUAL:
939             if prev.type != syms.tname:
940                 return NO
941
942         elif prev.type == token.EQUAL:
943             # A bit hacky: if the equal sign has whitespace, it means we
944             # previously found it's a typed argument.  So, we're using that, too.
945             return prev.prefix
946
947         elif prev.type != token.COMMA:
948             return NO
949
950     elif p.type == syms.tname:
951         # type names
952         if not prev:
953             prevp = preceding_leaf(p)
954             if not prevp or prevp.type != token.COMMA:
955                 return NO
956
957     elif p.type == syms.trailer:
958         # attributes and calls
959         if t == token.LPAR or t == token.RPAR:
960             return NO
961
962         if not prev:
963             if t == token.DOT:
964                 prevp = preceding_leaf(p)
965                 if not prevp or prevp.type != token.NUMBER:
966                     return NO
967
968             elif t == token.LSQB:
969                 return NO
970
971         elif prev.type != token.COMMA:
972             return NO
973
974     elif p.type == syms.argument:
975         # single argument
976         if t == token.EQUAL:
977             return NO
978
979         if not prev:
980             prevp = preceding_leaf(p)
981             if not prevp or prevp.type == token.LPAR:
982                 return NO
983
984         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
985             return NO
986
987     elif p.type == syms.decorator:
988         # decorators
989         return NO
990
991     elif p.type == syms.dotted_name:
992         if prev:
993             return NO
994
995         prevp = preceding_leaf(p)
996         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
997             return NO
998
999     elif p.type == syms.classdef:
1000         if t == token.LPAR:
1001             return NO
1002
1003         if prev and prev.type == token.LPAR:
1004             return NO
1005
1006     elif p.type == syms.subscript:
1007         # indexing
1008         if not prev:
1009             assert p.parent is not None, "subscripts are always parented"
1010             if p.parent.type == syms.subscriptlist:
1011                 return SPACE
1012
1013             return NO
1014
1015         else:
1016             return NO
1017
1018     elif p.type == syms.atom:
1019         if prev and t == token.DOT:
1020             # dots, but not the first one.
1021             return NO
1022
1023     elif (
1024         p.type == syms.listmaker
1025         or p.type == syms.testlist_gexp
1026         or p.type == syms.subscriptlist
1027     ):
1028         # list interior, including unpacking
1029         if not prev:
1030             return NO
1031
1032     elif p.type == syms.dictsetmaker:
1033         # dict and set interior, including unpacking
1034         if not prev:
1035             return NO
1036
1037         if prev.type == token.DOUBLESTAR:
1038             return NO
1039
1040     elif p.type in {syms.factor, syms.star_expr}:
1041         # unary ops
1042         if not prev:
1043             prevp = preceding_leaf(p)
1044             if not prevp or prevp.type in OPENING_BRACKETS:
1045                 return NO
1046
1047             prevp_parent = prevp.parent
1048             assert prevp_parent is not None
1049             if prevp.type == token.COLON and prevp_parent.type in {
1050                 syms.subscript, syms.sliceop
1051             }:
1052                 return NO
1053
1054             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1055                 return NO
1056
1057         elif t == token.NAME or t == token.NUMBER:
1058             return NO
1059
1060     elif p.type == syms.import_from:
1061         if t == token.DOT:
1062             if prev and prev.type == token.DOT:
1063                 return NO
1064
1065         elif t == token.NAME:
1066             if v == 'import':
1067                 return SPACE
1068
1069             if prev and prev.type == token.DOT:
1070                 return NO
1071
1072     elif p.type == syms.sliceop:
1073         return NO
1074
1075     return SPACE
1076
1077
1078 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1079     """Returns the first leaf that precedes `node`, if any."""
1080     while node:
1081         res = node.prev_sibling
1082         if res:
1083             if isinstance(res, Leaf):
1084                 return res
1085
1086             try:
1087                 return list(res.leaves())[-1]
1088
1089             except IndexError:
1090                 return None
1091
1092         node = node.parent
1093     return None
1094
1095
1096 def is_delimiter(leaf: Leaf) -> int:
1097     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1098
1099     Higher numbers are higher priority.
1100     """
1101     if leaf.type == token.COMMA:
1102         return COMMA_PRIORITY
1103
1104     if leaf.type in COMPARATORS:
1105         return COMPARATOR_PRIORITY
1106
1107     if (
1108         leaf.type in MATH_OPERATORS
1109         and leaf.parent
1110         and leaf.parent.type not in {syms.factor, syms.star_expr}
1111     ):
1112         return MATH_PRIORITY
1113
1114     return 0
1115
1116
1117 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1118     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1119
1120     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1121     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1122     move because it does away with modifying the grammar to include all the
1123     possible places in which comments can be placed.
1124
1125     The sad consequence for us though is that comments don't "belong" anywhere.
1126     This is why this function generates simple parentless Leaf objects for
1127     comments.  We simply don't know what the correct parent should be.
1128
1129     No matter though, we can live without this.  We really only need to
1130     differentiate between inline and standalone comments.  The latter don't
1131     share the line with any code.
1132
1133     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1134     are emitted with a fake STANDALONE_COMMENT token identifier.
1135     """
1136     if not leaf.prefix:
1137         return
1138
1139     if '#' not in leaf.prefix:
1140         return
1141
1142     before_comment, content = leaf.prefix.split('#', 1)
1143     content = content.rstrip()
1144     if content and (content[0] not in {' ', '!', '#'}):
1145         content = ' ' + content
1146     is_standalone_comment = (
1147         '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER
1148     )
1149     if not is_standalone_comment:
1150         # simple trailing comment
1151         yield Leaf(token.COMMENT, value='#' + content)
1152         return
1153
1154     for line in ('#' + content).split('\n'):
1155         line = line.lstrip()
1156         if not line.startswith('#'):
1157             continue
1158
1159         yield Leaf(STANDALONE_COMMENT, line)
1160
1161
1162 def split_line(
1163     line: Line, line_length: int, inner: bool = False, py36: bool = False
1164 ) -> Iterator[Line]:
1165     """Splits a `line` into potentially many lines.
1166
1167     They should fit in the allotted `line_length` but might not be able to.
1168     `inner` signifies that there were a pair of brackets somewhere around the
1169     current `line`, possibly transitively. This means we can fallback to splitting
1170     by delimiters if the LHS/RHS don't yield any results.
1171
1172     If `py36` is True, splitting may generate syntax that is only compatible
1173     with Python 3.6 and later.
1174     """
1175     line_str = str(line).strip('\n')
1176     if len(line_str) <= line_length and '\n' not in line_str:
1177         yield line
1178         return
1179
1180     if line.is_def:
1181         split_funcs = [left_hand_split]
1182     elif line.inside_brackets:
1183         split_funcs = [delimiter_split]
1184         if '\n' not in line_str:
1185             # Only attempt RHS if we don't have multiline strings or comments
1186             # on this line.
1187             split_funcs.append(right_hand_split)
1188     else:
1189         split_funcs = [right_hand_split]
1190     for split_func in split_funcs:
1191         # We are accumulating lines in `result` because we might want to abort
1192         # mission and return the original line in the end, or attempt a different
1193         # split altogether.
1194         result: List[Line] = []
1195         try:
1196             for l in split_func(line, py36=py36):
1197                 if str(l).strip('\n') == line_str:
1198                     raise CannotSplit("Split function returned an unchanged result")
1199
1200                 result.extend(
1201                     split_line(l, line_length=line_length, inner=True, py36=py36)
1202                 )
1203         except CannotSplit as cs:
1204             continue
1205
1206         else:
1207             yield from result
1208             break
1209
1210     else:
1211         yield line
1212
1213
1214 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1215     """Split line into many lines, starting with the first matching bracket pair.
1216
1217     Note: this usually looks weird, only use this for function definitions.
1218     Prefer RHS otherwise.
1219     """
1220     head = Line(depth=line.depth)
1221     body = Line(depth=line.depth + 1, inside_brackets=True)
1222     tail = Line(depth=line.depth)
1223     tail_leaves: List[Leaf] = []
1224     body_leaves: List[Leaf] = []
1225     head_leaves: List[Leaf] = []
1226     current_leaves = head_leaves
1227     matching_bracket = None
1228     for leaf in line.leaves:
1229         if (
1230             current_leaves is body_leaves
1231             and leaf.type in CLOSING_BRACKETS
1232             and leaf.opening_bracket is matching_bracket
1233         ):
1234             current_leaves = tail_leaves if body_leaves else head_leaves
1235         current_leaves.append(leaf)
1236         if current_leaves is head_leaves:
1237             if leaf.type in OPENING_BRACKETS:
1238                 matching_bracket = leaf
1239                 current_leaves = body_leaves
1240     # Since body is a new indent level, remove spurious leading whitespace.
1241     if body_leaves:
1242         normalize_prefix(body_leaves[0], inside_brackets=True)
1243     # Build the new lines.
1244     for result, leaves in (
1245         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1246     ):
1247         for leaf in leaves:
1248             result.append(leaf, preformatted=True)
1249             comment_after = line.comments.get(id(leaf))
1250             if comment_after:
1251                 result.append(comment_after, preformatted=True)
1252     split_succeeded_or_raise(head, body, tail)
1253     for result in (head, body, tail):
1254         if result:
1255             yield result
1256
1257
1258 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1259     """Split line into many lines, starting with the last matching bracket pair."""
1260     head = Line(depth=line.depth)
1261     body = Line(depth=line.depth + 1, inside_brackets=True)
1262     tail = Line(depth=line.depth)
1263     tail_leaves: List[Leaf] = []
1264     body_leaves: List[Leaf] = []
1265     head_leaves: List[Leaf] = []
1266     current_leaves = tail_leaves
1267     opening_bracket = None
1268     for leaf in reversed(line.leaves):
1269         if current_leaves is body_leaves:
1270             if leaf is opening_bracket:
1271                 current_leaves = head_leaves if body_leaves else tail_leaves
1272         current_leaves.append(leaf)
1273         if current_leaves is tail_leaves:
1274             if leaf.type in CLOSING_BRACKETS:
1275                 opening_bracket = leaf.opening_bracket
1276                 current_leaves = body_leaves
1277     tail_leaves.reverse()
1278     body_leaves.reverse()
1279     head_leaves.reverse()
1280     # Since body is a new indent level, remove spurious leading whitespace.
1281     if body_leaves:
1282         normalize_prefix(body_leaves[0], inside_brackets=True)
1283     # Build the new lines.
1284     for result, leaves in (
1285         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1286     ):
1287         for leaf in leaves:
1288             result.append(leaf, preformatted=True)
1289             comment_after = line.comments.get(id(leaf))
1290             if comment_after:
1291                 result.append(comment_after, preformatted=True)
1292     split_succeeded_or_raise(head, body, tail)
1293     for result in (head, body, tail):
1294         if result:
1295             yield result
1296
1297
1298 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1299     tail_len = len(str(tail).strip())
1300     if not body:
1301         if tail_len == 0:
1302             raise CannotSplit("Splitting brackets produced the same line")
1303
1304         elif tail_len < 3:
1305             raise CannotSplit(
1306                 f"Splitting brackets on an empty body to save "
1307                 f"{tail_len} characters is not worth it"
1308             )
1309
1310
1311 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1312     """Split according to delimiters of the highest priority.
1313
1314     This kind of split doesn't increase indentation.
1315     If `py36` is True, the split will add trailing commas also in function
1316     signatures that contain * and **.
1317     """
1318     try:
1319         last_leaf = line.leaves[-1]
1320     except IndexError:
1321         raise CannotSplit("Line empty")
1322
1323     delimiters = line.bracket_tracker.delimiters
1324     try:
1325         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1326     except ValueError:
1327         raise CannotSplit("No delimiters found")
1328
1329     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1330     lowest_depth = sys.maxsize
1331     trailing_comma_safe = True
1332     for leaf in line.leaves:
1333         current_line.append(leaf, preformatted=True)
1334         comment_after = line.comments.get(id(leaf))
1335         if comment_after:
1336             current_line.append(comment_after, preformatted=True)
1337         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1338         if (
1339             leaf.bracket_depth == lowest_depth
1340             and leaf.type == token.STAR
1341             or leaf.type == token.DOUBLESTAR
1342         ):
1343             trailing_comma_safe = trailing_comma_safe and py36
1344         leaf_priority = delimiters.get(id(leaf))
1345         if leaf_priority == delimiter_priority:
1346             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1347             yield current_line
1348
1349             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1350     if current_line:
1351         if (
1352             delimiter_priority == COMMA_PRIORITY
1353             and current_line.leaves[-1].type != token.COMMA
1354             and trailing_comma_safe
1355         ):
1356             current_line.append(Leaf(token.COMMA, ','))
1357         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1358         yield current_line
1359
1360
1361 def is_import(leaf: Leaf) -> bool:
1362     """Returns True if the given leaf starts an import statement."""
1363     p = leaf.parent
1364     t = leaf.type
1365     v = leaf.value
1366     return bool(
1367         t == token.NAME
1368         and (
1369             (v == 'import' and p and p.type == syms.import_name)
1370             or (v == 'from' and p and p.type == syms.import_from)
1371         )
1372     )
1373
1374
1375 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1376     """Leave existing extra newlines if not `inside_brackets`.
1377
1378     Remove everything else.  Note: don't use backslashes for formatting or
1379     you'll lose your voting rights.
1380     """
1381     if not inside_brackets:
1382         spl = leaf.prefix.split('#', 1)
1383         if '\\' not in spl[0]:
1384             nl_count = spl[0].count('\n')
1385             leaf.prefix = '\n' * nl_count
1386             return
1387
1388     leaf.prefix = ''
1389
1390
1391 def is_python36(node: Node) -> bool:
1392     """Returns True if the current file is using Python 3.6+ features.
1393
1394     Currently looking for:
1395     - f-strings; and
1396     - trailing commas after * or ** in function signatures.
1397     """
1398     for n in node.pre_order():
1399         if n.type == token.STRING:
1400             value_head = n.value[:2]  # type: ignore
1401             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1402                 return True
1403
1404         elif (
1405             n.type == syms.typedargslist
1406             and n.children
1407             and n.children[-1].type == token.COMMA
1408         ):
1409             for ch in n.children:
1410                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1411                     return True
1412
1413     return False
1414
1415
1416 PYTHON_EXTENSIONS = {'.py'}
1417 BLACKLISTED_DIRECTORIES = {
1418     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1419 }
1420
1421
1422 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1423     for child in path.iterdir():
1424         if child.is_dir():
1425             if child.name in BLACKLISTED_DIRECTORIES:
1426                 continue
1427
1428             yield from gen_python_files_in_dir(child)
1429
1430         elif child.suffix in PYTHON_EXTENSIONS:
1431             yield child
1432
1433
1434 @dataclass
1435 class Report:
1436     """Provides a reformatting counter."""
1437     change_count: int = 0
1438     same_count: int = 0
1439     failure_count: int = 0
1440
1441     def done(self, src: Path, changed: bool) -> None:
1442         """Increment the counter for successful reformatting. Write out a message."""
1443         if changed:
1444             out(f'reformatted {src}')
1445             self.change_count += 1
1446         else:
1447             out(f'{src} already well formatted, good job.', bold=False)
1448             self.same_count += 1
1449
1450     def failed(self, src: Path, message: str) -> None:
1451         """Increment the counter for failed reformatting. Write out a message."""
1452         err(f'error: cannot format {src}: {message}')
1453         self.failure_count += 1
1454
1455     @property
1456     def return_code(self) -> int:
1457         """Which return code should the app use considering the current state."""
1458         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1459         # 126 we have special returncodes reserved by the shell.
1460         if self.failure_count:
1461             return 123
1462
1463         elif self.change_count:
1464             return 1
1465
1466         return 0
1467
1468     def __str__(self) -> str:
1469         """A color report of the current state.
1470
1471         Use `click.unstyle` to remove colors.
1472         """
1473         report = []
1474         if self.change_count:
1475             s = 's' if self.change_count > 1 else ''
1476             report.append(
1477                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1478             )
1479         if self.same_count:
1480             s = 's' if self.same_count > 1 else ''
1481             report.append(f'{self.same_count} file{s} left unchanged')
1482         if self.failure_count:
1483             s = 's' if self.failure_count > 1 else ''
1484             report.append(
1485                 click.style(
1486                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1487                 )
1488             )
1489         return ', '.join(report) + '.'
1490
1491
1492 def assert_equivalent(src: str, dst: str) -> None:
1493     """Raises AssertionError if `src` and `dst` aren't equivalent.
1494
1495     This is a temporary sanity check until Black becomes stable.
1496     """
1497
1498     import ast
1499     import traceback
1500
1501     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1502         """Simple visitor generating strings to compare ASTs by content."""
1503         yield f"{'  ' * depth}{node.__class__.__name__}("
1504
1505         for field in sorted(node._fields):
1506             try:
1507                 value = getattr(node, field)
1508             except AttributeError:
1509                 continue
1510
1511             yield f"{'  ' * (depth+1)}{field}="
1512
1513             if isinstance(value, list):
1514                 for item in value:
1515                     if isinstance(item, ast.AST):
1516                         yield from _v(item, depth + 2)
1517
1518             elif isinstance(value, ast.AST):
1519                 yield from _v(value, depth + 2)
1520
1521             else:
1522                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1523
1524         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1525
1526     try:
1527         src_ast = ast.parse(src)
1528     except Exception as exc:
1529         raise AssertionError(f"cannot parse source: {exc}") from None
1530
1531     try:
1532         dst_ast = ast.parse(dst)
1533     except Exception as exc:
1534         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1535         raise AssertionError(
1536             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1537             f"Please report a bug on https://github.com/ambv/black/issues.  "
1538             f"This invalid output might be helpful: {log}"
1539         ) from None
1540
1541     src_ast_str = '\n'.join(_v(src_ast))
1542     dst_ast_str = '\n'.join(_v(dst_ast))
1543     if src_ast_str != dst_ast_str:
1544         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1545         raise AssertionError(
1546             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1547             f"the source.  "
1548             f"Please report a bug on https://github.com/ambv/black/issues.  "
1549             f"This diff might be helpful: {log}"
1550         ) from None
1551
1552
1553 def assert_stable(src: str, dst: str, line_length: int) -> None:
1554     """Raises AssertionError if `dst` reformats differently the second time.
1555
1556     This is a temporary sanity check until Black becomes stable.
1557     """
1558     newdst = format_str(dst, line_length=line_length)
1559     if dst != newdst:
1560         log = dump_to_file(
1561             diff(src, dst, 'source', 'first pass'),
1562             diff(dst, newdst, 'first pass', 'second pass'),
1563         )
1564         raise AssertionError(
1565             f"INTERNAL ERROR: Black produced different code on the second pass "
1566             f"of the formatter.  "
1567             f"Please report a bug on https://github.com/ambv/black/issues.  "
1568             f"This diff might be helpful: {log}"
1569         ) from None
1570
1571
1572 def dump_to_file(*output: str) -> str:
1573     """Dumps `output` to a temporary file. Returns path to the file."""
1574     import tempfile
1575
1576     with tempfile.NamedTemporaryFile(
1577         mode='w', prefix='blk_', suffix='.log', delete=False
1578     ) as f:
1579         for lines in output:
1580             f.write(lines)
1581             f.write('\n')
1582     return f.name
1583
1584
1585 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1586     """Returns a udiff string between strings `a` and `b`."""
1587     import difflib
1588
1589     a_lines = [line + '\n' for line in a.split('\n')]
1590     b_lines = [line + '\n' for line in b.split('\n')]
1591     return ''.join(
1592         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1593     )
1594
1595
1596 if __name__ == '__main__':
1597     main()