]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

d011483747546de1b3248dff21a88b52f9ad4649
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import attrib, dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a1"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
78 )
79 @click.pass_context
80 def main(
81     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
82 ) -> None:
83     """The uncompromising code formatter."""
84     sources: List[Path] = []
85     for s in src:
86         p = Path(s)
87         if p.is_dir():
88             sources.extend(gen_python_files_in_dir(p))
89         elif p.is_file():
90             # if a file was explicitly given, we don't care about its extension
91             sources.append(p)
92         else:
93             err(f'invalid path: {s}')
94     if len(sources) == 0:
95         ctx.exit(0)
96     elif len(sources) == 1:
97         p = sources[0]
98         report = Report()
99         try:
100             changed = format_file_in_place(
101                 p, line_length=line_length, fast=fast, write_back=not check
102             )
103             report.done(p, changed)
104         except Exception as exc:
105             report.failed(p, str(exc))
106         ctx.exit(report.return_code)
107     else:
108         loop = asyncio.get_event_loop()
109         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
110         return_code = 1
111         try:
112             return_code = loop.run_until_complete(
113                 schedule_formatting(
114                     sources, line_length, not check, fast, loop, executor
115                 )
116             )
117         finally:
118             loop.close()
119             ctx.exit(return_code)
120
121
122 async def schedule_formatting(
123     sources: List[Path],
124     line_length: int,
125     write_back: bool,
126     fast: bool,
127     loop: BaseEventLoop,
128     executor: Executor,
129 ) -> int:
130     tasks = {
131         src: loop.run_in_executor(
132             executor, format_file_in_place, src, line_length, fast, write_back
133         )
134         for src in sources
135     }
136     await asyncio.wait(tasks.values())
137     cancelled = []
138     report = Report()
139     for src, task in tasks.items():
140         if not task.done():
141             report.failed(src, 'timed out, cancelling')
142             task.cancel()
143             cancelled.append(task)
144         elif task.exception():
145             report.failed(src, str(task.exception()))
146         else:
147             report.done(src, task.result())
148     if cancelled:
149         await asyncio.wait(cancelled, timeout=2)
150     out('All done! ✨ 🍰 ✨')
151     click.echo(str(report))
152     return report.return_code
153
154
155 def format_file_in_place(
156     src: Path, line_length: int, fast: bool, write_back: bool = False
157 ) -> bool:
158     """Format the file and rewrite if changed. Return True if changed."""
159     try:
160         contents, encoding = format_file(src, line_length=line_length, fast=fast)
161     except NothingChanged:
162         return False
163
164     if write_back:
165         with open(src, "w", encoding=encoding) as f:
166             f.write(contents)
167     return True
168
169
170 def format_file(
171     src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173     """Reformats a file and returns its contents and encoding."""
174     with tokenize.open(src) as src_buffer:
175         src_contents = src_buffer.read()
176     if src_contents.strip() == '':
177         raise NothingChanged(src)
178
179     dst_contents = format_str(src_contents, line_length=line_length)
180     if src_contents == dst_contents:
181         raise NothingChanged(src)
182
183     if not fast:
184         assert_equivalent(src_contents, dst_contents)
185         assert_stable(src_contents, dst_contents, line_length=line_length)
186     return dst_contents, src_buffer.encoding
187
188
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190     """Reformats a string and returns new contents."""
191     src_node = lib2to3_parse(src_contents)
192     dst_contents = ""
193     comments: List[Line] = []
194     lines = LineGenerator()
195     elt = EmptyLineTracker()
196     py36 = is_python36(src_node)
197     empty_line = Line()
198     after = 0
199     for current_line in lines.visit(src_node):
200         for _ in range(after):
201             dst_contents += str(empty_line)
202         before, after = elt.maybe_empty_lines(current_line)
203         for _ in range(before):
204             dst_contents += str(empty_line)
205         if not current_line.is_comment:
206             for comment in comments:
207                 dst_contents += str(comment)
208             comments = []
209             for line in split_line(current_line, line_length=line_length, py36=py36):
210                 dst_contents += str(line)
211         else:
212             comments.append(current_line)
213     for comment in comments:
214         dst_contents += str(comment)
215     return dst_contents
216
217
218 def lib2to3_parse(src_txt: str) -> Node:
219     """Given a string with source, return the lib2to3 Node."""
220     grammar = pygram.python_grammar_no_print_statement
221     drv = driver.Driver(grammar, pytree.convert)
222     if src_txt[-1] != '\n':
223         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
224         src_txt += nl
225     try:
226         result = drv.parse_string(src_txt, True)
227     except ParseError as pe:
228         lineno, column = pe.context[1]
229         lines = src_txt.splitlines()
230         try:
231             faulty_line = lines[lineno - 1]
232         except IndexError:
233             faulty_line = "<line number missing in source>"
234         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
235
236     if isinstance(result, Leaf):
237         result = Node(syms.file_input, [result])
238     return result
239
240
241 def lib2to3_unparse(node: Node) -> str:
242     """Given a lib2to3 node, return its string representation."""
243     code = str(node)
244     return code
245
246
247 T = TypeVar('T')
248
249
250 class Visitor(Generic[T]):
251     """Basic lib2to3 visitor that yields things on visiting."""
252
253     def visit(self, node: LN) -> Iterator[T]:
254         if node.type < 256:
255             name = token.tok_name[node.type]
256         else:
257             name = type_repr(node.type)
258         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
259
260     def visit_default(self, node: LN) -> Iterator[T]:
261         if isinstance(node, Node):
262             for child in node.children:
263                 yield from self.visit(child)
264
265
266 @dataclass
267 class DebugVisitor(Visitor[T]):
268     tree_depth: int = attrib(default=0)
269
270     def visit_default(self, node: LN) -> Iterator[T]:
271         indent = ' ' * (2 * self.tree_depth)
272         if isinstance(node, Node):
273             _type = type_repr(node.type)
274             out(f'{indent}{_type}', fg='yellow')
275             self.tree_depth += 1
276             for child in node.children:
277                 yield from self.visit(child)
278
279             self.tree_depth -= 1
280             out(f'{indent}/{_type}', fg='yellow', bold=False)
281         else:
282             _type = token.tok_name.get(node.type, str(node.type))
283             out(f'{indent}{_type}', fg='blue', nl=False)
284             if node.prefix:
285                 # We don't have to handle prefixes for `Node` objects since
286                 # that delegates to the first child anyway.
287                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
288             out(f' {node.value!r}', fg='blue', bold=False)
289
290
291 KEYWORDS = set(keyword.kwlist)
292 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
293 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
294 STATEMENT = {
295     syms.if_stmt,
296     syms.while_stmt,
297     syms.for_stmt,
298     syms.try_stmt,
299     syms.except_clause,
300     syms.with_stmt,
301     syms.funcdef,
302     syms.classdef,
303 }
304 STANDALONE_COMMENT = 153
305 LOGIC_OPERATORS = {'and', 'or'}
306 COMPARATORS = {
307     token.LESS,
308     token.GREATER,
309     token.EQEQUAL,
310     token.NOTEQUAL,
311     token.LESSEQUAL,
312     token.GREATEREQUAL,
313 }
314 MATH_OPERATORS = {
315     token.PLUS,
316     token.MINUS,
317     token.STAR,
318     token.SLASH,
319     token.VBAR,
320     token.AMPER,
321     token.PERCENT,
322     token.CIRCUMFLEX,
323     token.LEFTSHIFT,
324     token.RIGHTSHIFT,
325     token.DOUBLESTAR,
326     token.DOUBLESLASH,
327 }
328 COMPREHENSION_PRIORITY = 20
329 COMMA_PRIORITY = 10
330 LOGIC_PRIORITY = 5
331 STRING_PRIORITY = 4
332 COMPARATOR_PRIORITY = 3
333 MATH_PRIORITY = 1
334
335
336 @dataclass
337 class BracketTracker:
338     depth: int = attrib(default=0)
339     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
340     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
341     previous: Optional[Leaf] = attrib(default=None)
342
343     def mark(self, leaf: Leaf) -> None:
344         if leaf.type == token.COMMENT:
345             return
346
347         if leaf.type in CLOSING_BRACKETS:
348             self.depth -= 1
349             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
350             leaf.opening_bracket = opening_bracket
351         leaf.bracket_depth = self.depth
352         if self.depth == 0:
353             delim = is_delimiter(leaf)
354             if delim:
355                 self.delimiters[id(leaf)] = delim
356             elif self.previous is not None:
357                 if leaf.type == token.STRING and self.previous.type == token.STRING:
358                     self.delimiters[id(self.previous)] = STRING_PRIORITY
359                 elif (
360                     leaf.type == token.NAME and
361                     leaf.value == 'for' and
362                     leaf.parent and
363                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
364                 ):
365                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
366                 elif (
367                     leaf.type == token.NAME and
368                     leaf.value == 'if' and
369                     leaf.parent and
370                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
371                 ):
372                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
373         if leaf.type in OPENING_BRACKETS:
374             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
375             self.depth += 1
376         self.previous = leaf
377
378     def any_open_brackets(self) -> bool:
379         """Returns True if there is an yet unmatched open bracket on the line."""
380         return bool(self.bracket_match)
381
382     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
383         """Returns the highest priority of a delimiter found on the line.
384
385         Values are consistent with what `is_delimiter()` returns.
386         """
387         return max(v for k, v in self.delimiters.items() if k not in exclude)
388
389
390 @dataclass
391 class Line:
392     depth: int = attrib(default=0)
393     leaves: List[Leaf] = attrib(default=Factory(list))
394     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
395     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
396     inside_brackets: bool = attrib(default=False)
397     has_for: bool = attrib(default=False)
398     _for_loop_variable: bool = attrib(default=False, init=False)
399
400     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
401         has_value = leaf.value.strip()
402         if not has_value:
403             return
404
405         if self.leaves and not preformatted:
406             # Note: at this point leaf.prefix should be empty except for
407             # imports, for which we only preserve newlines.
408             leaf.prefix += whitespace(leaf)
409         if self.inside_brackets or not preformatted:
410             self.maybe_decrement_after_for_loop_variable(leaf)
411             self.bracket_tracker.mark(leaf)
412             self.maybe_remove_trailing_comma(leaf)
413             self.maybe_increment_for_loop_variable(leaf)
414             if self.maybe_adapt_standalone_comment(leaf):
415                 return
416
417         if not self.append_comment(leaf):
418             self.leaves.append(leaf)
419
420     @property
421     def is_comment(self) -> bool:
422         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
423
424     @property
425     def is_decorator(self) -> bool:
426         return bool(self) and self.leaves[0].type == token.AT
427
428     @property
429     def is_import(self) -> bool:
430         return bool(self) and is_import(self.leaves[0])
431
432     @property
433     def is_class(self) -> bool:
434         return (
435             bool(self) and
436             self.leaves[0].type == token.NAME and
437             self.leaves[0].value == 'class'
438         )
439
440     @property
441     def is_def(self) -> bool:
442         """Also returns True for async defs."""
443         try:
444             first_leaf = self.leaves[0]
445         except IndexError:
446             return False
447
448         try:
449             second_leaf: Optional[Leaf] = self.leaves[1]
450         except IndexError:
451             second_leaf = None
452         return (
453             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
454             (
455                 first_leaf.type == token.NAME and
456                 first_leaf.value == 'async' and
457                 second_leaf is not None and
458                 second_leaf.type == token.NAME and
459                 second_leaf.value == 'def'
460             )
461         )
462
463     @property
464     def is_flow_control(self) -> bool:
465         return (
466             bool(self) and
467             self.leaves[0].type == token.NAME and
468             self.leaves[0].value in FLOW_CONTROL
469         )
470
471     @property
472     def is_yield(self) -> bool:
473         return (
474             bool(self) and
475             self.leaves[0].type == token.NAME and
476             self.leaves[0].value == 'yield'
477         )
478
479     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
480         if not (
481             self.leaves and
482             self.leaves[-1].type == token.COMMA and
483             closing.type in CLOSING_BRACKETS
484         ):
485             return False
486
487         if closing.type == token.RSQB or closing.type == token.RBRACE:
488             self.leaves.pop()
489             return True
490
491         # For parens let's check if it's safe to remove the comma.  If the
492         # trailing one is the only one, we might mistakenly change a tuple
493         # into a different type by removing the comma.
494         depth = closing.bracket_depth + 1
495         commas = 0
496         opening = closing.opening_bracket
497         for _opening_index, leaf in enumerate(self.leaves):
498             if leaf is opening:
499                 break
500
501         else:
502             return False
503
504         for leaf in self.leaves[_opening_index + 1:]:
505             if leaf is closing:
506                 break
507
508             bracket_depth = leaf.bracket_depth
509             if bracket_depth == depth and leaf.type == token.COMMA:
510                 commas += 1
511         if commas > 1:
512             self.leaves.pop()
513             return True
514
515         return False
516
517     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
518         """In a for loop, or comprehension, the variables are often unpacks.
519
520         To avoid splitting on the comma in this situation, we will increase
521         the depth of tokens between `for` and `in`.
522         """
523         if leaf.type == token.NAME and leaf.value == 'for':
524             self.has_for = True
525             self.bracket_tracker.depth += 1
526             self._for_loop_variable = True
527             return True
528
529         return False
530
531     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
532         # See `maybe_increment_for_loop_variable` above for explanation.
533         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
534             self.bracket_tracker.depth -= 1
535             self._for_loop_variable = False
536             return True
537
538         return False
539
540     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
541         """Hack a standalone comment to act as a trailing comment for line splitting.
542
543         If this line has brackets and a standalone `comment`, we need to adapt
544         it to be able to still reformat the line.
545
546         This is not perfect, the line to which the standalone comment gets
547         appended will appear "too long" when splitting.
548         """
549         if not (
550             comment.type == STANDALONE_COMMENT and
551             self.bracket_tracker.any_open_brackets()
552         ):
553             return False
554
555         comment.type = token.COMMENT
556         comment.prefix = '\n' + '    ' * (self.depth + 1)
557         return self.append_comment(comment)
558
559     def append_comment(self, comment: Leaf) -> bool:
560         if comment.type != token.COMMENT:
561             return False
562
563         try:
564             after = id(self.last_non_delimiter())
565         except LookupError:
566             comment.type = STANDALONE_COMMENT
567             comment.prefix = ''
568             return False
569
570         else:
571             if after in self.comments:
572                 self.comments[after].value += str(comment)
573             else:
574                 self.comments[after] = comment
575             return True
576
577     def last_non_delimiter(self) -> Leaf:
578         for i in range(len(self.leaves)):
579             last = self.leaves[-i - 1]
580             if not is_delimiter(last):
581                 return last
582
583         raise LookupError("No non-delimiters found")
584
585     def __str__(self) -> str:
586         if not self:
587             return '\n'
588
589         indent = '    ' * self.depth
590         leaves = iter(self.leaves)
591         first = next(leaves)
592         res = f'{first.prefix}{indent}{first.value}'
593         for leaf in leaves:
594             res += str(leaf)
595         for comment in self.comments.values():
596             res += str(comment)
597         return res + '\n'
598
599     def __bool__(self) -> bool:
600         return bool(self.leaves or self.comments)
601
602
603 @dataclass
604 class EmptyLineTracker:
605     """Provides a stateful method that returns the number of potential extra
606     empty lines needed before and after the currently processed line.
607
608     Note: this tracker works on lines that haven't been split yet.
609     """
610     previous_line: Optional[Line] = attrib(default=None)
611     previous_after: int = attrib(default=0)
612     previous_defs: List[int] = attrib(default=Factory(list))
613
614     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
615         """Returns the number of extra empty lines before and after the `current_line`.
616
617         This is for separating `def`, `async def` and `class` with extra empty lines
618         (two on module-level), as well as providing an extra empty line after flow
619         control keywords to make them more prominent.
620         """
621         before, after = self._maybe_empty_lines(current_line)
622         self.previous_after = after
623         self.previous_line = current_line
624         return before, after
625
626     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
627         before = 0
628         depth = current_line.depth
629         while self.previous_defs and self.previous_defs[-1] >= depth:
630             self.previous_defs.pop()
631             before = (1 if depth else 2) - self.previous_after
632         is_decorator = current_line.is_decorator
633         if is_decorator or current_line.is_def or current_line.is_class:
634             if not is_decorator:
635                 self.previous_defs.append(depth)
636             if self.previous_line is None:
637                 # Don't insert empty lines before the first line in the file.
638                 return 0, 0
639
640             if self.previous_line and self.previous_line.is_decorator:
641                 # Don't insert empty lines between decorators.
642                 return 0, 0
643
644             newlines = 2
645             if current_line.depth:
646                 newlines -= 1
647             newlines -= self.previous_after
648             return newlines, 0
649
650         if current_line.is_flow_control:
651             return before, 1
652
653         if (
654             self.previous_line and
655             self.previous_line.is_import and
656             not current_line.is_import and
657             depth == self.previous_line.depth
658         ):
659             return (before or 1), 0
660
661         if (
662             self.previous_line and
663             self.previous_line.is_yield and
664             (not current_line.is_yield or depth != self.previous_line.depth)
665         ):
666             return (before or 1), 0
667
668         return before, 0
669
670
671 @dataclass
672 class LineGenerator(Visitor[Line]):
673     """Generates reformatted Line objects.  Empty lines are not emitted.
674
675     Note: destroys the tree it's visiting by mutating prefixes of its leaves
676     in ways that will no longer stringify to valid Python code on the tree.
677     """
678     current_line: Line = attrib(default=Factory(Line))
679     standalone_comments: List[Leaf] = attrib(default=Factory(list))
680
681     def line(self, indent: int = 0) -> Iterator[Line]:
682         """Generate a line.
683
684         If the line is empty, only emit if it makes sense.
685         If the line is too long, split it first and then generate.
686
687         If any lines were generated, set up a new current_line.
688         """
689         if not self.current_line:
690             self.current_line.depth += indent
691             return  # Line is empty, don't emit. Creating a new one unnecessary.
692
693         complete_line = self.current_line
694         self.current_line = Line(depth=complete_line.depth + indent)
695         yield complete_line
696
697     def visit_default(self, node: LN) -> Iterator[Line]:
698         if isinstance(node, Leaf):
699             for comment in generate_comments(node):
700                 if self.current_line.bracket_tracker.any_open_brackets():
701                     # any comment within brackets is subject to splitting
702                     self.current_line.append(comment)
703                 elif comment.type == token.COMMENT:
704                     # regular trailing comment
705                     self.current_line.append(comment)
706                     yield from self.line()
707
708                 else:
709                     # regular standalone comment, to be processed later (see
710                     # docstring in `generate_comments()`
711                     self.standalone_comments.append(comment)
712             normalize_prefix(node)
713             if node.type not in WHITESPACE:
714                 for comment in self.standalone_comments:
715                     yield from self.line()
716
717                     self.current_line.append(comment)
718                     yield from self.line()
719
720                 self.standalone_comments = []
721                 self.current_line.append(node)
722         yield from super().visit_default(node)
723
724     def visit_suite(self, node: Node) -> Iterator[Line]:
725         """Body of a statement after a colon."""
726         children = iter(node.children)
727         # Process newline before indenting.  It might contain an inline
728         # comment that should go right after the colon.
729         newline = next(children)
730         yield from self.visit(newline)
731         yield from self.line(+1)
732
733         for child in children:
734             yield from self.visit(child)
735
736         yield from self.line(-1)
737
738     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
739         """Visit a statement.
740
741         The relevant Python language keywords for this statement are NAME leaves
742         within it.
743         """
744         for child in node.children:
745             if child.type == token.NAME and child.value in keywords:  # type: ignore
746                 yield from self.line()
747
748             yield from self.visit(child)
749
750     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
751         """A statement without nested statements."""
752         is_suite_like = node.parent and node.parent.type in STATEMENT
753         if is_suite_like:
754             yield from self.line(+1)
755             yield from self.visit_default(node)
756             yield from self.line(-1)
757
758         else:
759             yield from self.line()
760             yield from self.visit_default(node)
761
762     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
763         yield from self.line()
764
765         children = iter(node.children)
766         for child in children:
767             yield from self.visit(child)
768
769             if child.type == token.NAME and child.value == 'async':  # type: ignore
770                 break
771
772         internal_stmt = next(children)
773         for child in internal_stmt.children:
774             yield from self.visit(child)
775
776     def visit_decorators(self, node: Node) -> Iterator[Line]:
777         for child in node.children:
778             yield from self.line()
779             yield from self.visit(child)
780
781     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
782         yield from self.line()
783
784     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
785         yield from self.visit_default(leaf)
786         yield from self.line()
787
788     def __attrs_post_init__(self) -> None:
789         """You are in a twisty little maze of passages."""
790         v = self.visit_stmt
791         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
792         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
793         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
794         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
795         self.visit_except_clause = partial(v, keywords={'except'})
796         self.visit_funcdef = partial(v, keywords={'def'})
797         self.visit_with_stmt = partial(v, keywords={'with'})
798         self.visit_classdef = partial(v, keywords={'class'})
799         self.visit_async_funcdef = self.visit_async_stmt
800         self.visit_decorated = self.visit_decorators
801
802
803 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
804 OPENING_BRACKETS = set(BRACKET.keys())
805 CLOSING_BRACKETS = set(BRACKET.values())
806 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
807
808
809 def whitespace(leaf: Leaf) -> str:
810     """Return whitespace prefix if needed for the given `leaf`."""
811     NO = ''
812     SPACE = ' '
813     DOUBLESPACE = '  '
814     t = leaf.type
815     p = leaf.parent
816     v = leaf.value
817     if t == token.COLON:
818         return NO
819
820     if t == token.COMMA:
821         return NO
822
823     if t == token.RPAR:
824         return NO
825
826     if t == token.COMMENT:
827         return DOUBLESPACE
828
829     if t == STANDALONE_COMMENT:
830         return NO
831
832     if t in CLOSING_BRACKETS:
833         return NO
834
835     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
836     prev = leaf.prev_sibling
837     if not prev:
838         prevp = preceding_leaf(p)
839         if not prevp or prevp.type in OPENING_BRACKETS:
840             return NO
841
842         if prevp.type == token.EQUAL:
843             if prevp.parent and prevp.parent.type in {
844                 syms.typedargslist,
845                 syms.varargslist,
846                 syms.parameters,
847                 syms.arglist,
848                 syms.argument,
849             }:
850                 return NO
851
852         elif prevp.type == token.DOUBLESTAR:
853             if prevp.parent and prevp.parent.type in {
854                 syms.typedargslist,
855                 syms.varargslist,
856                 syms.parameters,
857                 syms.arglist,
858                 syms.dictsetmaker,
859             }:
860                 return NO
861
862         elif prevp.type == token.COLON:
863             if prevp.parent and prevp.parent.type == syms.subscript:
864                 return NO
865
866         elif prevp.parent and prevp.parent.type == syms.factor:
867             return NO
868
869     elif prev.type in OPENING_BRACKETS:
870         return NO
871
872     if p.type in {syms.parameters, syms.arglist}:
873         # untyped function signatures or calls
874         if t == token.RPAR:
875             return NO
876
877         if not prev or prev.type != token.COMMA:
878             return NO
879
880     if p.type == syms.varargslist:
881         # lambdas
882         if t == token.RPAR:
883             return NO
884
885         if prev and prev.type != token.COMMA:
886             return NO
887
888     elif p.type == syms.typedargslist:
889         # typed function signatures
890         if not prev:
891             return NO
892
893         if t == token.EQUAL:
894             if prev.type != syms.tname:
895                 return NO
896
897         elif prev.type == token.EQUAL:
898             # A bit hacky: if the equal sign has whitespace, it means we
899             # previously found it's a typed argument.  So, we're using that, too.
900             return prev.prefix
901
902         elif prev.type != token.COMMA:
903             return NO
904
905     elif p.type == syms.tname:
906         # type names
907         if not prev:
908             prevp = preceding_leaf(p)
909             if not prevp or prevp.type != token.COMMA:
910                 return NO
911
912     elif p.type == syms.trailer:
913         # attributes and calls
914         if t == token.LPAR or t == token.RPAR:
915             return NO
916
917         if not prev:
918             if t == token.DOT:
919                 prevp = preceding_leaf(p)
920                 if not prevp or prevp.type != token.NUMBER:
921                     return NO
922
923             elif t == token.LSQB:
924                 return NO
925
926         elif prev.type != token.COMMA:
927             return NO
928
929     elif p.type == syms.argument:
930         # single argument
931         if t == token.EQUAL:
932             return NO
933
934         if not prev:
935             prevp = preceding_leaf(p)
936             if not prevp or prevp.type == token.LPAR:
937                 return NO
938
939         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
940             return NO
941
942     elif p.type == syms.decorator:
943         # decorators
944         return NO
945
946     elif p.type == syms.dotted_name:
947         if prev:
948             return NO
949
950         prevp = preceding_leaf(p)
951         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
952             return NO
953
954     elif p.type == syms.classdef:
955         if t == token.LPAR:
956             return NO
957
958         if prev and prev.type == token.LPAR:
959             return NO
960
961     elif p.type == syms.subscript:
962         # indexing
963         if not prev:
964             assert p.parent is not None, "subscripts are always parented"
965             if p.parent.type == syms.subscriptlist:
966                 return SPACE
967
968             return NO
969
970         elif prev.type == token.COLON:
971             return NO
972
973     elif p.type == syms.atom:
974         if prev and t == token.DOT:
975             # dots, but not the first one.
976             return NO
977
978     elif (
979         p.type == syms.listmaker or
980         p.type == syms.testlist_gexp or
981         p.type == syms.subscriptlist
982     ):
983         # list interior, including unpacking
984         if not prev:
985             return NO
986
987     elif p.type == syms.dictsetmaker:
988         # dict and set interior, including unpacking
989         if not prev:
990             return NO
991
992         if prev.type == token.DOUBLESTAR:
993             return NO
994
995     elif p.type == syms.factor or p.type == syms.star_expr:
996         # unary ops
997         if not prev:
998             prevp = preceding_leaf(p)
999             if not prevp or prevp.type in OPENING_BRACKETS:
1000                 return NO
1001
1002             prevp_parent = prevp.parent
1003             assert prevp_parent is not None
1004             if prevp.type == token.COLON and prevp_parent.type in {
1005                 syms.subscript, syms.sliceop
1006             }:
1007                 return NO
1008
1009             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1010                 return NO
1011
1012         elif t == token.NAME or t == token.NUMBER:
1013             return NO
1014
1015     elif p.type == syms.import_from:
1016         if t == token.DOT:
1017             if prev and prev.type == token.DOT:
1018                 return NO
1019
1020         elif t == token.NAME:
1021             if v == 'import':
1022                 return SPACE
1023
1024             if prev and prev.type == token.DOT:
1025                 return NO
1026
1027     elif p.type == syms.sliceop:
1028         return NO
1029
1030     return SPACE
1031
1032
1033 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1034     """Returns the first leaf that precedes `node`, if any."""
1035     while node:
1036         res = node.prev_sibling
1037         if res:
1038             if isinstance(res, Leaf):
1039                 return res
1040
1041             try:
1042                 return list(res.leaves())[-1]
1043
1044             except IndexError:
1045                 return None
1046
1047         node = node.parent
1048     return None
1049
1050
1051 def is_delimiter(leaf: Leaf) -> int:
1052     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1053
1054     Higher numbers are higher priority.
1055     """
1056     if leaf.type == token.COMMA:
1057         return COMMA_PRIORITY
1058
1059     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1060         return LOGIC_PRIORITY
1061
1062     if leaf.type in COMPARATORS:
1063         return COMPARATOR_PRIORITY
1064
1065     if (
1066         leaf.type in MATH_OPERATORS and
1067         leaf.parent and
1068         leaf.parent.type not in {syms.factor, syms.star_expr}
1069     ):
1070         return MATH_PRIORITY
1071
1072     return 0
1073
1074
1075 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1076     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1077
1078     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1079     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1080     move because it does away with modifying the grammar to include all the
1081     possible places in which comments can be placed.
1082
1083     The sad consequence for us though is that comments don't "belong" anywhere.
1084     This is why this function generates simple parentless Leaf objects for
1085     comments.  We simply don't know what the correct parent should be.
1086
1087     No matter though, we can live without this.  We really only need to
1088     differentiate between inline and standalone comments.  The latter don't
1089     share the line with any code.
1090
1091     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1092     are emitted with a fake STANDALONE_COMMENT token identifier.
1093     """
1094     if not leaf.prefix:
1095         return
1096
1097     if '#' not in leaf.prefix:
1098         return
1099
1100     before_comment, content = leaf.prefix.split('#', 1)
1101     content = content.rstrip()
1102     if content and (content[0] not in {' ', '!', '#'}):
1103         content = ' ' + content
1104     is_standalone_comment = (
1105         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1106     )
1107     if not is_standalone_comment:
1108         # simple trailing comment
1109         yield Leaf(token.COMMENT, value='#' + content)
1110         return
1111
1112     for line in ('#' + content).split('\n'):
1113         line = line.lstrip()
1114         if not line.startswith('#'):
1115             continue
1116
1117         yield Leaf(STANDALONE_COMMENT, line)
1118
1119
1120 def split_line(
1121     line: Line, line_length: int, inner: bool = False, py36: bool = False
1122 ) -> Iterator[Line]:
1123     """Splits a `line` into potentially many lines.
1124
1125     They should fit in the allotted `line_length` but might not be able to.
1126     `inner` signifies that there were a pair of brackets somewhere around the
1127     current `line`, possibly transitively. This means we can fallback to splitting
1128     by delimiters if the LHS/RHS don't yield any results.
1129
1130     If `py36` is True, splitting may generate syntax that is only compatible
1131     with Python 3.6 and later.
1132     """
1133     line_str = str(line).strip('\n')
1134     if len(line_str) <= line_length and '\n' not in line_str:
1135         yield line
1136         return
1137
1138     if line.is_def:
1139         split_funcs = [left_hand_split]
1140     elif line.inside_brackets:
1141         split_funcs = [delimiter_split]
1142         if '\n' not in line_str:
1143             # Only attempt RHS if we don't have multiline strings or comments
1144             # on this line.
1145             split_funcs.append(right_hand_split)
1146     else:
1147         split_funcs = [right_hand_split]
1148     for split_func in split_funcs:
1149         # We are accumulating lines in `result` because we might want to abort
1150         # mission and return the original line in the end, or attempt a different
1151         # split altogether.
1152         result: List[Line] = []
1153         try:
1154             for l in split_func(line, py36=py36):
1155                 if str(l).strip('\n') == line_str:
1156                     raise CannotSplit("Split function returned an unchanged result")
1157
1158                 result.extend(
1159                     split_line(l, line_length=line_length, inner=True, py36=py36)
1160                 )
1161         except CannotSplit as cs:
1162             continue
1163
1164         else:
1165             yield from result
1166             break
1167
1168     else:
1169         yield line
1170
1171
1172 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1173     """Split line into many lines, starting with the first matching bracket pair.
1174
1175     Note: this usually looks weird, only use this for function definitions.
1176     Prefer RHS otherwise.
1177     """
1178     head = Line(depth=line.depth)
1179     body = Line(depth=line.depth + 1, inside_brackets=True)
1180     tail = Line(depth=line.depth)
1181     tail_leaves: List[Leaf] = []
1182     body_leaves: List[Leaf] = []
1183     head_leaves: List[Leaf] = []
1184     current_leaves = head_leaves
1185     matching_bracket = None
1186     for leaf in line.leaves:
1187         if (
1188             current_leaves is body_leaves and
1189             leaf.type in CLOSING_BRACKETS and
1190             leaf.opening_bracket is matching_bracket
1191         ):
1192             current_leaves = tail_leaves
1193         current_leaves.append(leaf)
1194         if current_leaves is head_leaves:
1195             if leaf.type in OPENING_BRACKETS:
1196                 matching_bracket = leaf
1197                 current_leaves = body_leaves
1198     # Since body is a new indent level, remove spurious leading whitespace.
1199     if body_leaves:
1200         normalize_prefix(body_leaves[0])
1201     # Build the new lines.
1202     for result, leaves in (
1203         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1204     ):
1205         for leaf in leaves:
1206             result.append(leaf, preformatted=True)
1207             comment_after = line.comments.get(id(leaf))
1208             if comment_after:
1209                 result.append(comment_after, preformatted=True)
1210     # Check if the split succeeded.
1211     tail_len = len(str(tail))
1212     if not body:
1213         if tail_len == 0:
1214             raise CannotSplit("Splitting brackets produced the same line")
1215
1216         elif tail_len < 3:
1217             raise CannotSplit(
1218                 f"Splitting brackets on an empty body to save "
1219                 f"{tail_len} characters is not worth it"
1220             )
1221
1222     for result in (head, body, tail):
1223         if result:
1224             yield result
1225
1226
1227 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1228     """Split line into many lines, starting with the last matching bracket pair."""
1229     head = Line(depth=line.depth)
1230     body = Line(depth=line.depth + 1, inside_brackets=True)
1231     tail = Line(depth=line.depth)
1232     tail_leaves: List[Leaf] = []
1233     body_leaves: List[Leaf] = []
1234     head_leaves: List[Leaf] = []
1235     current_leaves = tail_leaves
1236     opening_bracket = None
1237     for leaf in reversed(line.leaves):
1238         if current_leaves is body_leaves:
1239             if leaf is opening_bracket:
1240                 current_leaves = head_leaves
1241         current_leaves.append(leaf)
1242         if current_leaves is tail_leaves:
1243             if leaf.type in CLOSING_BRACKETS:
1244                 opening_bracket = leaf.opening_bracket
1245                 current_leaves = body_leaves
1246     tail_leaves.reverse()
1247     body_leaves.reverse()
1248     head_leaves.reverse()
1249     # Since body is a new indent level, remove spurious leading whitespace.
1250     if body_leaves:
1251         normalize_prefix(body_leaves[0])
1252     # Build the new lines.
1253     for result, leaves in (
1254         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1255     ):
1256         for leaf in leaves:
1257             result.append(leaf, preformatted=True)
1258             comment_after = line.comments.get(id(leaf))
1259             if comment_after:
1260                 result.append(comment_after, preformatted=True)
1261     # Check if the split succeeded.
1262     tail_len = len(str(tail).strip('\n'))
1263     if not body:
1264         if tail_len == 0:
1265             raise CannotSplit("Splitting brackets produced the same line")
1266
1267         elif tail_len < 3:
1268             raise CannotSplit(
1269                 f"Splitting brackets on an empty body to save "
1270                 f"{tail_len} characters is not worth it"
1271             )
1272
1273     for result in (head, body, tail):
1274         if result:
1275             yield result
1276
1277
1278 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1279     """Split according to delimiters of the highest priority.
1280
1281     This kind of split doesn't increase indentation.
1282     If `py36` is True, the split will add trailing commas also in function
1283     signatures that contain * and **.
1284     """
1285     try:
1286         last_leaf = line.leaves[-1]
1287     except IndexError:
1288         raise CannotSplit("Line empty")
1289
1290     delimiters = line.bracket_tracker.delimiters
1291     try:
1292         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1293     except ValueError:
1294         raise CannotSplit("No delimiters found")
1295
1296     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1297     lowest_depth = sys.maxsize
1298     trailing_comma_safe = True
1299     for leaf in line.leaves:
1300         current_line.append(leaf, preformatted=True)
1301         comment_after = line.comments.get(id(leaf))
1302         if comment_after:
1303             current_line.append(comment_after, preformatted=True)
1304         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1305         if (
1306             leaf.bracket_depth == lowest_depth and
1307             leaf.type == token.STAR or
1308             leaf.type == token.DOUBLESTAR
1309         ):
1310             trailing_comma_safe = trailing_comma_safe and py36
1311         leaf_priority = delimiters.get(id(leaf))
1312         if leaf_priority == delimiter_priority:
1313             normalize_prefix(current_line.leaves[0])
1314             yield current_line
1315
1316             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1317     if current_line:
1318         if (
1319             delimiter_priority == COMMA_PRIORITY and
1320             current_line.leaves[-1].type != token.COMMA and
1321             trailing_comma_safe
1322         ):
1323             current_line.append(Leaf(token.COMMA, ','))
1324         normalize_prefix(current_line.leaves[0])
1325         yield current_line
1326
1327
1328 def is_import(leaf: Leaf) -> bool:
1329     """Returns True if the given leaf starts an import statement."""
1330     p = leaf.parent
1331     t = leaf.type
1332     v = leaf.value
1333     return bool(
1334         t == token.NAME and
1335         (
1336             (v == 'import' and p and p.type == syms.import_name) or
1337             (v == 'from' and p and p.type == syms.import_from)
1338         )
1339     )
1340
1341
1342 def normalize_prefix(leaf: Leaf) -> None:
1343     """Leave existing extra newlines for imports.  Remove everything else."""
1344     if is_import(leaf):
1345         spl = leaf.prefix.split('#', 1)
1346         nl_count = spl[0].count('\n')
1347         if len(spl) > 1:
1348             # Skip one newline since it was for a standalone comment.
1349             nl_count -= 1
1350         leaf.prefix = '\n' * nl_count
1351         return
1352
1353     leaf.prefix = ''
1354
1355
1356 def is_python36(node: Node) -> bool:
1357     """Returns True if the current file is using Python 3.6+ features.
1358
1359     Currently looking for:
1360     - f-strings; and
1361     - trailing commas after * or ** in function signatures.
1362     """
1363     for n in node.pre_order():
1364         if n.type == token.STRING:
1365             value_head = n.value[:2]  # type: ignore
1366             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1367                 return True
1368
1369         elif (
1370             n.type == syms.typedargslist and
1371             n.children and
1372             n.children[-1].type == token.COMMA
1373         ):
1374             for ch in n.children:
1375                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1376                     return True
1377
1378     return False
1379
1380
1381 PYTHON_EXTENSIONS = {'.py'}
1382 BLACKLISTED_DIRECTORIES = {
1383     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1384 }
1385
1386
1387 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1388     for child in path.iterdir():
1389         if child.is_dir():
1390             if child.name in BLACKLISTED_DIRECTORIES:
1391                 continue
1392
1393             yield from gen_python_files_in_dir(child)
1394
1395         elif child.suffix in PYTHON_EXTENSIONS:
1396             yield child
1397
1398
1399 @dataclass
1400 class Report:
1401     """Provides a reformatting counter."""
1402     change_count: int = attrib(default=0)
1403     same_count: int = attrib(default=0)
1404     failure_count: int = attrib(default=0)
1405
1406     def done(self, src: Path, changed: bool) -> None:
1407         """Increment the counter for successful reformatting. Write out a message."""
1408         if changed:
1409             out(f'reformatted {src}')
1410             self.change_count += 1
1411         else:
1412             out(f'{src} already well formatted, good job.', bold=False)
1413             self.same_count += 1
1414
1415     def failed(self, src: Path, message: str) -> None:
1416         """Increment the counter for failed reformatting. Write out a message."""
1417         err(f'error: cannot format {src}: {message}')
1418         self.failure_count += 1
1419
1420     @property
1421     def return_code(self) -> int:
1422         """Which return code should the app use considering the current state."""
1423         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1424         # 126 we have special returncodes reserved by the shell.
1425         if self.failure_count:
1426             return 123
1427
1428         elif self.change_count:
1429             return 1
1430
1431         return 0
1432
1433     def __str__(self) -> str:
1434         """A color report of the current state.
1435
1436         Use `click.unstyle` to remove colors.
1437         """
1438         report = []
1439         if self.change_count:
1440             s = 's' if self.change_count > 1 else ''
1441             report.append(
1442                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1443             )
1444         if self.same_count:
1445             s = 's' if self.same_count > 1 else ''
1446             report.append(f'{self.same_count} file{s} left unchanged')
1447         if self.failure_count:
1448             s = 's' if self.failure_count > 1 else ''
1449             report.append(
1450                 click.style(
1451                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1452                 )
1453             )
1454         return ', '.join(report) + '.'
1455
1456
1457 def assert_equivalent(src: str, dst: str) -> None:
1458     """Raises AssertionError if `src` and `dst` aren't equivalent.
1459
1460     This is a temporary sanity check until Black becomes stable.
1461     """
1462
1463     import ast
1464     import traceback
1465
1466     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1467         """Simple visitor generating strings to compare ASTs by content."""
1468         yield f"{'  ' * depth}{node.__class__.__name__}("
1469
1470         for field in sorted(node._fields):
1471             try:
1472                 value = getattr(node, field)
1473             except AttributeError:
1474                 continue
1475
1476             yield f"{'  ' * (depth+1)}{field}="
1477
1478             if isinstance(value, list):
1479                 for item in value:
1480                     if isinstance(item, ast.AST):
1481                         yield from _v(item, depth + 2)
1482
1483             elif isinstance(value, ast.AST):
1484                 yield from _v(value, depth + 2)
1485
1486             else:
1487                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1488
1489         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1490
1491     try:
1492         src_ast = ast.parse(src)
1493     except Exception as exc:
1494         raise AssertionError(f"cannot parse source: {exc}") from None
1495
1496     try:
1497         dst_ast = ast.parse(dst)
1498     except Exception as exc:
1499         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1500         raise AssertionError(
1501             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1502             f"Please report a bug on https://github.com/ambv/black/issues.  "
1503             f"This invalid output might be helpful: {log}",
1504         ) from None
1505
1506     src_ast_str = '\n'.join(_v(src_ast))
1507     dst_ast_str = '\n'.join(_v(dst_ast))
1508     if src_ast_str != dst_ast_str:
1509         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1510         raise AssertionError(
1511             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1512             f"the source.  "
1513             f"Please report a bug on https://github.com/ambv/black/issues.  "
1514             f"This diff might be helpful: {log}",
1515         ) from None
1516
1517
1518 def assert_stable(src: str, dst: str, line_length: int) -> None:
1519     """Raises AssertionError if `dst` reformats differently the second time.
1520
1521     This is a temporary sanity check until Black becomes stable.
1522     """
1523     newdst = format_str(dst, line_length=line_length)
1524     if dst != newdst:
1525         log = dump_to_file(
1526             diff(src, dst, 'source', 'first pass'),
1527             diff(dst, newdst, 'first pass', 'second pass'),
1528         )
1529         raise AssertionError(
1530             f"INTERNAL ERROR: Black produced different code on the second pass "
1531             f"of the formatter.  "
1532             f"Please report a bug on https://github.com/ambv/black/issues.  "
1533             f"This diff might be helpful: {log}",
1534         ) from None
1535
1536
1537 def dump_to_file(*output: str) -> str:
1538     """Dumps `output` to a temporary file. Returns path to the file."""
1539     import tempfile
1540
1541     with tempfile.NamedTemporaryFile(
1542         mode='w', prefix='blk_', suffix='.log', delete=False
1543     ) as f:
1544         for lines in output:
1545             f.write(lines)
1546             f.write('\n')
1547     return f.name
1548
1549
1550 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1551     """Returns a udiff string between strings `a` and `b`."""
1552     import difflib
1553
1554     a_lines = [line + '\n' for line in a.split('\n')]
1555     b_lines = [line + '\n' for line in b.split('\n')]
1556     return ''.join(
1557         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1558     )
1559
1560
1561 if __name__ == '__main__':
1562     main()