]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

b343da7b6d7363e52bc4a29e8715b247886a6a02
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import attrib, dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
78 )
79 @click.pass_context
80 def main(
81     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
82 ) -> None:
83     """The uncompromising code formatter."""
84     sources: List[Path] = []
85     for s in src:
86         p = Path(s)
87         if p.is_dir():
88             sources.extend(gen_python_files_in_dir(p))
89         elif p.is_file():
90             # if a file was explicitly given, we don't care about its extension
91             sources.append(p)
92         else:
93             err(f'invalid path: {s}')
94     if len(sources) == 0:
95         ctx.exit(0)
96     elif len(sources) == 1:
97         p = sources[0]
98         report = Report()
99         try:
100             changed = format_file_in_place(
101                 p, line_length=line_length, fast=fast, write_back=not check
102             )
103             report.done(p, changed)
104         except Exception as exc:
105             report.failed(p, str(exc))
106         ctx.exit(report.return_code)
107     else:
108         loop = asyncio.get_event_loop()
109         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
110         return_code = 1
111         try:
112             return_code = loop.run_until_complete(
113                 schedule_formatting(
114                     sources, line_length, not check, fast, loop, executor
115                 )
116             )
117         finally:
118             loop.close()
119             ctx.exit(return_code)
120
121
122 async def schedule_formatting(
123     sources: List[Path],
124     line_length: int,
125     write_back: bool,
126     fast: bool,
127     loop: BaseEventLoop,
128     executor: Executor,
129 ) -> int:
130     tasks = {
131         src: loop.run_in_executor(
132             executor, format_file_in_place, src, line_length, fast, write_back
133         )
134         for src in sources
135     }
136     await asyncio.wait(tasks.values())
137     cancelled = []
138     report = Report()
139     for src, task in tasks.items():
140         if not task.done():
141             report.failed(src, 'timed out, cancelling')
142             task.cancel()
143             cancelled.append(task)
144         elif task.exception():
145             report.failed(src, str(task.exception()))
146         else:
147             report.done(src, task.result())
148     if cancelled:
149         await asyncio.wait(cancelled, timeout=2)
150     out('All done! ✨ 🍰 ✨')
151     click.echo(str(report))
152     return report.return_code
153
154
155 def format_file_in_place(
156     src: Path, line_length: int, fast: bool, write_back: bool = False
157 ) -> bool:
158     """Format the file and rewrite if changed. Return True if changed."""
159     try:
160         contents, encoding = format_file(src, line_length=line_length, fast=fast)
161     except NothingChanged:
162         return False
163
164     if write_back:
165         with open(src, "w", encoding=encoding) as f:
166             f.write(contents)
167     return True
168
169
170 def format_file(
171     src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173     """Reformats a file and returns its contents and encoding."""
174     with tokenize.open(src) as src_buffer:
175         src_contents = src_buffer.read()
176     if src_contents.strip() == '':
177         raise NothingChanged(src)
178
179     dst_contents = format_str(src_contents, line_length=line_length)
180     if src_contents == dst_contents:
181         raise NothingChanged(src)
182
183     if not fast:
184         assert_equivalent(src_contents, dst_contents)
185         assert_stable(src_contents, dst_contents, line_length=line_length)
186     return dst_contents, src_buffer.encoding
187
188
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190     """Reformats a string and returns new contents."""
191     src_node = lib2to3_parse(src_contents)
192     dst_contents = ""
193     comments: List[Line] = []
194     lines = LineGenerator()
195     elt = EmptyLineTracker()
196     py36 = is_python36(src_node)
197     empty_line = Line()
198     after = 0
199     for current_line in lines.visit(src_node):
200         for _ in range(after):
201             dst_contents += str(empty_line)
202         before, after = elt.maybe_empty_lines(current_line)
203         for _ in range(before):
204             dst_contents += str(empty_line)
205         if not current_line.is_comment:
206             for comment in comments:
207                 dst_contents += str(comment)
208             comments = []
209             for line in split_line(current_line, line_length=line_length, py36=py36):
210                 dst_contents += str(line)
211         else:
212             comments.append(current_line)
213     for comment in comments:
214         dst_contents += str(comment)
215     return dst_contents
216
217
218 def lib2to3_parse(src_txt: str) -> Node:
219     """Given a string with source, return the lib2to3 Node."""
220     grammar = pygram.python_grammar_no_print_statement
221     drv = driver.Driver(grammar, pytree.convert)
222     if src_txt[-1] != '\n':
223         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
224         src_txt += nl
225     try:
226         result = drv.parse_string(src_txt, True)
227     except ParseError as pe:
228         lineno, column = pe.context[1]
229         lines = src_txt.splitlines()
230         try:
231             faulty_line = lines[lineno - 1]
232         except IndexError:
233             faulty_line = "<line number missing in source>"
234         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
235
236     if isinstance(result, Leaf):
237         result = Node(syms.file_input, [result])
238     return result
239
240
241 def lib2to3_unparse(node: Node) -> str:
242     """Given a lib2to3 node, return its string representation."""
243     code = str(node)
244     return code
245
246
247 T = TypeVar('T')
248
249
250 class Visitor(Generic[T]):
251     """Basic lib2to3 visitor that yields things on visiting."""
252
253     def visit(self, node: LN) -> Iterator[T]:
254         if node.type < 256:
255             name = token.tok_name[node.type]
256         else:
257             name = type_repr(node.type)
258         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
259
260     def visit_default(self, node: LN) -> Iterator[T]:
261         if isinstance(node, Node):
262             for child in node.children:
263                 yield from self.visit(child)
264
265
266 @dataclass
267 class DebugVisitor(Visitor[T]):
268     tree_depth: int = attrib(default=0)
269
270     def visit_default(self, node: LN) -> Iterator[T]:
271         indent = ' ' * (2 * self.tree_depth)
272         if isinstance(node, Node):
273             _type = type_repr(node.type)
274             out(f'{indent}{_type}', fg='yellow')
275             self.tree_depth += 1
276             for child in node.children:
277                 yield from self.visit(child)
278
279             self.tree_depth -= 1
280             out(f'{indent}/{_type}', fg='yellow', bold=False)
281         else:
282             _type = token.tok_name.get(node.type, str(node.type))
283             out(f'{indent}{_type}', fg='blue', nl=False)
284             if node.prefix:
285                 # We don't have to handle prefixes for `Node` objects since
286                 # that delegates to the first child anyway.
287                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
288             out(f' {node.value!r}', fg='blue', bold=False)
289
290
291 KEYWORDS = set(keyword.kwlist)
292 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
293 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
294 STATEMENT = {
295     syms.if_stmt,
296     syms.while_stmt,
297     syms.for_stmt,
298     syms.try_stmt,
299     syms.except_clause,
300     syms.with_stmt,
301     syms.funcdef,
302     syms.classdef,
303 }
304 STANDALONE_COMMENT = 153
305 LOGIC_OPERATORS = {'and', 'or'}
306 COMPARATORS = {
307     token.LESS,
308     token.GREATER,
309     token.EQEQUAL,
310     token.NOTEQUAL,
311     token.LESSEQUAL,
312     token.GREATEREQUAL,
313 }
314 MATH_OPERATORS = {
315     token.PLUS,
316     token.MINUS,
317     token.STAR,
318     token.SLASH,
319     token.VBAR,
320     token.AMPER,
321     token.PERCENT,
322     token.CIRCUMFLEX,
323     token.LEFTSHIFT,
324     token.RIGHTSHIFT,
325     token.DOUBLESTAR,
326     token.DOUBLESLASH,
327 }
328 COMPREHENSION_PRIORITY = 20
329 COMMA_PRIORITY = 10
330 LOGIC_PRIORITY = 5
331 STRING_PRIORITY = 4
332 COMPARATOR_PRIORITY = 3
333 MATH_PRIORITY = 1
334
335
336 @dataclass
337 class BracketTracker:
338     depth: int = attrib(default=0)
339     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
340     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
341     previous: Optional[Leaf] = attrib(default=None)
342
343     def mark(self, leaf: Leaf) -> None:
344         if leaf.type == token.COMMENT:
345             return
346
347         if leaf.type in CLOSING_BRACKETS:
348             self.depth -= 1
349             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
350             leaf.opening_bracket = opening_bracket
351         leaf.bracket_depth = self.depth
352         if self.depth == 0:
353             delim = is_delimiter(leaf)
354             if delim:
355                 self.delimiters[id(leaf)] = delim
356             elif self.previous is not None:
357                 if leaf.type == token.STRING and self.previous.type == token.STRING:
358                     self.delimiters[id(self.previous)] = STRING_PRIORITY
359                 elif (
360                     leaf.type == token.NAME and
361                     leaf.value == 'for' and
362                     leaf.parent and
363                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
364                 ):
365                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
366                 elif (
367                     leaf.type == token.NAME and
368                     leaf.value == 'if' and
369                     leaf.parent and
370                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
371                 ):
372                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
373         if leaf.type in OPENING_BRACKETS:
374             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
375             self.depth += 1
376         self.previous = leaf
377
378     def any_open_brackets(self) -> bool:
379         """Returns True if there is an yet unmatched open bracket on the line."""
380         return bool(self.bracket_match)
381
382     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
383         """Returns the highest priority of a delimiter found on the line.
384
385         Values are consistent with what `is_delimiter()` returns.
386         """
387         return max(v for k, v in self.delimiters.items() if k not in exclude)
388
389
390 @dataclass
391 class Line:
392     depth: int = attrib(default=0)
393     leaves: List[Leaf] = attrib(default=Factory(list))
394     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
395     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
396     inside_brackets: bool = attrib(default=False)
397     has_for: bool = attrib(default=False)
398     _for_loop_variable: bool = attrib(default=False, init=False)
399
400     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
401         has_value = leaf.value.strip()
402         if not has_value:
403             return
404
405         if self.leaves and not preformatted:
406             # Note: at this point leaf.prefix should be empty except for
407             # imports, for which we only preserve newlines.
408             leaf.prefix += whitespace(leaf)
409         if self.inside_brackets or not preformatted:
410             self.maybe_decrement_after_for_loop_variable(leaf)
411             self.bracket_tracker.mark(leaf)
412             self.maybe_remove_trailing_comma(leaf)
413             self.maybe_increment_for_loop_variable(leaf)
414             if self.maybe_adapt_standalone_comment(leaf):
415                 return
416
417         if not self.append_comment(leaf):
418             self.leaves.append(leaf)
419
420     @property
421     def is_comment(self) -> bool:
422         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
423
424     @property
425     def is_decorator(self) -> bool:
426         return bool(self) and self.leaves[0].type == token.AT
427
428     @property
429     def is_import(self) -> bool:
430         return bool(self) and is_import(self.leaves[0])
431
432     @property
433     def is_class(self) -> bool:
434         return (
435             bool(self) and
436             self.leaves[0].type == token.NAME and
437             self.leaves[0].value == 'class'
438         )
439
440     @property
441     def is_def(self) -> bool:
442         """Also returns True for async defs."""
443         try:
444             first_leaf = self.leaves[0]
445         except IndexError:
446             return False
447
448         try:
449             second_leaf: Optional[Leaf] = self.leaves[1]
450         except IndexError:
451             second_leaf = None
452         return (
453             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
454             (
455                 first_leaf.type == token.NAME and
456                 first_leaf.value == 'async' and
457                 second_leaf is not None and
458                 second_leaf.type == token.NAME and
459                 second_leaf.value == 'def'
460             )
461         )
462
463     @property
464     def is_flow_control(self) -> bool:
465         return (
466             bool(self) and
467             self.leaves[0].type == token.NAME and
468             self.leaves[0].value in FLOW_CONTROL
469         )
470
471     @property
472     def is_yield(self) -> bool:
473         return (
474             bool(self) and
475             self.leaves[0].type == token.NAME and
476             self.leaves[0].value == 'yield'
477         )
478
479     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
480         if not (
481             self.leaves and
482             self.leaves[-1].type == token.COMMA and
483             closing.type in CLOSING_BRACKETS
484         ):
485             return False
486
487         if closing.type == token.RSQB or closing.type == token.RBRACE:
488             self.leaves.pop()
489             return True
490
491         # For parens let's check if it's safe to remove the comma.  If the
492         # trailing one is the only one, we might mistakenly change a tuple
493         # into a different type by removing the comma.
494         depth = closing.bracket_depth + 1
495         commas = 0
496         opening = closing.opening_bracket
497         for _opening_index, leaf in enumerate(self.leaves):
498             if leaf is opening:
499                 break
500
501         else:
502             return False
503
504         for leaf in self.leaves[_opening_index + 1:]:
505             if leaf is closing:
506                 break
507
508             bracket_depth = leaf.bracket_depth
509             if bracket_depth == depth and leaf.type == token.COMMA:
510                 commas += 1
511         if commas > 1:
512             self.leaves.pop()
513             return True
514
515         return False
516
517     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
518         """In a for loop, or comprehension, the variables are often unpacks.
519
520         To avoid splitting on the comma in this situation, we will increase
521         the depth of tokens between `for` and `in`.
522         """
523         if leaf.type == token.NAME and leaf.value == 'for':
524             self.has_for = True
525             self.bracket_tracker.depth += 1
526             self._for_loop_variable = True
527             return True
528
529         return False
530
531     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
532         # See `maybe_increment_for_loop_variable` above for explanation.
533         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
534             self.bracket_tracker.depth -= 1
535             self._for_loop_variable = False
536             return True
537
538         return False
539
540     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
541         """Hack a standalone comment to act as a trailing comment for line splitting.
542
543         If this line has brackets and a standalone `comment`, we need to adapt
544         it to be able to still reformat the line.
545
546         This is not perfect, the line to which the standalone comment gets
547         appended will appear "too long" when splitting.
548         """
549         if not (
550             comment.type == STANDALONE_COMMENT and
551             self.bracket_tracker.any_open_brackets()
552         ):
553             return False
554
555         comment.type = token.COMMENT
556         comment.prefix = '\n' + '    ' * (self.depth + 1)
557         return self.append_comment(comment)
558
559     def append_comment(self, comment: Leaf) -> bool:
560         if comment.type != token.COMMENT:
561             return False
562
563         try:
564             after = id(self.last_non_delimiter())
565         except LookupError:
566             comment.type = STANDALONE_COMMENT
567             comment.prefix = ''
568             return False
569
570         else:
571             if after in self.comments:
572                 self.comments[after].value += str(comment)
573             else:
574                 self.comments[after] = comment
575             return True
576
577     def last_non_delimiter(self) -> Leaf:
578         for i in range(len(self.leaves)):
579             last = self.leaves[-i - 1]
580             if not is_delimiter(last):
581                 return last
582
583         raise LookupError("No non-delimiters found")
584
585     def __str__(self) -> str:
586         if not self:
587             return '\n'
588
589         indent = '    ' * self.depth
590         leaves = iter(self.leaves)
591         first = next(leaves)
592         res = f'{first.prefix}{indent}{first.value}'
593         for leaf in leaves:
594             res += str(leaf)
595         for comment in self.comments.values():
596             res += str(comment)
597         return res + '\n'
598
599     def __bool__(self) -> bool:
600         return bool(self.leaves or self.comments)
601
602
603 @dataclass
604 class EmptyLineTracker:
605     """Provides a stateful method that returns the number of potential extra
606     empty lines needed before and after the currently processed line.
607
608     Note: this tracker works on lines that haven't been split yet.
609     """
610     previous_line: Optional[Line] = attrib(default=None)
611     previous_after: int = attrib(default=0)
612     previous_defs: List[int] = attrib(default=Factory(list))
613
614     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
615         """Returns the number of extra empty lines before and after the `current_line`.
616
617         This is for separating `def`, `async def` and `class` with extra empty lines
618         (two on module-level), as well as providing an extra empty line after flow
619         control keywords to make them more prominent.
620         """
621         before, after = self._maybe_empty_lines(current_line)
622         self.previous_after = after
623         self.previous_line = current_line
624         return before, after
625
626     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
627         before = 0
628         depth = current_line.depth
629         while self.previous_defs and self.previous_defs[-1] >= depth:
630             self.previous_defs.pop()
631             before = (1 if depth else 2) - self.previous_after
632         is_decorator = current_line.is_decorator
633         if is_decorator or current_line.is_def or current_line.is_class:
634             if not is_decorator:
635                 self.previous_defs.append(depth)
636             if self.previous_line is None:
637                 # Don't insert empty lines before the first line in the file.
638                 return 0, 0
639
640             if self.previous_line and self.previous_line.is_decorator:
641                 # Don't insert empty lines between decorators.
642                 return 0, 0
643
644             newlines = 2
645             if current_line.depth:
646                 newlines -= 1
647             newlines -= self.previous_after
648             return newlines, 0
649
650         if current_line.is_flow_control:
651             return before, 1
652
653         if (
654             self.previous_line and
655             self.previous_line.is_import and
656             not current_line.is_import and
657             depth == self.previous_line.depth
658         ):
659             return (before or 1), 0
660
661         if (
662             self.previous_line and
663             self.previous_line.is_yield and
664             (not current_line.is_yield or depth != self.previous_line.depth)
665         ):
666             return (before or 1), 0
667
668         return before, 0
669
670
671 @dataclass
672 class LineGenerator(Visitor[Line]):
673     """Generates reformatted Line objects.  Empty lines are not emitted.
674
675     Note: destroys the tree it's visiting by mutating prefixes of its leaves
676     in ways that will no longer stringify to valid Python code on the tree.
677     """
678     current_line: Line = attrib(default=Factory(Line))
679     standalone_comments: List[Leaf] = attrib(default=Factory(list))
680
681     def line(self, indent: int = 0) -> Iterator[Line]:
682         """Generate a line.
683
684         If the line is empty, only emit if it makes sense.
685         If the line is too long, split it first and then generate.
686
687         If any lines were generated, set up a new current_line.
688         """
689         if not self.current_line:
690             self.current_line.depth += indent
691             return  # Line is empty, don't emit. Creating a new one unnecessary.
692
693         complete_line = self.current_line
694         self.current_line = Line(depth=complete_line.depth + indent)
695         yield complete_line
696
697     def visit_default(self, node: LN) -> Iterator[Line]:
698         if isinstance(node, Leaf):
699             for comment in generate_comments(node):
700                 if self.current_line.bracket_tracker.any_open_brackets():
701                     # any comment within brackets is subject to splitting
702                     self.current_line.append(comment)
703                 elif comment.type == token.COMMENT:
704                     # regular trailing comment
705                     self.current_line.append(comment)
706                     yield from self.line()
707
708                 else:
709                     # regular standalone comment, to be processed later (see
710                     # docstring in `generate_comments()`
711                     self.standalone_comments.append(comment)
712             normalize_prefix(node)
713             if node.type not in WHITESPACE:
714                 for comment in self.standalone_comments:
715                     yield from self.line()
716
717                     self.current_line.append(comment)
718                     yield from self.line()
719
720                 self.standalone_comments = []
721                 self.current_line.append(node)
722         yield from super().visit_default(node)
723
724     def visit_suite(self, node: Node) -> Iterator[Line]:
725         """Body of a statement after a colon."""
726         children = iter(node.children)
727         # Process newline before indenting.  It might contain an inline
728         # comment that should go right after the colon.
729         newline = next(children)
730         yield from self.visit(newline)
731         yield from self.line(+1)
732
733         for child in children:
734             yield from self.visit(child)
735
736         yield from self.line(-1)
737
738     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
739         """Visit a statement.
740
741         The relevant Python language keywords for this statement are NAME leaves
742         within it.
743         """
744         for child in node.children:
745             if child.type == token.NAME and child.value in keywords:  # type: ignore
746                 yield from self.line()
747
748             yield from self.visit(child)
749
750     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
751         """A statement without nested statements."""
752         is_suite_like = node.parent and node.parent.type in STATEMENT
753         if is_suite_like:
754             yield from self.line(+1)
755             yield from self.visit_default(node)
756             yield from self.line(-1)
757
758         else:
759             yield from self.line()
760             yield from self.visit_default(node)
761
762     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
763         yield from self.line()
764
765         children = iter(node.children)
766         for child in children:
767             yield from self.visit(child)
768
769             if child.type == token.NAME and child.value == 'async':  # type: ignore
770                 break
771
772         internal_stmt = next(children)
773         for child in internal_stmt.children:
774             yield from self.visit(child)
775
776     def visit_decorators(self, node: Node) -> Iterator[Line]:
777         for child in node.children:
778             yield from self.line()
779             yield from self.visit(child)
780
781     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
782         yield from self.line()
783
784     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
785         yield from self.visit_default(leaf)
786         yield from self.line()
787
788     def __attrs_post_init__(self) -> None:
789         """You are in a twisty little maze of passages."""
790         v = self.visit_stmt
791         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
792         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
793         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
794         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
795         self.visit_except_clause = partial(v, keywords={'except'})
796         self.visit_funcdef = partial(v, keywords={'def'})
797         self.visit_with_stmt = partial(v, keywords={'with'})
798         self.visit_classdef = partial(v, keywords={'class'})
799         self.visit_async_funcdef = self.visit_async_stmt
800         self.visit_decorated = self.visit_decorators
801
802
803 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
804 OPENING_BRACKETS = set(BRACKET.keys())
805 CLOSING_BRACKETS = set(BRACKET.values())
806 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
807 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
808
809
810 def whitespace(leaf: Leaf) -> str:  # noqa C901
811     """Return whitespace prefix if needed for the given `leaf`."""
812     NO = ''
813     SPACE = ' '
814     DOUBLESPACE = '  '
815     t = leaf.type
816     p = leaf.parent
817     v = leaf.value
818     if t in ALWAYS_NO_SPACE:
819         return NO
820
821     if t == token.COMMENT:
822         return DOUBLESPACE
823
824     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
825     prev = leaf.prev_sibling
826     if not prev:
827         prevp = preceding_leaf(p)
828         if not prevp or prevp.type in OPENING_BRACKETS:
829             return NO
830
831         if prevp.type == token.EQUAL:
832             if prevp.parent and prevp.parent.type in {
833                 syms.typedargslist,
834                 syms.varargslist,
835                 syms.parameters,
836                 syms.arglist,
837                 syms.argument,
838             }:
839                 return NO
840
841         elif prevp.type == token.DOUBLESTAR:
842             if prevp.parent and prevp.parent.type in {
843                 syms.typedargslist,
844                 syms.varargslist,
845                 syms.parameters,
846                 syms.arglist,
847                 syms.dictsetmaker,
848             }:
849                 return NO
850
851         elif prevp.type == token.COLON:
852             if prevp.parent and prevp.parent.type == syms.subscript:
853                 return NO
854
855         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
856             return NO
857
858     elif prev.type in OPENING_BRACKETS:
859         return NO
860
861     if p.type in {syms.parameters, syms.arglist}:
862         # untyped function signatures or calls
863         if t == token.RPAR:
864             return NO
865
866         if not prev or prev.type != token.COMMA:
867             return NO
868
869     if p.type == syms.varargslist:
870         # lambdas
871         if t == token.RPAR:
872             return NO
873
874         if prev and prev.type != token.COMMA:
875             return NO
876
877     elif p.type == syms.typedargslist:
878         # typed function signatures
879         if not prev:
880             return NO
881
882         if t == token.EQUAL:
883             if prev.type != syms.tname:
884                 return NO
885
886         elif prev.type == token.EQUAL:
887             # A bit hacky: if the equal sign has whitespace, it means we
888             # previously found it's a typed argument.  So, we're using that, too.
889             return prev.prefix
890
891         elif prev.type != token.COMMA:
892             return NO
893
894     elif p.type == syms.tname:
895         # type names
896         if not prev:
897             prevp = preceding_leaf(p)
898             if not prevp or prevp.type != token.COMMA:
899                 return NO
900
901     elif p.type == syms.trailer:
902         # attributes and calls
903         if t == token.LPAR or t == token.RPAR:
904             return NO
905
906         if not prev:
907             if t == token.DOT:
908                 prevp = preceding_leaf(p)
909                 if not prevp or prevp.type != token.NUMBER:
910                     return NO
911
912             elif t == token.LSQB:
913                 return NO
914
915         elif prev.type != token.COMMA:
916             return NO
917
918     elif p.type == syms.argument:
919         # single argument
920         if t == token.EQUAL:
921             return NO
922
923         if not prev:
924             prevp = preceding_leaf(p)
925             if not prevp or prevp.type == token.LPAR:
926                 return NO
927
928         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
929             return NO
930
931     elif p.type == syms.decorator:
932         # decorators
933         return NO
934
935     elif p.type == syms.dotted_name:
936         if prev:
937             return NO
938
939         prevp = preceding_leaf(p)
940         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
941             return NO
942
943     elif p.type == syms.classdef:
944         if t == token.LPAR:
945             return NO
946
947         if prev and prev.type == token.LPAR:
948             return NO
949
950     elif p.type == syms.subscript:
951         # indexing
952         if not prev:
953             assert p.parent is not None, "subscripts are always parented"
954             if p.parent.type == syms.subscriptlist:
955                 return SPACE
956
957             return NO
958
959         elif prev.type == token.COLON:
960             return NO
961
962     elif p.type == syms.atom:
963         if prev and t == token.DOT:
964             # dots, but not the first one.
965             return NO
966
967     elif (
968         p.type == syms.listmaker or
969         p.type == syms.testlist_gexp or
970         p.type == syms.subscriptlist
971     ):
972         # list interior, including unpacking
973         if not prev:
974             return NO
975
976     elif p.type == syms.dictsetmaker:
977         # dict and set interior, including unpacking
978         if not prev:
979             return NO
980
981         if prev.type == token.DOUBLESTAR:
982             return NO
983
984     elif p.type in {syms.factor, syms.star_expr}:
985         # unary ops
986         if not prev:
987             prevp = preceding_leaf(p)
988             if not prevp or prevp.type in OPENING_BRACKETS:
989                 return NO
990
991             prevp_parent = prevp.parent
992             assert prevp_parent is not None
993             if prevp.type == token.COLON and prevp_parent.type in {
994                 syms.subscript, syms.sliceop
995             }:
996                 return NO
997
998             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
999                 return NO
1000
1001         elif t == token.NAME or t == token.NUMBER:
1002             return NO
1003
1004     elif p.type == syms.import_from:
1005         if t == token.DOT:
1006             if prev and prev.type == token.DOT:
1007                 return NO
1008
1009         elif t == token.NAME:
1010             if v == 'import':
1011                 return SPACE
1012
1013             if prev and prev.type == token.DOT:
1014                 return NO
1015
1016     elif p.type == syms.sliceop:
1017         return NO
1018
1019     return SPACE
1020
1021
1022 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1023     """Returns the first leaf that precedes `node`, if any."""
1024     while node:
1025         res = node.prev_sibling
1026         if res:
1027             if isinstance(res, Leaf):
1028                 return res
1029
1030             try:
1031                 return list(res.leaves())[-1]
1032
1033             except IndexError:
1034                 return None
1035
1036         node = node.parent
1037     return None
1038
1039
1040 def is_delimiter(leaf: Leaf) -> int:
1041     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1042
1043     Higher numbers are higher priority.
1044     """
1045     if leaf.type == token.COMMA:
1046         return COMMA_PRIORITY
1047
1048     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1049         return LOGIC_PRIORITY
1050
1051     if leaf.type in COMPARATORS:
1052         return COMPARATOR_PRIORITY
1053
1054     if (
1055         leaf.type in MATH_OPERATORS and
1056         leaf.parent and
1057         leaf.parent.type not in {syms.factor, syms.star_expr}
1058     ):
1059         return MATH_PRIORITY
1060
1061     return 0
1062
1063
1064 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1065     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1066
1067     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1068     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1069     move because it does away with modifying the grammar to include all the
1070     possible places in which comments can be placed.
1071
1072     The sad consequence for us though is that comments don't "belong" anywhere.
1073     This is why this function generates simple parentless Leaf objects for
1074     comments.  We simply don't know what the correct parent should be.
1075
1076     No matter though, we can live without this.  We really only need to
1077     differentiate between inline and standalone comments.  The latter don't
1078     share the line with any code.
1079
1080     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1081     are emitted with a fake STANDALONE_COMMENT token identifier.
1082     """
1083     if not leaf.prefix:
1084         return
1085
1086     if '#' not in leaf.prefix:
1087         return
1088
1089     before_comment, content = leaf.prefix.split('#', 1)
1090     content = content.rstrip()
1091     if content and (content[0] not in {' ', '!', '#'}):
1092         content = ' ' + content
1093     is_standalone_comment = (
1094         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1095     )
1096     if not is_standalone_comment:
1097         # simple trailing comment
1098         yield Leaf(token.COMMENT, value='#' + content)
1099         return
1100
1101     for line in ('#' + content).split('\n'):
1102         line = line.lstrip()
1103         if not line.startswith('#'):
1104             continue
1105
1106         yield Leaf(STANDALONE_COMMENT, line)
1107
1108
1109 def split_line(
1110     line: Line, line_length: int, inner: bool = False, py36: bool = False
1111 ) -> Iterator[Line]:
1112     """Splits a `line` into potentially many lines.
1113
1114     They should fit in the allotted `line_length` but might not be able to.
1115     `inner` signifies that there were a pair of brackets somewhere around the
1116     current `line`, possibly transitively. This means we can fallback to splitting
1117     by delimiters if the LHS/RHS don't yield any results.
1118
1119     If `py36` is True, splitting may generate syntax that is only compatible
1120     with Python 3.6 and later.
1121     """
1122     line_str = str(line).strip('\n')
1123     if len(line_str) <= line_length and '\n' not in line_str:
1124         yield line
1125         return
1126
1127     if line.is_def:
1128         split_funcs = [left_hand_split]
1129     elif line.inside_brackets:
1130         split_funcs = [delimiter_split]
1131         if '\n' not in line_str:
1132             # Only attempt RHS if we don't have multiline strings or comments
1133             # on this line.
1134             split_funcs.append(right_hand_split)
1135     else:
1136         split_funcs = [right_hand_split]
1137     for split_func in split_funcs:
1138         # We are accumulating lines in `result` because we might want to abort
1139         # mission and return the original line in the end, or attempt a different
1140         # split altogether.
1141         result: List[Line] = []
1142         try:
1143             for l in split_func(line, py36=py36):
1144                 if str(l).strip('\n') == line_str:
1145                     raise CannotSplit("Split function returned an unchanged result")
1146
1147                 result.extend(
1148                     split_line(l, line_length=line_length, inner=True, py36=py36)
1149                 )
1150         except CannotSplit as cs:
1151             continue
1152
1153         else:
1154             yield from result
1155             break
1156
1157     else:
1158         yield line
1159
1160
1161 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1162     """Split line into many lines, starting with the first matching bracket pair.
1163
1164     Note: this usually looks weird, only use this for function definitions.
1165     Prefer RHS otherwise.
1166     """
1167     head = Line(depth=line.depth)
1168     body = Line(depth=line.depth + 1, inside_brackets=True)
1169     tail = Line(depth=line.depth)
1170     tail_leaves: List[Leaf] = []
1171     body_leaves: List[Leaf] = []
1172     head_leaves: List[Leaf] = []
1173     current_leaves = head_leaves
1174     matching_bracket = None
1175     for leaf in line.leaves:
1176         if (
1177             current_leaves is body_leaves and
1178             leaf.type in CLOSING_BRACKETS and
1179             leaf.opening_bracket is matching_bracket
1180         ):
1181             current_leaves = tail_leaves if body_leaves else head_leaves
1182         current_leaves.append(leaf)
1183         if current_leaves is head_leaves:
1184             if leaf.type in OPENING_BRACKETS:
1185                 matching_bracket = leaf
1186                 current_leaves = body_leaves
1187     # Since body is a new indent level, remove spurious leading whitespace.
1188     if body_leaves:
1189         normalize_prefix(body_leaves[0])
1190     # Build the new lines.
1191     for result, leaves in (
1192         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1193     ):
1194         for leaf in leaves:
1195             result.append(leaf, preformatted=True)
1196             comment_after = line.comments.get(id(leaf))
1197             if comment_after:
1198                 result.append(comment_after, preformatted=True)
1199     split_succeeded_or_raise(head, body, tail)
1200     for result in (head, body, tail):
1201         if result:
1202             yield result
1203
1204
1205 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1206     """Split line into many lines, starting with the last matching bracket pair."""
1207     head = Line(depth=line.depth)
1208     body = Line(depth=line.depth + 1, inside_brackets=True)
1209     tail = Line(depth=line.depth)
1210     tail_leaves: List[Leaf] = []
1211     body_leaves: List[Leaf] = []
1212     head_leaves: List[Leaf] = []
1213     current_leaves = tail_leaves
1214     opening_bracket = None
1215     for leaf in reversed(line.leaves):
1216         if current_leaves is body_leaves:
1217             if leaf is opening_bracket:
1218                 current_leaves = head_leaves if body_leaves else tail_leaves
1219         current_leaves.append(leaf)
1220         if current_leaves is tail_leaves:
1221             if leaf.type in CLOSING_BRACKETS:
1222                 opening_bracket = leaf.opening_bracket
1223                 current_leaves = body_leaves
1224     tail_leaves.reverse()
1225     body_leaves.reverse()
1226     head_leaves.reverse()
1227     # Since body is a new indent level, remove spurious leading whitespace.
1228     if body_leaves:
1229         normalize_prefix(body_leaves[0])
1230     # Build the new lines.
1231     for result, leaves in (
1232         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1233     ):
1234         for leaf in leaves:
1235             result.append(leaf, preformatted=True)
1236             comment_after = line.comments.get(id(leaf))
1237             if comment_after:
1238                 result.append(comment_after, preformatted=True)
1239     split_succeeded_or_raise(head, body, tail)
1240     for result in (head, body, tail):
1241         if result:
1242             yield result
1243
1244
1245 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1246     tail_len = len(str(tail).strip())
1247     if not body:
1248         if tail_len == 0:
1249             raise CannotSplit("Splitting brackets produced the same line")
1250
1251         elif tail_len < 3:
1252             raise CannotSplit(
1253                 f"Splitting brackets on an empty body to save "
1254                 f"{tail_len} characters is not worth it"
1255             )
1256
1257
1258 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1259     """Split according to delimiters of the highest priority.
1260
1261     This kind of split doesn't increase indentation.
1262     If `py36` is True, the split will add trailing commas also in function
1263     signatures that contain * and **.
1264     """
1265     try:
1266         last_leaf = line.leaves[-1]
1267     except IndexError:
1268         raise CannotSplit("Line empty")
1269
1270     delimiters = line.bracket_tracker.delimiters
1271     try:
1272         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1273     except ValueError:
1274         raise CannotSplit("No delimiters found")
1275
1276     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1277     lowest_depth = sys.maxsize
1278     trailing_comma_safe = True
1279     for leaf in line.leaves:
1280         current_line.append(leaf, preformatted=True)
1281         comment_after = line.comments.get(id(leaf))
1282         if comment_after:
1283             current_line.append(comment_after, preformatted=True)
1284         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1285         if (
1286             leaf.bracket_depth == lowest_depth and
1287             leaf.type == token.STAR or
1288             leaf.type == token.DOUBLESTAR
1289         ):
1290             trailing_comma_safe = trailing_comma_safe and py36
1291         leaf_priority = delimiters.get(id(leaf))
1292         if leaf_priority == delimiter_priority:
1293             normalize_prefix(current_line.leaves[0])
1294             yield current_line
1295
1296             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1297     if current_line:
1298         if (
1299             delimiter_priority == COMMA_PRIORITY and
1300             current_line.leaves[-1].type != token.COMMA and
1301             trailing_comma_safe
1302         ):
1303             current_line.append(Leaf(token.COMMA, ','))
1304         normalize_prefix(current_line.leaves[0])
1305         yield current_line
1306
1307
1308 def is_import(leaf: Leaf) -> bool:
1309     """Returns True if the given leaf starts an import statement."""
1310     p = leaf.parent
1311     t = leaf.type
1312     v = leaf.value
1313     return bool(
1314         t == token.NAME and
1315         (
1316             (v == 'import' and p and p.type == syms.import_name) or
1317             (v == 'from' and p and p.type == syms.import_from)
1318         )
1319     )
1320
1321
1322 def normalize_prefix(leaf: Leaf) -> None:
1323     """Leave existing extra newlines for imports.  Remove everything else."""
1324     if is_import(leaf):
1325         spl = leaf.prefix.split('#', 1)
1326         nl_count = spl[0].count('\n')
1327         if len(spl) > 1:
1328             # Skip one newline since it was for a standalone comment.
1329             nl_count -= 1
1330         leaf.prefix = '\n' * nl_count
1331         return
1332
1333     leaf.prefix = ''
1334
1335
1336 def is_python36(node: Node) -> bool:
1337     """Returns True if the current file is using Python 3.6+ features.
1338
1339     Currently looking for:
1340     - f-strings; and
1341     - trailing commas after * or ** in function signatures.
1342     """
1343     for n in node.pre_order():
1344         if n.type == token.STRING:
1345             value_head = n.value[:2]  # type: ignore
1346             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1347                 return True
1348
1349         elif (
1350             n.type == syms.typedargslist and
1351             n.children and
1352             n.children[-1].type == token.COMMA
1353         ):
1354             for ch in n.children:
1355                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1356                     return True
1357
1358     return False
1359
1360
1361 PYTHON_EXTENSIONS = {'.py'}
1362 BLACKLISTED_DIRECTORIES = {
1363     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1364 }
1365
1366
1367 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1368     for child in path.iterdir():
1369         if child.is_dir():
1370             if child.name in BLACKLISTED_DIRECTORIES:
1371                 continue
1372
1373             yield from gen_python_files_in_dir(child)
1374
1375         elif child.suffix in PYTHON_EXTENSIONS:
1376             yield child
1377
1378
1379 @dataclass
1380 class Report:
1381     """Provides a reformatting counter."""
1382     change_count: int = attrib(default=0)
1383     same_count: int = attrib(default=0)
1384     failure_count: int = attrib(default=0)
1385
1386     def done(self, src: Path, changed: bool) -> None:
1387         """Increment the counter for successful reformatting. Write out a message."""
1388         if changed:
1389             out(f'reformatted {src}')
1390             self.change_count += 1
1391         else:
1392             out(f'{src} already well formatted, good job.', bold=False)
1393             self.same_count += 1
1394
1395     def failed(self, src: Path, message: str) -> None:
1396         """Increment the counter for failed reformatting. Write out a message."""
1397         err(f'error: cannot format {src}: {message}')
1398         self.failure_count += 1
1399
1400     @property
1401     def return_code(self) -> int:
1402         """Which return code should the app use considering the current state."""
1403         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1404         # 126 we have special returncodes reserved by the shell.
1405         if self.failure_count:
1406             return 123
1407
1408         elif self.change_count:
1409             return 1
1410
1411         return 0
1412
1413     def __str__(self) -> str:
1414         """A color report of the current state.
1415
1416         Use `click.unstyle` to remove colors.
1417         """
1418         report = []
1419         if self.change_count:
1420             s = 's' if self.change_count > 1 else ''
1421             report.append(
1422                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1423             )
1424         if self.same_count:
1425             s = 's' if self.same_count > 1 else ''
1426             report.append(f'{self.same_count} file{s} left unchanged')
1427         if self.failure_count:
1428             s = 's' if self.failure_count > 1 else ''
1429             report.append(
1430                 click.style(
1431                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1432                 )
1433             )
1434         return ', '.join(report) + '.'
1435
1436
1437 def assert_equivalent(src: str, dst: str) -> None:
1438     """Raises AssertionError if `src` and `dst` aren't equivalent.
1439
1440     This is a temporary sanity check until Black becomes stable.
1441     """
1442
1443     import ast
1444     import traceback
1445
1446     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1447         """Simple visitor generating strings to compare ASTs by content."""
1448         yield f"{'  ' * depth}{node.__class__.__name__}("
1449
1450         for field in sorted(node._fields):
1451             try:
1452                 value = getattr(node, field)
1453             except AttributeError:
1454                 continue
1455
1456             yield f"{'  ' * (depth+1)}{field}="
1457
1458             if isinstance(value, list):
1459                 for item in value:
1460                     if isinstance(item, ast.AST):
1461                         yield from _v(item, depth + 2)
1462
1463             elif isinstance(value, ast.AST):
1464                 yield from _v(value, depth + 2)
1465
1466             else:
1467                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1468
1469         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1470
1471     try:
1472         src_ast = ast.parse(src)
1473     except Exception as exc:
1474         raise AssertionError(f"cannot parse source: {exc}") from None
1475
1476     try:
1477         dst_ast = ast.parse(dst)
1478     except Exception as exc:
1479         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1480         raise AssertionError(
1481             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1482             f"Please report a bug on https://github.com/ambv/black/issues.  "
1483             f"This invalid output might be helpful: {log}",
1484         ) from None
1485
1486     src_ast_str = '\n'.join(_v(src_ast))
1487     dst_ast_str = '\n'.join(_v(dst_ast))
1488     if src_ast_str != dst_ast_str:
1489         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1490         raise AssertionError(
1491             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1492             f"the source.  "
1493             f"Please report a bug on https://github.com/ambv/black/issues.  "
1494             f"This diff might be helpful: {log}",
1495         ) from None
1496
1497
1498 def assert_stable(src: str, dst: str, line_length: int) -> None:
1499     """Raises AssertionError if `dst` reformats differently the second time.
1500
1501     This is a temporary sanity check until Black becomes stable.
1502     """
1503     newdst = format_str(dst, line_length=line_length)
1504     if dst != newdst:
1505         log = dump_to_file(
1506             diff(src, dst, 'source', 'first pass'),
1507             diff(dst, newdst, 'first pass', 'second pass'),
1508         )
1509         raise AssertionError(
1510             f"INTERNAL ERROR: Black produced different code on the second pass "
1511             f"of the formatter.  "
1512             f"Please report a bug on https://github.com/ambv/black/issues.  "
1513             f"This diff might be helpful: {log}",
1514         ) from None
1515
1516
1517 def dump_to_file(*output: str) -> str:
1518     """Dumps `output` to a temporary file. Returns path to the file."""
1519     import tempfile
1520
1521     with tempfile.NamedTemporaryFile(
1522         mode='w', prefix='blk_', suffix='.log', delete=False
1523     ) as f:
1524         for lines in output:
1525             f.write(lines)
1526             f.write('\n')
1527     return f.name
1528
1529
1530 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1531     """Returns a udiff string between strings `a` and `b`."""
1532     import difflib
1533
1534     a_lines = [line + '\n' for line in a.split('\n')]
1535     b_lines = [line + '\n' for line in b.split('\n')]
1536     return ''.join(
1537         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1538     )
1539
1540
1541 if __name__ == '__main__':
1542     main()