]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Don't omit whitespace when the factor is not a math operator
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a3"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by `format_file` when the reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by `left_hand_split()` and `right_hand_split()`.
48     """
49
50
51 @click.command()
52 @click.option(
53     '-l',
54     '--line-length',
55     type=int,
56     default=DEFAULT_LINE_LENGTH,
57     help='How many character per line to allow.',
58     show_default=True,
59 )
60 @click.option(
61     '--check',
62     is_flag=True,
63     help=(
64         "Don't write back the files, just return the status.  Return code 0 "
65         "means nothing changed.  Return code 1 means some files were "
66         "reformatted.  Return code 123 means there was an internal error."
67     ),
68 )
69 @click.option(
70     '--fast/--safe',
71     is_flag=True,
72     help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 )
74 @click.version_option(version=__version__)
75 @click.argument(
76     'src',
77     nargs=-1,
78     type=click.Path(
79         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
80     ),
81 )
82 @click.pass_context
83 def main(
84     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 ) -> None:
86     """The uncompromising code formatter."""
87     sources: List[Path] = []
88     for s in src:
89         p = Path(s)
90         if p.is_dir():
91             sources.extend(gen_python_files_in_dir(p))
92         elif p.is_file():
93             # if a file was explicitly given, we don't care about its extension
94             sources.append(p)
95         elif s == '-':
96             sources.append(Path('-'))
97         else:
98             err(f'invalid path: {s}')
99     if len(sources) == 0:
100         ctx.exit(0)
101     elif len(sources) == 1:
102         p = sources[0]
103         report = Report()
104         try:
105             if not p.is_file() and str(p) == '-':
106                 changed = format_stdin_to_stdout(
107                     line_length=line_length, fast=fast, write_back=not check
108                 )
109             else:
110                 changed = format_file_in_place(
111                     p, line_length=line_length, fast=fast, write_back=not check
112                 )
113             report.done(p, changed)
114         except Exception as exc:
115             report.failed(p, str(exc))
116         ctx.exit(report.return_code)
117     else:
118         loop = asyncio.get_event_loop()
119         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
120         return_code = 1
121         try:
122             return_code = loop.run_until_complete(
123                 schedule_formatting(
124                     sources, line_length, not check, fast, loop, executor
125                 )
126             )
127         finally:
128             loop.close()
129             ctx.exit(return_code)
130
131
132 async def schedule_formatting(
133     sources: List[Path],
134     line_length: int,
135     write_back: bool,
136     fast: bool,
137     loop: BaseEventLoop,
138     executor: Executor,
139 ) -> int:
140     tasks = {
141         src: loop.run_in_executor(
142             executor, format_file_in_place, src, line_length, fast, write_back
143         )
144         for src in sources
145     }
146     await asyncio.wait(tasks.values())
147     cancelled = []
148     report = Report()
149     for src, task in tasks.items():
150         if not task.done():
151             report.failed(src, 'timed out, cancelling')
152             task.cancel()
153             cancelled.append(task)
154         elif task.exception():
155             report.failed(src, str(task.exception()))
156         else:
157             report.done(src, task.result())
158     if cancelled:
159         await asyncio.wait(cancelled, timeout=2)
160     out('All done! ✨ 🍰 ✨')
161     click.echo(str(report))
162     return report.return_code
163
164
165 def format_file_in_place(
166     src: Path, line_length: int, fast: bool, write_back: bool = False
167 ) -> bool:
168     """Format the file and rewrite if changed. Return True if changed."""
169     with tokenize.open(src) as src_buffer:
170         src_contents = src_buffer.read()
171     try:
172         contents = format_file_contents(
173             src_contents, line_length=line_length, fast=fast
174         )
175     except NothingChanged:
176         return False
177
178     if write_back:
179         with open(src, "w", encoding=src_buffer.encoding) as f:
180             f.write(contents)
181     return True
182
183
184 def format_stdin_to_stdout(
185     line_length: int, fast: bool, write_back: bool = False
186 ) -> bool:
187     """Format file on stdin and pipe output to stdout. Return True if changed."""
188     contents = sys.stdin.read()
189     try:
190         contents = format_file_contents(contents, line_length=line_length, fast=fast)
191         return True
192
193     except NothingChanged:
194         return False
195
196     finally:
197         if write_back:
198             sys.stdout.write(contents)
199
200
201 def format_file_contents(
202     src_contents: str, line_length: int, fast: bool
203 ) -> FileContent:
204     """Reformats a file and returns its contents and encoding."""
205     if src_contents.strip() == '':
206         raise NothingChanged
207
208     dst_contents = format_str(src_contents, line_length=line_length)
209     if src_contents == dst_contents:
210         raise NothingChanged
211
212     if not fast:
213         assert_equivalent(src_contents, dst_contents)
214         assert_stable(src_contents, dst_contents, line_length=line_length)
215     return dst_contents
216
217
218 def format_str(src_contents: str, line_length: int) -> FileContent:
219     """Reformats a string and returns new contents."""
220     src_node = lib2to3_parse(src_contents)
221     dst_contents = ""
222     lines = LineGenerator()
223     elt = EmptyLineTracker()
224     py36 = is_python36(src_node)
225     empty_line = Line()
226     after = 0
227     for current_line in lines.visit(src_node):
228         for _ in range(after):
229             dst_contents += str(empty_line)
230         before, after = elt.maybe_empty_lines(current_line)
231         for _ in range(before):
232             dst_contents += str(empty_line)
233         for line in split_line(current_line, line_length=line_length, py36=py36):
234             dst_contents += str(line)
235     return dst_contents
236
237
238 def lib2to3_parse(src_txt: str) -> Node:
239     """Given a string with source, return the lib2to3 Node."""
240     grammar = pygram.python_grammar_no_print_statement
241     drv = driver.Driver(grammar, pytree.convert)
242     if src_txt[-1] != '\n':
243         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
244         src_txt += nl
245     try:
246         result = drv.parse_string(src_txt, True)
247     except ParseError as pe:
248         lineno, column = pe.context[1]
249         lines = src_txt.splitlines()
250         try:
251             faulty_line = lines[lineno - 1]
252         except IndexError:
253             faulty_line = "<line number missing in source>"
254         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
255
256     if isinstance(result, Leaf):
257         result = Node(syms.file_input, [result])
258     return result
259
260
261 def lib2to3_unparse(node: Node) -> str:
262     """Given a lib2to3 node, return its string representation."""
263     code = str(node)
264     return code
265
266
267 T = TypeVar('T')
268
269
270 class Visitor(Generic[T]):
271     """Basic lib2to3 visitor that yields things on visiting."""
272
273     def visit(self, node: LN) -> Iterator[T]:
274         if node.type < 256:
275             name = token.tok_name[node.type]
276         else:
277             name = type_repr(node.type)
278         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
279
280     def visit_default(self, node: LN) -> Iterator[T]:
281         if isinstance(node, Node):
282             for child in node.children:
283                 yield from self.visit(child)
284
285
286 @dataclass
287 class DebugVisitor(Visitor[T]):
288     tree_depth: int = 0
289
290     def visit_default(self, node: LN) -> Iterator[T]:
291         indent = ' ' * (2 * self.tree_depth)
292         if isinstance(node, Node):
293             _type = type_repr(node.type)
294             out(f'{indent}{_type}', fg='yellow')
295             self.tree_depth += 1
296             for child in node.children:
297                 yield from self.visit(child)
298
299             self.tree_depth -= 1
300             out(f'{indent}/{_type}', fg='yellow', bold=False)
301         else:
302             _type = token.tok_name.get(node.type, str(node.type))
303             out(f'{indent}{_type}', fg='blue', nl=False)
304             if node.prefix:
305                 # We don't have to handle prefixes for `Node` objects since
306                 # that delegates to the first child anyway.
307                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
308             out(f' {node.value!r}', fg='blue', bold=False)
309
310
311 KEYWORDS = set(keyword.kwlist)
312 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
313 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
314 STATEMENT = {
315     syms.if_stmt,
316     syms.while_stmt,
317     syms.for_stmt,
318     syms.try_stmt,
319     syms.except_clause,
320     syms.with_stmt,
321     syms.funcdef,
322     syms.classdef,
323 }
324 STANDALONE_COMMENT = 153
325 LOGIC_OPERATORS = {'and', 'or'}
326 COMPARATORS = {
327     token.LESS,
328     token.GREATER,
329     token.EQEQUAL,
330     token.NOTEQUAL,
331     token.LESSEQUAL,
332     token.GREATEREQUAL,
333 }
334 MATH_OPERATORS = {
335     token.PLUS,
336     token.MINUS,
337     token.STAR,
338     token.SLASH,
339     token.VBAR,
340     token.AMPER,
341     token.PERCENT,
342     token.CIRCUMFLEX,
343     token.TILDE,
344     token.LEFTSHIFT,
345     token.RIGHTSHIFT,
346     token.DOUBLESTAR,
347     token.DOUBLESLASH,
348 }
349 COMPREHENSION_PRIORITY = 20
350 COMMA_PRIORITY = 10
351 LOGIC_PRIORITY = 5
352 STRING_PRIORITY = 4
353 COMPARATOR_PRIORITY = 3
354 MATH_PRIORITY = 1
355
356
357 @dataclass
358 class BracketTracker:
359     depth: int = 0
360     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
361     delimiters: Dict[LeafID, Priority] = Factory(dict)
362     previous: Optional[Leaf] = None
363
364     def mark(self, leaf: Leaf) -> None:
365         if leaf.type == token.COMMENT:
366             return
367
368         if leaf.type in CLOSING_BRACKETS:
369             self.depth -= 1
370             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
371             leaf.opening_bracket = opening_bracket
372         leaf.bracket_depth = self.depth
373         if self.depth == 0:
374             delim = is_delimiter(leaf)
375             if delim:
376                 self.delimiters[id(leaf)] = delim
377             elif self.previous is not None:
378                 if leaf.type == token.STRING and self.previous.type == token.STRING:
379                     self.delimiters[id(self.previous)] = STRING_PRIORITY
380                 elif (
381                     leaf.type == token.NAME
382                     and leaf.value == 'for'
383                     and leaf.parent
384                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
385                 ):
386                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
387                 elif (
388                     leaf.type == token.NAME
389                     and leaf.value == 'if'
390                     and leaf.parent
391                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
392                 ):
393                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
394                 elif (
395                     leaf.type == token.NAME
396                     and leaf.value in LOGIC_OPERATORS
397                     and leaf.parent
398                 ):
399                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
400         if leaf.type in OPENING_BRACKETS:
401             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
402             self.depth += 1
403         self.previous = leaf
404
405     def any_open_brackets(self) -> bool:
406         """Returns True if there is an yet unmatched open bracket on the line."""
407         return bool(self.bracket_match)
408
409     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
410         """Returns the highest priority of a delimiter found on the line.
411
412         Values are consistent with what `is_delimiter()` returns.
413         """
414         return max(v for k, v in self.delimiters.items() if k not in exclude)
415
416
417 @dataclass
418 class Line:
419     depth: int = 0
420     leaves: List[Leaf] = Factory(list)
421     comments: Dict[LeafID, Leaf] = Factory(dict)
422     bracket_tracker: BracketTracker = Factory(BracketTracker)
423     inside_brackets: bool = False
424     has_for: bool = False
425     _for_loop_variable: bool = False
426
427     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
428         has_value = leaf.value.strip()
429         if not has_value:
430             return
431
432         if self.leaves and not preformatted:
433             # Note: at this point leaf.prefix should be empty except for
434             # imports, for which we only preserve newlines.
435             leaf.prefix += whitespace(leaf)
436         if self.inside_brackets or not preformatted:
437             self.maybe_decrement_after_for_loop_variable(leaf)
438             self.bracket_tracker.mark(leaf)
439             self.maybe_remove_trailing_comma(leaf)
440             self.maybe_increment_for_loop_variable(leaf)
441             if self.maybe_adapt_standalone_comment(leaf):
442                 return
443
444         if not self.append_comment(leaf):
445             self.leaves.append(leaf)
446
447     @property
448     def is_comment(self) -> bool:
449         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
450
451     @property
452     def is_decorator(self) -> bool:
453         return bool(self) and self.leaves[0].type == token.AT
454
455     @property
456     def is_import(self) -> bool:
457         return bool(self) and is_import(self.leaves[0])
458
459     @property
460     def is_class(self) -> bool:
461         return (
462             bool(self)
463             and self.leaves[0].type == token.NAME
464             and self.leaves[0].value == 'class'
465         )
466
467     @property
468     def is_def(self) -> bool:
469         """Also returns True for async defs."""
470         try:
471             first_leaf = self.leaves[0]
472         except IndexError:
473             return False
474
475         try:
476             second_leaf: Optional[Leaf] = self.leaves[1]
477         except IndexError:
478             second_leaf = None
479         return (
480             (first_leaf.type == token.NAME and first_leaf.value == 'def')
481             or (
482                 first_leaf.type == token.ASYNC
483                 and second_leaf is not None
484                 and second_leaf.type == token.NAME
485                 and second_leaf.value == 'def'
486             )
487         )
488
489     @property
490     def is_flow_control(self) -> bool:
491         return (
492             bool(self)
493             and self.leaves[0].type == token.NAME
494             and self.leaves[0].value in FLOW_CONTROL
495         )
496
497     @property
498     def is_yield(self) -> bool:
499         return (
500             bool(self)
501             and self.leaves[0].type == token.NAME
502             and self.leaves[0].value == 'yield'
503         )
504
505     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
506         if not (
507             self.leaves
508             and self.leaves[-1].type == token.COMMA
509             and closing.type in CLOSING_BRACKETS
510         ):
511             return False
512
513         if closing.type == token.RSQB or closing.type == token.RBRACE:
514             self.leaves.pop()
515             return True
516
517         # For parens let's check if it's safe to remove the comma.  If the
518         # trailing one is the only one, we might mistakenly change a tuple
519         # into a different type by removing the comma.
520         depth = closing.bracket_depth + 1
521         commas = 0
522         opening = closing.opening_bracket
523         for _opening_index, leaf in enumerate(self.leaves):
524             if leaf is opening:
525                 break
526
527         else:
528             return False
529
530         for leaf in self.leaves[_opening_index + 1:]:
531             if leaf is closing:
532                 break
533
534             bracket_depth = leaf.bracket_depth
535             if bracket_depth == depth and leaf.type == token.COMMA:
536                 commas += 1
537                 if leaf.parent and leaf.parent.type == syms.arglist:
538                     commas += 1
539                     break
540
541         if commas > 1:
542             self.leaves.pop()
543             return True
544
545         return False
546
547     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
548         """In a for loop, or comprehension, the variables are often unpacks.
549
550         To avoid splitting on the comma in this situation, we will increase
551         the depth of tokens between `for` and `in`.
552         """
553         if leaf.type == token.NAME and leaf.value == 'for':
554             self.has_for = True
555             self.bracket_tracker.depth += 1
556             self._for_loop_variable = True
557             return True
558
559         return False
560
561     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
562         # See `maybe_increment_for_loop_variable` above for explanation.
563         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
564             self.bracket_tracker.depth -= 1
565             self._for_loop_variable = False
566             return True
567
568         return False
569
570     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
571         """Hack a standalone comment to act as a trailing comment for line splitting.
572
573         If this line has brackets and a standalone `comment`, we need to adapt
574         it to be able to still reformat the line.
575
576         This is not perfect, the line to which the standalone comment gets
577         appended will appear "too long" when splitting.
578         """
579         if not (
580             comment.type == STANDALONE_COMMENT
581             and self.bracket_tracker.any_open_brackets()
582         ):
583             return False
584
585         comment.type = token.COMMENT
586         comment.prefix = '\n' + '    ' * (self.depth + 1)
587         return self.append_comment(comment)
588
589     def append_comment(self, comment: Leaf) -> bool:
590         if comment.type != token.COMMENT:
591             return False
592
593         try:
594             after = id(self.last_non_delimiter())
595         except LookupError:
596             comment.type = STANDALONE_COMMENT
597             comment.prefix = ''
598             return False
599
600         else:
601             if after in self.comments:
602                 self.comments[after].value += str(comment)
603             else:
604                 self.comments[after] = comment
605             return True
606
607     def last_non_delimiter(self) -> Leaf:
608         for i in range(len(self.leaves)):
609             last = self.leaves[-i - 1]
610             if not is_delimiter(last):
611                 return last
612
613         raise LookupError("No non-delimiters found")
614
615     def __str__(self) -> str:
616         if not self:
617             return '\n'
618
619         indent = '    ' * self.depth
620         leaves = iter(self.leaves)
621         first = next(leaves)
622         res = f'{first.prefix}{indent}{first.value}'
623         for leaf in leaves:
624             res += str(leaf)
625         for comment in self.comments.values():
626             res += str(comment)
627         return res + '\n'
628
629     def __bool__(self) -> bool:
630         return bool(self.leaves or self.comments)
631
632
633 @dataclass
634 class EmptyLineTracker:
635     """Provides a stateful method that returns the number of potential extra
636     empty lines needed before and after the currently processed line.
637
638     Note: this tracker works on lines that haven't been split yet.  It assumes
639     the prefix of the first leaf consists of optional newlines.  Those newlines
640     are consumed by `maybe_empty_lines()` and included in the computation.
641     """
642     previous_line: Optional[Line] = None
643     previous_after: int = 0
644     previous_defs: List[int] = Factory(list)
645
646     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
647         """Returns the number of extra empty lines before and after the `current_line`.
648
649         This is for separating `def`, `async def` and `class` with extra empty lines
650         (two on module-level), as well as providing an extra empty line after flow
651         control keywords to make them more prominent.
652         """
653         before, after = self._maybe_empty_lines(current_line)
654         before -= self.previous_after
655         self.previous_after = after
656         self.previous_line = current_line
657         return before, after
658
659     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
660         max_allowed = 1
661         if current_line.is_comment and current_line.depth == 0:
662             max_allowed = 2
663         if current_line.leaves:
664             # Consume the first leaf's extra newlines.
665             first_leaf = current_line.leaves[0]
666             before = first_leaf.prefix.count('\n')
667             before = min(before, max(before, max_allowed))
668             first_leaf.prefix = ''
669         else:
670             before = 0
671         depth = current_line.depth
672         while self.previous_defs and self.previous_defs[-1] >= depth:
673             self.previous_defs.pop()
674             before = 1 if depth else 2
675         is_decorator = current_line.is_decorator
676         if is_decorator or current_line.is_def or current_line.is_class:
677             if not is_decorator:
678                 self.previous_defs.append(depth)
679             if self.previous_line is None:
680                 # Don't insert empty lines before the first line in the file.
681                 return 0, 0
682
683             if self.previous_line and self.previous_line.is_decorator:
684                 # Don't insert empty lines between decorators.
685                 return 0, 0
686
687             newlines = 2
688             if current_line.depth:
689                 newlines -= 1
690             return newlines, 0
691
692         if current_line.is_flow_control:
693             return before, 1
694
695         if (
696             self.previous_line
697             and self.previous_line.is_import
698             and not current_line.is_import
699             and depth == self.previous_line.depth
700         ):
701             return (before or 1), 0
702
703         if (
704             self.previous_line
705             and self.previous_line.is_yield
706             and (not current_line.is_yield or depth != self.previous_line.depth)
707         ):
708             return (before or 1), 0
709
710         return before, 0
711
712
713 @dataclass
714 class LineGenerator(Visitor[Line]):
715     """Generates reformatted Line objects.  Empty lines are not emitted.
716
717     Note: destroys the tree it's visiting by mutating prefixes of its leaves
718     in ways that will no longer stringify to valid Python code on the tree.
719     """
720     current_line: Line = Factory(Line)
721
722     def line(self, indent: int = 0) -> Iterator[Line]:
723         """Generate a line.
724
725         If the line is empty, only emit if it makes sense.
726         If the line is too long, split it first and then generate.
727
728         If any lines were generated, set up a new current_line.
729         """
730         if not self.current_line:
731             self.current_line.depth += indent
732             return  # Line is empty, don't emit. Creating a new one unnecessary.
733
734         complete_line = self.current_line
735         self.current_line = Line(depth=complete_line.depth + indent)
736         yield complete_line
737
738     def visit_default(self, node: LN) -> Iterator[Line]:
739         if isinstance(node, Leaf):
740             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
741             for comment in generate_comments(node):
742                 if any_open_brackets:
743                     # any comment within brackets is subject to splitting
744                     self.current_line.append(comment)
745                 elif comment.type == token.COMMENT:
746                     # regular trailing comment
747                     self.current_line.append(comment)
748                     yield from self.line()
749
750                 else:
751                     # regular standalone comment
752                     yield from self.line()
753
754                     self.current_line.append(comment)
755                     yield from self.line()
756
757             normalize_prefix(node, inside_brackets=any_open_brackets)
758             if node.type not in WHITESPACE:
759                 self.current_line.append(node)
760         yield from super().visit_default(node)
761
762     def visit_INDENT(self, node: Node) -> Iterator[Line]:
763         yield from self.line(+1)
764         yield from self.visit_default(node)
765
766     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
767         yield from self.line(-1)
768
769     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
770         """Visit a statement.
771
772         The relevant Python language keywords for this statement are NAME leaves
773         within it.
774         """
775         for child in node.children:
776             if child.type == token.NAME and child.value in keywords:  # type: ignore
777                 yield from self.line()
778
779             yield from self.visit(child)
780
781     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
782         """A statement without nested statements."""
783         is_suite_like = node.parent and node.parent.type in STATEMENT
784         if is_suite_like:
785             yield from self.line(+1)
786             yield from self.visit_default(node)
787             yield from self.line(-1)
788
789         else:
790             yield from self.line()
791             yield from self.visit_default(node)
792
793     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
794         yield from self.line()
795
796         children = iter(node.children)
797         for child in children:
798             yield from self.visit(child)
799
800             if child.type == token.ASYNC:
801                 break
802
803         internal_stmt = next(children)
804         for child in internal_stmt.children:
805             yield from self.visit(child)
806
807     def visit_decorators(self, node: Node) -> Iterator[Line]:
808         for child in node.children:
809             yield from self.line()
810             yield from self.visit(child)
811
812     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
813         yield from self.line()
814
815     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
816         yield from self.visit_default(leaf)
817         yield from self.line()
818
819     def __attrs_post_init__(self) -> None:
820         """You are in a twisty little maze of passages."""
821         v = self.visit_stmt
822         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
823         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
824         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
825         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
826         self.visit_except_clause = partial(v, keywords={'except'})
827         self.visit_funcdef = partial(v, keywords={'def'})
828         self.visit_with_stmt = partial(v, keywords={'with'})
829         self.visit_classdef = partial(v, keywords={'class'})
830         self.visit_async_funcdef = self.visit_async_stmt
831         self.visit_decorated = self.visit_decorators
832
833
834 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
835 OPENING_BRACKETS = set(BRACKET.keys())
836 CLOSING_BRACKETS = set(BRACKET.values())
837 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
838 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
839
840
841 def whitespace(leaf: Leaf) -> str:  # noqa C901
842     """Return whitespace prefix if needed for the given `leaf`."""
843     NO = ''
844     SPACE = ' '
845     DOUBLESPACE = '  '
846     t = leaf.type
847     p = leaf.parent
848     v = leaf.value
849     if t in ALWAYS_NO_SPACE:
850         return NO
851
852     if t == token.COMMENT:
853         return DOUBLESPACE
854
855     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
856     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
857         return NO
858
859     prev = leaf.prev_sibling
860     if not prev:
861         prevp = preceding_leaf(p)
862         if not prevp or prevp.type in OPENING_BRACKETS:
863             return NO
864
865         if t == token.COLON:
866             return SPACE if prevp.type == token.COMMA else NO
867
868         if prevp.type == token.EQUAL:
869             if prevp.parent and prevp.parent.type in {
870                 syms.typedargslist,
871                 syms.varargslist,
872                 syms.parameters,
873                 syms.arglist,
874                 syms.argument,
875             }:
876                 return NO
877
878         elif prevp.type == token.DOUBLESTAR:
879             if prevp.parent and prevp.parent.type in {
880                 syms.typedargslist,
881                 syms.varargslist,
882                 syms.parameters,
883                 syms.arglist,
884                 syms.dictsetmaker,
885             }:
886                 return NO
887
888         elif prevp.type == token.COLON:
889             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
890                 return NO
891
892         elif (
893             prevp.parent
894             and prevp.parent.type in {syms.factor, syms.star_expr}
895             and prevp.type in MATH_OPERATORS
896         ):
897             return NO
898
899     elif prev.type in OPENING_BRACKETS:
900         return NO
901
902     if p.type in {syms.parameters, syms.arglist}:
903         # untyped function signatures or calls
904         if t == token.RPAR:
905             return NO
906
907         if not prev or prev.type != token.COMMA:
908             return NO
909
910     if p.type == syms.varargslist:
911         # lambdas
912         if t == token.RPAR:
913             return NO
914
915         if prev and prev.type != token.COMMA:
916             return NO
917
918     elif p.type == syms.typedargslist:
919         # typed function signatures
920         if not prev:
921             return NO
922
923         if t == token.EQUAL:
924             if prev.type != syms.tname:
925                 return NO
926
927         elif prev.type == token.EQUAL:
928             # A bit hacky: if the equal sign has whitespace, it means we
929             # previously found it's a typed argument.  So, we're using that, too.
930             return prev.prefix
931
932         elif prev.type != token.COMMA:
933             return NO
934
935     elif p.type == syms.tname:
936         # type names
937         if not prev:
938             prevp = preceding_leaf(p)
939             if not prevp or prevp.type != token.COMMA:
940                 return NO
941
942     elif p.type == syms.trailer:
943         # attributes and calls
944         if t == token.LPAR or t == token.RPAR:
945             return NO
946
947         if not prev:
948             if t == token.DOT:
949                 prevp = preceding_leaf(p)
950                 if not prevp or prevp.type != token.NUMBER:
951                     return NO
952
953             elif t == token.LSQB:
954                 return NO
955
956         elif prev.type != token.COMMA:
957             return NO
958
959     elif p.type == syms.argument:
960         # single argument
961         if t == token.EQUAL:
962             return NO
963
964         if not prev:
965             prevp = preceding_leaf(p)
966             if not prevp or prevp.type == token.LPAR:
967                 return NO
968
969         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
970             return NO
971
972     elif p.type == syms.decorator:
973         # decorators
974         return NO
975
976     elif p.type == syms.dotted_name:
977         if prev:
978             return NO
979
980         prevp = preceding_leaf(p)
981         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
982             return NO
983
984     elif p.type == syms.classdef:
985         if t == token.LPAR:
986             return NO
987
988         if prev and prev.type == token.LPAR:
989             return NO
990
991     elif p.type == syms.subscript:
992         # indexing
993         if not prev:
994             assert p.parent is not None, "subscripts are always parented"
995             if p.parent.type == syms.subscriptlist:
996                 return SPACE
997
998             return NO
999
1000         else:
1001             return NO
1002
1003     elif p.type == syms.atom:
1004         if prev and t == token.DOT:
1005             # dots, but not the first one.
1006             return NO
1007
1008     elif (
1009         p.type == syms.listmaker
1010         or p.type == syms.testlist_gexp
1011         or p.type == syms.subscriptlist
1012     ):
1013         # list interior, including unpacking
1014         if not prev:
1015             return NO
1016
1017     elif p.type == syms.dictsetmaker:
1018         # dict and set interior, including unpacking
1019         if not prev:
1020             return NO
1021
1022         if prev.type == token.DOUBLESTAR:
1023             return NO
1024
1025     elif p.type in {syms.factor, syms.star_expr}:
1026         # unary ops
1027         if not prev:
1028             prevp = preceding_leaf(p)
1029             if not prevp or prevp.type in OPENING_BRACKETS:
1030                 return NO
1031
1032             prevp_parent = prevp.parent
1033             assert prevp_parent is not None
1034             if prevp.type == token.COLON and prevp_parent.type in {
1035                 syms.subscript, syms.sliceop
1036             }:
1037                 return NO
1038
1039             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1040                 return NO
1041
1042         elif t == token.NAME or t == token.NUMBER:
1043             return NO
1044
1045     elif p.type == syms.import_from:
1046         if t == token.DOT:
1047             if prev and prev.type == token.DOT:
1048                 return NO
1049
1050         elif t == token.NAME:
1051             if v == 'import':
1052                 return SPACE
1053
1054             if prev and prev.type == token.DOT:
1055                 return NO
1056
1057     elif p.type == syms.sliceop:
1058         return NO
1059
1060     return SPACE
1061
1062
1063 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1064     """Returns the first leaf that precedes `node`, if any."""
1065     while node:
1066         res = node.prev_sibling
1067         if res:
1068             if isinstance(res, Leaf):
1069                 return res
1070
1071             try:
1072                 return list(res.leaves())[-1]
1073
1074             except IndexError:
1075                 return None
1076
1077         node = node.parent
1078     return None
1079
1080
1081 def is_delimiter(leaf: Leaf) -> int:
1082     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1083
1084     Higher numbers are higher priority.
1085     """
1086     if leaf.type == token.COMMA:
1087         return COMMA_PRIORITY
1088
1089     if leaf.type in COMPARATORS:
1090         return COMPARATOR_PRIORITY
1091
1092     if (
1093         leaf.type in MATH_OPERATORS
1094         and leaf.parent
1095         and leaf.parent.type not in {syms.factor, syms.star_expr}
1096     ):
1097         return MATH_PRIORITY
1098
1099     return 0
1100
1101
1102 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1103     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1104
1105     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1106     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1107     move because it does away with modifying the grammar to include all the
1108     possible places in which comments can be placed.
1109
1110     The sad consequence for us though is that comments don't "belong" anywhere.
1111     This is why this function generates simple parentless Leaf objects for
1112     comments.  We simply don't know what the correct parent should be.
1113
1114     No matter though, we can live without this.  We really only need to
1115     differentiate between inline and standalone comments.  The latter don't
1116     share the line with any code.
1117
1118     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1119     are emitted with a fake STANDALONE_COMMENT token identifier.
1120     """
1121     p = leaf.prefix
1122     if not p:
1123         return
1124
1125     if '#' not in p:
1126         return
1127
1128     nlines = 0
1129     for index, line in enumerate(p.split('\n')):
1130         line = line.lstrip()
1131         if not line:
1132             nlines += 1
1133         if not line.startswith('#'):
1134             continue
1135
1136         if index == 0 and leaf.type != token.ENDMARKER:
1137             comment_type = token.COMMENT  # simple trailing comment
1138         else:
1139             comment_type = STANDALONE_COMMENT
1140         yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
1141
1142         nlines = 0
1143
1144
1145 def make_comment(content: str) -> str:
1146     content = content.rstrip()
1147     if not content:
1148         return '#'
1149
1150     if content[0] == '#':
1151         content = content[1:]
1152     if content and content[0] not in {' ', '!', '#'}:
1153         content = ' ' + content
1154     return '#' + content
1155
1156
1157 def split_line(
1158     line: Line, line_length: int, inner: bool = False, py36: bool = False
1159 ) -> Iterator[Line]:
1160     """Splits a `line` into potentially many lines.
1161
1162     They should fit in the allotted `line_length` but might not be able to.
1163     `inner` signifies that there were a pair of brackets somewhere around the
1164     current `line`, possibly transitively. This means we can fallback to splitting
1165     by delimiters if the LHS/RHS don't yield any results.
1166
1167     If `py36` is True, splitting may generate syntax that is only compatible
1168     with Python 3.6 and later.
1169     """
1170     line_str = str(line).strip('\n')
1171     if len(line_str) <= line_length and '\n' not in line_str:
1172         yield line
1173         return
1174
1175     if line.is_def:
1176         split_funcs = [left_hand_split]
1177     elif line.inside_brackets:
1178         split_funcs = [delimiter_split]
1179         if '\n' not in line_str:
1180             # Only attempt RHS if we don't have multiline strings or comments
1181             # on this line.
1182             split_funcs.append(right_hand_split)
1183     else:
1184         split_funcs = [right_hand_split]
1185     for split_func in split_funcs:
1186         # We are accumulating lines in `result` because we might want to abort
1187         # mission and return the original line in the end, or attempt a different
1188         # split altogether.
1189         result: List[Line] = []
1190         try:
1191             for l in split_func(line, py36=py36):
1192                 if str(l).strip('\n') == line_str:
1193                     raise CannotSplit("Split function returned an unchanged result")
1194
1195                 result.extend(
1196                     split_line(l, line_length=line_length, inner=True, py36=py36)
1197                 )
1198         except CannotSplit as cs:
1199             continue
1200
1201         else:
1202             yield from result
1203             break
1204
1205     else:
1206         yield line
1207
1208
1209 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1210     """Split line into many lines, starting with the first matching bracket pair.
1211
1212     Note: this usually looks weird, only use this for function definitions.
1213     Prefer RHS otherwise.
1214     """
1215     head = Line(depth=line.depth)
1216     body = Line(depth=line.depth + 1, inside_brackets=True)
1217     tail = Line(depth=line.depth)
1218     tail_leaves: List[Leaf] = []
1219     body_leaves: List[Leaf] = []
1220     head_leaves: List[Leaf] = []
1221     current_leaves = head_leaves
1222     matching_bracket = None
1223     for leaf in line.leaves:
1224         if (
1225             current_leaves is body_leaves
1226             and leaf.type in CLOSING_BRACKETS
1227             and leaf.opening_bracket is matching_bracket
1228         ):
1229             current_leaves = tail_leaves if body_leaves else head_leaves
1230         current_leaves.append(leaf)
1231         if current_leaves is head_leaves:
1232             if leaf.type in OPENING_BRACKETS:
1233                 matching_bracket = leaf
1234                 current_leaves = body_leaves
1235     # Since body is a new indent level, remove spurious leading whitespace.
1236     if body_leaves:
1237         normalize_prefix(body_leaves[0], inside_brackets=True)
1238     # Build the new lines.
1239     for result, leaves in (
1240         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1241     ):
1242         for leaf in leaves:
1243             result.append(leaf, preformatted=True)
1244             comment_after = line.comments.get(id(leaf))
1245             if comment_after:
1246                 result.append(comment_after, preformatted=True)
1247     split_succeeded_or_raise(head, body, tail)
1248     for result in (head, body, tail):
1249         if result:
1250             yield result
1251
1252
1253 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1254     """Split line into many lines, starting with the last matching bracket pair."""
1255     head = Line(depth=line.depth)
1256     body = Line(depth=line.depth + 1, inside_brackets=True)
1257     tail = Line(depth=line.depth)
1258     tail_leaves: List[Leaf] = []
1259     body_leaves: List[Leaf] = []
1260     head_leaves: List[Leaf] = []
1261     current_leaves = tail_leaves
1262     opening_bracket = None
1263     for leaf in reversed(line.leaves):
1264         if current_leaves is body_leaves:
1265             if leaf is opening_bracket:
1266                 current_leaves = head_leaves if body_leaves else tail_leaves
1267         current_leaves.append(leaf)
1268         if current_leaves is tail_leaves:
1269             if leaf.type in CLOSING_BRACKETS:
1270                 opening_bracket = leaf.opening_bracket
1271                 current_leaves = body_leaves
1272     tail_leaves.reverse()
1273     body_leaves.reverse()
1274     head_leaves.reverse()
1275     # Since body is a new indent level, remove spurious leading whitespace.
1276     if body_leaves:
1277         normalize_prefix(body_leaves[0], inside_brackets=True)
1278     # Build the new lines.
1279     for result, leaves in (
1280         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1281     ):
1282         for leaf in leaves:
1283             result.append(leaf, preformatted=True)
1284             comment_after = line.comments.get(id(leaf))
1285             if comment_after:
1286                 result.append(comment_after, preformatted=True)
1287     split_succeeded_or_raise(head, body, tail)
1288     for result in (head, body, tail):
1289         if result:
1290             yield result
1291
1292
1293 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1294     tail_len = len(str(tail).strip())
1295     if not body:
1296         if tail_len == 0:
1297             raise CannotSplit("Splitting brackets produced the same line")
1298
1299         elif tail_len < 3:
1300             raise CannotSplit(
1301                 f"Splitting brackets on an empty body to save "
1302                 f"{tail_len} characters is not worth it"
1303             )
1304
1305
1306 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1307     """Split according to delimiters of the highest priority.
1308
1309     This kind of split doesn't increase indentation.
1310     If `py36` is True, the split will add trailing commas also in function
1311     signatures that contain * and **.
1312     """
1313     try:
1314         last_leaf = line.leaves[-1]
1315     except IndexError:
1316         raise CannotSplit("Line empty")
1317
1318     delimiters = line.bracket_tracker.delimiters
1319     try:
1320         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1321     except ValueError:
1322         raise CannotSplit("No delimiters found")
1323
1324     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1325     lowest_depth = sys.maxsize
1326     trailing_comma_safe = True
1327     for leaf in line.leaves:
1328         current_line.append(leaf, preformatted=True)
1329         comment_after = line.comments.get(id(leaf))
1330         if comment_after:
1331             current_line.append(comment_after, preformatted=True)
1332         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1333         if (
1334             leaf.bracket_depth == lowest_depth
1335             and leaf.type == token.STAR
1336             or leaf.type == token.DOUBLESTAR
1337         ):
1338             trailing_comma_safe = trailing_comma_safe and py36
1339         leaf_priority = delimiters.get(id(leaf))
1340         if leaf_priority == delimiter_priority:
1341             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1342             yield current_line
1343
1344             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1345     if current_line:
1346         if (
1347             delimiter_priority == COMMA_PRIORITY
1348             and current_line.leaves[-1].type != token.COMMA
1349             and trailing_comma_safe
1350         ):
1351             current_line.append(Leaf(token.COMMA, ','))
1352         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1353         yield current_line
1354
1355
1356 def is_import(leaf: Leaf) -> bool:
1357     """Returns True if the given leaf starts an import statement."""
1358     p = leaf.parent
1359     t = leaf.type
1360     v = leaf.value
1361     return bool(
1362         t == token.NAME
1363         and (
1364             (v == 'import' and p and p.type == syms.import_name)
1365             or (v == 'from' and p and p.type == syms.import_from)
1366         )
1367     )
1368
1369
1370 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1371     """Leave existing extra newlines if not `inside_brackets`.
1372
1373     Remove everything else.  Note: don't use backslashes for formatting or
1374     you'll lose your voting rights.
1375     """
1376     if not inside_brackets:
1377         spl = leaf.prefix.split('#')
1378         if '\\' not in spl[0]:
1379             nl_count = spl[-1].count('\n')
1380             if len(spl) > 1:
1381                 nl_count -= 1
1382             leaf.prefix = '\n' * nl_count
1383             return
1384
1385     leaf.prefix = ''
1386
1387
1388 def is_python36(node: Node) -> bool:
1389     """Returns True if the current file is using Python 3.6+ features.
1390
1391     Currently looking for:
1392     - f-strings; and
1393     - trailing commas after * or ** in function signatures.
1394     """
1395     for n in node.pre_order():
1396         if n.type == token.STRING:
1397             value_head = n.value[:2]  # type: ignore
1398             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1399                 return True
1400
1401         elif (
1402             n.type == syms.typedargslist
1403             and n.children
1404             and n.children[-1].type == token.COMMA
1405         ):
1406             for ch in n.children:
1407                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1408                     return True
1409
1410     return False
1411
1412
1413 PYTHON_EXTENSIONS = {'.py'}
1414 BLACKLISTED_DIRECTORIES = {
1415     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1416 }
1417
1418
1419 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1420     for child in path.iterdir():
1421         if child.is_dir():
1422             if child.name in BLACKLISTED_DIRECTORIES:
1423                 continue
1424
1425             yield from gen_python_files_in_dir(child)
1426
1427         elif child.suffix in PYTHON_EXTENSIONS:
1428             yield child
1429
1430
1431 @dataclass
1432 class Report:
1433     """Provides a reformatting counter."""
1434     change_count: int = 0
1435     same_count: int = 0
1436     failure_count: int = 0
1437
1438     def done(self, src: Path, changed: bool) -> None:
1439         """Increment the counter for successful reformatting. Write out a message."""
1440         if changed:
1441             out(f'reformatted {src}')
1442             self.change_count += 1
1443         else:
1444             out(f'{src} already well formatted, good job.', bold=False)
1445             self.same_count += 1
1446
1447     def failed(self, src: Path, message: str) -> None:
1448         """Increment the counter for failed reformatting. Write out a message."""
1449         err(f'error: cannot format {src}: {message}')
1450         self.failure_count += 1
1451
1452     @property
1453     def return_code(self) -> int:
1454         """Which return code should the app use considering the current state."""
1455         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1456         # 126 we have special returncodes reserved by the shell.
1457         if self.failure_count:
1458             return 123
1459
1460         elif self.change_count:
1461             return 1
1462
1463         return 0
1464
1465     def __str__(self) -> str:
1466         """A color report of the current state.
1467
1468         Use `click.unstyle` to remove colors.
1469         """
1470         report = []
1471         if self.change_count:
1472             s = 's' if self.change_count > 1 else ''
1473             report.append(
1474                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1475             )
1476         if self.same_count:
1477             s = 's' if self.same_count > 1 else ''
1478             report.append(f'{self.same_count} file{s} left unchanged')
1479         if self.failure_count:
1480             s = 's' if self.failure_count > 1 else ''
1481             report.append(
1482                 click.style(
1483                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1484                 )
1485             )
1486         return ', '.join(report) + '.'
1487
1488
1489 def assert_equivalent(src: str, dst: str) -> None:
1490     """Raises AssertionError if `src` and `dst` aren't equivalent.
1491
1492     This is a temporary sanity check until Black becomes stable.
1493     """
1494
1495     import ast
1496     import traceback
1497
1498     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1499         """Simple visitor generating strings to compare ASTs by content."""
1500         yield f"{'  ' * depth}{node.__class__.__name__}("
1501
1502         for field in sorted(node._fields):
1503             try:
1504                 value = getattr(node, field)
1505             except AttributeError:
1506                 continue
1507
1508             yield f"{'  ' * (depth+1)}{field}="
1509
1510             if isinstance(value, list):
1511                 for item in value:
1512                     if isinstance(item, ast.AST):
1513                         yield from _v(item, depth + 2)
1514
1515             elif isinstance(value, ast.AST):
1516                 yield from _v(value, depth + 2)
1517
1518             else:
1519                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1520
1521         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1522
1523     try:
1524         src_ast = ast.parse(src)
1525     except Exception as exc:
1526         raise AssertionError(f"cannot parse source: {exc}") from None
1527
1528     try:
1529         dst_ast = ast.parse(dst)
1530     except Exception as exc:
1531         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1532         raise AssertionError(
1533             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1534             f"Please report a bug on https://github.com/ambv/black/issues.  "
1535             f"This invalid output might be helpful: {log}"
1536         ) from None
1537
1538     src_ast_str = '\n'.join(_v(src_ast))
1539     dst_ast_str = '\n'.join(_v(dst_ast))
1540     if src_ast_str != dst_ast_str:
1541         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1542         raise AssertionError(
1543             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1544             f"the source.  "
1545             f"Please report a bug on https://github.com/ambv/black/issues.  "
1546             f"This diff might be helpful: {log}"
1547         ) from None
1548
1549
1550 def assert_stable(src: str, dst: str, line_length: int) -> None:
1551     """Raises AssertionError if `dst` reformats differently the second time.
1552
1553     This is a temporary sanity check until Black becomes stable.
1554     """
1555     newdst = format_str(dst, line_length=line_length)
1556     if dst != newdst:
1557         log = dump_to_file(
1558             diff(src, dst, 'source', 'first pass'),
1559             diff(dst, newdst, 'first pass', 'second pass'),
1560         )
1561         raise AssertionError(
1562             f"INTERNAL ERROR: Black produced different code on the second pass "
1563             f"of the formatter.  "
1564             f"Please report a bug on https://github.com/ambv/black/issues.  "
1565             f"This diff might be helpful: {log}"
1566         ) from None
1567
1568
1569 def dump_to_file(*output: str) -> str:
1570     """Dumps `output` to a temporary file. Returns path to the file."""
1571     import tempfile
1572
1573     with tempfile.NamedTemporaryFile(
1574         mode='w', prefix='blk_', suffix='.log', delete=False
1575     ) as f:
1576         for lines in output:
1577             f.write(lines)
1578             f.write('\n')
1579     return f.name
1580
1581
1582 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1583     """Returns a udiff string between strings `a` and `b`."""
1584     import difflib
1585
1586     a_lines = [line + '\n' for line in a.split('\n')]
1587     b_lines = [line + '\n' for line in b.split('\n')]
1588     return ''.join(
1589         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1590     )
1591
1592
1593 if __name__ == '__main__':
1594     main()