]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

39f25d3eb73058a4b28cb14e48ede96266cc89fe
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import attrib, dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a1"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
78 )
79 @click.pass_context
80 def main(
81     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
82 ) -> None:
83     """The uncompromising code formatter."""
84     sources: List[Path] = []
85     for s in src:
86         p = Path(s)
87         if p.is_dir():
88             sources.extend(gen_python_files_in_dir(p))
89         elif p.is_file():
90             # if a file was explicitly given, we don't care about its extension
91             sources.append(p)
92         else:
93             err(f'invalid path: {s}')
94     if len(sources) == 0:
95         ctx.exit(0)
96     elif len(sources) == 1:
97         p = sources[0]
98         report = Report()
99         try:
100             changed = format_file_in_place(
101                 p, line_length=line_length, fast=fast, write_back=not check
102             )
103             report.done(p, changed)
104         except Exception as exc:
105             report.failed(p, str(exc))
106         ctx.exit(report.return_code)
107     else:
108         loop = asyncio.get_event_loop()
109         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
110         return_code = 1
111         try:
112             return_code = loop.run_until_complete(
113                 schedule_formatting(
114                     sources, line_length, not check, fast, loop, executor
115                 )
116             )
117         finally:
118             loop.close()
119             ctx.exit(return_code)
120
121
122 async def schedule_formatting(
123     sources: List[Path],
124     line_length: int,
125     write_back: bool,
126     fast: bool,
127     loop: BaseEventLoop,
128     executor: Executor,
129 ) -> int:
130     tasks = {
131         src: loop.run_in_executor(
132             executor, format_file_in_place, src, line_length, fast, write_back
133         )
134         for src in sources
135     }
136     await asyncio.wait(tasks.values())
137     cancelled = []
138     report = Report()
139     for src, task in tasks.items():
140         if not task.done():
141             report.failed(src, 'timed out, cancelling')
142             task.cancel()
143             cancelled.append(task)
144         elif task.exception():
145             report.failed(src, str(task.exception()))
146         else:
147             report.done(src, task.result())
148     if cancelled:
149         await asyncio.wait(cancelled, timeout=2)
150     out('All done! ✨ 🍰 ✨')
151     click.echo(str(report))
152     return report.return_code
153
154
155 def format_file_in_place(
156     src: Path, line_length: int, fast: bool, write_back: bool = False
157 ) -> bool:
158     """Format the file and rewrite if changed. Return True if changed."""
159     try:
160         contents, encoding = format_file(src, line_length=line_length, fast=fast)
161     except NothingChanged:
162         return False
163
164     if write_back:
165         with open(src, "w", encoding=encoding) as f:
166             f.write(contents)
167     return True
168
169
170 def format_file(
171     src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173     """Reformats a file and returns its contents and encoding."""
174     with tokenize.open(src) as src_buffer:
175         src_contents = src_buffer.read()
176     if src_contents.strip() == '':
177         raise NothingChanged(src)
178
179     dst_contents = format_str(src_contents, line_length=line_length)
180     if src_contents == dst_contents:
181         raise NothingChanged(src)
182
183     if not fast:
184         assert_equivalent(src_contents, dst_contents)
185         assert_stable(src_contents, dst_contents, line_length=line_length)
186     return dst_contents, src_buffer.encoding
187
188
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190     """Reformats a string and returns new contents."""
191     src_node = lib2to3_parse(src_contents)
192     dst_contents = ""
193     comments: List[Line] = []
194     lines = LineGenerator()
195     elt = EmptyLineTracker()
196     py36 = is_python36(src_node)
197     empty_line = Line()
198     after = 0
199     for current_line in lines.visit(src_node):
200         for _ in range(after):
201             dst_contents += str(empty_line)
202         before, after = elt.maybe_empty_lines(current_line)
203         for _ in range(before):
204             dst_contents += str(empty_line)
205         if not current_line.is_comment:
206             for comment in comments:
207                 dst_contents += str(comment)
208             comments = []
209             for line in split_line(current_line, line_length=line_length, py36=py36):
210                 dst_contents += str(line)
211         else:
212             comments.append(current_line)
213     for comment in comments:
214         dst_contents += str(comment)
215     return dst_contents
216
217
218 def lib2to3_parse(src_txt: str) -> Node:
219     """Given a string with source, return the lib2to3 Node."""
220     grammar = pygram.python_grammar_no_print_statement
221     drv = driver.Driver(grammar, pytree.convert)
222     if src_txt[-1] != '\n':
223         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
224         src_txt += nl
225     try:
226         result = drv.parse_string(src_txt, True)
227     except ParseError as pe:
228         lineno, column = pe.context[1]
229         lines = src_txt.splitlines()
230         try:
231             faulty_line = lines[lineno - 1]
232         except IndexError:
233             faulty_line = "<line number missing in source>"
234         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
235
236     if isinstance(result, Leaf):
237         result = Node(syms.file_input, [result])
238     return result
239
240
241 def lib2to3_unparse(node: Node) -> str:
242     """Given a lib2to3 node, return its string representation."""
243     code = str(node)
244     return code
245
246
247 T = TypeVar('T')
248
249
250 class Visitor(Generic[T]):
251     """Basic lib2to3 visitor that yields things on visiting."""
252
253     def visit(self, node: LN) -> Iterator[T]:
254         if node.type < 256:
255             name = token.tok_name[node.type]
256         else:
257             name = type_repr(node.type)
258         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
259
260     def visit_default(self, node: LN) -> Iterator[T]:
261         if isinstance(node, Node):
262             for child in node.children:
263                 yield from self.visit(child)
264
265
266 @dataclass
267 class DebugVisitor(Visitor[T]):
268     tree_depth: int = attrib(default=0)
269
270     def visit_default(self, node: LN) -> Iterator[T]:
271         indent = ' ' * (2 * self.tree_depth)
272         if isinstance(node, Node):
273             _type = type_repr(node.type)
274             out(f'{indent}{_type}', fg='yellow')
275             self.tree_depth += 1
276             for child in node.children:
277                 yield from self.visit(child)
278
279             self.tree_depth -= 1
280             out(f'{indent}/{_type}', fg='yellow', bold=False)
281         else:
282             _type = token.tok_name.get(node.type, str(node.type))
283             out(f'{indent}{_type}', fg='blue', nl=False)
284             if node.prefix:
285                 # We don't have to handle prefixes for `Node` objects since
286                 # that delegates to the first child anyway.
287                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
288             out(f' {node.value!r}', fg='blue', bold=False)
289
290
291 KEYWORDS = set(keyword.kwlist)
292 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
293 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
294 STATEMENT = {
295     syms.if_stmt,
296     syms.while_stmt,
297     syms.for_stmt,
298     syms.try_stmt,
299     syms.except_clause,
300     syms.with_stmt,
301     syms.funcdef,
302     syms.classdef,
303 }
304 STANDALONE_COMMENT = 153
305 LOGIC_OPERATORS = {'and', 'or'}
306 COMPARATORS = {
307     token.LESS,
308     token.GREATER,
309     token.EQEQUAL,
310     token.NOTEQUAL,
311     token.LESSEQUAL,
312     token.GREATEREQUAL,
313 }
314 MATH_OPERATORS = {
315     token.PLUS,
316     token.MINUS,
317     token.STAR,
318     token.SLASH,
319     token.VBAR,
320     token.AMPER,
321     token.PERCENT,
322     token.CIRCUMFLEX,
323     token.LEFTSHIFT,
324     token.RIGHTSHIFT,
325     token.DOUBLESTAR,
326     token.DOUBLESLASH,
327 }
328 COMPREHENSION_PRIORITY = 20
329 COMMA_PRIORITY = 10
330 LOGIC_PRIORITY = 5
331 STRING_PRIORITY = 4
332 COMPARATOR_PRIORITY = 3
333 MATH_PRIORITY = 1
334
335
336 @dataclass
337 class BracketTracker:
338     depth: int = attrib(default=0)
339     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
340     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
341     previous: Optional[Leaf] = attrib(default=None)
342
343     def mark(self, leaf: Leaf) -> None:
344         if leaf.type == token.COMMENT:
345             return
346
347         if leaf.type in CLOSING_BRACKETS:
348             self.depth -= 1
349             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
350             leaf.opening_bracket = opening_bracket
351         leaf.bracket_depth = self.depth
352         if self.depth == 0:
353             delim = is_delimiter(leaf)
354             if delim:
355                 self.delimiters[id(leaf)] = delim
356             elif self.previous is not None:
357                 if leaf.type == token.STRING and self.previous.type == token.STRING:
358                     self.delimiters[id(self.previous)] = STRING_PRIORITY
359                 elif (
360                     leaf.type == token.NAME and
361                     leaf.value == 'for' and
362                     leaf.parent and
363                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
364                 ):
365                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
366                 elif (
367                     leaf.type == token.NAME and
368                     leaf.value == 'if' and
369                     leaf.parent and
370                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
371                 ):
372                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
373         if leaf.type in OPENING_BRACKETS:
374             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
375             self.depth += 1
376         self.previous = leaf
377
378     def any_open_brackets(self) -> bool:
379         """Returns True if there is an yet unmatched open bracket on the line."""
380         return bool(self.bracket_match)
381
382     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
383         """Returns the highest priority of a delimiter found on the line.
384
385         Values are consistent with what `is_delimiter()` returns.
386         """
387         return max(v for k, v in self.delimiters.items() if k not in exclude)
388
389
390 @dataclass
391 class Line:
392     depth: int = attrib(default=0)
393     leaves: List[Leaf] = attrib(default=Factory(list))
394     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
395     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
396     inside_brackets: bool = attrib(default=False)
397     has_for: bool = attrib(default=False)
398     _for_loop_variable: bool = attrib(default=False, init=False)
399
400     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
401         has_value = leaf.value.strip()
402         if not has_value:
403             return
404
405         if self.leaves and not preformatted:
406             # Note: at this point leaf.prefix should be empty except for
407             # imports, for which we only preserve newlines.
408             leaf.prefix += whitespace(leaf)
409         if self.inside_brackets or not preformatted:
410             self.maybe_decrement_after_for_loop_variable(leaf)
411             self.bracket_tracker.mark(leaf)
412             self.maybe_remove_trailing_comma(leaf)
413             self.maybe_increment_for_loop_variable(leaf)
414             if self.maybe_adapt_standalone_comment(leaf):
415                 return
416
417         if not self.append_comment(leaf):
418             self.leaves.append(leaf)
419
420     @property
421     def is_comment(self) -> bool:
422         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
423
424     @property
425     def is_decorator(self) -> bool:
426         return bool(self) and self.leaves[0].type == token.AT
427
428     @property
429     def is_import(self) -> bool:
430         return bool(self) and is_import(self.leaves[0])
431
432     @property
433     def is_class(self) -> bool:
434         return (
435             bool(self) and
436             self.leaves[0].type == token.NAME and
437             self.leaves[0].value == 'class'
438         )
439
440     @property
441     def is_def(self) -> bool:
442         """Also returns True for async defs."""
443         try:
444             first_leaf = self.leaves[0]
445         except IndexError:
446             return False
447
448         try:
449             second_leaf: Optional[Leaf] = self.leaves[1]
450         except IndexError:
451             second_leaf = None
452         return (
453             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
454             (
455                 first_leaf.type == token.NAME and
456                 first_leaf.value == 'async' and
457                 second_leaf is not None and
458                 second_leaf.type == token.NAME and
459                 second_leaf.value == 'def'
460             )
461         )
462
463     @property
464     def is_flow_control(self) -> bool:
465         return (
466             bool(self) and
467             self.leaves[0].type == token.NAME and
468             self.leaves[0].value in FLOW_CONTROL
469         )
470
471     @property
472     def is_yield(self) -> bool:
473         return (
474             bool(self) and
475             self.leaves[0].type == token.NAME and
476             self.leaves[0].value == 'yield'
477         )
478
479     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
480         if not (
481             self.leaves and
482             self.leaves[-1].type == token.COMMA and
483             closing.type in CLOSING_BRACKETS
484         ):
485             return False
486
487         if closing.type == token.RSQB or closing.type == token.RBRACE:
488             self.leaves.pop()
489             return True
490
491         # For parens let's check if it's safe to remove the comma.  If the
492         # trailing one is the only one, we might mistakenly change a tuple
493         # into a different type by removing the comma.
494         depth = closing.bracket_depth + 1
495         commas = 0
496         opening = closing.opening_bracket
497         for _opening_index, leaf in enumerate(self.leaves):
498             if leaf is opening:
499                 break
500
501         else:
502             return False
503
504         for leaf in self.leaves[_opening_index + 1:]:
505             if leaf is closing:
506                 break
507
508             bracket_depth = leaf.bracket_depth
509             if bracket_depth == depth and leaf.type == token.COMMA:
510                 commas += 1
511         if commas > 1:
512             self.leaves.pop()
513             return True
514
515         return False
516
517     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
518         """In a for loop, or comprehension, the variables are often unpacks.
519
520         To avoid splitting on the comma in this situation, we will increase
521         the depth of tokens between `for` and `in`.
522         """
523         if leaf.type == token.NAME and leaf.value == 'for':
524             self.has_for = True
525             self.bracket_tracker.depth += 1
526             self._for_loop_variable = True
527             return True
528
529         return False
530
531     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
532         # See `maybe_increment_for_loop_variable` above for explanation.
533         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
534             self.bracket_tracker.depth -= 1
535             self._for_loop_variable = False
536             return True
537
538         return False
539
540     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
541         """Hack a standalone comment to act as a trailing comment for line splitting.
542
543         If this line has brackets and a standalone `comment`, we need to adapt
544         it to be able to still reformat the line.
545
546         This is not perfect, the line to which the standalone comment gets
547         appended will appear "too long" when splitting.
548         """
549         if not (
550             comment.type == STANDALONE_COMMENT and
551             self.bracket_tracker.any_open_brackets()
552         ):
553             return False
554
555         comment.type = token.COMMENT
556         comment.prefix = '\n' + '    ' * (self.depth + 1)
557         return self.append_comment(comment)
558
559     def append_comment(self, comment: Leaf) -> bool:
560         if comment.type != token.COMMENT:
561             return False
562
563         try:
564             after = id(self.last_non_delimiter())
565         except LookupError:
566             comment.type = STANDALONE_COMMENT
567             comment.prefix = ''
568             return False
569
570         else:
571             if after in self.comments:
572                 self.comments[after].value += str(comment)
573             else:
574                 self.comments[after] = comment
575             return True
576
577     def last_non_delimiter(self) -> Leaf:
578         for i in range(len(self.leaves)):
579             last = self.leaves[-i - 1]
580             if not is_delimiter(last):
581                 return last
582
583         raise LookupError("No non-delimiters found")
584
585     def __str__(self) -> str:
586         if not self:
587             return '\n'
588
589         indent = '    ' * self.depth
590         leaves = iter(self.leaves)
591         first = next(leaves)
592         res = f'{first.prefix}{indent}{first.value}'
593         for leaf in leaves:
594             res += str(leaf)
595         for comment in self.comments.values():
596             res += str(comment)
597         return res + '\n'
598
599     def __bool__(self) -> bool:
600         return bool(self.leaves or self.comments)
601
602
603 @dataclass
604 class EmptyLineTracker:
605     """Provides a stateful method that returns the number of potential extra
606     empty lines needed before and after the currently processed line.
607
608     Note: this tracker works on lines that haven't been split yet.
609     """
610     previous_line: Optional[Line] = attrib(default=None)
611     previous_after: int = attrib(default=0)
612     previous_defs: List[int] = attrib(default=Factory(list))
613
614     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
615         """Returns the number of extra empty lines before and after the `current_line`.
616
617         This is for separating `def`, `async def` and `class` with extra empty lines
618         (two on module-level), as well as providing an extra empty line after flow
619         control keywords to make them more prominent.
620         """
621         before, after = self._maybe_empty_lines(current_line)
622         self.previous_after = after
623         self.previous_line = current_line
624         return before, after
625
626     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
627         before = 0
628         depth = current_line.depth
629         while self.previous_defs and self.previous_defs[-1] >= depth:
630             self.previous_defs.pop()
631             before = (1 if depth else 2) - self.previous_after
632         is_decorator = current_line.is_decorator
633         if is_decorator or current_line.is_def or current_line.is_class:
634             if not is_decorator:
635                 self.previous_defs.append(depth)
636             if self.previous_line is None:
637                 # Don't insert empty lines before the first line in the file.
638                 return 0, 0
639
640             if self.previous_line and self.previous_line.is_decorator:
641                 # Don't insert empty lines between decorators.
642                 return 0, 0
643
644             newlines = 2
645             if current_line.depth:
646                 newlines -= 1
647             newlines -= self.previous_after
648             return newlines, 0
649
650         if current_line.is_flow_control:
651             return before, 1
652
653         if (
654             self.previous_line and
655             self.previous_line.is_import and
656             not current_line.is_import and
657             depth == self.previous_line.depth
658         ):
659             return (before or 1), 0
660
661         if (
662             self.previous_line and
663             self.previous_line.is_yield and
664             (not current_line.is_yield or depth != self.previous_line.depth)
665         ):
666             return (before or 1), 0
667
668         return before, 0
669
670
671 @dataclass
672 class LineGenerator(Visitor[Line]):
673     """Generates reformatted Line objects.  Empty lines are not emitted.
674
675     Note: destroys the tree it's visiting by mutating prefixes of its leaves
676     in ways that will no longer stringify to valid Python code on the tree.
677     """
678     current_line: Line = attrib(default=Factory(Line))
679     standalone_comments: List[Leaf] = attrib(default=Factory(list))
680
681     def line(self, indent: int = 0) -> Iterator[Line]:
682         """Generate a line.
683
684         If the line is empty, only emit if it makes sense.
685         If the line is too long, split it first and then generate.
686
687         If any lines were generated, set up a new current_line.
688         """
689         if not self.current_line:
690             self.current_line.depth += indent
691             return  # Line is empty, don't emit. Creating a new one unnecessary.
692
693         complete_line = self.current_line
694         self.current_line = Line(depth=complete_line.depth + indent)
695         yield complete_line
696
697     def visit_default(self, node: LN) -> Iterator[Line]:
698         if isinstance(node, Leaf):
699             for comment in generate_comments(node):
700                 if self.current_line.bracket_tracker.any_open_brackets():
701                     # any comment within brackets is subject to splitting
702                     self.current_line.append(comment)
703                 elif comment.type == token.COMMENT:
704                     # regular trailing comment
705                     self.current_line.append(comment)
706                     yield from self.line()
707
708                 else:
709                     # regular standalone comment, to be processed later (see
710                     # docstring in `generate_comments()`
711                     self.standalone_comments.append(comment)
712             normalize_prefix(node)
713             if node.type not in WHITESPACE:
714                 for comment in self.standalone_comments:
715                     yield from self.line()
716
717                     self.current_line.append(comment)
718                     yield from self.line()
719
720                 self.standalone_comments = []
721                 self.current_line.append(node)
722         yield from super().visit_default(node)
723
724     def visit_suite(self, node: Node) -> Iterator[Line]:
725         """Body of a statement after a colon."""
726         children = iter(node.children)
727         # Process newline before indenting.  It might contain an inline
728         # comment that should go right after the colon.
729         newline = next(children)
730         yield from self.visit(newline)
731         yield from self.line(+1)
732
733         for child in children:
734             yield from self.visit(child)
735
736         yield from self.line(-1)
737
738     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
739         """Visit a statement.
740
741         The relevant Python language keywords for this statement are NAME leaves
742         within it.
743         """
744         for child in node.children:
745             if child.type == token.NAME and child.value in keywords:  # type: ignore
746                 yield from self.line()
747
748             yield from self.visit(child)
749
750     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
751         """A statement without nested statements."""
752         is_suite_like = node.parent and node.parent.type in STATEMENT
753         if is_suite_like:
754             yield from self.line(+1)
755             yield from self.visit_default(node)
756             yield from self.line(-1)
757
758         else:
759             yield from self.line()
760             yield from self.visit_default(node)
761
762     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
763         yield from self.line()
764
765         children = iter(node.children)
766         for child in children:
767             yield from self.visit(child)
768
769             if child.type == token.NAME and child.value == 'async':  # type: ignore
770                 break
771
772         internal_stmt = next(children)
773         for child in internal_stmt.children:
774             yield from self.visit(child)
775
776     def visit_decorators(self, node: Node) -> Iterator[Line]:
777         for child in node.children:
778             yield from self.line()
779             yield from self.visit(child)
780
781     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
782         yield from self.line()
783
784     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
785         yield from self.visit_default(leaf)
786         yield from self.line()
787
788     def __attrs_post_init__(self) -> None:
789         """You are in a twisty little maze of passages."""
790         v = self.visit_stmt
791         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
792         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
793         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
794         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
795         self.visit_except_clause = partial(v, keywords={'except'})
796         self.visit_funcdef = partial(v, keywords={'def'})
797         self.visit_with_stmt = partial(v, keywords={'with'})
798         self.visit_classdef = partial(v, keywords={'class'})
799         self.visit_async_funcdef = self.visit_async_stmt
800         self.visit_decorated = self.visit_decorators
801
802
803 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
804 OPENING_BRACKETS = set(BRACKET.keys())
805 CLOSING_BRACKETS = set(BRACKET.values())
806 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
807 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
808
809
810 def whitespace(leaf: Leaf) -> str:
811     """Return whitespace prefix if needed for the given `leaf`."""
812     NO = ''
813     SPACE = ' '
814     DOUBLESPACE = '  '
815     t = leaf.type
816     p = leaf.parent
817     v = leaf.value
818     if t in ALWAYS_NO_SPACE:
819         return NO
820
821     if t == token.COMMENT:
822         return DOUBLESPACE
823
824     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
825     prev = leaf.prev_sibling
826     if not prev:
827         prevp = preceding_leaf(p)
828         if not prevp or prevp.type in OPENING_BRACKETS:
829             return NO
830
831         if prevp.type == token.EQUAL:
832             if prevp.parent and prevp.parent.type in {
833                 syms.typedargslist,
834                 syms.varargslist,
835                 syms.parameters,
836                 syms.arglist,
837                 syms.argument,
838             }:
839                 return NO
840
841         elif prevp.type == token.DOUBLESTAR:
842             if prevp.parent and prevp.parent.type in {
843                 syms.typedargslist,
844                 syms.varargslist,
845                 syms.parameters,
846                 syms.arglist,
847                 syms.dictsetmaker,
848             }:
849                 return NO
850
851         elif prevp.type == token.COLON:
852             if prevp.parent and prevp.parent.type == syms.subscript:
853                 return NO
854
855         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
856             return NO
857
858     elif prev.type in OPENING_BRACKETS:
859         return NO
860
861     if p.type in {syms.parameters, syms.arglist}:
862         # untyped function signatures or calls
863         if t == token.RPAR:
864             return NO
865
866         if not prev or prev.type != token.COMMA:
867             return NO
868
869     if p.type == syms.varargslist:
870         # lambdas
871         if t == token.RPAR:
872             return NO
873
874         if prev and prev.type != token.COMMA:
875             return NO
876
877     elif p.type == syms.typedargslist:
878         # typed function signatures
879         if not prev:
880             return NO
881
882         if t == token.EQUAL:
883             if prev.type != syms.tname:
884                 return NO
885
886         elif prev.type == token.EQUAL:
887             # A bit hacky: if the equal sign has whitespace, it means we
888             # previously found it's a typed argument.  So, we're using that, too.
889             return prev.prefix
890
891         elif prev.type != token.COMMA:
892             return NO
893
894     elif p.type == syms.tname:
895         # type names
896         if not prev:
897             prevp = preceding_leaf(p)
898             if not prevp or prevp.type != token.COMMA:
899                 return NO
900
901     elif p.type == syms.trailer:
902         # attributes and calls
903         if t == token.LPAR or t == token.RPAR:
904             return NO
905
906         if not prev:
907             if t == token.DOT:
908                 prevp = preceding_leaf(p)
909                 if not prevp or prevp.type != token.NUMBER:
910                     return NO
911
912             elif t == token.LSQB:
913                 return NO
914
915         elif prev.type != token.COMMA:
916             return NO
917
918     elif p.type == syms.argument:
919         # single argument
920         if t == token.EQUAL:
921             return NO
922
923         if not prev:
924             prevp = preceding_leaf(p)
925             if not prevp or prevp.type == token.LPAR:
926                 return NO
927
928         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
929             return NO
930
931     elif p.type == syms.decorator:
932         # decorators
933         return NO
934
935     elif p.type == syms.dotted_name:
936         if prev:
937             return NO
938
939         prevp = preceding_leaf(p)
940         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
941             return NO
942
943     elif p.type == syms.classdef:
944         if t == token.LPAR:
945             return NO
946
947         if prev and prev.type == token.LPAR:
948             return NO
949
950     elif p.type == syms.subscript:
951         # indexing
952         if not prev:
953             assert p.parent is not None, "subscripts are always parented"
954             if p.parent.type == syms.subscriptlist:
955                 return SPACE
956
957             return NO
958
959         elif prev.type == token.COLON:
960             return NO
961
962     elif p.type == syms.atom:
963         if prev and t == token.DOT:
964             # dots, but not the first one.
965             return NO
966
967     elif (
968         p.type == syms.listmaker or
969         p.type == syms.testlist_gexp or
970         p.type == syms.subscriptlist
971     ):
972         # list interior, including unpacking
973         if not prev:
974             return NO
975
976     elif p.type == syms.dictsetmaker:
977         # dict and set interior, including unpacking
978         if not prev:
979             return NO
980
981         if prev.type == token.DOUBLESTAR:
982             return NO
983
984     elif p.type in {syms.factor, syms.star_expr}:
985         # unary ops
986         if not prev:
987             prevp = preceding_leaf(p)
988             if not prevp or prevp.type in OPENING_BRACKETS:
989                 return NO
990
991             prevp_parent = prevp.parent
992             assert prevp_parent is not None
993             if prevp.type == token.COLON and prevp_parent.type in {
994                 syms.subscript, syms.sliceop
995             }:
996                 return NO
997
998             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
999                 return NO
1000
1001         elif t == token.NAME or t == token.NUMBER:
1002             return NO
1003
1004     elif p.type == syms.import_from:
1005         if t == token.DOT:
1006             if prev and prev.type == token.DOT:
1007                 return NO
1008
1009         elif t == token.NAME:
1010             if v == 'import':
1011                 return SPACE
1012
1013             if prev and prev.type == token.DOT:
1014                 return NO
1015
1016     elif p.type == syms.sliceop:
1017         return NO
1018
1019     return SPACE
1020
1021
1022 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1023     """Returns the first leaf that precedes `node`, if any."""
1024     while node:
1025         res = node.prev_sibling
1026         if res:
1027             if isinstance(res, Leaf):
1028                 return res
1029
1030             try:
1031                 return list(res.leaves())[-1]
1032
1033             except IndexError:
1034                 return None
1035
1036         node = node.parent
1037     return None
1038
1039
1040 def is_delimiter(leaf: Leaf) -> int:
1041     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1042
1043     Higher numbers are higher priority.
1044     """
1045     if leaf.type == token.COMMA:
1046         return COMMA_PRIORITY
1047
1048     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1049         return LOGIC_PRIORITY
1050
1051     if leaf.type in COMPARATORS:
1052         return COMPARATOR_PRIORITY
1053
1054     if (
1055         leaf.type in MATH_OPERATORS and
1056         leaf.parent and
1057         leaf.parent.type not in {syms.factor, syms.star_expr}
1058     ):
1059         return MATH_PRIORITY
1060
1061     return 0
1062
1063
1064 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1065     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1066
1067     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1068     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1069     move because it does away with modifying the grammar to include all the
1070     possible places in which comments can be placed.
1071
1072     The sad consequence for us though is that comments don't "belong" anywhere.
1073     This is why this function generates simple parentless Leaf objects for
1074     comments.  We simply don't know what the correct parent should be.
1075
1076     No matter though, we can live without this.  We really only need to
1077     differentiate between inline and standalone comments.  The latter don't
1078     share the line with any code.
1079
1080     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1081     are emitted with a fake STANDALONE_COMMENT token identifier.
1082     """
1083     if not leaf.prefix:
1084         return
1085
1086     if '#' not in leaf.prefix:
1087         return
1088
1089     before_comment, content = leaf.prefix.split('#', 1)
1090     content = content.rstrip()
1091     if content and (content[0] not in {' ', '!', '#'}):
1092         content = ' ' + content
1093     is_standalone_comment = (
1094         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1095     )
1096     if not is_standalone_comment:
1097         # simple trailing comment
1098         yield Leaf(token.COMMENT, value='#' + content)
1099         return
1100
1101     for line in ('#' + content).split('\n'):
1102         line = line.lstrip()
1103         if not line.startswith('#'):
1104             continue
1105
1106         yield Leaf(STANDALONE_COMMENT, line)
1107
1108
1109 def split_line(
1110     line: Line, line_length: int, inner: bool = False, py36: bool = False
1111 ) -> Iterator[Line]:
1112     """Splits a `line` into potentially many lines.
1113
1114     They should fit in the allotted `line_length` but might not be able to.
1115     `inner` signifies that there were a pair of brackets somewhere around the
1116     current `line`, possibly transitively. This means we can fallback to splitting
1117     by delimiters if the LHS/RHS don't yield any results.
1118
1119     If `py36` is True, splitting may generate syntax that is only compatible
1120     with Python 3.6 and later.
1121     """
1122     line_str = str(line).strip('\n')
1123     if len(line_str) <= line_length and '\n' not in line_str:
1124         yield line
1125         return
1126
1127     if line.is_def:
1128         split_funcs = [left_hand_split]
1129     elif line.inside_brackets:
1130         split_funcs = [delimiter_split]
1131         if '\n' not in line_str:
1132             # Only attempt RHS if we don't have multiline strings or comments
1133             # on this line.
1134             split_funcs.append(right_hand_split)
1135     else:
1136         split_funcs = [right_hand_split]
1137     for split_func in split_funcs:
1138         # We are accumulating lines in `result` because we might want to abort
1139         # mission and return the original line in the end, or attempt a different
1140         # split altogether.
1141         result: List[Line] = []
1142         try:
1143             for l in split_func(line, py36=py36):
1144                 if str(l).strip('\n') == line_str:
1145                     raise CannotSplit("Split function returned an unchanged result")
1146
1147                 result.extend(
1148                     split_line(l, line_length=line_length, inner=True, py36=py36)
1149                 )
1150         except CannotSplit as cs:
1151             continue
1152
1153         else:
1154             yield from result
1155             break
1156
1157     else:
1158         yield line
1159
1160
1161 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1162     """Split line into many lines, starting with the first matching bracket pair.
1163
1164     Note: this usually looks weird, only use this for function definitions.
1165     Prefer RHS otherwise.
1166     """
1167     head = Line(depth=line.depth)
1168     body = Line(depth=line.depth + 1, inside_brackets=True)
1169     tail = Line(depth=line.depth)
1170     tail_leaves: List[Leaf] = []
1171     body_leaves: List[Leaf] = []
1172     head_leaves: List[Leaf] = []
1173     current_leaves = head_leaves
1174     matching_bracket = None
1175     for leaf in line.leaves:
1176         if (
1177             current_leaves is body_leaves and
1178             leaf.type in CLOSING_BRACKETS and
1179             leaf.opening_bracket is matching_bracket
1180         ):
1181             current_leaves = tail_leaves
1182         current_leaves.append(leaf)
1183         if current_leaves is head_leaves:
1184             if leaf.type in OPENING_BRACKETS:
1185                 matching_bracket = leaf
1186                 current_leaves = body_leaves
1187     # Since body is a new indent level, remove spurious leading whitespace.
1188     if body_leaves:
1189         normalize_prefix(body_leaves[0])
1190     # Build the new lines.
1191     for result, leaves in (
1192         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1193     ):
1194         for leaf in leaves:
1195             result.append(leaf, preformatted=True)
1196             comment_after = line.comments.get(id(leaf))
1197             if comment_after:
1198                 result.append(comment_after, preformatted=True)
1199     # Check if the split succeeded.
1200     tail_len = len(str(tail))
1201     if not body:
1202         if tail_len == 0:
1203             raise CannotSplit("Splitting brackets produced the same line")
1204
1205         elif tail_len < 3:
1206             raise CannotSplit(
1207                 f"Splitting brackets on an empty body to save "
1208                 f"{tail_len} characters is not worth it"
1209             )
1210
1211     for result in (head, body, tail):
1212         if result:
1213             yield result
1214
1215
1216 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1217     """Split line into many lines, starting with the last matching bracket pair."""
1218     head = Line(depth=line.depth)
1219     body = Line(depth=line.depth + 1, inside_brackets=True)
1220     tail = Line(depth=line.depth)
1221     tail_leaves: List[Leaf] = []
1222     body_leaves: List[Leaf] = []
1223     head_leaves: List[Leaf] = []
1224     current_leaves = tail_leaves
1225     opening_bracket = None
1226     for leaf in reversed(line.leaves):
1227         if current_leaves is body_leaves:
1228             if leaf is opening_bracket:
1229                 current_leaves = head_leaves
1230         current_leaves.append(leaf)
1231         if current_leaves is tail_leaves:
1232             if leaf.type in CLOSING_BRACKETS:
1233                 opening_bracket = leaf.opening_bracket
1234                 current_leaves = body_leaves
1235     tail_leaves.reverse()
1236     body_leaves.reverse()
1237     head_leaves.reverse()
1238     # Since body is a new indent level, remove spurious leading whitespace.
1239     if body_leaves:
1240         normalize_prefix(body_leaves[0])
1241     # Build the new lines.
1242     for result, leaves in (
1243         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1244     ):
1245         for leaf in leaves:
1246             result.append(leaf, preformatted=True)
1247             comment_after = line.comments.get(id(leaf))
1248             if comment_after:
1249                 result.append(comment_after, preformatted=True)
1250     # Check if the split succeeded.
1251     tail_len = len(str(tail).strip('\n'))
1252     if not body:
1253         if tail_len == 0:
1254             raise CannotSplit("Splitting brackets produced the same line")
1255
1256         elif tail_len < 3:
1257             raise CannotSplit(
1258                 f"Splitting brackets on an empty body to save "
1259                 f"{tail_len} characters is not worth it"
1260             )
1261
1262     for result in (head, body, tail):
1263         if result:
1264             yield result
1265
1266
1267 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1268     """Split according to delimiters of the highest priority.
1269
1270     This kind of split doesn't increase indentation.
1271     If `py36` is True, the split will add trailing commas also in function
1272     signatures that contain * and **.
1273     """
1274     try:
1275         last_leaf = line.leaves[-1]
1276     except IndexError:
1277         raise CannotSplit("Line empty")
1278
1279     delimiters = line.bracket_tracker.delimiters
1280     try:
1281         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1282     except ValueError:
1283         raise CannotSplit("No delimiters found")
1284
1285     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1286     lowest_depth = sys.maxsize
1287     trailing_comma_safe = True
1288     for leaf in line.leaves:
1289         current_line.append(leaf, preformatted=True)
1290         comment_after = line.comments.get(id(leaf))
1291         if comment_after:
1292             current_line.append(comment_after, preformatted=True)
1293         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1294         if (
1295             leaf.bracket_depth == lowest_depth and
1296             leaf.type == token.STAR or
1297             leaf.type == token.DOUBLESTAR
1298         ):
1299             trailing_comma_safe = trailing_comma_safe and py36
1300         leaf_priority = delimiters.get(id(leaf))
1301         if leaf_priority == delimiter_priority:
1302             normalize_prefix(current_line.leaves[0])
1303             yield current_line
1304
1305             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1306     if current_line:
1307         if (
1308             delimiter_priority == COMMA_PRIORITY and
1309             current_line.leaves[-1].type != token.COMMA and
1310             trailing_comma_safe
1311         ):
1312             current_line.append(Leaf(token.COMMA, ','))
1313         normalize_prefix(current_line.leaves[0])
1314         yield current_line
1315
1316
1317 def is_import(leaf: Leaf) -> bool:
1318     """Returns True if the given leaf starts an import statement."""
1319     p = leaf.parent
1320     t = leaf.type
1321     v = leaf.value
1322     return bool(
1323         t == token.NAME and
1324         (
1325             (v == 'import' and p and p.type == syms.import_name) or
1326             (v == 'from' and p and p.type == syms.import_from)
1327         )
1328     )
1329
1330
1331 def normalize_prefix(leaf: Leaf) -> None:
1332     """Leave existing extra newlines for imports.  Remove everything else."""
1333     if is_import(leaf):
1334         spl = leaf.prefix.split('#', 1)
1335         nl_count = spl[0].count('\n')
1336         if len(spl) > 1:
1337             # Skip one newline since it was for a standalone comment.
1338             nl_count -= 1
1339         leaf.prefix = '\n' * nl_count
1340         return
1341
1342     leaf.prefix = ''
1343
1344
1345 def is_python36(node: Node) -> bool:
1346     """Returns True if the current file is using Python 3.6+ features.
1347
1348     Currently looking for:
1349     - f-strings; and
1350     - trailing commas after * or ** in function signatures.
1351     """
1352     for n in node.pre_order():
1353         if n.type == token.STRING:
1354             value_head = n.value[:2]  # type: ignore
1355             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1356                 return True
1357
1358         elif (
1359             n.type == syms.typedargslist and
1360             n.children and
1361             n.children[-1].type == token.COMMA
1362         ):
1363             for ch in n.children:
1364                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1365                     return True
1366
1367     return False
1368
1369
1370 PYTHON_EXTENSIONS = {'.py'}
1371 BLACKLISTED_DIRECTORIES = {
1372     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1373 }
1374
1375
1376 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1377     for child in path.iterdir():
1378         if child.is_dir():
1379             if child.name in BLACKLISTED_DIRECTORIES:
1380                 continue
1381
1382             yield from gen_python_files_in_dir(child)
1383
1384         elif child.suffix in PYTHON_EXTENSIONS:
1385             yield child
1386
1387
1388 @dataclass
1389 class Report:
1390     """Provides a reformatting counter."""
1391     change_count: int = attrib(default=0)
1392     same_count: int = attrib(default=0)
1393     failure_count: int = attrib(default=0)
1394
1395     def done(self, src: Path, changed: bool) -> None:
1396         """Increment the counter for successful reformatting. Write out a message."""
1397         if changed:
1398             out(f'reformatted {src}')
1399             self.change_count += 1
1400         else:
1401             out(f'{src} already well formatted, good job.', bold=False)
1402             self.same_count += 1
1403
1404     def failed(self, src: Path, message: str) -> None:
1405         """Increment the counter for failed reformatting. Write out a message."""
1406         err(f'error: cannot format {src}: {message}')
1407         self.failure_count += 1
1408
1409     @property
1410     def return_code(self) -> int:
1411         """Which return code should the app use considering the current state."""
1412         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1413         # 126 we have special returncodes reserved by the shell.
1414         if self.failure_count:
1415             return 123
1416
1417         elif self.change_count:
1418             return 1
1419
1420         return 0
1421
1422     def __str__(self) -> str:
1423         """A color report of the current state.
1424
1425         Use `click.unstyle` to remove colors.
1426         """
1427         report = []
1428         if self.change_count:
1429             s = 's' if self.change_count > 1 else ''
1430             report.append(
1431                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1432             )
1433         if self.same_count:
1434             s = 's' if self.same_count > 1 else ''
1435             report.append(f'{self.same_count} file{s} left unchanged')
1436         if self.failure_count:
1437             s = 's' if self.failure_count > 1 else ''
1438             report.append(
1439                 click.style(
1440                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1441                 )
1442             )
1443         return ', '.join(report) + '.'
1444
1445
1446 def assert_equivalent(src: str, dst: str) -> None:
1447     """Raises AssertionError if `src` and `dst` aren't equivalent.
1448
1449     This is a temporary sanity check until Black becomes stable.
1450     """
1451
1452     import ast
1453     import traceback
1454
1455     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1456         """Simple visitor generating strings to compare ASTs by content."""
1457         yield f"{'  ' * depth}{node.__class__.__name__}("
1458
1459         for field in sorted(node._fields):
1460             try:
1461                 value = getattr(node, field)
1462             except AttributeError:
1463                 continue
1464
1465             yield f"{'  ' * (depth+1)}{field}="
1466
1467             if isinstance(value, list):
1468                 for item in value:
1469                     if isinstance(item, ast.AST):
1470                         yield from _v(item, depth + 2)
1471
1472             elif isinstance(value, ast.AST):
1473                 yield from _v(value, depth + 2)
1474
1475             else:
1476                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1477
1478         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1479
1480     try:
1481         src_ast = ast.parse(src)
1482     except Exception as exc:
1483         raise AssertionError(f"cannot parse source: {exc}") from None
1484
1485     try:
1486         dst_ast = ast.parse(dst)
1487     except Exception as exc:
1488         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1489         raise AssertionError(
1490             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1491             f"Please report a bug on https://github.com/ambv/black/issues.  "
1492             f"This invalid output might be helpful: {log}",
1493         ) from None
1494
1495     src_ast_str = '\n'.join(_v(src_ast))
1496     dst_ast_str = '\n'.join(_v(dst_ast))
1497     if src_ast_str != dst_ast_str:
1498         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1499         raise AssertionError(
1500             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1501             f"the source.  "
1502             f"Please report a bug on https://github.com/ambv/black/issues.  "
1503             f"This diff might be helpful: {log}",
1504         ) from None
1505
1506
1507 def assert_stable(src: str, dst: str, line_length: int) -> None:
1508     """Raises AssertionError if `dst` reformats differently the second time.
1509
1510     This is a temporary sanity check until Black becomes stable.
1511     """
1512     newdst = format_str(dst, line_length=line_length)
1513     if dst != newdst:
1514         log = dump_to_file(
1515             diff(src, dst, 'source', 'first pass'),
1516             diff(dst, newdst, 'first pass', 'second pass'),
1517         )
1518         raise AssertionError(
1519             f"INTERNAL ERROR: Black produced different code on the second pass "
1520             f"of the formatter.  "
1521             f"Please report a bug on https://github.com/ambv/black/issues.  "
1522             f"This diff might be helpful: {log}",
1523         ) from None
1524
1525
1526 def dump_to_file(*output: str) -> str:
1527     """Dumps `output` to a temporary file. Returns path to the file."""
1528     import tempfile
1529
1530     with tempfile.NamedTemporaryFile(
1531         mode='w', prefix='blk_', suffix='.log', delete=False
1532     ) as f:
1533         for lines in output:
1534             f.write(lines)
1535             f.write('\n')
1536     return f.name
1537
1538
1539 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1540     """Returns a udiff string between strings `a` and `b`."""
1541     import difflib
1542
1543     a_lines = [line + '\n' for line in a.split('\n')]
1544     b_lines = [line + '\n' for line in b.split('\n')]
1545     return ''.join(
1546         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1547     )
1548
1549
1550 if __name__ == '__main__':
1551     main()