]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Omit extra space in kwarg unpacking if it's an argument
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a3"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by `format_file` when the reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by `left_hand_split()` and `right_hand_split()`.
48     """
49
50
51 @click.command()
52 @click.option(
53     '-l',
54     '--line-length',
55     type=int,
56     default=DEFAULT_LINE_LENGTH,
57     help='How many character per line to allow.',
58     show_default=True,
59 )
60 @click.option(
61     '--check',
62     is_flag=True,
63     help=(
64         "Don't write back the files, just return the status.  Return code 0 "
65         "means nothing changed.  Return code 1 means some files were "
66         "reformatted.  Return code 123 means there was an internal error."
67     ),
68 )
69 @click.option(
70     '--fast/--safe',
71     is_flag=True,
72     help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 )
74 @click.version_option(version=__version__)
75 @click.argument(
76     'src',
77     nargs=-1,
78     type=click.Path(
79         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
80     ),
81 )
82 @click.pass_context
83 def main(
84     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 ) -> None:
86     """The uncompromising code formatter."""
87     sources: List[Path] = []
88     for s in src:
89         p = Path(s)
90         if p.is_dir():
91             sources.extend(gen_python_files_in_dir(p))
92         elif p.is_file():
93             # if a file was explicitly given, we don't care about its extension
94             sources.append(p)
95         elif s == '-':
96             sources.append(Path('-'))
97         else:
98             err(f'invalid path: {s}')
99     if len(sources) == 0:
100         ctx.exit(0)
101     elif len(sources) == 1:
102         p = sources[0]
103         report = Report()
104         try:
105             if not p.is_file() and str(p) == '-':
106                 changed = format_stdin_to_stdout(
107                     line_length=line_length, fast=fast, write_back=not check
108                 )
109             else:
110                 changed = format_file_in_place(
111                     p, line_length=line_length, fast=fast, write_back=not check
112                 )
113             report.done(p, changed)
114         except Exception as exc:
115             report.failed(p, str(exc))
116         ctx.exit(report.return_code)
117     else:
118         loop = asyncio.get_event_loop()
119         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
120         return_code = 1
121         try:
122             return_code = loop.run_until_complete(
123                 schedule_formatting(
124                     sources, line_length, not check, fast, loop, executor
125                 )
126             )
127         finally:
128             loop.close()
129             ctx.exit(return_code)
130
131
132 async def schedule_formatting(
133     sources: List[Path],
134     line_length: int,
135     write_back: bool,
136     fast: bool,
137     loop: BaseEventLoop,
138     executor: Executor,
139 ) -> int:
140     tasks = {
141         src: loop.run_in_executor(
142             executor, format_file_in_place, src, line_length, fast, write_back
143         )
144         for src in sources
145     }
146     await asyncio.wait(tasks.values())
147     cancelled = []
148     report = Report()
149     for src, task in tasks.items():
150         if not task.done():
151             report.failed(src, 'timed out, cancelling')
152             task.cancel()
153             cancelled.append(task)
154         elif task.exception():
155             report.failed(src, str(task.exception()))
156         else:
157             report.done(src, task.result())
158     if cancelled:
159         await asyncio.wait(cancelled, timeout=2)
160     out('All done! ✨ 🍰 ✨')
161     click.echo(str(report))
162     return report.return_code
163
164
165 def format_file_in_place(
166     src: Path, line_length: int, fast: bool, write_back: bool = False
167 ) -> bool:
168     """Format the file and rewrite if changed. Return True if changed."""
169     with tokenize.open(src) as src_buffer:
170         src_contents = src_buffer.read()
171     try:
172         contents = format_file_contents(
173             src_contents, line_length=line_length, fast=fast
174         )
175     except NothingChanged:
176         return False
177
178     if write_back:
179         with open(src, "w", encoding=src_buffer.encoding) as f:
180             f.write(contents)
181     return True
182
183
184 def format_stdin_to_stdout(
185     line_length: int, fast: bool, write_back: bool = False
186 ) -> bool:
187     """Format file on stdin and pipe output to stdout. Return True if changed."""
188     contents = sys.stdin.read()
189     try:
190         contents = format_file_contents(contents, line_length=line_length, fast=fast)
191         return True
192
193     except NothingChanged:
194         return False
195
196     finally:
197         if write_back:
198             sys.stdout.write(contents)
199
200
201 def format_file_contents(
202     src_contents: str, line_length: int, fast: bool
203 ) -> FileContent:
204     """Reformats a file and returns its contents and encoding."""
205     if src_contents.strip() == '':
206         raise NothingChanged
207
208     dst_contents = format_str(src_contents, line_length=line_length)
209     if src_contents == dst_contents:
210         raise NothingChanged
211
212     if not fast:
213         assert_equivalent(src_contents, dst_contents)
214         assert_stable(src_contents, dst_contents, line_length=line_length)
215     return dst_contents
216
217
218 def format_str(src_contents: str, line_length: int) -> FileContent:
219     """Reformats a string and returns new contents."""
220     src_node = lib2to3_parse(src_contents)
221     dst_contents = ""
222     lines = LineGenerator()
223     elt = EmptyLineTracker()
224     py36 = is_python36(src_node)
225     empty_line = Line()
226     after = 0
227     for current_line in lines.visit(src_node):
228         for _ in range(after):
229             dst_contents += str(empty_line)
230         before, after = elt.maybe_empty_lines(current_line)
231         for _ in range(before):
232             dst_contents += str(empty_line)
233         for line in split_line(current_line, line_length=line_length, py36=py36):
234             dst_contents += str(line)
235     return dst_contents
236
237
238 def lib2to3_parse(src_txt: str) -> Node:
239     """Given a string with source, return the lib2to3 Node."""
240     grammar = pygram.python_grammar_no_print_statement
241     drv = driver.Driver(grammar, pytree.convert)
242     if src_txt[-1] != '\n':
243         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
244         src_txt += nl
245     try:
246         result = drv.parse_string(src_txt, True)
247     except ParseError as pe:
248         lineno, column = pe.context[1]
249         lines = src_txt.splitlines()
250         try:
251             faulty_line = lines[lineno - 1]
252         except IndexError:
253             faulty_line = "<line number missing in source>"
254         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
255
256     if isinstance(result, Leaf):
257         result = Node(syms.file_input, [result])
258     return result
259
260
261 def lib2to3_unparse(node: Node) -> str:
262     """Given a lib2to3 node, return its string representation."""
263     code = str(node)
264     return code
265
266
267 T = TypeVar('T')
268
269
270 class Visitor(Generic[T]):
271     """Basic lib2to3 visitor that yields things on visiting."""
272
273     def visit(self, node: LN) -> Iterator[T]:
274         if node.type < 256:
275             name = token.tok_name[node.type]
276         else:
277             name = type_repr(node.type)
278         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
279
280     def visit_default(self, node: LN) -> Iterator[T]:
281         if isinstance(node, Node):
282             for child in node.children:
283                 yield from self.visit(child)
284
285
286 @dataclass
287 class DebugVisitor(Visitor[T]):
288     tree_depth: int = 0
289
290     def visit_default(self, node: LN) -> Iterator[T]:
291         indent = ' ' * (2 * self.tree_depth)
292         if isinstance(node, Node):
293             _type = type_repr(node.type)
294             out(f'{indent}{_type}', fg='yellow')
295             self.tree_depth += 1
296             for child in node.children:
297                 yield from self.visit(child)
298
299             self.tree_depth -= 1
300             out(f'{indent}/{_type}', fg='yellow', bold=False)
301         else:
302             _type = token.tok_name.get(node.type, str(node.type))
303             out(f'{indent}{_type}', fg='blue', nl=False)
304             if node.prefix:
305                 # We don't have to handle prefixes for `Node` objects since
306                 # that delegates to the first child anyway.
307                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
308             out(f' {node.value!r}', fg='blue', bold=False)
309
310
311 KEYWORDS = set(keyword.kwlist)
312 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
313 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
314 STATEMENT = {
315     syms.if_stmt,
316     syms.while_stmt,
317     syms.for_stmt,
318     syms.try_stmt,
319     syms.except_clause,
320     syms.with_stmt,
321     syms.funcdef,
322     syms.classdef,
323 }
324 STANDALONE_COMMENT = 153
325 LOGIC_OPERATORS = {'and', 'or'}
326 COMPARATORS = {
327     token.LESS,
328     token.GREATER,
329     token.EQEQUAL,
330     token.NOTEQUAL,
331     token.LESSEQUAL,
332     token.GREATEREQUAL,
333 }
334 MATH_OPERATORS = {
335     token.PLUS,
336     token.MINUS,
337     token.STAR,
338     token.SLASH,
339     token.VBAR,
340     token.AMPER,
341     token.PERCENT,
342     token.CIRCUMFLEX,
343     token.TILDE,
344     token.LEFTSHIFT,
345     token.RIGHTSHIFT,
346     token.DOUBLESTAR,
347     token.DOUBLESLASH,
348 }
349 COMPREHENSION_PRIORITY = 20
350 COMMA_PRIORITY = 10
351 LOGIC_PRIORITY = 5
352 STRING_PRIORITY = 4
353 COMPARATOR_PRIORITY = 3
354 MATH_PRIORITY = 1
355
356
357 @dataclass
358 class BracketTracker:
359     depth: int = 0
360     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
361     delimiters: Dict[LeafID, Priority] = Factory(dict)
362     previous: Optional[Leaf] = None
363
364     def mark(self, leaf: Leaf) -> None:
365         if leaf.type == token.COMMENT:
366             return
367
368         if leaf.type in CLOSING_BRACKETS:
369             self.depth -= 1
370             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
371             leaf.opening_bracket = opening_bracket
372         leaf.bracket_depth = self.depth
373         if self.depth == 0:
374             delim = is_delimiter(leaf)
375             if delim:
376                 self.delimiters[id(leaf)] = delim
377             elif self.previous is not None:
378                 if leaf.type == token.STRING and self.previous.type == token.STRING:
379                     self.delimiters[id(self.previous)] = STRING_PRIORITY
380                 elif (
381                     leaf.type == token.NAME
382                     and leaf.value == 'for'
383                     and leaf.parent
384                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
385                 ):
386                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
387                 elif (
388                     leaf.type == token.NAME
389                     and leaf.value == 'if'
390                     and leaf.parent
391                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
392                 ):
393                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
394                 elif (
395                     leaf.type == token.NAME
396                     and leaf.value in LOGIC_OPERATORS
397                     and leaf.parent
398                 ):
399                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
400         if leaf.type in OPENING_BRACKETS:
401             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
402             self.depth += 1
403         self.previous = leaf
404
405     def any_open_brackets(self) -> bool:
406         """Returns True if there is an yet unmatched open bracket on the line."""
407         return bool(self.bracket_match)
408
409     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
410         """Returns the highest priority of a delimiter found on the line.
411
412         Values are consistent with what `is_delimiter()` returns.
413         """
414         return max(v for k, v in self.delimiters.items() if k not in exclude)
415
416
417 @dataclass
418 class Line:
419     depth: int = 0
420     leaves: List[Leaf] = Factory(list)
421     comments: Dict[LeafID, Leaf] = Factory(dict)
422     bracket_tracker: BracketTracker = Factory(BracketTracker)
423     inside_brackets: bool = False
424     has_for: bool = False
425     _for_loop_variable: bool = False
426
427     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
428         has_value = leaf.value.strip()
429         if not has_value:
430             return
431
432         if self.leaves and not preformatted:
433             # Note: at this point leaf.prefix should be empty except for
434             # imports, for which we only preserve newlines.
435             leaf.prefix += whitespace(leaf)
436         if self.inside_brackets or not preformatted:
437             self.maybe_decrement_after_for_loop_variable(leaf)
438             self.bracket_tracker.mark(leaf)
439             self.maybe_remove_trailing_comma(leaf)
440             self.maybe_increment_for_loop_variable(leaf)
441             if self.maybe_adapt_standalone_comment(leaf):
442                 return
443
444         if not self.append_comment(leaf):
445             self.leaves.append(leaf)
446
447     @property
448     def is_comment(self) -> bool:
449         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
450
451     @property
452     def is_decorator(self) -> bool:
453         return bool(self) and self.leaves[0].type == token.AT
454
455     @property
456     def is_import(self) -> bool:
457         return bool(self) and is_import(self.leaves[0])
458
459     @property
460     def is_class(self) -> bool:
461         return (
462             bool(self)
463             and self.leaves[0].type == token.NAME
464             and self.leaves[0].value == 'class'
465         )
466
467     @property
468     def is_def(self) -> bool:
469         """Also returns True for async defs."""
470         try:
471             first_leaf = self.leaves[0]
472         except IndexError:
473             return False
474
475         try:
476             second_leaf: Optional[Leaf] = self.leaves[1]
477         except IndexError:
478             second_leaf = None
479         return (
480             (first_leaf.type == token.NAME and first_leaf.value == 'def')
481             or (
482                 first_leaf.type == token.ASYNC
483                 and second_leaf is not None
484                 and second_leaf.type == token.NAME
485                 and second_leaf.value == 'def'
486             )
487         )
488
489     @property
490     def is_flow_control(self) -> bool:
491         return (
492             bool(self)
493             and self.leaves[0].type == token.NAME
494             and self.leaves[0].value in FLOW_CONTROL
495         )
496
497     @property
498     def is_yield(self) -> bool:
499         return (
500             bool(self)
501             and self.leaves[0].type == token.NAME
502             and self.leaves[0].value == 'yield'
503         )
504
505     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
506         if not (
507             self.leaves
508             and self.leaves[-1].type == token.COMMA
509             and closing.type in CLOSING_BRACKETS
510         ):
511             return False
512
513         if closing.type == token.RSQB or closing.type == token.RBRACE:
514             self.leaves.pop()
515             return True
516
517         # For parens let's check if it's safe to remove the comma.  If the
518         # trailing one is the only one, we might mistakenly change a tuple
519         # into a different type by removing the comma.
520         depth = closing.bracket_depth + 1
521         commas = 0
522         opening = closing.opening_bracket
523         for _opening_index, leaf in enumerate(self.leaves):
524             if leaf is opening:
525                 break
526
527         else:
528             return False
529
530         for leaf in self.leaves[_opening_index + 1:]:
531             if leaf is closing:
532                 break
533
534             bracket_depth = leaf.bracket_depth
535             if bracket_depth == depth and leaf.type == token.COMMA:
536                 commas += 1
537                 if leaf.parent and leaf.parent.type == syms.arglist:
538                     commas += 1
539                     break
540
541         if commas > 1:
542             self.leaves.pop()
543             return True
544
545         return False
546
547     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
548         """In a for loop, or comprehension, the variables are often unpacks.
549
550         To avoid splitting on the comma in this situation, we will increase
551         the depth of tokens between `for` and `in`.
552         """
553         if leaf.type == token.NAME and leaf.value == 'for':
554             self.has_for = True
555             self.bracket_tracker.depth += 1
556             self._for_loop_variable = True
557             return True
558
559         return False
560
561     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
562         # See `maybe_increment_for_loop_variable` above for explanation.
563         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
564             self.bracket_tracker.depth -= 1
565             self._for_loop_variable = False
566             return True
567
568         return False
569
570     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
571         """Hack a standalone comment to act as a trailing comment for line splitting.
572
573         If this line has brackets and a standalone `comment`, we need to adapt
574         it to be able to still reformat the line.
575
576         This is not perfect, the line to which the standalone comment gets
577         appended will appear "too long" when splitting.
578         """
579         if not (
580             comment.type == STANDALONE_COMMENT
581             and self.bracket_tracker.any_open_brackets()
582         ):
583             return False
584
585         comment.type = token.COMMENT
586         comment.prefix = '\n' + '    ' * (self.depth + 1)
587         return self.append_comment(comment)
588
589     def append_comment(self, comment: Leaf) -> bool:
590         if comment.type != token.COMMENT:
591             return False
592
593         try:
594             after = id(self.last_non_delimiter())
595         except LookupError:
596             comment.type = STANDALONE_COMMENT
597             comment.prefix = ''
598             return False
599
600         else:
601             if after in self.comments:
602                 self.comments[after].value += str(comment)
603             else:
604                 self.comments[after] = comment
605             return True
606
607     def last_non_delimiter(self) -> Leaf:
608         for i in range(len(self.leaves)):
609             last = self.leaves[-i - 1]
610             if not is_delimiter(last):
611                 return last
612
613         raise LookupError("No non-delimiters found")
614
615     def __str__(self) -> str:
616         if not self:
617             return '\n'
618
619         indent = '    ' * self.depth
620         leaves = iter(self.leaves)
621         first = next(leaves)
622         res = f'{first.prefix}{indent}{first.value}'
623         for leaf in leaves:
624             res += str(leaf)
625         for comment in self.comments.values():
626             res += str(comment)
627         return res + '\n'
628
629     def __bool__(self) -> bool:
630         return bool(self.leaves or self.comments)
631
632
633 @dataclass
634 class EmptyLineTracker:
635     """Provides a stateful method that returns the number of potential extra
636     empty lines needed before and after the currently processed line.
637
638     Note: this tracker works on lines that haven't been split yet.  It assumes
639     the prefix of the first leaf consists of optional newlines.  Those newlines
640     are consumed by `maybe_empty_lines()` and included in the computation.
641     """
642     previous_line: Optional[Line] = None
643     previous_after: int = 0
644     previous_defs: List[int] = Factory(list)
645
646     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
647         """Returns the number of extra empty lines before and after the `current_line`.
648
649         This is for separating `def`, `async def` and `class` with extra empty lines
650         (two on module-level), as well as providing an extra empty line after flow
651         control keywords to make them more prominent.
652         """
653         before, after = self._maybe_empty_lines(current_line)
654         before -= self.previous_after
655         self.previous_after = after
656         self.previous_line = current_line
657         return before, after
658
659     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
660         max_allowed = 1
661         if current_line.is_comment and current_line.depth == 0:
662             max_allowed = 2
663         if current_line.leaves:
664             # Consume the first leaf's extra newlines.
665             first_leaf = current_line.leaves[0]
666             before = first_leaf.prefix.count('\n')
667             before = min(before, max(before, max_allowed))
668             first_leaf.prefix = ''
669         else:
670             before = 0
671         depth = current_line.depth
672         while self.previous_defs and self.previous_defs[-1] >= depth:
673             self.previous_defs.pop()
674             before = 1 if depth else 2
675         is_decorator = current_line.is_decorator
676         if is_decorator or current_line.is_def or current_line.is_class:
677             if not is_decorator:
678                 self.previous_defs.append(depth)
679             if self.previous_line is None:
680                 # Don't insert empty lines before the first line in the file.
681                 return 0, 0
682
683             if self.previous_line and self.previous_line.is_decorator:
684                 # Don't insert empty lines between decorators.
685                 return 0, 0
686
687             newlines = 2
688             if current_line.depth:
689                 newlines -= 1
690             return newlines, 0
691
692         if current_line.is_flow_control:
693             return before, 1
694
695         if (
696             self.previous_line
697             and self.previous_line.is_import
698             and not current_line.is_import
699             and depth == self.previous_line.depth
700         ):
701             return (before or 1), 0
702
703         if (
704             self.previous_line
705             and self.previous_line.is_yield
706             and (not current_line.is_yield or depth != self.previous_line.depth)
707         ):
708             return (before or 1), 0
709
710         return before, 0
711
712
713 @dataclass
714 class LineGenerator(Visitor[Line]):
715     """Generates reformatted Line objects.  Empty lines are not emitted.
716
717     Note: destroys the tree it's visiting by mutating prefixes of its leaves
718     in ways that will no longer stringify to valid Python code on the tree.
719     """
720     current_line: Line = Factory(Line)
721
722     def line(self, indent: int = 0) -> Iterator[Line]:
723         """Generate a line.
724
725         If the line is empty, only emit if it makes sense.
726         If the line is too long, split it first and then generate.
727
728         If any lines were generated, set up a new current_line.
729         """
730         if not self.current_line:
731             self.current_line.depth += indent
732             return  # Line is empty, don't emit. Creating a new one unnecessary.
733
734         complete_line = self.current_line
735         self.current_line = Line(depth=complete_line.depth + indent)
736         yield complete_line
737
738     def visit_default(self, node: LN) -> Iterator[Line]:
739         if isinstance(node, Leaf):
740             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
741             for comment in generate_comments(node):
742                 if any_open_brackets:
743                     # any comment within brackets is subject to splitting
744                     self.current_line.append(comment)
745                 elif comment.type == token.COMMENT:
746                     # regular trailing comment
747                     self.current_line.append(comment)
748                     yield from self.line()
749
750                 else:
751                     # regular standalone comment
752                     yield from self.line()
753
754                     self.current_line.append(comment)
755                     yield from self.line()
756
757             normalize_prefix(node, inside_brackets=any_open_brackets)
758             if node.type not in WHITESPACE:
759                 self.current_line.append(node)
760         yield from super().visit_default(node)
761
762     def visit_INDENT(self, node: Node) -> Iterator[Line]:
763         yield from self.line(+1)
764         yield from self.visit_default(node)
765
766     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
767         yield from self.line(-1)
768
769     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
770         """Visit a statement.
771
772         The relevant Python language keywords for this statement are NAME leaves
773         within it.
774         """
775         for child in node.children:
776             if child.type == token.NAME and child.value in keywords:  # type: ignore
777                 yield from self.line()
778
779             yield from self.visit(child)
780
781     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
782         """A statement without nested statements."""
783         is_suite_like = node.parent and node.parent.type in STATEMENT
784         if is_suite_like:
785             yield from self.line(+1)
786             yield from self.visit_default(node)
787             yield from self.line(-1)
788
789         else:
790             yield from self.line()
791             yield from self.visit_default(node)
792
793     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
794         yield from self.line()
795
796         children = iter(node.children)
797         for child in children:
798             yield from self.visit(child)
799
800             if child.type == token.ASYNC:
801                 break
802
803         internal_stmt = next(children)
804         for child in internal_stmt.children:
805             yield from self.visit(child)
806
807     def visit_decorators(self, node: Node) -> Iterator[Line]:
808         for child in node.children:
809             yield from self.line()
810             yield from self.visit(child)
811
812     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
813         yield from self.line()
814
815     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
816         yield from self.visit_default(leaf)
817         yield from self.line()
818
819     def __attrs_post_init__(self) -> None:
820         """You are in a twisty little maze of passages."""
821         v = self.visit_stmt
822         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
823         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
824         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
825         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
826         self.visit_except_clause = partial(v, keywords={'except'})
827         self.visit_funcdef = partial(v, keywords={'def'})
828         self.visit_with_stmt = partial(v, keywords={'with'})
829         self.visit_classdef = partial(v, keywords={'class'})
830         self.visit_async_funcdef = self.visit_async_stmt
831         self.visit_decorated = self.visit_decorators
832
833
834 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
835 OPENING_BRACKETS = set(BRACKET.keys())
836 CLOSING_BRACKETS = set(BRACKET.values())
837 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
838 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
839
840
841 def whitespace(leaf: Leaf) -> str:  # noqa C901
842     """Return whitespace prefix if needed for the given `leaf`."""
843     NO = ''
844     SPACE = ' '
845     DOUBLESPACE = '  '
846     t = leaf.type
847     p = leaf.parent
848     v = leaf.value
849     if t in ALWAYS_NO_SPACE:
850         return NO
851
852     if t == token.COMMENT:
853         return DOUBLESPACE
854
855     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
856     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
857         return NO
858
859     prev = leaf.prev_sibling
860     if not prev:
861         prevp = preceding_leaf(p)
862         if not prevp or prevp.type in OPENING_BRACKETS:
863             return NO
864
865         if t == token.COLON:
866             return SPACE if prevp.type == token.COMMA else NO
867
868         if prevp.type == token.EQUAL:
869             if prevp.parent and prevp.parent.type in {
870                 syms.arglist,
871                 syms.argument,
872                 syms.parameters,
873                 syms.typedargslist,
874                 syms.varargslist,
875             }:
876                 return NO
877
878         elif prevp.type == token.DOUBLESTAR:
879             if prevp.parent and prevp.parent.type in {
880                 syms.arglist,
881                 syms.argument,
882                 syms.dictsetmaker,
883                 syms.parameters,
884                 syms.typedargslist,
885                 syms.varargslist,
886             }:
887                 return NO
888
889         elif prevp.type == token.COLON:
890             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
891                 return NO
892
893         elif (
894             prevp.parent
895             and prevp.parent.type in {syms.factor, syms.star_expr}
896             and prevp.type in MATH_OPERATORS
897         ):
898             return NO
899
900     elif prev.type in OPENING_BRACKETS:
901         return NO
902
903     if p.type in {syms.parameters, syms.arglist}:
904         # untyped function signatures or calls
905         if t == token.RPAR:
906             return NO
907
908         if not prev or prev.type != token.COMMA:
909             return NO
910
911     if p.type == syms.varargslist:
912         # lambdas
913         if t == token.RPAR:
914             return NO
915
916         if prev and prev.type != token.COMMA:
917             return NO
918
919     elif p.type == syms.typedargslist:
920         # typed function signatures
921         if not prev:
922             return NO
923
924         if t == token.EQUAL:
925             if prev.type != syms.tname:
926                 return NO
927
928         elif prev.type == token.EQUAL:
929             # A bit hacky: if the equal sign has whitespace, it means we
930             # previously found it's a typed argument.  So, we're using that, too.
931             return prev.prefix
932
933         elif prev.type != token.COMMA:
934             return NO
935
936     elif p.type == syms.tname:
937         # type names
938         if not prev:
939             prevp = preceding_leaf(p)
940             if not prevp or prevp.type != token.COMMA:
941                 return NO
942
943     elif p.type == syms.trailer:
944         # attributes and calls
945         if t == token.LPAR or t == token.RPAR:
946             return NO
947
948         if not prev:
949             if t == token.DOT:
950                 prevp = preceding_leaf(p)
951                 if not prevp or prevp.type != token.NUMBER:
952                     return NO
953
954             elif t == token.LSQB:
955                 return NO
956
957         elif prev.type != token.COMMA:
958             return NO
959
960     elif p.type == syms.argument:
961         # single argument
962         if t == token.EQUAL:
963             return NO
964
965         if not prev:
966             prevp = preceding_leaf(p)
967             if not prevp or prevp.type == token.LPAR:
968                 return NO
969
970         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
971             return NO
972
973     elif p.type == syms.decorator:
974         # decorators
975         return NO
976
977     elif p.type == syms.dotted_name:
978         if prev:
979             return NO
980
981         prevp = preceding_leaf(p)
982         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
983             return NO
984
985     elif p.type == syms.classdef:
986         if t == token.LPAR:
987             return NO
988
989         if prev and prev.type == token.LPAR:
990             return NO
991
992     elif p.type == syms.subscript:
993         # indexing
994         if not prev:
995             assert p.parent is not None, "subscripts are always parented"
996             if p.parent.type == syms.subscriptlist:
997                 return SPACE
998
999             return NO
1000
1001         else:
1002             return NO
1003
1004     elif p.type == syms.atom:
1005         if prev and t == token.DOT:
1006             # dots, but not the first one.
1007             return NO
1008
1009     elif (
1010         p.type == syms.listmaker
1011         or p.type == syms.testlist_gexp
1012         or p.type == syms.subscriptlist
1013     ):
1014         # list interior, including unpacking
1015         if not prev:
1016             return NO
1017
1018     elif p.type == syms.dictsetmaker:
1019         # dict and set interior, including unpacking
1020         if not prev:
1021             return NO
1022
1023         if prev.type == token.DOUBLESTAR:
1024             return NO
1025
1026     elif p.type in {syms.factor, syms.star_expr}:
1027         # unary ops
1028         if not prev:
1029             prevp = preceding_leaf(p)
1030             if not prevp or prevp.type in OPENING_BRACKETS:
1031                 return NO
1032
1033             prevp_parent = prevp.parent
1034             assert prevp_parent is not None
1035             if prevp.type == token.COLON and prevp_parent.type in {
1036                 syms.subscript, syms.sliceop
1037             }:
1038                 return NO
1039
1040             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1041                 return NO
1042
1043         elif t == token.NAME or t == token.NUMBER:
1044             return NO
1045
1046     elif p.type == syms.import_from:
1047         if t == token.DOT:
1048             if prev and prev.type == token.DOT:
1049                 return NO
1050
1051         elif t == token.NAME:
1052             if v == 'import':
1053                 return SPACE
1054
1055             if prev and prev.type == token.DOT:
1056                 return NO
1057
1058     elif p.type == syms.sliceop:
1059         return NO
1060
1061     return SPACE
1062
1063
1064 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1065     """Returns the first leaf that precedes `node`, if any."""
1066     while node:
1067         res = node.prev_sibling
1068         if res:
1069             if isinstance(res, Leaf):
1070                 return res
1071
1072             try:
1073                 return list(res.leaves())[-1]
1074
1075             except IndexError:
1076                 return None
1077
1078         node = node.parent
1079     return None
1080
1081
1082 def is_delimiter(leaf: Leaf) -> int:
1083     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1084
1085     Higher numbers are higher priority.
1086     """
1087     if leaf.type == token.COMMA:
1088         return COMMA_PRIORITY
1089
1090     if leaf.type in COMPARATORS:
1091         return COMPARATOR_PRIORITY
1092
1093     if (
1094         leaf.type in MATH_OPERATORS
1095         and leaf.parent
1096         and leaf.parent.type not in {syms.factor, syms.star_expr}
1097     ):
1098         return MATH_PRIORITY
1099
1100     return 0
1101
1102
1103 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1104     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1105
1106     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1107     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1108     move because it does away with modifying the grammar to include all the
1109     possible places in which comments can be placed.
1110
1111     The sad consequence for us though is that comments don't "belong" anywhere.
1112     This is why this function generates simple parentless Leaf objects for
1113     comments.  We simply don't know what the correct parent should be.
1114
1115     No matter though, we can live without this.  We really only need to
1116     differentiate between inline and standalone comments.  The latter don't
1117     share the line with any code.
1118
1119     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1120     are emitted with a fake STANDALONE_COMMENT token identifier.
1121     """
1122     p = leaf.prefix
1123     if not p:
1124         return
1125
1126     if '#' not in p:
1127         return
1128
1129     nlines = 0
1130     for index, line in enumerate(p.split('\n')):
1131         line = line.lstrip()
1132         if not line:
1133             nlines += 1
1134         if not line.startswith('#'):
1135             continue
1136
1137         if index == 0 and leaf.type != token.ENDMARKER:
1138             comment_type = token.COMMENT  # simple trailing comment
1139         else:
1140             comment_type = STANDALONE_COMMENT
1141         yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
1142
1143         nlines = 0
1144
1145
1146 def make_comment(content: str) -> str:
1147     content = content.rstrip()
1148     if not content:
1149         return '#'
1150
1151     if content[0] == '#':
1152         content = content[1:]
1153     if content and content[0] not in {' ', '!', '#'}:
1154         content = ' ' + content
1155     return '#' + content
1156
1157
1158 def split_line(
1159     line: Line, line_length: int, inner: bool = False, py36: bool = False
1160 ) -> Iterator[Line]:
1161     """Splits a `line` into potentially many lines.
1162
1163     They should fit in the allotted `line_length` but might not be able to.
1164     `inner` signifies that there were a pair of brackets somewhere around the
1165     current `line`, possibly transitively. This means we can fallback to splitting
1166     by delimiters if the LHS/RHS don't yield any results.
1167
1168     If `py36` is True, splitting may generate syntax that is only compatible
1169     with Python 3.6 and later.
1170     """
1171     line_str = str(line).strip('\n')
1172     if len(line_str) <= line_length and '\n' not in line_str:
1173         yield line
1174         return
1175
1176     if line.is_def:
1177         split_funcs = [left_hand_split]
1178     elif line.inside_brackets:
1179         split_funcs = [delimiter_split]
1180         if '\n' not in line_str:
1181             # Only attempt RHS if we don't have multiline strings or comments
1182             # on this line.
1183             split_funcs.append(right_hand_split)
1184     else:
1185         split_funcs = [right_hand_split]
1186     for split_func in split_funcs:
1187         # We are accumulating lines in `result` because we might want to abort
1188         # mission and return the original line in the end, or attempt a different
1189         # split altogether.
1190         result: List[Line] = []
1191         try:
1192             for l in split_func(line, py36=py36):
1193                 if str(l).strip('\n') == line_str:
1194                     raise CannotSplit("Split function returned an unchanged result")
1195
1196                 result.extend(
1197                     split_line(l, line_length=line_length, inner=True, py36=py36)
1198                 )
1199         except CannotSplit as cs:
1200             continue
1201
1202         else:
1203             yield from result
1204             break
1205
1206     else:
1207         yield line
1208
1209
1210 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1211     """Split line into many lines, starting with the first matching bracket pair.
1212
1213     Note: this usually looks weird, only use this for function definitions.
1214     Prefer RHS otherwise.
1215     """
1216     head = Line(depth=line.depth)
1217     body = Line(depth=line.depth + 1, inside_brackets=True)
1218     tail = Line(depth=line.depth)
1219     tail_leaves: List[Leaf] = []
1220     body_leaves: List[Leaf] = []
1221     head_leaves: List[Leaf] = []
1222     current_leaves = head_leaves
1223     matching_bracket = None
1224     for leaf in line.leaves:
1225         if (
1226             current_leaves is body_leaves
1227             and leaf.type in CLOSING_BRACKETS
1228             and leaf.opening_bracket is matching_bracket
1229         ):
1230             current_leaves = tail_leaves if body_leaves else head_leaves
1231         current_leaves.append(leaf)
1232         if current_leaves is head_leaves:
1233             if leaf.type in OPENING_BRACKETS:
1234                 matching_bracket = leaf
1235                 current_leaves = body_leaves
1236     # Since body is a new indent level, remove spurious leading whitespace.
1237     if body_leaves:
1238         normalize_prefix(body_leaves[0], inside_brackets=True)
1239     # Build the new lines.
1240     for result, leaves in (
1241         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1242     ):
1243         for leaf in leaves:
1244             result.append(leaf, preformatted=True)
1245             comment_after = line.comments.get(id(leaf))
1246             if comment_after:
1247                 result.append(comment_after, preformatted=True)
1248     split_succeeded_or_raise(head, body, tail)
1249     for result in (head, body, tail):
1250         if result:
1251             yield result
1252
1253
1254 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1255     """Split line into many lines, starting with the last matching bracket pair."""
1256     head = Line(depth=line.depth)
1257     body = Line(depth=line.depth + 1, inside_brackets=True)
1258     tail = Line(depth=line.depth)
1259     tail_leaves: List[Leaf] = []
1260     body_leaves: List[Leaf] = []
1261     head_leaves: List[Leaf] = []
1262     current_leaves = tail_leaves
1263     opening_bracket = None
1264     for leaf in reversed(line.leaves):
1265         if current_leaves is body_leaves:
1266             if leaf is opening_bracket:
1267                 current_leaves = head_leaves if body_leaves else tail_leaves
1268         current_leaves.append(leaf)
1269         if current_leaves is tail_leaves:
1270             if leaf.type in CLOSING_BRACKETS:
1271                 opening_bracket = leaf.opening_bracket
1272                 current_leaves = body_leaves
1273     tail_leaves.reverse()
1274     body_leaves.reverse()
1275     head_leaves.reverse()
1276     # Since body is a new indent level, remove spurious leading whitespace.
1277     if body_leaves:
1278         normalize_prefix(body_leaves[0], inside_brackets=True)
1279     # Build the new lines.
1280     for result, leaves in (
1281         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1282     ):
1283         for leaf in leaves:
1284             result.append(leaf, preformatted=True)
1285             comment_after = line.comments.get(id(leaf))
1286             if comment_after:
1287                 result.append(comment_after, preformatted=True)
1288     split_succeeded_or_raise(head, body, tail)
1289     for result in (head, body, tail):
1290         if result:
1291             yield result
1292
1293
1294 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1295     tail_len = len(str(tail).strip())
1296     if not body:
1297         if tail_len == 0:
1298             raise CannotSplit("Splitting brackets produced the same line")
1299
1300         elif tail_len < 3:
1301             raise CannotSplit(
1302                 f"Splitting brackets on an empty body to save "
1303                 f"{tail_len} characters is not worth it"
1304             )
1305
1306
1307 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1308     """Split according to delimiters of the highest priority.
1309
1310     This kind of split doesn't increase indentation.
1311     If `py36` is True, the split will add trailing commas also in function
1312     signatures that contain * and **.
1313     """
1314     try:
1315         last_leaf = line.leaves[-1]
1316     except IndexError:
1317         raise CannotSplit("Line empty")
1318
1319     delimiters = line.bracket_tracker.delimiters
1320     try:
1321         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1322     except ValueError:
1323         raise CannotSplit("No delimiters found")
1324
1325     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1326     lowest_depth = sys.maxsize
1327     trailing_comma_safe = True
1328     for leaf in line.leaves:
1329         current_line.append(leaf, preformatted=True)
1330         comment_after = line.comments.get(id(leaf))
1331         if comment_after:
1332             current_line.append(comment_after, preformatted=True)
1333         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1334         if (
1335             leaf.bracket_depth == lowest_depth
1336             and leaf.type == token.STAR
1337             or leaf.type == token.DOUBLESTAR
1338         ):
1339             trailing_comma_safe = trailing_comma_safe and py36
1340         leaf_priority = delimiters.get(id(leaf))
1341         if leaf_priority == delimiter_priority:
1342             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1343             yield current_line
1344
1345             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1346     if current_line:
1347         if (
1348             delimiter_priority == COMMA_PRIORITY
1349             and current_line.leaves[-1].type != token.COMMA
1350             and trailing_comma_safe
1351         ):
1352             current_line.append(Leaf(token.COMMA, ','))
1353         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1354         yield current_line
1355
1356
1357 def is_import(leaf: Leaf) -> bool:
1358     """Returns True if the given leaf starts an import statement."""
1359     p = leaf.parent
1360     t = leaf.type
1361     v = leaf.value
1362     return bool(
1363         t == token.NAME
1364         and (
1365             (v == 'import' and p and p.type == syms.import_name)
1366             or (v == 'from' and p and p.type == syms.import_from)
1367         )
1368     )
1369
1370
1371 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1372     """Leave existing extra newlines if not `inside_brackets`.
1373
1374     Remove everything else.  Note: don't use backslashes for formatting or
1375     you'll lose your voting rights.
1376     """
1377     if not inside_brackets:
1378         spl = leaf.prefix.split('#')
1379         if '\\' not in spl[0]:
1380             nl_count = spl[-1].count('\n')
1381             if len(spl) > 1:
1382                 nl_count -= 1
1383             leaf.prefix = '\n' * nl_count
1384             return
1385
1386     leaf.prefix = ''
1387
1388
1389 def is_python36(node: Node) -> bool:
1390     """Returns True if the current file is using Python 3.6+ features.
1391
1392     Currently looking for:
1393     - f-strings; and
1394     - trailing commas after * or ** in function signatures.
1395     """
1396     for n in node.pre_order():
1397         if n.type == token.STRING:
1398             value_head = n.value[:2]  # type: ignore
1399             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1400                 return True
1401
1402         elif (
1403             n.type == syms.typedargslist
1404             and n.children
1405             and n.children[-1].type == token.COMMA
1406         ):
1407             for ch in n.children:
1408                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1409                     return True
1410
1411     return False
1412
1413
1414 PYTHON_EXTENSIONS = {'.py'}
1415 BLACKLISTED_DIRECTORIES = {
1416     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1417 }
1418
1419
1420 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1421     for child in path.iterdir():
1422         if child.is_dir():
1423             if child.name in BLACKLISTED_DIRECTORIES:
1424                 continue
1425
1426             yield from gen_python_files_in_dir(child)
1427
1428         elif child.suffix in PYTHON_EXTENSIONS:
1429             yield child
1430
1431
1432 @dataclass
1433 class Report:
1434     """Provides a reformatting counter."""
1435     change_count: int = 0
1436     same_count: int = 0
1437     failure_count: int = 0
1438
1439     def done(self, src: Path, changed: bool) -> None:
1440         """Increment the counter for successful reformatting. Write out a message."""
1441         if changed:
1442             out(f'reformatted {src}')
1443             self.change_count += 1
1444         else:
1445             out(f'{src} already well formatted, good job.', bold=False)
1446             self.same_count += 1
1447
1448     def failed(self, src: Path, message: str) -> None:
1449         """Increment the counter for failed reformatting. Write out a message."""
1450         err(f'error: cannot format {src}: {message}')
1451         self.failure_count += 1
1452
1453     @property
1454     def return_code(self) -> int:
1455         """Which return code should the app use considering the current state."""
1456         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1457         # 126 we have special returncodes reserved by the shell.
1458         if self.failure_count:
1459             return 123
1460
1461         elif self.change_count:
1462             return 1
1463
1464         return 0
1465
1466     def __str__(self) -> str:
1467         """A color report of the current state.
1468
1469         Use `click.unstyle` to remove colors.
1470         """
1471         report = []
1472         if self.change_count:
1473             s = 's' if self.change_count > 1 else ''
1474             report.append(
1475                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1476             )
1477         if self.same_count:
1478             s = 's' if self.same_count > 1 else ''
1479             report.append(f'{self.same_count} file{s} left unchanged')
1480         if self.failure_count:
1481             s = 's' if self.failure_count > 1 else ''
1482             report.append(
1483                 click.style(
1484                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1485                 )
1486             )
1487         return ', '.join(report) + '.'
1488
1489
1490 def assert_equivalent(src: str, dst: str) -> None:
1491     """Raises AssertionError if `src` and `dst` aren't equivalent.
1492
1493     This is a temporary sanity check until Black becomes stable.
1494     """
1495
1496     import ast
1497     import traceback
1498
1499     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1500         """Simple visitor generating strings to compare ASTs by content."""
1501         yield f"{'  ' * depth}{node.__class__.__name__}("
1502
1503         for field in sorted(node._fields):
1504             try:
1505                 value = getattr(node, field)
1506             except AttributeError:
1507                 continue
1508
1509             yield f"{'  ' * (depth+1)}{field}="
1510
1511             if isinstance(value, list):
1512                 for item in value:
1513                     if isinstance(item, ast.AST):
1514                         yield from _v(item, depth + 2)
1515
1516             elif isinstance(value, ast.AST):
1517                 yield from _v(value, depth + 2)
1518
1519             else:
1520                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1521
1522         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1523
1524     try:
1525         src_ast = ast.parse(src)
1526     except Exception as exc:
1527         raise AssertionError(f"cannot parse source: {exc}") from None
1528
1529     try:
1530         dst_ast = ast.parse(dst)
1531     except Exception as exc:
1532         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1533         raise AssertionError(
1534             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1535             f"Please report a bug on https://github.com/ambv/black/issues.  "
1536             f"This invalid output might be helpful: {log}"
1537         ) from None
1538
1539     src_ast_str = '\n'.join(_v(src_ast))
1540     dst_ast_str = '\n'.join(_v(dst_ast))
1541     if src_ast_str != dst_ast_str:
1542         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1543         raise AssertionError(
1544             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1545             f"the source.  "
1546             f"Please report a bug on https://github.com/ambv/black/issues.  "
1547             f"This diff might be helpful: {log}"
1548         ) from None
1549
1550
1551 def assert_stable(src: str, dst: str, line_length: int) -> None:
1552     """Raises AssertionError if `dst` reformats differently the second time.
1553
1554     This is a temporary sanity check until Black becomes stable.
1555     """
1556     newdst = format_str(dst, line_length=line_length)
1557     if dst != newdst:
1558         log = dump_to_file(
1559             diff(src, dst, 'source', 'first pass'),
1560             diff(dst, newdst, 'first pass', 'second pass'),
1561         )
1562         raise AssertionError(
1563             f"INTERNAL ERROR: Black produced different code on the second pass "
1564             f"of the formatter.  "
1565             f"Please report a bug on https://github.com/ambv/black/issues.  "
1566             f"This diff might be helpful: {log}"
1567         ) from None
1568
1569
1570 def dump_to_file(*output: str) -> str:
1571     """Dumps `output` to a temporary file. Returns path to the file."""
1572     import tempfile
1573
1574     with tempfile.NamedTemporaryFile(
1575         mode='w', prefix='blk_', suffix='.log', delete=False
1576     ) as f:
1577         for lines in output:
1578             f.write(lines)
1579             f.write('\n')
1580     return f.name
1581
1582
1583 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1584     """Returns a udiff string between strings `a` and `b`."""
1585     import difflib
1586
1587     a_lines = [line + '\n' for line in a.split('\n')]
1588     b_lines = [line + '\n' for line in b.split('\n')]
1589     return ''.join(
1590         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1591     )
1592
1593
1594 if __name__ == '__main__':
1595     main()