]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Restore ability to format code with legacy usage of `async` as a name
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(
78         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
79     ),
80 )
81 @click.pass_context
82 def main(
83     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
84 ) -> None:
85     """The uncompromising code formatter."""
86     sources: List[Path] = []
87     for s in src:
88         p = Path(s)
89         if p.is_dir():
90             sources.extend(gen_python_files_in_dir(p))
91         elif p.is_file():
92             # if a file was explicitly given, we don't care about its extension
93             sources.append(p)
94         elif s == '-':
95             sources.append(Path('-'))
96         else:
97             err(f'invalid path: {s}')
98     if len(sources) == 0:
99         ctx.exit(0)
100     elif len(sources) == 1:
101         p = sources[0]
102         report = Report()
103         try:
104             if not p.is_file() and str(p) == '-':
105                 changed = format_stdin_to_stdout(line_length=line_length, fast=fast)
106             else:
107                 changed = format_file_in_place(
108                     p, line_length=line_length, fast=fast, write_back=not check
109                 )
110             report.done(p, changed)
111         except Exception as exc:
112             report.failed(p, str(exc))
113         ctx.exit(report.return_code)
114     else:
115         loop = asyncio.get_event_loop()
116         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
117         return_code = 1
118         try:
119             return_code = loop.run_until_complete(
120                 schedule_formatting(
121                     sources, line_length, not check, fast, loop, executor
122                 )
123             )
124         finally:
125             loop.close()
126             ctx.exit(return_code)
127
128
129 async def schedule_formatting(
130     sources: List[Path],
131     line_length: int,
132     write_back: bool,
133     fast: bool,
134     loop: BaseEventLoop,
135     executor: Executor,
136 ) -> int:
137     tasks = {
138         src: loop.run_in_executor(
139             executor, format_file_in_place, src, line_length, fast, write_back
140         )
141         for src in sources
142     }
143     await asyncio.wait(tasks.values())
144     cancelled = []
145     report = Report()
146     for src, task in tasks.items():
147         if not task.done():
148             report.failed(src, 'timed out, cancelling')
149             task.cancel()
150             cancelled.append(task)
151         elif task.exception():
152             report.failed(src, str(task.exception()))
153         else:
154             report.done(src, task.result())
155     if cancelled:
156         await asyncio.wait(cancelled, timeout=2)
157     out('All done! ✨ 🍰 ✨')
158     click.echo(str(report))
159     return report.return_code
160
161
162 def format_file_in_place(
163     src: Path, line_length: int, fast: bool, write_back: bool = False
164 ) -> bool:
165     """Format the file and rewrite if changed. Return True if changed."""
166     with tokenize.open(src) as src_buffer:
167         src_contents = src_buffer.read()
168     try:
169         contents = format_file_contents(
170             src_contents, line_length=line_length, fast=fast
171         )
172     except NothingChanged:
173         return False
174
175     if write_back:
176         with open(src, "w", encoding=src_buffer.encoding) as f:
177             f.write(contents)
178     return True
179
180
181 def format_stdin_to_stdout(line_length: int, fast: bool) -> bool:
182     """Format file on stdin and pipe output to stdout. Return True if changed."""
183     contents = sys.stdin.read()
184     try:
185         contents = format_file_contents(contents, line_length=line_length, fast=fast)
186         return True
187
188     except NothingChanged:
189         return False
190
191     finally:
192         sys.stdout.write(contents)
193
194
195 def format_file_contents(
196     src_contents: str, line_length: int, fast: bool
197 ) -> FileContent:
198     """Reformats a file and returns its contents and encoding."""
199     if src_contents.strip() == '':
200         raise NothingChanged
201
202     dst_contents = format_str(src_contents, line_length=line_length)
203     if src_contents == dst_contents:
204         raise NothingChanged
205
206     if not fast:
207         assert_equivalent(src_contents, dst_contents)
208         assert_stable(src_contents, dst_contents, line_length=line_length)
209     return dst_contents
210
211
212 def format_str(src_contents: str, line_length: int) -> FileContent:
213     """Reformats a string and returns new contents."""
214     src_node = lib2to3_parse(src_contents)
215     dst_contents = ""
216     comments: List[Line] = []
217     lines = LineGenerator()
218     elt = EmptyLineTracker()
219     py36 = is_python36(src_node)
220     empty_line = Line()
221     after = 0
222     for current_line in lines.visit(src_node):
223         for _ in range(after):
224             dst_contents += str(empty_line)
225         before, after = elt.maybe_empty_lines(current_line)
226         for _ in range(before):
227             dst_contents += str(empty_line)
228         if not current_line.is_comment:
229             for comment in comments:
230                 dst_contents += str(comment)
231             comments = []
232             for line in split_line(current_line, line_length=line_length, py36=py36):
233                 dst_contents += str(line)
234         else:
235             comments.append(current_line)
236     if comments:
237         if elt.previous_defs:
238             # Separate postscriptum comments from the last module-level def.
239             dst_contents += str(empty_line)
240             dst_contents += str(empty_line)
241         for comment in comments:
242             dst_contents += str(comment)
243     return dst_contents
244
245
246 def lib2to3_parse(src_txt: str) -> Node:
247     """Given a string with source, return the lib2to3 Node."""
248     grammar = pygram.python_grammar_no_print_statement
249     drv = driver.Driver(grammar, pytree.convert)
250     if src_txt[-1] != '\n':
251         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
252         src_txt += nl
253     try:
254         result = drv.parse_string(src_txt, True)
255     except ParseError as pe:
256         lineno, column = pe.context[1]
257         lines = src_txt.splitlines()
258         try:
259             faulty_line = lines[lineno - 1]
260         except IndexError:
261             faulty_line = "<line number missing in source>"
262         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
263
264     if isinstance(result, Leaf):
265         result = Node(syms.file_input, [result])
266     return result
267
268
269 def lib2to3_unparse(node: Node) -> str:
270     """Given a lib2to3 node, return its string representation."""
271     code = str(node)
272     return code
273
274
275 T = TypeVar('T')
276
277
278 class Visitor(Generic[T]):
279     """Basic lib2to3 visitor that yields things on visiting."""
280
281     def visit(self, node: LN) -> Iterator[T]:
282         if node.type < 256:
283             name = token.tok_name[node.type]
284         else:
285             name = type_repr(node.type)
286         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
287
288     def visit_default(self, node: LN) -> Iterator[T]:
289         if isinstance(node, Node):
290             for child in node.children:
291                 yield from self.visit(child)
292
293
294 @dataclass
295 class DebugVisitor(Visitor[T]):
296     tree_depth: int = 0
297
298     def visit_default(self, node: LN) -> Iterator[T]:
299         indent = ' ' * (2 * self.tree_depth)
300         if isinstance(node, Node):
301             _type = type_repr(node.type)
302             out(f'{indent}{_type}', fg='yellow')
303             self.tree_depth += 1
304             for child in node.children:
305                 yield from self.visit(child)
306
307             self.tree_depth -= 1
308             out(f'{indent}/{_type}', fg='yellow', bold=False)
309         else:
310             _type = token.tok_name.get(node.type, str(node.type))
311             out(f'{indent}{_type}', fg='blue', nl=False)
312             if node.prefix:
313                 # We don't have to handle prefixes for `Node` objects since
314                 # that delegates to the first child anyway.
315                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
316             out(f' {node.value!r}', fg='blue', bold=False)
317
318
319 KEYWORDS = set(keyword.kwlist)
320 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
321 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
322 STATEMENT = {
323     syms.if_stmt,
324     syms.while_stmt,
325     syms.for_stmt,
326     syms.try_stmt,
327     syms.except_clause,
328     syms.with_stmt,
329     syms.funcdef,
330     syms.classdef,
331 }
332 STANDALONE_COMMENT = 153
333 LOGIC_OPERATORS = {'and', 'or'}
334 COMPARATORS = {
335     token.LESS,
336     token.GREATER,
337     token.EQEQUAL,
338     token.NOTEQUAL,
339     token.LESSEQUAL,
340     token.GREATEREQUAL,
341 }
342 MATH_OPERATORS = {
343     token.PLUS,
344     token.MINUS,
345     token.STAR,
346     token.SLASH,
347     token.VBAR,
348     token.AMPER,
349     token.PERCENT,
350     token.CIRCUMFLEX,
351     token.LEFTSHIFT,
352     token.RIGHTSHIFT,
353     token.DOUBLESTAR,
354     token.DOUBLESLASH,
355 }
356 COMPREHENSION_PRIORITY = 20
357 COMMA_PRIORITY = 10
358 LOGIC_PRIORITY = 5
359 STRING_PRIORITY = 4
360 COMPARATOR_PRIORITY = 3
361 MATH_PRIORITY = 1
362
363
364 @dataclass
365 class BracketTracker:
366     depth: int = 0
367     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
368     delimiters: Dict[LeafID, Priority] = Factory(dict)
369     previous: Optional[Leaf] = None
370
371     def mark(self, leaf: Leaf) -> None:
372         if leaf.type == token.COMMENT:
373             return
374
375         if leaf.type in CLOSING_BRACKETS:
376             self.depth -= 1
377             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
378             leaf.opening_bracket = opening_bracket
379         leaf.bracket_depth = self.depth
380         if self.depth == 0:
381             delim = is_delimiter(leaf)
382             if delim:
383                 self.delimiters[id(leaf)] = delim
384             elif self.previous is not None:
385                 if leaf.type == token.STRING and self.previous.type == token.STRING:
386                     self.delimiters[id(self.previous)] = STRING_PRIORITY
387                 elif (
388                     leaf.type == token.NAME
389                     and leaf.value == 'for'
390                     and leaf.parent
391                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
392                 ):
393                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
394                 elif (
395                     leaf.type == token.NAME
396                     and leaf.value == 'if'
397                     and leaf.parent
398                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
399                 ):
400                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
401                 elif (
402                     leaf.type == token.NAME
403                     and leaf.value in LOGIC_OPERATORS
404                     and leaf.parent
405                 ):
406                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
407         if leaf.type in OPENING_BRACKETS:
408             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
409             self.depth += 1
410         self.previous = leaf
411
412     def any_open_brackets(self) -> bool:
413         """Returns True if there is an yet unmatched open bracket on the line."""
414         return bool(self.bracket_match)
415
416     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
417         """Returns the highest priority of a delimiter found on the line.
418
419         Values are consistent with what `is_delimiter()` returns.
420         """
421         return max(v for k, v in self.delimiters.items() if k not in exclude)
422
423
424 @dataclass
425 class Line:
426     depth: int = 0
427     leaves: List[Leaf] = Factory(list)
428     comments: Dict[LeafID, Leaf] = Factory(dict)
429     bracket_tracker: BracketTracker = Factory(BracketTracker)
430     inside_brackets: bool = False
431     has_for: bool = False
432     _for_loop_variable: bool = False
433
434     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
435         has_value = leaf.value.strip()
436         if not has_value:
437             return
438
439         if self.leaves and not preformatted:
440             # Note: at this point leaf.prefix should be empty except for
441             # imports, for which we only preserve newlines.
442             leaf.prefix += whitespace(leaf)
443         if self.inside_brackets or not preformatted:
444             self.maybe_decrement_after_for_loop_variable(leaf)
445             self.bracket_tracker.mark(leaf)
446             self.maybe_remove_trailing_comma(leaf)
447             self.maybe_increment_for_loop_variable(leaf)
448             if self.maybe_adapt_standalone_comment(leaf):
449                 return
450
451         if not self.append_comment(leaf):
452             self.leaves.append(leaf)
453
454     @property
455     def is_comment(self) -> bool:
456         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
457
458     @property
459     def is_decorator(self) -> bool:
460         return bool(self) and self.leaves[0].type == token.AT
461
462     @property
463     def is_import(self) -> bool:
464         return bool(self) and is_import(self.leaves[0])
465
466     @property
467     def is_class(self) -> bool:
468         return (
469             bool(self)
470             and self.leaves[0].type == token.NAME
471             and self.leaves[0].value == 'class'
472         )
473
474     @property
475     def is_def(self) -> bool:
476         """Also returns True for async defs."""
477         try:
478             first_leaf = self.leaves[0]
479         except IndexError:
480             return False
481
482         try:
483             second_leaf: Optional[Leaf] = self.leaves[1]
484         except IndexError:
485             second_leaf = None
486         return (
487             (first_leaf.type == token.NAME and first_leaf.value == 'def')
488             or (
489                 first_leaf.type == token.ASYNC
490                 and second_leaf is not None
491                 and second_leaf.type == token.NAME
492                 and second_leaf.value == 'def'
493             )
494         )
495
496     @property
497     def is_flow_control(self) -> bool:
498         return (
499             bool(self)
500             and self.leaves[0].type == token.NAME
501             and self.leaves[0].value in FLOW_CONTROL
502         )
503
504     @property
505     def is_yield(self) -> bool:
506         return (
507             bool(self)
508             and self.leaves[0].type == token.NAME
509             and self.leaves[0].value == 'yield'
510         )
511
512     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
513         if not (
514             self.leaves
515             and self.leaves[-1].type == token.COMMA
516             and closing.type in CLOSING_BRACKETS
517         ):
518             return False
519
520         if closing.type == token.RSQB or closing.type == token.RBRACE:
521             self.leaves.pop()
522             return True
523
524         # For parens let's check if it's safe to remove the comma.  If the
525         # trailing one is the only one, we might mistakenly change a tuple
526         # into a different type by removing the comma.
527         depth = closing.bracket_depth + 1
528         commas = 0
529         opening = closing.opening_bracket
530         for _opening_index, leaf in enumerate(self.leaves):
531             if leaf is opening:
532                 break
533
534         else:
535             return False
536
537         for leaf in self.leaves[_opening_index + 1:]:
538             if leaf is closing:
539                 break
540
541             bracket_depth = leaf.bracket_depth
542             if bracket_depth == depth and leaf.type == token.COMMA:
543                 commas += 1
544                 if leaf.parent and leaf.parent.type == syms.arglist:
545                     commas += 1
546                     break
547
548         if commas > 1:
549             self.leaves.pop()
550             return True
551
552         return False
553
554     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
555         """In a for loop, or comprehension, the variables are often unpacks.
556
557         To avoid splitting on the comma in this situation, we will increase
558         the depth of tokens between `for` and `in`.
559         """
560         if leaf.type == token.NAME and leaf.value == 'for':
561             self.has_for = True
562             self.bracket_tracker.depth += 1
563             self._for_loop_variable = True
564             return True
565
566         return False
567
568     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
569         # See `maybe_increment_for_loop_variable` above for explanation.
570         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
571             self.bracket_tracker.depth -= 1
572             self._for_loop_variable = False
573             return True
574
575         return False
576
577     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
578         """Hack a standalone comment to act as a trailing comment for line splitting.
579
580         If this line has brackets and a standalone `comment`, we need to adapt
581         it to be able to still reformat the line.
582
583         This is not perfect, the line to which the standalone comment gets
584         appended will appear "too long" when splitting.
585         """
586         if not (
587             comment.type == STANDALONE_COMMENT
588             and self.bracket_tracker.any_open_brackets()
589         ):
590             return False
591
592         comment.type = token.COMMENT
593         comment.prefix = '\n' + '    ' * (self.depth + 1)
594         return self.append_comment(comment)
595
596     def append_comment(self, comment: Leaf) -> bool:
597         if comment.type != token.COMMENT:
598             return False
599
600         try:
601             after = id(self.last_non_delimiter())
602         except LookupError:
603             comment.type = STANDALONE_COMMENT
604             comment.prefix = ''
605             return False
606
607         else:
608             if after in self.comments:
609                 self.comments[after].value += str(comment)
610             else:
611                 self.comments[after] = comment
612             return True
613
614     def last_non_delimiter(self) -> Leaf:
615         for i in range(len(self.leaves)):
616             last = self.leaves[-i - 1]
617             if not is_delimiter(last):
618                 return last
619
620         raise LookupError("No non-delimiters found")
621
622     def __str__(self) -> str:
623         if not self:
624             return '\n'
625
626         indent = '    ' * self.depth
627         leaves = iter(self.leaves)
628         first = next(leaves)
629         res = f'{first.prefix}{indent}{first.value}'
630         for leaf in leaves:
631             res += str(leaf)
632         for comment in self.comments.values():
633             res += str(comment)
634         return res + '\n'
635
636     def __bool__(self) -> bool:
637         return bool(self.leaves or self.comments)
638
639
640 @dataclass
641 class EmptyLineTracker:
642     """Provides a stateful method that returns the number of potential extra
643     empty lines needed before and after the currently processed line.
644
645     Note: this tracker works on lines that haven't been split yet.  It assumes
646     the prefix of the first leaf consists of optional newlines.  Those newlines
647     are consumed by `maybe_empty_lines()` and included in the computation.
648     """
649     previous_line: Optional[Line] = None
650     previous_after: int = 0
651     previous_defs: List[int] = Factory(list)
652
653     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
654         """Returns the number of extra empty lines before and after the `current_line`.
655
656         This is for separating `def`, `async def` and `class` with extra empty lines
657         (two on module-level), as well as providing an extra empty line after flow
658         control keywords to make them more prominent.
659         """
660         if current_line.is_comment:
661             # Don't count standalone comments towards previous empty lines.
662             return 0, 0
663
664         before, after = self._maybe_empty_lines(current_line)
665         before -= self.previous_after
666         self.previous_after = after
667         self.previous_line = current_line
668         return before, after
669
670     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
671         if current_line.leaves:
672             # Consume the first leaf's extra newlines.
673             first_leaf = current_line.leaves[0]
674             before = int('\n' in first_leaf.prefix)
675             first_leaf.prefix = ''
676         else:
677             before = 0
678         depth = current_line.depth
679         while self.previous_defs and self.previous_defs[-1] >= depth:
680             self.previous_defs.pop()
681             before = 1 if depth else 2
682         is_decorator = current_line.is_decorator
683         if is_decorator or current_line.is_def or current_line.is_class:
684             if not is_decorator:
685                 self.previous_defs.append(depth)
686             if self.previous_line is None:
687                 # Don't insert empty lines before the first line in the file.
688                 return 0, 0
689
690             if self.previous_line and self.previous_line.is_decorator:
691                 # Don't insert empty lines between decorators.
692                 return 0, 0
693
694             newlines = 2
695             if current_line.depth:
696                 newlines -= 1
697             return newlines, 0
698
699         if current_line.is_flow_control:
700             return before, 1
701
702         if (
703             self.previous_line
704             and self.previous_line.is_import
705             and not current_line.is_import
706             and depth == self.previous_line.depth
707         ):
708             return (before or 1), 0
709
710         if (
711             self.previous_line
712             and self.previous_line.is_yield
713             and (not current_line.is_yield or depth != self.previous_line.depth)
714         ):
715             return (before or 1), 0
716
717         return before, 0
718
719
720 @dataclass
721 class LineGenerator(Visitor[Line]):
722     """Generates reformatted Line objects.  Empty lines are not emitted.
723
724     Note: destroys the tree it's visiting by mutating prefixes of its leaves
725     in ways that will no longer stringify to valid Python code on the tree.
726     """
727     current_line: Line = Factory(Line)
728     standalone_comments: List[Leaf] = Factory(list)
729
730     def line(self, indent: int = 0) -> Iterator[Line]:
731         """Generate a line.
732
733         If the line is empty, only emit if it makes sense.
734         If the line is too long, split it first and then generate.
735
736         If any lines were generated, set up a new current_line.
737         """
738         if not self.current_line:
739             self.current_line.depth += indent
740             return  # Line is empty, don't emit. Creating a new one unnecessary.
741
742         complete_line = self.current_line
743         self.current_line = Line(depth=complete_line.depth + indent)
744         yield complete_line
745
746     def visit_default(self, node: LN) -> Iterator[Line]:
747         if isinstance(node, Leaf):
748             for comment in generate_comments(node):
749                 if self.current_line.bracket_tracker.any_open_brackets():
750                     # any comment within brackets is subject to splitting
751                     self.current_line.append(comment)
752                 elif comment.type == token.COMMENT:
753                     # regular trailing comment
754                     self.current_line.append(comment)
755                     yield from self.line()
756
757                 else:
758                     # regular standalone comment, to be processed later (see
759                     # docstring in `generate_comments()`
760                     self.standalone_comments.append(comment)
761             normalize_prefix(node)
762             if node.type not in WHITESPACE:
763                 for comment in self.standalone_comments:
764                     yield from self.line()
765
766                     self.current_line.append(comment)
767                     yield from self.line()
768
769                 self.standalone_comments = []
770                 self.current_line.append(node)
771         yield from super().visit_default(node)
772
773     def visit_suite(self, node: Node) -> Iterator[Line]:
774         """Body of a statement after a colon."""
775         children = iter(node.children)
776         # Process newline before indenting.  It might contain an inline
777         # comment that should go right after the colon.
778         newline = next(children)
779         yield from self.visit(newline)
780         yield from self.line(+1)
781
782         for child in children:
783             yield from self.visit(child)
784
785         yield from self.line(-1)
786
787     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
788         """Visit a statement.
789
790         The relevant Python language keywords for this statement are NAME leaves
791         within it.
792         """
793         for child in node.children:
794             if child.type == token.NAME and child.value in keywords:  # type: ignore
795                 yield from self.line()
796
797             yield from self.visit(child)
798
799     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
800         """A statement without nested statements."""
801         is_suite_like = node.parent and node.parent.type in STATEMENT
802         if is_suite_like:
803             yield from self.line(+1)
804             yield from self.visit_default(node)
805             yield from self.line(-1)
806
807         else:
808             yield from self.line()
809             yield from self.visit_default(node)
810
811     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
812         yield from self.line()
813
814         children = iter(node.children)
815         for child in children:
816             yield from self.visit(child)
817
818             if child.type == token.ASYNC:
819                 break
820
821         internal_stmt = next(children)
822         for child in internal_stmt.children:
823             yield from self.visit(child)
824
825     def visit_decorators(self, node: Node) -> Iterator[Line]:
826         for child in node.children:
827             yield from self.line()
828             yield from self.visit(child)
829
830     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
831         yield from self.line()
832
833     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
834         yield from self.visit_default(leaf)
835         yield from self.line()
836
837     def __attrs_post_init__(self) -> None:
838         """You are in a twisty little maze of passages."""
839         v = self.visit_stmt
840         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
841         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
842         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
843         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
844         self.visit_except_clause = partial(v, keywords={'except'})
845         self.visit_funcdef = partial(v, keywords={'def'})
846         self.visit_with_stmt = partial(v, keywords={'with'})
847         self.visit_classdef = partial(v, keywords={'class'})
848         self.visit_async_funcdef = self.visit_async_stmt
849         self.visit_decorated = self.visit_decorators
850
851
852 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
853 OPENING_BRACKETS = set(BRACKET.keys())
854 CLOSING_BRACKETS = set(BRACKET.values())
855 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
856 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
857
858
859 def whitespace(leaf: Leaf) -> str:  # noqa C901
860     """Return whitespace prefix if needed for the given `leaf`."""
861     NO = ''
862     SPACE = ' '
863     DOUBLESPACE = '  '
864     t = leaf.type
865     p = leaf.parent
866     v = leaf.value
867     if t in ALWAYS_NO_SPACE:
868         return NO
869
870     if t == token.COMMENT:
871         return DOUBLESPACE
872
873     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
874     if t == token.COLON and p.type != syms.subscript:
875         return NO
876
877     prev = leaf.prev_sibling
878     if not prev:
879         prevp = preceding_leaf(p)
880         if not prevp or prevp.type in OPENING_BRACKETS:
881             return NO
882
883         if t == token.COLON:
884             return SPACE if prevp.type == token.COMMA else NO
885
886         if prevp.type == token.EQUAL:
887             if prevp.parent and prevp.parent.type in {
888                 syms.typedargslist,
889                 syms.varargslist,
890                 syms.parameters,
891                 syms.arglist,
892                 syms.argument,
893             }:
894                 return NO
895
896         elif prevp.type == token.DOUBLESTAR:
897             if prevp.parent and prevp.parent.type in {
898                 syms.typedargslist,
899                 syms.varargslist,
900                 syms.parameters,
901                 syms.arglist,
902                 syms.dictsetmaker,
903             }:
904                 return NO
905
906         elif prevp.type == token.COLON:
907             if prevp.parent and prevp.parent.type == syms.subscript:
908                 return NO
909
910         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
911             return NO
912
913     elif prev.type in OPENING_BRACKETS:
914         return NO
915
916     if p.type in {syms.parameters, syms.arglist}:
917         # untyped function signatures or calls
918         if t == token.RPAR:
919             return NO
920
921         if not prev or prev.type != token.COMMA:
922             return NO
923
924     if p.type == syms.varargslist:
925         # lambdas
926         if t == token.RPAR:
927             return NO
928
929         if prev and prev.type != token.COMMA:
930             return NO
931
932     elif p.type == syms.typedargslist:
933         # typed function signatures
934         if not prev:
935             return NO
936
937         if t == token.EQUAL:
938             if prev.type != syms.tname:
939                 return NO
940
941         elif prev.type == token.EQUAL:
942             # A bit hacky: if the equal sign has whitespace, it means we
943             # previously found it's a typed argument.  So, we're using that, too.
944             return prev.prefix
945
946         elif prev.type != token.COMMA:
947             return NO
948
949     elif p.type == syms.tname:
950         # type names
951         if not prev:
952             prevp = preceding_leaf(p)
953             if not prevp or prevp.type != token.COMMA:
954                 return NO
955
956     elif p.type == syms.trailer:
957         # attributes and calls
958         if t == token.LPAR or t == token.RPAR:
959             return NO
960
961         if not prev:
962             if t == token.DOT:
963                 prevp = preceding_leaf(p)
964                 if not prevp or prevp.type != token.NUMBER:
965                     return NO
966
967             elif t == token.LSQB:
968                 return NO
969
970         elif prev.type != token.COMMA:
971             return NO
972
973     elif p.type == syms.argument:
974         # single argument
975         if t == token.EQUAL:
976             return NO
977
978         if not prev:
979             prevp = preceding_leaf(p)
980             if not prevp or prevp.type == token.LPAR:
981                 return NO
982
983         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
984             return NO
985
986     elif p.type == syms.decorator:
987         # decorators
988         return NO
989
990     elif p.type == syms.dotted_name:
991         if prev:
992             return NO
993
994         prevp = preceding_leaf(p)
995         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
996             return NO
997
998     elif p.type == syms.classdef:
999         if t == token.LPAR:
1000             return NO
1001
1002         if prev and prev.type == token.LPAR:
1003             return NO
1004
1005     elif p.type == syms.subscript:
1006         # indexing
1007         if not prev:
1008             assert p.parent is not None, "subscripts are always parented"
1009             if p.parent.type == syms.subscriptlist:
1010                 return SPACE
1011
1012             return NO
1013
1014         else:
1015             return NO
1016
1017     elif p.type == syms.atom:
1018         if prev and t == token.DOT:
1019             # dots, but not the first one.
1020             return NO
1021
1022     elif (
1023         p.type == syms.listmaker
1024         or p.type == syms.testlist_gexp
1025         or p.type == syms.subscriptlist
1026     ):
1027         # list interior, including unpacking
1028         if not prev:
1029             return NO
1030
1031     elif p.type == syms.dictsetmaker:
1032         # dict and set interior, including unpacking
1033         if not prev:
1034             return NO
1035
1036         if prev.type == token.DOUBLESTAR:
1037             return NO
1038
1039     elif p.type in {syms.factor, syms.star_expr}:
1040         # unary ops
1041         if not prev:
1042             prevp = preceding_leaf(p)
1043             if not prevp or prevp.type in OPENING_BRACKETS:
1044                 return NO
1045
1046             prevp_parent = prevp.parent
1047             assert prevp_parent is not None
1048             if prevp.type == token.COLON and prevp_parent.type in {
1049                 syms.subscript, syms.sliceop
1050             }:
1051                 return NO
1052
1053             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1054                 return NO
1055
1056         elif t == token.NAME or t == token.NUMBER:
1057             return NO
1058
1059     elif p.type == syms.import_from:
1060         if t == token.DOT:
1061             if prev and prev.type == token.DOT:
1062                 return NO
1063
1064         elif t == token.NAME:
1065             if v == 'import':
1066                 return SPACE
1067
1068             if prev and prev.type == token.DOT:
1069                 return NO
1070
1071     elif p.type == syms.sliceop:
1072         return NO
1073
1074     return SPACE
1075
1076
1077 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1078     """Returns the first leaf that precedes `node`, if any."""
1079     while node:
1080         res = node.prev_sibling
1081         if res:
1082             if isinstance(res, Leaf):
1083                 return res
1084
1085             try:
1086                 return list(res.leaves())[-1]
1087
1088             except IndexError:
1089                 return None
1090
1091         node = node.parent
1092     return None
1093
1094
1095 def is_delimiter(leaf: Leaf) -> int:
1096     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1097
1098     Higher numbers are higher priority.
1099     """
1100     if leaf.type == token.COMMA:
1101         return COMMA_PRIORITY
1102
1103     if leaf.type in COMPARATORS:
1104         return COMPARATOR_PRIORITY
1105
1106     if (
1107         leaf.type in MATH_OPERATORS
1108         and leaf.parent
1109         and leaf.parent.type not in {syms.factor, syms.star_expr}
1110     ):
1111         return MATH_PRIORITY
1112
1113     return 0
1114
1115
1116 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1117     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1118
1119     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1120     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1121     move because it does away with modifying the grammar to include all the
1122     possible places in which comments can be placed.
1123
1124     The sad consequence for us though is that comments don't "belong" anywhere.
1125     This is why this function generates simple parentless Leaf objects for
1126     comments.  We simply don't know what the correct parent should be.
1127
1128     No matter though, we can live without this.  We really only need to
1129     differentiate between inline and standalone comments.  The latter don't
1130     share the line with any code.
1131
1132     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1133     are emitted with a fake STANDALONE_COMMENT token identifier.
1134     """
1135     if not leaf.prefix:
1136         return
1137
1138     if '#' not in leaf.prefix:
1139         return
1140
1141     before_comment, content = leaf.prefix.split('#', 1)
1142     content = content.rstrip()
1143     if content and (content[0] not in {' ', '!', '#'}):
1144         content = ' ' + content
1145     is_standalone_comment = (
1146         '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER
1147     )
1148     if not is_standalone_comment:
1149         # simple trailing comment
1150         yield Leaf(token.COMMENT, value='#' + content)
1151         return
1152
1153     for line in ('#' + content).split('\n'):
1154         line = line.lstrip()
1155         if not line.startswith('#'):
1156             continue
1157
1158         yield Leaf(STANDALONE_COMMENT, line)
1159
1160
1161 def split_line(
1162     line: Line, line_length: int, inner: bool = False, py36: bool = False
1163 ) -> Iterator[Line]:
1164     """Splits a `line` into potentially many lines.
1165
1166     They should fit in the allotted `line_length` but might not be able to.
1167     `inner` signifies that there were a pair of brackets somewhere around the
1168     current `line`, possibly transitively. This means we can fallback to splitting
1169     by delimiters if the LHS/RHS don't yield any results.
1170
1171     If `py36` is True, splitting may generate syntax that is only compatible
1172     with Python 3.6 and later.
1173     """
1174     line_str = str(line).strip('\n')
1175     if len(line_str) <= line_length and '\n' not in line_str:
1176         yield line
1177         return
1178
1179     if line.is_def:
1180         split_funcs = [left_hand_split]
1181     elif line.inside_brackets:
1182         split_funcs = [delimiter_split]
1183         if '\n' not in line_str:
1184             # Only attempt RHS if we don't have multiline strings or comments
1185             # on this line.
1186             split_funcs.append(right_hand_split)
1187     else:
1188         split_funcs = [right_hand_split]
1189     for split_func in split_funcs:
1190         # We are accumulating lines in `result` because we might want to abort
1191         # mission and return the original line in the end, or attempt a different
1192         # split altogether.
1193         result: List[Line] = []
1194         try:
1195             for l in split_func(line, py36=py36):
1196                 if str(l).strip('\n') == line_str:
1197                     raise CannotSplit("Split function returned an unchanged result")
1198
1199                 result.extend(
1200                     split_line(l, line_length=line_length, inner=True, py36=py36)
1201                 )
1202         except CannotSplit as cs:
1203             continue
1204
1205         else:
1206             yield from result
1207             break
1208
1209     else:
1210         yield line
1211
1212
1213 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1214     """Split line into many lines, starting with the first matching bracket pair.
1215
1216     Note: this usually looks weird, only use this for function definitions.
1217     Prefer RHS otherwise.
1218     """
1219     head = Line(depth=line.depth)
1220     body = Line(depth=line.depth + 1, inside_brackets=True)
1221     tail = Line(depth=line.depth)
1222     tail_leaves: List[Leaf] = []
1223     body_leaves: List[Leaf] = []
1224     head_leaves: List[Leaf] = []
1225     current_leaves = head_leaves
1226     matching_bracket = None
1227     for leaf in line.leaves:
1228         if (
1229             current_leaves is body_leaves
1230             and leaf.type in CLOSING_BRACKETS
1231             and leaf.opening_bracket is matching_bracket
1232         ):
1233             current_leaves = tail_leaves if body_leaves else head_leaves
1234         current_leaves.append(leaf)
1235         if current_leaves is head_leaves:
1236             if leaf.type in OPENING_BRACKETS:
1237                 matching_bracket = leaf
1238                 current_leaves = body_leaves
1239     # Since body is a new indent level, remove spurious leading whitespace.
1240     if body_leaves:
1241         normalize_prefix(body_leaves[0])
1242     # Build the new lines.
1243     for result, leaves in (
1244         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1245     ):
1246         for leaf in leaves:
1247             result.append(leaf, preformatted=True)
1248             comment_after = line.comments.get(id(leaf))
1249             if comment_after:
1250                 result.append(comment_after, preformatted=True)
1251     split_succeeded_or_raise(head, body, tail)
1252     for result in (head, body, tail):
1253         if result:
1254             yield result
1255
1256
1257 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1258     """Split line into many lines, starting with the last matching bracket pair."""
1259     head = Line(depth=line.depth)
1260     body = Line(depth=line.depth + 1, inside_brackets=True)
1261     tail = Line(depth=line.depth)
1262     tail_leaves: List[Leaf] = []
1263     body_leaves: List[Leaf] = []
1264     head_leaves: List[Leaf] = []
1265     current_leaves = tail_leaves
1266     opening_bracket = None
1267     for leaf in reversed(line.leaves):
1268         if current_leaves is body_leaves:
1269             if leaf is opening_bracket:
1270                 current_leaves = head_leaves if body_leaves else tail_leaves
1271         current_leaves.append(leaf)
1272         if current_leaves is tail_leaves:
1273             if leaf.type in CLOSING_BRACKETS:
1274                 opening_bracket = leaf.opening_bracket
1275                 current_leaves = body_leaves
1276     tail_leaves.reverse()
1277     body_leaves.reverse()
1278     head_leaves.reverse()
1279     # Since body is a new indent level, remove spurious leading whitespace.
1280     if body_leaves:
1281         normalize_prefix(body_leaves[0])
1282     # Build the new lines.
1283     for result, leaves in (
1284         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1285     ):
1286         for leaf in leaves:
1287             result.append(leaf, preformatted=True)
1288             comment_after = line.comments.get(id(leaf))
1289             if comment_after:
1290                 result.append(comment_after, preformatted=True)
1291     split_succeeded_or_raise(head, body, tail)
1292     for result in (head, body, tail):
1293         if result:
1294             yield result
1295
1296
1297 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1298     tail_len = len(str(tail).strip())
1299     if not body:
1300         if tail_len == 0:
1301             raise CannotSplit("Splitting brackets produced the same line")
1302
1303         elif tail_len < 3:
1304             raise CannotSplit(
1305                 f"Splitting brackets on an empty body to save "
1306                 f"{tail_len} characters is not worth it"
1307             )
1308
1309
1310 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1311     """Split according to delimiters of the highest priority.
1312
1313     This kind of split doesn't increase indentation.
1314     If `py36` is True, the split will add trailing commas also in function
1315     signatures that contain * and **.
1316     """
1317     try:
1318         last_leaf = line.leaves[-1]
1319     except IndexError:
1320         raise CannotSplit("Line empty")
1321
1322     delimiters = line.bracket_tracker.delimiters
1323     try:
1324         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1325     except ValueError:
1326         raise CannotSplit("No delimiters found")
1327
1328     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1329     lowest_depth = sys.maxsize
1330     trailing_comma_safe = True
1331     for leaf in line.leaves:
1332         current_line.append(leaf, preformatted=True)
1333         comment_after = line.comments.get(id(leaf))
1334         if comment_after:
1335             current_line.append(comment_after, preformatted=True)
1336         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1337         if (
1338             leaf.bracket_depth == lowest_depth
1339             and leaf.type == token.STAR
1340             or leaf.type == token.DOUBLESTAR
1341         ):
1342             trailing_comma_safe = trailing_comma_safe and py36
1343         leaf_priority = delimiters.get(id(leaf))
1344         if leaf_priority == delimiter_priority:
1345             normalize_prefix(current_line.leaves[0])
1346             yield current_line
1347
1348             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1349     if current_line:
1350         if (
1351             delimiter_priority == COMMA_PRIORITY
1352             and current_line.leaves[-1].type != token.COMMA
1353             and trailing_comma_safe
1354         ):
1355             current_line.append(Leaf(token.COMMA, ','))
1356         normalize_prefix(current_line.leaves[0])
1357         yield current_line
1358
1359
1360 def is_import(leaf: Leaf) -> bool:
1361     """Returns True if the given leaf starts an import statement."""
1362     p = leaf.parent
1363     t = leaf.type
1364     v = leaf.value
1365     return bool(
1366         t == token.NAME
1367         and (
1368             (v == 'import' and p and p.type == syms.import_name)
1369             or (v == 'from' and p and p.type == syms.import_from)
1370         )
1371     )
1372
1373
1374 def normalize_prefix(leaf: Leaf) -> None:
1375     """Leave existing extra newlines for imports.  Remove everything else."""
1376     if is_import(leaf):
1377         spl = leaf.prefix.split('#', 1)
1378         nl_count = spl[0].count('\n')
1379         leaf.prefix = '\n' * nl_count
1380         return
1381
1382     leaf.prefix = ''
1383
1384
1385 def is_python36(node: Node) -> bool:
1386     """Returns True if the current file is using Python 3.6+ features.
1387
1388     Currently looking for:
1389     - f-strings; and
1390     - trailing commas after * or ** in function signatures.
1391     """
1392     for n in node.pre_order():
1393         if n.type == token.STRING:
1394             value_head = n.value[:2]  # type: ignore
1395             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1396                 return True
1397
1398         elif (
1399             n.type == syms.typedargslist
1400             and n.children
1401             and n.children[-1].type == token.COMMA
1402         ):
1403             for ch in n.children:
1404                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1405                     return True
1406
1407     return False
1408
1409
1410 PYTHON_EXTENSIONS = {'.py'}
1411 BLACKLISTED_DIRECTORIES = {
1412     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1413 }
1414
1415
1416 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1417     for child in path.iterdir():
1418         if child.is_dir():
1419             if child.name in BLACKLISTED_DIRECTORIES:
1420                 continue
1421
1422             yield from gen_python_files_in_dir(child)
1423
1424         elif child.suffix in PYTHON_EXTENSIONS:
1425             yield child
1426
1427
1428 @dataclass
1429 class Report:
1430     """Provides a reformatting counter."""
1431     change_count: int = 0
1432     same_count: int = 0
1433     failure_count: int = 0
1434
1435     def done(self, src: Path, changed: bool) -> None:
1436         """Increment the counter for successful reformatting. Write out a message."""
1437         if changed:
1438             out(f'reformatted {src}')
1439             self.change_count += 1
1440         else:
1441             out(f'{src} already well formatted, good job.', bold=False)
1442             self.same_count += 1
1443
1444     def failed(self, src: Path, message: str) -> None:
1445         """Increment the counter for failed reformatting. Write out a message."""
1446         err(f'error: cannot format {src}: {message}')
1447         self.failure_count += 1
1448
1449     @property
1450     def return_code(self) -> int:
1451         """Which return code should the app use considering the current state."""
1452         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1453         # 126 we have special returncodes reserved by the shell.
1454         if self.failure_count:
1455             return 123
1456
1457         elif self.change_count:
1458             return 1
1459
1460         return 0
1461
1462     def __str__(self) -> str:
1463         """A color report of the current state.
1464
1465         Use `click.unstyle` to remove colors.
1466         """
1467         report = []
1468         if self.change_count:
1469             s = 's' if self.change_count > 1 else ''
1470             report.append(
1471                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1472             )
1473         if self.same_count:
1474             s = 's' if self.same_count > 1 else ''
1475             report.append(f'{self.same_count} file{s} left unchanged')
1476         if self.failure_count:
1477             s = 's' if self.failure_count > 1 else ''
1478             report.append(
1479                 click.style(
1480                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1481                 )
1482             )
1483         return ', '.join(report) + '.'
1484
1485
1486 def assert_equivalent(src: str, dst: str) -> None:
1487     """Raises AssertionError if `src` and `dst` aren't equivalent.
1488
1489     This is a temporary sanity check until Black becomes stable.
1490     """
1491
1492     import ast
1493     import traceback
1494
1495     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1496         """Simple visitor generating strings to compare ASTs by content."""
1497         yield f"{'  ' * depth}{node.__class__.__name__}("
1498
1499         for field in sorted(node._fields):
1500             try:
1501                 value = getattr(node, field)
1502             except AttributeError:
1503                 continue
1504
1505             yield f"{'  ' * (depth+1)}{field}="
1506
1507             if isinstance(value, list):
1508                 for item in value:
1509                     if isinstance(item, ast.AST):
1510                         yield from _v(item, depth + 2)
1511
1512             elif isinstance(value, ast.AST):
1513                 yield from _v(value, depth + 2)
1514
1515             else:
1516                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1517
1518         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1519
1520     try:
1521         src_ast = ast.parse(src)
1522     except Exception as exc:
1523         raise AssertionError(f"cannot parse source: {exc}") from None
1524
1525     try:
1526         dst_ast = ast.parse(dst)
1527     except Exception as exc:
1528         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1529         raise AssertionError(
1530             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1531             f"Please report a bug on https://github.com/ambv/black/issues.  "
1532             f"This invalid output might be helpful: {log}"
1533         ) from None
1534
1535     src_ast_str = '\n'.join(_v(src_ast))
1536     dst_ast_str = '\n'.join(_v(dst_ast))
1537     if src_ast_str != dst_ast_str:
1538         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1539         raise AssertionError(
1540             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1541             f"the source.  "
1542             f"Please report a bug on https://github.com/ambv/black/issues.  "
1543             f"This diff might be helpful: {log}"
1544         ) from None
1545
1546
1547 def assert_stable(src: str, dst: str, line_length: int) -> None:
1548     """Raises AssertionError if `dst` reformats differently the second time.
1549
1550     This is a temporary sanity check until Black becomes stable.
1551     """
1552     newdst = format_str(dst, line_length=line_length)
1553     if dst != newdst:
1554         log = dump_to_file(
1555             diff(src, dst, 'source', 'first pass'),
1556             diff(dst, newdst, 'first pass', 'second pass'),
1557         )
1558         raise AssertionError(
1559             f"INTERNAL ERROR: Black produced different code on the second pass "
1560             f"of the formatter.  "
1561             f"Please report a bug on https://github.com/ambv/black/issues.  "
1562             f"This diff might be helpful: {log}"
1563         ) from None
1564
1565
1566 def dump_to_file(*output: str) -> str:
1567     """Dumps `output` to a temporary file. Returns path to the file."""
1568     import tempfile
1569
1570     with tempfile.NamedTemporaryFile(
1571         mode='w', prefix='blk_', suffix='.log', delete=False
1572     ) as f:
1573         for lines in output:
1574             f.write(lines)
1575             f.write('\n')
1576     return f.name
1577
1578
1579 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1580     """Returns a udiff string between strings `a` and `b`."""
1581     import difflib
1582
1583     a_lines = [line + '\n' for line in a.split('\n')]
1584     b_lines = [line + '\n' for line in b.split('\n')]
1585     return ''.join(
1586         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1587     )
1588
1589
1590 if __name__ == '__main__':
1591     main()