]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix tests on 3.7
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a2"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by `format_file` when the reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by `left_hand_split()` and `right_hand_split()`.
48     """
49
50
51 @click.command()
52 @click.option(
53     '-l',
54     '--line-length',
55     type=int,
56     default=DEFAULT_LINE_LENGTH,
57     help='How many character per line to allow.',
58     show_default=True,
59 )
60 @click.option(
61     '--check',
62     is_flag=True,
63     help=(
64         "Don't write back the files, just return the status.  Return code 0 "
65         "means nothing changed.  Return code 1 means some files were "
66         "reformatted.  Return code 123 means there was an internal error."
67     ),
68 )
69 @click.option(
70     '--fast/--safe',
71     is_flag=True,
72     help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 )
74 @click.version_option(version=__version__)
75 @click.argument(
76     'src',
77     nargs=-1,
78     type=click.Path(
79         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
80     ),
81 )
82 @click.pass_context
83 def main(
84     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 ) -> None:
86     """The uncompromising code formatter."""
87     sources: List[Path] = []
88     for s in src:
89         p = Path(s)
90         if p.is_dir():
91             sources.extend(gen_python_files_in_dir(p))
92         elif p.is_file():
93             # if a file was explicitly given, we don't care about its extension
94             sources.append(p)
95         elif s == '-':
96             sources.append(Path('-'))
97         else:
98             err(f'invalid path: {s}')
99     if len(sources) == 0:
100         ctx.exit(0)
101     elif len(sources) == 1:
102         p = sources[0]
103         report = Report()
104         try:
105             if not p.is_file() and str(p) == '-':
106                 changed = format_stdin_to_stdout(
107                     line_length=line_length, fast=fast, write_back=not check
108                 )
109             else:
110                 changed = format_file_in_place(
111                     p, line_length=line_length, fast=fast, write_back=not check
112                 )
113             report.done(p, changed)
114         except Exception as exc:
115             report.failed(p, str(exc))
116         ctx.exit(report.return_code)
117     else:
118         loop = asyncio.get_event_loop()
119         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
120         return_code = 1
121         try:
122             return_code = loop.run_until_complete(
123                 schedule_formatting(
124                     sources, line_length, not check, fast, loop, executor
125                 )
126             )
127         finally:
128             loop.close()
129             ctx.exit(return_code)
130
131
132 async def schedule_formatting(
133     sources: List[Path],
134     line_length: int,
135     write_back: bool,
136     fast: bool,
137     loop: BaseEventLoop,
138     executor: Executor,
139 ) -> int:
140     tasks = {
141         src: loop.run_in_executor(
142             executor, format_file_in_place, src, line_length, fast, write_back
143         )
144         for src in sources
145     }
146     await asyncio.wait(tasks.values())
147     cancelled = []
148     report = Report()
149     for src, task in tasks.items():
150         if not task.done():
151             report.failed(src, 'timed out, cancelling')
152             task.cancel()
153             cancelled.append(task)
154         elif task.exception():
155             report.failed(src, str(task.exception()))
156         else:
157             report.done(src, task.result())
158     if cancelled:
159         await asyncio.wait(cancelled, timeout=2)
160     out('All done! ✨ 🍰 ✨')
161     click.echo(str(report))
162     return report.return_code
163
164
165 def format_file_in_place(
166     src: Path, line_length: int, fast: bool, write_back: bool = False
167 ) -> bool:
168     """Format the file and rewrite if changed. Return True if changed."""
169     with tokenize.open(src) as src_buffer:
170         src_contents = src_buffer.read()
171     try:
172         contents = format_file_contents(
173             src_contents, line_length=line_length, fast=fast
174         )
175     except NothingChanged:
176         return False
177
178     if write_back:
179         with open(src, "w", encoding=src_buffer.encoding) as f:
180             f.write(contents)
181     return True
182
183
184 def format_stdin_to_stdout(
185     line_length: int, fast: bool, write_back: bool = False
186 ) -> bool:
187     """Format file on stdin and pipe output to stdout. Return True if changed."""
188     contents = sys.stdin.read()
189     try:
190         contents = format_file_contents(contents, line_length=line_length, fast=fast)
191         return True
192
193     except NothingChanged:
194         return False
195
196     finally:
197         if write_back:
198             sys.stdout.write(contents)
199
200
201 def format_file_contents(
202     src_contents: str, line_length: int, fast: bool
203 ) -> FileContent:
204     """Reformats a file and returns its contents and encoding."""
205     if src_contents.strip() == '':
206         raise NothingChanged
207
208     dst_contents = format_str(src_contents, line_length=line_length)
209     if src_contents == dst_contents:
210         raise NothingChanged
211
212     if not fast:
213         assert_equivalent(src_contents, dst_contents)
214         assert_stable(src_contents, dst_contents, line_length=line_length)
215     return dst_contents
216
217
218 def format_str(src_contents: str, line_length: int) -> FileContent:
219     """Reformats a string and returns new contents."""
220     src_node = lib2to3_parse(src_contents)
221     dst_contents = ""
222     lines = LineGenerator()
223     elt = EmptyLineTracker()
224     py36 = is_python36(src_node)
225     empty_line = Line()
226     after = 0
227     for current_line in lines.visit(src_node):
228         for _ in range(after):
229             dst_contents += str(empty_line)
230         before, after = elt.maybe_empty_lines(current_line)
231         for _ in range(before):
232             dst_contents += str(empty_line)
233         for line in split_line(current_line, line_length=line_length, py36=py36):
234             dst_contents += str(line)
235     return dst_contents
236
237
238 def lib2to3_parse(src_txt: str) -> Node:
239     """Given a string with source, return the lib2to3 Node."""
240     grammar = pygram.python_grammar_no_print_statement
241     drv = driver.Driver(grammar, pytree.convert)
242     if src_txt[-1] != '\n':
243         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
244         src_txt += nl
245     try:
246         result = drv.parse_string(src_txt, True)
247     except ParseError as pe:
248         lineno, column = pe.context[1]
249         lines = src_txt.splitlines()
250         try:
251             faulty_line = lines[lineno - 1]
252         except IndexError:
253             faulty_line = "<line number missing in source>"
254         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
255
256     if isinstance(result, Leaf):
257         result = Node(syms.file_input, [result])
258     return result
259
260
261 def lib2to3_unparse(node: Node) -> str:
262     """Given a lib2to3 node, return its string representation."""
263     code = str(node)
264     return code
265
266
267 T = TypeVar('T')
268
269
270 class Visitor(Generic[T]):
271     """Basic lib2to3 visitor that yields things on visiting."""
272
273     def visit(self, node: LN) -> Iterator[T]:
274         if node.type < 256:
275             name = token.tok_name[node.type]
276         else:
277             name = type_repr(node.type)
278         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
279
280     def visit_default(self, node: LN) -> Iterator[T]:
281         if isinstance(node, Node):
282             for child in node.children:
283                 yield from self.visit(child)
284
285
286 @dataclass
287 class DebugVisitor(Visitor[T]):
288     tree_depth: int = 0
289
290     def visit_default(self, node: LN) -> Iterator[T]:
291         indent = ' ' * (2 * self.tree_depth)
292         if isinstance(node, Node):
293             _type = type_repr(node.type)
294             out(f'{indent}{_type}', fg='yellow')
295             self.tree_depth += 1
296             for child in node.children:
297                 yield from self.visit(child)
298
299             self.tree_depth -= 1
300             out(f'{indent}/{_type}', fg='yellow', bold=False)
301         else:
302             _type = token.tok_name.get(node.type, str(node.type))
303             out(f'{indent}{_type}', fg='blue', nl=False)
304             if node.prefix:
305                 # We don't have to handle prefixes for `Node` objects since
306                 # that delegates to the first child anyway.
307                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
308             out(f' {node.value!r}', fg='blue', bold=False)
309
310
311 KEYWORDS = set(keyword.kwlist)
312 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
313 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
314 STATEMENT = {
315     syms.if_stmt,
316     syms.while_stmt,
317     syms.for_stmt,
318     syms.try_stmt,
319     syms.except_clause,
320     syms.with_stmt,
321     syms.funcdef,
322     syms.classdef,
323 }
324 STANDALONE_COMMENT = 153
325 LOGIC_OPERATORS = {'and', 'or'}
326 COMPARATORS = {
327     token.LESS,
328     token.GREATER,
329     token.EQEQUAL,
330     token.NOTEQUAL,
331     token.LESSEQUAL,
332     token.GREATEREQUAL,
333 }
334 MATH_OPERATORS = {
335     token.PLUS,
336     token.MINUS,
337     token.STAR,
338     token.SLASH,
339     token.VBAR,
340     token.AMPER,
341     token.PERCENT,
342     token.CIRCUMFLEX,
343     token.LEFTSHIFT,
344     token.RIGHTSHIFT,
345     token.DOUBLESTAR,
346     token.DOUBLESLASH,
347 }
348 COMPREHENSION_PRIORITY = 20
349 COMMA_PRIORITY = 10
350 LOGIC_PRIORITY = 5
351 STRING_PRIORITY = 4
352 COMPARATOR_PRIORITY = 3
353 MATH_PRIORITY = 1
354
355
356 @dataclass
357 class BracketTracker:
358     depth: int = 0
359     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
360     delimiters: Dict[LeafID, Priority] = Factory(dict)
361     previous: Optional[Leaf] = None
362
363     def mark(self, leaf: Leaf) -> None:
364         if leaf.type == token.COMMENT:
365             return
366
367         if leaf.type in CLOSING_BRACKETS:
368             self.depth -= 1
369             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
370             leaf.opening_bracket = opening_bracket
371         leaf.bracket_depth = self.depth
372         if self.depth == 0:
373             delim = is_delimiter(leaf)
374             if delim:
375                 self.delimiters[id(leaf)] = delim
376             elif self.previous is not None:
377                 if leaf.type == token.STRING and self.previous.type == token.STRING:
378                     self.delimiters[id(self.previous)] = STRING_PRIORITY
379                 elif (
380                     leaf.type == token.NAME
381                     and leaf.value == 'for'
382                     and leaf.parent
383                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
384                 ):
385                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
386                 elif (
387                     leaf.type == token.NAME
388                     and leaf.value == 'if'
389                     and leaf.parent
390                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
391                 ):
392                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
393                 elif (
394                     leaf.type == token.NAME
395                     and leaf.value in LOGIC_OPERATORS
396                     and leaf.parent
397                 ):
398                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
399         if leaf.type in OPENING_BRACKETS:
400             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
401             self.depth += 1
402         self.previous = leaf
403
404     def any_open_brackets(self) -> bool:
405         """Returns True if there is an yet unmatched open bracket on the line."""
406         return bool(self.bracket_match)
407
408     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
409         """Returns the highest priority of a delimiter found on the line.
410
411         Values are consistent with what `is_delimiter()` returns.
412         """
413         return max(v for k, v in self.delimiters.items() if k not in exclude)
414
415
416 @dataclass
417 class Line:
418     depth: int = 0
419     leaves: List[Leaf] = Factory(list)
420     comments: Dict[LeafID, Leaf] = Factory(dict)
421     bracket_tracker: BracketTracker = Factory(BracketTracker)
422     inside_brackets: bool = False
423     has_for: bool = False
424     _for_loop_variable: bool = False
425
426     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
427         has_value = leaf.value.strip()
428         if not has_value:
429             return
430
431         if self.leaves and not preformatted:
432             # Note: at this point leaf.prefix should be empty except for
433             # imports, for which we only preserve newlines.
434             leaf.prefix += whitespace(leaf)
435         if self.inside_brackets or not preformatted:
436             self.maybe_decrement_after_for_loop_variable(leaf)
437             self.bracket_tracker.mark(leaf)
438             self.maybe_remove_trailing_comma(leaf)
439             self.maybe_increment_for_loop_variable(leaf)
440             if self.maybe_adapt_standalone_comment(leaf):
441                 return
442
443         if not self.append_comment(leaf):
444             self.leaves.append(leaf)
445
446     @property
447     def is_comment(self) -> bool:
448         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
449
450     @property
451     def is_decorator(self) -> bool:
452         return bool(self) and self.leaves[0].type == token.AT
453
454     @property
455     def is_import(self) -> bool:
456         return bool(self) and is_import(self.leaves[0])
457
458     @property
459     def is_class(self) -> bool:
460         return (
461             bool(self)
462             and self.leaves[0].type == token.NAME
463             and self.leaves[0].value == 'class'
464         )
465
466     @property
467     def is_def(self) -> bool:
468         """Also returns True for async defs."""
469         try:
470             first_leaf = self.leaves[0]
471         except IndexError:
472             return False
473
474         try:
475             second_leaf: Optional[Leaf] = self.leaves[1]
476         except IndexError:
477             second_leaf = None
478         return (
479             (first_leaf.type == token.NAME and first_leaf.value == 'def')
480             or (
481                 first_leaf.type == token.ASYNC
482                 and second_leaf is not None
483                 and second_leaf.type == token.NAME
484                 and second_leaf.value == 'def'
485             )
486         )
487
488     @property
489     def is_flow_control(self) -> bool:
490         return (
491             bool(self)
492             and self.leaves[0].type == token.NAME
493             and self.leaves[0].value in FLOW_CONTROL
494         )
495
496     @property
497     def is_yield(self) -> bool:
498         return (
499             bool(self)
500             and self.leaves[0].type == token.NAME
501             and self.leaves[0].value == 'yield'
502         )
503
504     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
505         if not (
506             self.leaves
507             and self.leaves[-1].type == token.COMMA
508             and closing.type in CLOSING_BRACKETS
509         ):
510             return False
511
512         if closing.type == token.RSQB or closing.type == token.RBRACE:
513             self.leaves.pop()
514             return True
515
516         # For parens let's check if it's safe to remove the comma.  If the
517         # trailing one is the only one, we might mistakenly change a tuple
518         # into a different type by removing the comma.
519         depth = closing.bracket_depth + 1
520         commas = 0
521         opening = closing.opening_bracket
522         for _opening_index, leaf in enumerate(self.leaves):
523             if leaf is opening:
524                 break
525
526         else:
527             return False
528
529         for leaf in self.leaves[_opening_index + 1:]:
530             if leaf is closing:
531                 break
532
533             bracket_depth = leaf.bracket_depth
534             if bracket_depth == depth and leaf.type == token.COMMA:
535                 commas += 1
536                 if leaf.parent and leaf.parent.type == syms.arglist:
537                     commas += 1
538                     break
539
540         if commas > 1:
541             self.leaves.pop()
542             return True
543
544         return False
545
546     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
547         """In a for loop, or comprehension, the variables are often unpacks.
548
549         To avoid splitting on the comma in this situation, we will increase
550         the depth of tokens between `for` and `in`.
551         """
552         if leaf.type == token.NAME and leaf.value == 'for':
553             self.has_for = True
554             self.bracket_tracker.depth += 1
555             self._for_loop_variable = True
556             return True
557
558         return False
559
560     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
561         # See `maybe_increment_for_loop_variable` above for explanation.
562         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
563             self.bracket_tracker.depth -= 1
564             self._for_loop_variable = False
565             return True
566
567         return False
568
569     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
570         """Hack a standalone comment to act as a trailing comment for line splitting.
571
572         If this line has brackets and a standalone `comment`, we need to adapt
573         it to be able to still reformat the line.
574
575         This is not perfect, the line to which the standalone comment gets
576         appended will appear "too long" when splitting.
577         """
578         if not (
579             comment.type == STANDALONE_COMMENT
580             and self.bracket_tracker.any_open_brackets()
581         ):
582             return False
583
584         comment.type = token.COMMENT
585         comment.prefix = '\n' + '    ' * (self.depth + 1)
586         return self.append_comment(comment)
587
588     def append_comment(self, comment: Leaf) -> bool:
589         if comment.type != token.COMMENT:
590             return False
591
592         try:
593             after = id(self.last_non_delimiter())
594         except LookupError:
595             comment.type = STANDALONE_COMMENT
596             comment.prefix = ''
597             return False
598
599         else:
600             if after in self.comments:
601                 self.comments[after].value += str(comment)
602             else:
603                 self.comments[after] = comment
604             return True
605
606     def last_non_delimiter(self) -> Leaf:
607         for i in range(len(self.leaves)):
608             last = self.leaves[-i - 1]
609             if not is_delimiter(last):
610                 return last
611
612         raise LookupError("No non-delimiters found")
613
614     def __str__(self) -> str:
615         if not self:
616             return '\n'
617
618         indent = '    ' * self.depth
619         leaves = iter(self.leaves)
620         first = next(leaves)
621         res = f'{first.prefix}{indent}{first.value}'
622         for leaf in leaves:
623             res += str(leaf)
624         for comment in self.comments.values():
625             res += str(comment)
626         return res + '\n'
627
628     def __bool__(self) -> bool:
629         return bool(self.leaves or self.comments)
630
631
632 @dataclass
633 class EmptyLineTracker:
634     """Provides a stateful method that returns the number of potential extra
635     empty lines needed before and after the currently processed line.
636
637     Note: this tracker works on lines that haven't been split yet.  It assumes
638     the prefix of the first leaf consists of optional newlines.  Those newlines
639     are consumed by `maybe_empty_lines()` and included in the computation.
640     """
641     previous_line: Optional[Line] = None
642     previous_after: int = 0
643     previous_defs: List[int] = Factory(list)
644
645     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
646         """Returns the number of extra empty lines before and after the `current_line`.
647
648         This is for separating `def`, `async def` and `class` with extra empty lines
649         (two on module-level), as well as providing an extra empty line after flow
650         control keywords to make them more prominent.
651         """
652         before, after = self._maybe_empty_lines(current_line)
653         before -= self.previous_after
654         self.previous_after = after
655         self.previous_line = current_line
656         return before, after
657
658     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
659         max_allowed = 1
660         if current_line.is_comment and current_line.depth == 0:
661             max_allowed = 2
662         if current_line.leaves:
663             # Consume the first leaf's extra newlines.
664             first_leaf = current_line.leaves[0]
665             before = first_leaf.prefix.count('\n')
666             before = min(before, max(before, max_allowed))
667             first_leaf.prefix = ''
668         else:
669             before = 0
670         depth = current_line.depth
671         while self.previous_defs and self.previous_defs[-1] >= depth:
672             self.previous_defs.pop()
673             before = 1 if depth else 2
674         is_decorator = current_line.is_decorator
675         if is_decorator or current_line.is_def or current_line.is_class:
676             if not is_decorator:
677                 self.previous_defs.append(depth)
678             if self.previous_line is None:
679                 # Don't insert empty lines before the first line in the file.
680                 return 0, 0
681
682             if self.previous_line and self.previous_line.is_decorator:
683                 # Don't insert empty lines between decorators.
684                 return 0, 0
685
686             newlines = 2
687             if current_line.depth:
688                 newlines -= 1
689             return newlines, 0
690
691         if current_line.is_flow_control:
692             return before, 1
693
694         if (
695             self.previous_line
696             and self.previous_line.is_import
697             and not current_line.is_import
698             and depth == self.previous_line.depth
699         ):
700             return (before or 1), 0
701
702         if (
703             self.previous_line
704             and self.previous_line.is_yield
705             and (not current_line.is_yield or depth != self.previous_line.depth)
706         ):
707             return (before or 1), 0
708
709         return before, 0
710
711
712 @dataclass
713 class LineGenerator(Visitor[Line]):
714     """Generates reformatted Line objects.  Empty lines are not emitted.
715
716     Note: destroys the tree it's visiting by mutating prefixes of its leaves
717     in ways that will no longer stringify to valid Python code on the tree.
718     """
719     current_line: Line = Factory(Line)
720
721     def line(self, indent: int = 0) -> Iterator[Line]:
722         """Generate a line.
723
724         If the line is empty, only emit if it makes sense.
725         If the line is too long, split it first and then generate.
726
727         If any lines were generated, set up a new current_line.
728         """
729         if not self.current_line:
730             self.current_line.depth += indent
731             return  # Line is empty, don't emit. Creating a new one unnecessary.
732
733         complete_line = self.current_line
734         self.current_line = Line(depth=complete_line.depth + indent)
735         yield complete_line
736
737     def visit_default(self, node: LN) -> Iterator[Line]:
738         if isinstance(node, Leaf):
739             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
740             for comment in generate_comments(node):
741                 if any_open_brackets:
742                     # any comment within brackets is subject to splitting
743                     self.current_line.append(comment)
744                 elif comment.type == token.COMMENT:
745                     # regular trailing comment
746                     self.current_line.append(comment)
747                     yield from self.line()
748
749                 else:
750                     # regular standalone comment
751                     yield from self.line()
752
753                     self.current_line.append(comment)
754                     yield from self.line()
755
756             normalize_prefix(node, inside_brackets=any_open_brackets)
757             if node.type not in WHITESPACE:
758                 self.current_line.append(node)
759         yield from super().visit_default(node)
760
761     def visit_INDENT(self, node: Node) -> Iterator[Line]:
762         yield from self.line(+1)
763         yield from self.visit_default(node)
764
765     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
766         yield from self.line(-1)
767
768     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
769         """Visit a statement.
770
771         The relevant Python language keywords for this statement are NAME leaves
772         within it.
773         """
774         for child in node.children:
775             if child.type == token.NAME and child.value in keywords:  # type: ignore
776                 yield from self.line()
777
778             yield from self.visit(child)
779
780     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
781         """A statement without nested statements."""
782         is_suite_like = node.parent and node.parent.type in STATEMENT
783         if is_suite_like:
784             yield from self.line(+1)
785             yield from self.visit_default(node)
786             yield from self.line(-1)
787
788         else:
789             yield from self.line()
790             yield from self.visit_default(node)
791
792     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
793         yield from self.line()
794
795         children = iter(node.children)
796         for child in children:
797             yield from self.visit(child)
798
799             if child.type == token.ASYNC:
800                 break
801
802         internal_stmt = next(children)
803         for child in internal_stmt.children:
804             yield from self.visit(child)
805
806     def visit_decorators(self, node: Node) -> Iterator[Line]:
807         for child in node.children:
808             yield from self.line()
809             yield from self.visit(child)
810
811     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
812         yield from self.line()
813
814     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
815         yield from self.visit_default(leaf)
816         yield from self.line()
817
818     def __attrs_post_init__(self) -> None:
819         """You are in a twisty little maze of passages."""
820         v = self.visit_stmt
821         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
822         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
823         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
824         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
825         self.visit_except_clause = partial(v, keywords={'except'})
826         self.visit_funcdef = partial(v, keywords={'def'})
827         self.visit_with_stmt = partial(v, keywords={'with'})
828         self.visit_classdef = partial(v, keywords={'class'})
829         self.visit_async_funcdef = self.visit_async_stmt
830         self.visit_decorated = self.visit_decorators
831
832
833 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
834 OPENING_BRACKETS = set(BRACKET.keys())
835 CLOSING_BRACKETS = set(BRACKET.values())
836 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
837 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
838
839
840 def whitespace(leaf: Leaf) -> str:  # noqa C901
841     """Return whitespace prefix if needed for the given `leaf`."""
842     NO = ''
843     SPACE = ' '
844     DOUBLESPACE = '  '
845     t = leaf.type
846     p = leaf.parent
847     v = leaf.value
848     if t in ALWAYS_NO_SPACE:
849         return NO
850
851     if t == token.COMMENT:
852         return DOUBLESPACE
853
854     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
855     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
856         return NO
857
858     prev = leaf.prev_sibling
859     if not prev:
860         prevp = preceding_leaf(p)
861         if not prevp or prevp.type in OPENING_BRACKETS:
862             return NO
863
864         if t == token.COLON:
865             return SPACE if prevp.type == token.COMMA else NO
866
867         if prevp.type == token.EQUAL:
868             if prevp.parent and prevp.parent.type in {
869                 syms.typedargslist,
870                 syms.varargslist,
871                 syms.parameters,
872                 syms.arglist,
873                 syms.argument,
874             }:
875                 return NO
876
877         elif prevp.type == token.DOUBLESTAR:
878             if prevp.parent and prevp.parent.type in {
879                 syms.typedargslist,
880                 syms.varargslist,
881                 syms.parameters,
882                 syms.arglist,
883                 syms.dictsetmaker,
884             }:
885                 return NO
886
887         elif prevp.type == token.COLON:
888             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
889                 return NO
890
891         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
892             return NO
893
894     elif prev.type in OPENING_BRACKETS:
895         return NO
896
897     if p.type in {syms.parameters, syms.arglist}:
898         # untyped function signatures or calls
899         if t == token.RPAR:
900             return NO
901
902         if not prev or prev.type != token.COMMA:
903             return NO
904
905     if p.type == syms.varargslist:
906         # lambdas
907         if t == token.RPAR:
908             return NO
909
910         if prev and prev.type != token.COMMA:
911             return NO
912
913     elif p.type == syms.typedargslist:
914         # typed function signatures
915         if not prev:
916             return NO
917
918         if t == token.EQUAL:
919             if prev.type != syms.tname:
920                 return NO
921
922         elif prev.type == token.EQUAL:
923             # A bit hacky: if the equal sign has whitespace, it means we
924             # previously found it's a typed argument.  So, we're using that, too.
925             return prev.prefix
926
927         elif prev.type != token.COMMA:
928             return NO
929
930     elif p.type == syms.tname:
931         # type names
932         if not prev:
933             prevp = preceding_leaf(p)
934             if not prevp or prevp.type != token.COMMA:
935                 return NO
936
937     elif p.type == syms.trailer:
938         # attributes and calls
939         if t == token.LPAR or t == token.RPAR:
940             return NO
941
942         if not prev:
943             if t == token.DOT:
944                 prevp = preceding_leaf(p)
945                 if not prevp or prevp.type != token.NUMBER:
946                     return NO
947
948             elif t == token.LSQB:
949                 return NO
950
951         elif prev.type != token.COMMA:
952             return NO
953
954     elif p.type == syms.argument:
955         # single argument
956         if t == token.EQUAL:
957             return NO
958
959         if not prev:
960             prevp = preceding_leaf(p)
961             if not prevp or prevp.type == token.LPAR:
962                 return NO
963
964         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
965             return NO
966
967     elif p.type == syms.decorator:
968         # decorators
969         return NO
970
971     elif p.type == syms.dotted_name:
972         if prev:
973             return NO
974
975         prevp = preceding_leaf(p)
976         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
977             return NO
978
979     elif p.type == syms.classdef:
980         if t == token.LPAR:
981             return NO
982
983         if prev and prev.type == token.LPAR:
984             return NO
985
986     elif p.type == syms.subscript:
987         # indexing
988         if not prev:
989             assert p.parent is not None, "subscripts are always parented"
990             if p.parent.type == syms.subscriptlist:
991                 return SPACE
992
993             return NO
994
995         else:
996             return NO
997
998     elif p.type == syms.atom:
999         if prev and t == token.DOT:
1000             # dots, but not the first one.
1001             return NO
1002
1003     elif (
1004         p.type == syms.listmaker
1005         or p.type == syms.testlist_gexp
1006         or p.type == syms.subscriptlist
1007     ):
1008         # list interior, including unpacking
1009         if not prev:
1010             return NO
1011
1012     elif p.type == syms.dictsetmaker:
1013         # dict and set interior, including unpacking
1014         if not prev:
1015             return NO
1016
1017         if prev.type == token.DOUBLESTAR:
1018             return NO
1019
1020     elif p.type in {syms.factor, syms.star_expr}:
1021         # unary ops
1022         if not prev:
1023             prevp = preceding_leaf(p)
1024             if not prevp or prevp.type in OPENING_BRACKETS:
1025                 return NO
1026
1027             prevp_parent = prevp.parent
1028             assert prevp_parent is not None
1029             if prevp.type == token.COLON and prevp_parent.type in {
1030                 syms.subscript, syms.sliceop
1031             }:
1032                 return NO
1033
1034             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1035                 return NO
1036
1037         elif t == token.NAME or t == token.NUMBER:
1038             return NO
1039
1040     elif p.type == syms.import_from:
1041         if t == token.DOT:
1042             if prev and prev.type == token.DOT:
1043                 return NO
1044
1045         elif t == token.NAME:
1046             if v == 'import':
1047                 return SPACE
1048
1049             if prev and prev.type == token.DOT:
1050                 return NO
1051
1052     elif p.type == syms.sliceop:
1053         return NO
1054
1055     return SPACE
1056
1057
1058 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1059     """Returns the first leaf that precedes `node`, if any."""
1060     while node:
1061         res = node.prev_sibling
1062         if res:
1063             if isinstance(res, Leaf):
1064                 return res
1065
1066             try:
1067                 return list(res.leaves())[-1]
1068
1069             except IndexError:
1070                 return None
1071
1072         node = node.parent
1073     return None
1074
1075
1076 def is_delimiter(leaf: Leaf) -> int:
1077     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1078
1079     Higher numbers are higher priority.
1080     """
1081     if leaf.type == token.COMMA:
1082         return COMMA_PRIORITY
1083
1084     if leaf.type in COMPARATORS:
1085         return COMPARATOR_PRIORITY
1086
1087     if (
1088         leaf.type in MATH_OPERATORS
1089         and leaf.parent
1090         and leaf.parent.type not in {syms.factor, syms.star_expr}
1091     ):
1092         return MATH_PRIORITY
1093
1094     return 0
1095
1096
1097 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1098     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1099
1100     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1101     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1102     move because it does away with modifying the grammar to include all the
1103     possible places in which comments can be placed.
1104
1105     The sad consequence for us though is that comments don't "belong" anywhere.
1106     This is why this function generates simple parentless Leaf objects for
1107     comments.  We simply don't know what the correct parent should be.
1108
1109     No matter though, we can live without this.  We really only need to
1110     differentiate between inline and standalone comments.  The latter don't
1111     share the line with any code.
1112
1113     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1114     are emitted with a fake STANDALONE_COMMENT token identifier.
1115     """
1116     p = leaf.prefix
1117     if not p:
1118         return
1119
1120     if '#' not in p:
1121         return
1122
1123     nlines = 0
1124     for index, line in enumerate(p.split('\n')):
1125         line = line.lstrip()
1126         if not line:
1127             nlines += 1
1128         if not line.startswith('#'):
1129             continue
1130
1131         if index == 0 and leaf.type != token.ENDMARKER:
1132             comment_type = token.COMMENT  # simple trailing comment
1133         else:
1134             comment_type = STANDALONE_COMMENT
1135         yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
1136
1137         nlines = 0
1138
1139
1140 def make_comment(content: str) -> str:
1141     content = content.rstrip()
1142     if not content:
1143         return '#'
1144
1145     if content[0] == '#':
1146         content = content[1:]
1147     if content and content[0] not in {' ', '!', '#'}:
1148         content = ' ' + content
1149     return '#' + content
1150
1151
1152 def split_line(
1153     line: Line, line_length: int, inner: bool = False, py36: bool = False
1154 ) -> Iterator[Line]:
1155     """Splits a `line` into potentially many lines.
1156
1157     They should fit in the allotted `line_length` but might not be able to.
1158     `inner` signifies that there were a pair of brackets somewhere around the
1159     current `line`, possibly transitively. This means we can fallback to splitting
1160     by delimiters if the LHS/RHS don't yield any results.
1161
1162     If `py36` is True, splitting may generate syntax that is only compatible
1163     with Python 3.6 and later.
1164     """
1165     line_str = str(line).strip('\n')
1166     if len(line_str) <= line_length and '\n' not in line_str:
1167         yield line
1168         return
1169
1170     if line.is_def:
1171         split_funcs = [left_hand_split]
1172     elif line.inside_brackets:
1173         split_funcs = [delimiter_split]
1174         if '\n' not in line_str:
1175             # Only attempt RHS if we don't have multiline strings or comments
1176             # on this line.
1177             split_funcs.append(right_hand_split)
1178     else:
1179         split_funcs = [right_hand_split]
1180     for split_func in split_funcs:
1181         # We are accumulating lines in `result` because we might want to abort
1182         # mission and return the original line in the end, or attempt a different
1183         # split altogether.
1184         result: List[Line] = []
1185         try:
1186             for l in split_func(line, py36=py36):
1187                 if str(l).strip('\n') == line_str:
1188                     raise CannotSplit("Split function returned an unchanged result")
1189
1190                 result.extend(
1191                     split_line(l, line_length=line_length, inner=True, py36=py36)
1192                 )
1193         except CannotSplit as cs:
1194             continue
1195
1196         else:
1197             yield from result
1198             break
1199
1200     else:
1201         yield line
1202
1203
1204 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1205     """Split line into many lines, starting with the first matching bracket pair.
1206
1207     Note: this usually looks weird, only use this for function definitions.
1208     Prefer RHS otherwise.
1209     """
1210     head = Line(depth=line.depth)
1211     body = Line(depth=line.depth + 1, inside_brackets=True)
1212     tail = Line(depth=line.depth)
1213     tail_leaves: List[Leaf] = []
1214     body_leaves: List[Leaf] = []
1215     head_leaves: List[Leaf] = []
1216     current_leaves = head_leaves
1217     matching_bracket = None
1218     for leaf in line.leaves:
1219         if (
1220             current_leaves is body_leaves
1221             and leaf.type in CLOSING_BRACKETS
1222             and leaf.opening_bracket is matching_bracket
1223         ):
1224             current_leaves = tail_leaves if body_leaves else head_leaves
1225         current_leaves.append(leaf)
1226         if current_leaves is head_leaves:
1227             if leaf.type in OPENING_BRACKETS:
1228                 matching_bracket = leaf
1229                 current_leaves = body_leaves
1230     # Since body is a new indent level, remove spurious leading whitespace.
1231     if body_leaves:
1232         normalize_prefix(body_leaves[0], inside_brackets=True)
1233     # Build the new lines.
1234     for result, leaves in (
1235         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1236     ):
1237         for leaf in leaves:
1238             result.append(leaf, preformatted=True)
1239             comment_after = line.comments.get(id(leaf))
1240             if comment_after:
1241                 result.append(comment_after, preformatted=True)
1242     split_succeeded_or_raise(head, body, tail)
1243     for result in (head, body, tail):
1244         if result:
1245             yield result
1246
1247
1248 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1249     """Split line into many lines, starting with the last matching bracket pair."""
1250     head = Line(depth=line.depth)
1251     body = Line(depth=line.depth + 1, inside_brackets=True)
1252     tail = Line(depth=line.depth)
1253     tail_leaves: List[Leaf] = []
1254     body_leaves: List[Leaf] = []
1255     head_leaves: List[Leaf] = []
1256     current_leaves = tail_leaves
1257     opening_bracket = None
1258     for leaf in reversed(line.leaves):
1259         if current_leaves is body_leaves:
1260             if leaf is opening_bracket:
1261                 current_leaves = head_leaves if body_leaves else tail_leaves
1262         current_leaves.append(leaf)
1263         if current_leaves is tail_leaves:
1264             if leaf.type in CLOSING_BRACKETS:
1265                 opening_bracket = leaf.opening_bracket
1266                 current_leaves = body_leaves
1267     tail_leaves.reverse()
1268     body_leaves.reverse()
1269     head_leaves.reverse()
1270     # Since body is a new indent level, remove spurious leading whitespace.
1271     if body_leaves:
1272         normalize_prefix(body_leaves[0], inside_brackets=True)
1273     # Build the new lines.
1274     for result, leaves in (
1275         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1276     ):
1277         for leaf in leaves:
1278             result.append(leaf, preformatted=True)
1279             comment_after = line.comments.get(id(leaf))
1280             if comment_after:
1281                 result.append(comment_after, preformatted=True)
1282     split_succeeded_or_raise(head, body, tail)
1283     for result in (head, body, tail):
1284         if result:
1285             yield result
1286
1287
1288 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1289     tail_len = len(str(tail).strip())
1290     if not body:
1291         if tail_len == 0:
1292             raise CannotSplit("Splitting brackets produced the same line")
1293
1294         elif tail_len < 3:
1295             raise CannotSplit(
1296                 f"Splitting brackets on an empty body to save "
1297                 f"{tail_len} characters is not worth it"
1298             )
1299
1300
1301 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1302     """Split according to delimiters of the highest priority.
1303
1304     This kind of split doesn't increase indentation.
1305     If `py36` is True, the split will add trailing commas also in function
1306     signatures that contain * and **.
1307     """
1308     try:
1309         last_leaf = line.leaves[-1]
1310     except IndexError:
1311         raise CannotSplit("Line empty")
1312
1313     delimiters = line.bracket_tracker.delimiters
1314     try:
1315         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1316     except ValueError:
1317         raise CannotSplit("No delimiters found")
1318
1319     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1320     lowest_depth = sys.maxsize
1321     trailing_comma_safe = True
1322     for leaf in line.leaves:
1323         current_line.append(leaf, preformatted=True)
1324         comment_after = line.comments.get(id(leaf))
1325         if comment_after:
1326             current_line.append(comment_after, preformatted=True)
1327         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1328         if (
1329             leaf.bracket_depth == lowest_depth
1330             and leaf.type == token.STAR
1331             or leaf.type == token.DOUBLESTAR
1332         ):
1333             trailing_comma_safe = trailing_comma_safe and py36
1334         leaf_priority = delimiters.get(id(leaf))
1335         if leaf_priority == delimiter_priority:
1336             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1337             yield current_line
1338
1339             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1340     if current_line:
1341         if (
1342             delimiter_priority == COMMA_PRIORITY
1343             and current_line.leaves[-1].type != token.COMMA
1344             and trailing_comma_safe
1345         ):
1346             current_line.append(Leaf(token.COMMA, ','))
1347         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1348         yield current_line
1349
1350
1351 def is_import(leaf: Leaf) -> bool:
1352     """Returns True if the given leaf starts an import statement."""
1353     p = leaf.parent
1354     t = leaf.type
1355     v = leaf.value
1356     return bool(
1357         t == token.NAME
1358         and (
1359             (v == 'import' and p and p.type == syms.import_name)
1360             or (v == 'from' and p and p.type == syms.import_from)
1361         )
1362     )
1363
1364
1365 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1366     """Leave existing extra newlines if not `inside_brackets`.
1367
1368     Remove everything else.  Note: don't use backslashes for formatting or
1369     you'll lose your voting rights.
1370     """
1371     if not inside_brackets:
1372         spl = leaf.prefix.split('#')
1373         if '\\' not in spl[0]:
1374             nl_count = spl[-1].count('\n')
1375             if len(spl) > 1:
1376                 nl_count -= 1
1377             leaf.prefix = '\n' * nl_count
1378             return
1379
1380     leaf.prefix = ''
1381
1382
1383 def is_python36(node: Node) -> bool:
1384     """Returns True if the current file is using Python 3.6+ features.
1385
1386     Currently looking for:
1387     - f-strings; and
1388     - trailing commas after * or ** in function signatures.
1389     """
1390     for n in node.pre_order():
1391         if n.type == token.STRING:
1392             value_head = n.value[:2]  # type: ignore
1393             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1394                 return True
1395
1396         elif (
1397             n.type == syms.typedargslist
1398             and n.children
1399             and n.children[-1].type == token.COMMA
1400         ):
1401             for ch in n.children:
1402                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1403                     return True
1404
1405     return False
1406
1407
1408 PYTHON_EXTENSIONS = {'.py'}
1409 BLACKLISTED_DIRECTORIES = {
1410     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1411 }
1412
1413
1414 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1415     for child in path.iterdir():
1416         if child.is_dir():
1417             if child.name in BLACKLISTED_DIRECTORIES:
1418                 continue
1419
1420             yield from gen_python_files_in_dir(child)
1421
1422         elif child.suffix in PYTHON_EXTENSIONS:
1423             yield child
1424
1425
1426 @dataclass
1427 class Report:
1428     """Provides a reformatting counter."""
1429     change_count: int = 0
1430     same_count: int = 0
1431     failure_count: int = 0
1432
1433     def done(self, src: Path, changed: bool) -> None:
1434         """Increment the counter for successful reformatting. Write out a message."""
1435         if changed:
1436             out(f'reformatted {src}')
1437             self.change_count += 1
1438         else:
1439             out(f'{src} already well formatted, good job.', bold=False)
1440             self.same_count += 1
1441
1442     def failed(self, src: Path, message: str) -> None:
1443         """Increment the counter for failed reformatting. Write out a message."""
1444         err(f'error: cannot format {src}: {message}')
1445         self.failure_count += 1
1446
1447     @property
1448     def return_code(self) -> int:
1449         """Which return code should the app use considering the current state."""
1450         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1451         # 126 we have special returncodes reserved by the shell.
1452         if self.failure_count:
1453             return 123
1454
1455         elif self.change_count:
1456             return 1
1457
1458         return 0
1459
1460     def __str__(self) -> str:
1461         """A color report of the current state.
1462
1463         Use `click.unstyle` to remove colors.
1464         """
1465         report = []
1466         if self.change_count:
1467             s = 's' if self.change_count > 1 else ''
1468             report.append(
1469                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1470             )
1471         if self.same_count:
1472             s = 's' if self.same_count > 1 else ''
1473             report.append(f'{self.same_count} file{s} left unchanged')
1474         if self.failure_count:
1475             s = 's' if self.failure_count > 1 else ''
1476             report.append(
1477                 click.style(
1478                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1479                 )
1480             )
1481         return ', '.join(report) + '.'
1482
1483
1484 def assert_equivalent(src: str, dst: str) -> None:
1485     """Raises AssertionError if `src` and `dst` aren't equivalent.
1486
1487     This is a temporary sanity check until Black becomes stable.
1488     """
1489
1490     import ast
1491     import traceback
1492
1493     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1494         """Simple visitor generating strings to compare ASTs by content."""
1495         yield f"{'  ' * depth}{node.__class__.__name__}("
1496
1497         for field in sorted(node._fields):
1498             try:
1499                 value = getattr(node, field)
1500             except AttributeError:
1501                 continue
1502
1503             yield f"{'  ' * (depth+1)}{field}="
1504
1505             if isinstance(value, list):
1506                 for item in value:
1507                     if isinstance(item, ast.AST):
1508                         yield from _v(item, depth + 2)
1509
1510             elif isinstance(value, ast.AST):
1511                 yield from _v(value, depth + 2)
1512
1513             else:
1514                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1515
1516         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1517
1518     try:
1519         src_ast = ast.parse(src)
1520     except Exception as exc:
1521         raise AssertionError(f"cannot parse source: {exc}") from None
1522
1523     try:
1524         dst_ast = ast.parse(dst)
1525     except Exception as exc:
1526         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1527         raise AssertionError(
1528             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1529             f"Please report a bug on https://github.com/ambv/black/issues.  "
1530             f"This invalid output might be helpful: {log}"
1531         ) from None
1532
1533     src_ast_str = '\n'.join(_v(src_ast))
1534     dst_ast_str = '\n'.join(_v(dst_ast))
1535     if src_ast_str != dst_ast_str:
1536         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1537         raise AssertionError(
1538             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1539             f"the source.  "
1540             f"Please report a bug on https://github.com/ambv/black/issues.  "
1541             f"This diff might be helpful: {log}"
1542         ) from None
1543
1544
1545 def assert_stable(src: str, dst: str, line_length: int) -> None:
1546     """Raises AssertionError if `dst` reformats differently the second time.
1547
1548     This is a temporary sanity check until Black becomes stable.
1549     """
1550     newdst = format_str(dst, line_length=line_length)
1551     if dst != newdst:
1552         log = dump_to_file(
1553             diff(src, dst, 'source', 'first pass'),
1554             diff(dst, newdst, 'first pass', 'second pass'),
1555         )
1556         raise AssertionError(
1557             f"INTERNAL ERROR: Black produced different code on the second pass "
1558             f"of the formatter.  "
1559             f"Please report a bug on https://github.com/ambv/black/issues.  "
1560             f"This diff might be helpful: {log}"
1561         ) from None
1562
1563
1564 def dump_to_file(*output: str) -> str:
1565     """Dumps `output` to a temporary file. Returns path to the file."""
1566     import tempfile
1567
1568     with tempfile.NamedTemporaryFile(
1569         mode='w', prefix='blk_', suffix='.log', delete=False
1570     ) as f:
1571         for lines in output:
1572             f.write(lines)
1573             f.write('\n')
1574     return f.name
1575
1576
1577 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1578     """Returns a udiff string between strings `a` and `b`."""
1579     import difflib
1580
1581     a_lines = [line + '\n' for line in a.split('\n')]
1582     b_lines = [line + '\n' for line in b.split('\n')]
1583     return ''.join(
1584         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1585     )
1586
1587
1588 if __name__ == '__main__':
1589     main()