]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

8d4b0952d63549537e85846b551c4d9c6afac0d1
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
78 )
79 @click.pass_context
80 def main(
81     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
82 ) -> None:
83     """The uncompromising code formatter."""
84     sources: List[Path] = []
85     for s in src:
86         p = Path(s)
87         if p.is_dir():
88             sources.extend(gen_python_files_in_dir(p))
89         elif p.is_file():
90             # if a file was explicitly given, we don't care about its extension
91             sources.append(p)
92         else:
93             err(f'invalid path: {s}')
94     if len(sources) == 0:
95         ctx.exit(0)
96     elif len(sources) == 1:
97         p = sources[0]
98         report = Report()
99         try:
100             changed = format_file_in_place(
101                 p, line_length=line_length, fast=fast, write_back=not check
102             )
103             report.done(p, changed)
104         except Exception as exc:
105             report.failed(p, str(exc))
106         ctx.exit(report.return_code)
107     else:
108         loop = asyncio.get_event_loop()
109         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
110         return_code = 1
111         try:
112             return_code = loop.run_until_complete(
113                 schedule_formatting(
114                     sources, line_length, not check, fast, loop, executor
115                 )
116             )
117         finally:
118             loop.close()
119             ctx.exit(return_code)
120
121
122 async def schedule_formatting(
123     sources: List[Path],
124     line_length: int,
125     write_back: bool,
126     fast: bool,
127     loop: BaseEventLoop,
128     executor: Executor,
129 ) -> int:
130     tasks = {
131         src: loop.run_in_executor(
132             executor, format_file_in_place, src, line_length, fast, write_back
133         )
134         for src in sources
135     }
136     await asyncio.wait(tasks.values())
137     cancelled = []
138     report = Report()
139     for src, task in tasks.items():
140         if not task.done():
141             report.failed(src, 'timed out, cancelling')
142             task.cancel()
143             cancelled.append(task)
144         elif task.exception():
145             report.failed(src, str(task.exception()))
146         else:
147             report.done(src, task.result())
148     if cancelled:
149         await asyncio.wait(cancelled, timeout=2)
150     out('All done! ✨ 🍰 ✨')
151     click.echo(str(report))
152     return report.return_code
153
154
155 def format_file_in_place(
156     src: Path, line_length: int, fast: bool, write_back: bool = False
157 ) -> bool:
158     """Format the file and rewrite if changed. Return True if changed."""
159     try:
160         contents, encoding = format_file(src, line_length=line_length, fast=fast)
161     except NothingChanged:
162         return False
163
164     if write_back:
165         with open(src, "w", encoding=encoding) as f:
166             f.write(contents)
167     return True
168
169
170 def format_file(
171     src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173     """Reformats a file and returns its contents and encoding."""
174     with tokenize.open(src) as src_buffer:
175         src_contents = src_buffer.read()
176     if src_contents.strip() == '':
177         raise NothingChanged(src)
178
179     dst_contents = format_str(src_contents, line_length=line_length)
180     if src_contents == dst_contents:
181         raise NothingChanged(src)
182
183     if not fast:
184         assert_equivalent(src_contents, dst_contents)
185         assert_stable(src_contents, dst_contents, line_length=line_length)
186     return dst_contents, src_buffer.encoding
187
188
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190     """Reformats a string and returns new contents."""
191     src_node = lib2to3_parse(src_contents)
192     dst_contents = ""
193     comments: List[Line] = []
194     lines = LineGenerator()
195     elt = EmptyLineTracker()
196     py36 = is_python36(src_node)
197     empty_line = Line()
198     after = 0
199     for current_line in lines.visit(src_node):
200         for _ in range(after):
201             dst_contents += str(empty_line)
202         before, after = elt.maybe_empty_lines(current_line)
203         for _ in range(before):
204             dst_contents += str(empty_line)
205         if not current_line.is_comment:
206             for comment in comments:
207                 dst_contents += str(comment)
208             comments = []
209             for line in split_line(current_line, line_length=line_length, py36=py36):
210                 dst_contents += str(line)
211         else:
212             comments.append(current_line)
213     for comment in comments:
214         dst_contents += str(comment)
215     return dst_contents
216
217
218 def lib2to3_parse(src_txt: str) -> Node:
219     """Given a string with source, return the lib2to3 Node."""
220     grammar = pygram.python_grammar_no_print_statement
221     drv = driver.Driver(grammar, pytree.convert)
222     if src_txt[-1] != '\n':
223         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
224         src_txt += nl
225     try:
226         result = drv.parse_string(src_txt, True)
227     except ParseError as pe:
228         lineno, column = pe.context[1]
229         lines = src_txt.splitlines()
230         try:
231             faulty_line = lines[lineno - 1]
232         except IndexError:
233             faulty_line = "<line number missing in source>"
234         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
235
236     if isinstance(result, Leaf):
237         result = Node(syms.file_input, [result])
238     return result
239
240
241 def lib2to3_unparse(node: Node) -> str:
242     """Given a lib2to3 node, return its string representation."""
243     code = str(node)
244     return code
245
246
247 T = TypeVar('T')
248
249
250 class Visitor(Generic[T]):
251     """Basic lib2to3 visitor that yields things on visiting."""
252
253     def visit(self, node: LN) -> Iterator[T]:
254         if node.type < 256:
255             name = token.tok_name[node.type]
256         else:
257             name = type_repr(node.type)
258         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
259
260     def visit_default(self, node: LN) -> Iterator[T]:
261         if isinstance(node, Node):
262             for child in node.children:
263                 yield from self.visit(child)
264
265
266 @dataclass
267 class DebugVisitor(Visitor[T]):
268     tree_depth: int = 0
269
270     def visit_default(self, node: LN) -> Iterator[T]:
271         indent = ' ' * (2 * self.tree_depth)
272         if isinstance(node, Node):
273             _type = type_repr(node.type)
274             out(f'{indent}{_type}', fg='yellow')
275             self.tree_depth += 1
276             for child in node.children:
277                 yield from self.visit(child)
278
279             self.tree_depth -= 1
280             out(f'{indent}/{_type}', fg='yellow', bold=False)
281         else:
282             _type = token.tok_name.get(node.type, str(node.type))
283             out(f'{indent}{_type}', fg='blue', nl=False)
284             if node.prefix:
285                 # We don't have to handle prefixes for `Node` objects since
286                 # that delegates to the first child anyway.
287                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
288             out(f' {node.value!r}', fg='blue', bold=False)
289
290
291 KEYWORDS = set(keyword.kwlist)
292 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
293 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
294 STATEMENT = {
295     syms.if_stmt,
296     syms.while_stmt,
297     syms.for_stmt,
298     syms.try_stmt,
299     syms.except_clause,
300     syms.with_stmt,
301     syms.funcdef,
302     syms.classdef,
303 }
304 STANDALONE_COMMENT = 153
305 LOGIC_OPERATORS = {'and', 'or'}
306 COMPARATORS = {
307     token.LESS,
308     token.GREATER,
309     token.EQEQUAL,
310     token.NOTEQUAL,
311     token.LESSEQUAL,
312     token.GREATEREQUAL,
313 }
314 MATH_OPERATORS = {
315     token.PLUS,
316     token.MINUS,
317     token.STAR,
318     token.SLASH,
319     token.VBAR,
320     token.AMPER,
321     token.PERCENT,
322     token.CIRCUMFLEX,
323     token.LEFTSHIFT,
324     token.RIGHTSHIFT,
325     token.DOUBLESTAR,
326     token.DOUBLESLASH,
327 }
328 COMPREHENSION_PRIORITY = 20
329 COMMA_PRIORITY = 10
330 LOGIC_PRIORITY = 5
331 STRING_PRIORITY = 4
332 COMPARATOR_PRIORITY = 3
333 MATH_PRIORITY = 1
334
335
336 @dataclass
337 class BracketTracker:
338     depth: int = 0
339     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
340     delimiters: Dict[LeafID, Priority] = Factory(dict)
341     previous: Optional[Leaf] = None
342
343     def mark(self, leaf: Leaf) -> None:
344         if leaf.type == token.COMMENT:
345             return
346
347         if leaf.type in CLOSING_BRACKETS:
348             self.depth -= 1
349             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
350             leaf.opening_bracket = opening_bracket
351         leaf.bracket_depth = self.depth
352         if self.depth == 0:
353             delim = is_delimiter(leaf)
354             if delim:
355                 self.delimiters[id(leaf)] = delim
356             elif self.previous is not None:
357                 if leaf.type == token.STRING and self.previous.type == token.STRING:
358                     self.delimiters[id(self.previous)] = STRING_PRIORITY
359                 elif (
360                     leaf.type == token.NAME and
361                     leaf.value == 'for' and
362                     leaf.parent and
363                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
364                 ):
365                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
366                 elif (
367                     leaf.type == token.NAME and
368                     leaf.value == 'if' and
369                     leaf.parent and
370                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
371                 ):
372                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
373         if leaf.type in OPENING_BRACKETS:
374             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
375             self.depth += 1
376         self.previous = leaf
377
378     def any_open_brackets(self) -> bool:
379         """Returns True if there is an yet unmatched open bracket on the line."""
380         return bool(self.bracket_match)
381
382     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
383         """Returns the highest priority of a delimiter found on the line.
384
385         Values are consistent with what `is_delimiter()` returns.
386         """
387         return max(v for k, v in self.delimiters.items() if k not in exclude)
388
389
390 @dataclass
391 class Line:
392     depth: int = 0
393     leaves: List[Leaf] = Factory(list)
394     comments: Dict[LeafID, Leaf] = Factory(dict)
395     bracket_tracker: BracketTracker = Factory(BracketTracker)
396     inside_brackets: bool = False
397     has_for: bool = False
398     _for_loop_variable: bool = False
399
400     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
401         has_value = leaf.value.strip()
402         if not has_value:
403             return
404
405         if self.leaves and not preformatted:
406             # Note: at this point leaf.prefix should be empty except for
407             # imports, for which we only preserve newlines.
408             leaf.prefix += whitespace(leaf)
409         if self.inside_brackets or not preformatted:
410             self.maybe_decrement_after_for_loop_variable(leaf)
411             self.bracket_tracker.mark(leaf)
412             self.maybe_remove_trailing_comma(leaf)
413             self.maybe_increment_for_loop_variable(leaf)
414             if self.maybe_adapt_standalone_comment(leaf):
415                 return
416
417         if not self.append_comment(leaf):
418             self.leaves.append(leaf)
419
420     @property
421     def is_comment(self) -> bool:
422         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
423
424     @property
425     def is_decorator(self) -> bool:
426         return bool(self) and self.leaves[0].type == token.AT
427
428     @property
429     def is_import(self) -> bool:
430         return bool(self) and is_import(self.leaves[0])
431
432     @property
433     def is_class(self) -> bool:
434         return (
435             bool(self) and
436             self.leaves[0].type == token.NAME and
437             self.leaves[0].value == 'class'
438         )
439
440     @property
441     def is_def(self) -> bool:
442         """Also returns True for async defs."""
443         try:
444             first_leaf = self.leaves[0]
445         except IndexError:
446             return False
447
448         try:
449             second_leaf: Optional[Leaf] = self.leaves[1]
450         except IndexError:
451             second_leaf = None
452         return (
453             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
454             (
455                 first_leaf.type == token.NAME and
456                 first_leaf.value == 'async' and
457                 second_leaf is not None and
458                 second_leaf.type == token.NAME and
459                 second_leaf.value == 'def'
460             )
461         )
462
463     @property
464     def is_flow_control(self) -> bool:
465         return (
466             bool(self) and
467             self.leaves[0].type == token.NAME and
468             self.leaves[0].value in FLOW_CONTROL
469         )
470
471     @property
472     def is_yield(self) -> bool:
473         return (
474             bool(self) and
475             self.leaves[0].type == token.NAME and
476             self.leaves[0].value == 'yield'
477         )
478
479     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
480         if not (
481             self.leaves and
482             self.leaves[-1].type == token.COMMA and
483             closing.type in CLOSING_BRACKETS
484         ):
485             return False
486
487         if closing.type == token.RSQB or closing.type == token.RBRACE:
488             self.leaves.pop()
489             return True
490
491         # For parens let's check if it's safe to remove the comma.  If the
492         # trailing one is the only one, we might mistakenly change a tuple
493         # into a different type by removing the comma.
494         depth = closing.bracket_depth + 1
495         commas = 0
496         opening = closing.opening_bracket
497         for _opening_index, leaf in enumerate(self.leaves):
498             if leaf is opening:
499                 break
500
501         else:
502             return False
503
504         for leaf in self.leaves[_opening_index + 1:]:
505             if leaf is closing:
506                 break
507
508             bracket_depth = leaf.bracket_depth
509             if bracket_depth == depth and leaf.type == token.COMMA:
510                 commas += 1
511                 if leaf.parent and leaf.parent.type == syms.arglist:
512                     commas += 1
513                     break
514
515         if commas > 1:
516             self.leaves.pop()
517             return True
518
519         return False
520
521     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
522         """In a for loop, or comprehension, the variables are often unpacks.
523
524         To avoid splitting on the comma in this situation, we will increase
525         the depth of tokens between `for` and `in`.
526         """
527         if leaf.type == token.NAME and leaf.value == 'for':
528             self.has_for = True
529             self.bracket_tracker.depth += 1
530             self._for_loop_variable = True
531             return True
532
533         return False
534
535     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
536         # See `maybe_increment_for_loop_variable` above for explanation.
537         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
538             self.bracket_tracker.depth -= 1
539             self._for_loop_variable = False
540             return True
541
542         return False
543
544     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
545         """Hack a standalone comment to act as a trailing comment for line splitting.
546
547         If this line has brackets and a standalone `comment`, we need to adapt
548         it to be able to still reformat the line.
549
550         This is not perfect, the line to which the standalone comment gets
551         appended will appear "too long" when splitting.
552         """
553         if not (
554             comment.type == STANDALONE_COMMENT and
555             self.bracket_tracker.any_open_brackets()
556         ):
557             return False
558
559         comment.type = token.COMMENT
560         comment.prefix = '\n' + '    ' * (self.depth + 1)
561         return self.append_comment(comment)
562
563     def append_comment(self, comment: Leaf) -> bool:
564         if comment.type != token.COMMENT:
565             return False
566
567         try:
568             after = id(self.last_non_delimiter())
569         except LookupError:
570             comment.type = STANDALONE_COMMENT
571             comment.prefix = ''
572             return False
573
574         else:
575             if after in self.comments:
576                 self.comments[after].value += str(comment)
577             else:
578                 self.comments[after] = comment
579             return True
580
581     def last_non_delimiter(self) -> Leaf:
582         for i in range(len(self.leaves)):
583             last = self.leaves[-i - 1]
584             if not is_delimiter(last):
585                 return last
586
587         raise LookupError("No non-delimiters found")
588
589     def __str__(self) -> str:
590         if not self:
591             return '\n'
592
593         indent = '    ' * self.depth
594         leaves = iter(self.leaves)
595         first = next(leaves)
596         res = f'{first.prefix}{indent}{first.value}'
597         for leaf in leaves:
598             res += str(leaf)
599         for comment in self.comments.values():
600             res += str(comment)
601         return res + '\n'
602
603     def __bool__(self) -> bool:
604         return bool(self.leaves or self.comments)
605
606
607 @dataclass
608 class EmptyLineTracker:
609     """Provides a stateful method that returns the number of potential extra
610     empty lines needed before and after the currently processed line.
611
612     Note: this tracker works on lines that haven't been split yet.
613     """
614     previous_line: Optional[Line] = None
615     previous_after: int = 0
616     previous_defs: List[int] = Factory(list)
617
618     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
619         """Returns the number of extra empty lines before and after the `current_line`.
620
621         This is for separating `def`, `async def` and `class` with extra empty lines
622         (two on module-level), as well as providing an extra empty line after flow
623         control keywords to make them more prominent.
624         """
625         before, after = self._maybe_empty_lines(current_line)
626         self.previous_after = after
627         self.previous_line = current_line
628         return before, after
629
630     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
631         before = 0
632         depth = current_line.depth
633         while self.previous_defs and self.previous_defs[-1] >= depth:
634             self.previous_defs.pop()
635             before = (1 if depth else 2) - self.previous_after
636         is_decorator = current_line.is_decorator
637         if is_decorator or current_line.is_def or current_line.is_class:
638             if not is_decorator:
639                 self.previous_defs.append(depth)
640             if self.previous_line is None:
641                 # Don't insert empty lines before the first line in the file.
642                 return 0, 0
643
644             if self.previous_line and self.previous_line.is_decorator:
645                 # Don't insert empty lines between decorators.
646                 return 0, 0
647
648             newlines = 2
649             if current_line.depth:
650                 newlines -= 1
651             newlines -= self.previous_after
652             return newlines, 0
653
654         if current_line.is_flow_control:
655             return before, 1
656
657         if (
658             self.previous_line and
659             self.previous_line.is_import and
660             not current_line.is_import and
661             depth == self.previous_line.depth
662         ):
663             return (before or 1), 0
664
665         if (
666             self.previous_line and
667             self.previous_line.is_yield and
668             (not current_line.is_yield or depth != self.previous_line.depth)
669         ):
670             return (before or 1), 0
671
672         return before, 0
673
674
675 @dataclass
676 class LineGenerator(Visitor[Line]):
677     """Generates reformatted Line objects.  Empty lines are not emitted.
678
679     Note: destroys the tree it's visiting by mutating prefixes of its leaves
680     in ways that will no longer stringify to valid Python code on the tree.
681     """
682     current_line: Line = Factory(Line)
683     standalone_comments: List[Leaf] = Factory(list)
684
685     def line(self, indent: int = 0) -> Iterator[Line]:
686         """Generate a line.
687
688         If the line is empty, only emit if it makes sense.
689         If the line is too long, split it first and then generate.
690
691         If any lines were generated, set up a new current_line.
692         """
693         if not self.current_line:
694             self.current_line.depth += indent
695             return  # Line is empty, don't emit. Creating a new one unnecessary.
696
697         complete_line = self.current_line
698         self.current_line = Line(depth=complete_line.depth + indent)
699         yield complete_line
700
701     def visit_default(self, node: LN) -> Iterator[Line]:
702         if isinstance(node, Leaf):
703             for comment in generate_comments(node):
704                 if self.current_line.bracket_tracker.any_open_brackets():
705                     # any comment within brackets is subject to splitting
706                     self.current_line.append(comment)
707                 elif comment.type == token.COMMENT:
708                     # regular trailing comment
709                     self.current_line.append(comment)
710                     yield from self.line()
711
712                 else:
713                     # regular standalone comment, to be processed later (see
714                     # docstring in `generate_comments()`
715                     self.standalone_comments.append(comment)
716             normalize_prefix(node)
717             if node.type not in WHITESPACE:
718                 for comment in self.standalone_comments:
719                     yield from self.line()
720
721                     self.current_line.append(comment)
722                     yield from self.line()
723
724                 self.standalone_comments = []
725                 self.current_line.append(node)
726         yield from super().visit_default(node)
727
728     def visit_suite(self, node: Node) -> Iterator[Line]:
729         """Body of a statement after a colon."""
730         children = iter(node.children)
731         # Process newline before indenting.  It might contain an inline
732         # comment that should go right after the colon.
733         newline = next(children)
734         yield from self.visit(newline)
735         yield from self.line(+1)
736
737         for child in children:
738             yield from self.visit(child)
739
740         yield from self.line(-1)
741
742     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
743         """Visit a statement.
744
745         The relevant Python language keywords for this statement are NAME leaves
746         within it.
747         """
748         for child in node.children:
749             if child.type == token.NAME and child.value in keywords:  # type: ignore
750                 yield from self.line()
751
752             yield from self.visit(child)
753
754     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
755         """A statement without nested statements."""
756         is_suite_like = node.parent and node.parent.type in STATEMENT
757         if is_suite_like:
758             yield from self.line(+1)
759             yield from self.visit_default(node)
760             yield from self.line(-1)
761
762         else:
763             yield from self.line()
764             yield from self.visit_default(node)
765
766     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
767         yield from self.line()
768
769         children = iter(node.children)
770         for child in children:
771             yield from self.visit(child)
772
773             if child.type == token.NAME and child.value == 'async':  # type: ignore
774                 break
775
776         internal_stmt = next(children)
777         for child in internal_stmt.children:
778             yield from self.visit(child)
779
780     def visit_decorators(self, node: Node) -> Iterator[Line]:
781         for child in node.children:
782             yield from self.line()
783             yield from self.visit(child)
784
785     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
786         yield from self.line()
787
788     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
789         yield from self.visit_default(leaf)
790         yield from self.line()
791
792     def __attrs_post_init__(self) -> None:
793         """You are in a twisty little maze of passages."""
794         v = self.visit_stmt
795         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
796         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
797         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
798         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
799         self.visit_except_clause = partial(v, keywords={'except'})
800         self.visit_funcdef = partial(v, keywords={'def'})
801         self.visit_with_stmt = partial(v, keywords={'with'})
802         self.visit_classdef = partial(v, keywords={'class'})
803         self.visit_async_funcdef = self.visit_async_stmt
804         self.visit_decorated = self.visit_decorators
805
806
807 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
808 OPENING_BRACKETS = set(BRACKET.keys())
809 CLOSING_BRACKETS = set(BRACKET.values())
810 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
811 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
812
813
814 def whitespace(leaf: Leaf) -> str:  # noqa C901
815     """Return whitespace prefix if needed for the given `leaf`."""
816     NO = ''
817     SPACE = ' '
818     DOUBLESPACE = '  '
819     t = leaf.type
820     p = leaf.parent
821     v = leaf.value
822     if t in ALWAYS_NO_SPACE:
823         return NO
824
825     if t == token.COMMENT:
826         return DOUBLESPACE
827
828     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
829     prev = leaf.prev_sibling
830     if not prev:
831         prevp = preceding_leaf(p)
832         if not prevp or prevp.type in OPENING_BRACKETS:
833             return NO
834
835         if prevp.type == token.EQUAL:
836             if prevp.parent and prevp.parent.type in {
837                 syms.typedargslist,
838                 syms.varargslist,
839                 syms.parameters,
840                 syms.arglist,
841                 syms.argument,
842             }:
843                 return NO
844
845         elif prevp.type == token.DOUBLESTAR:
846             if prevp.parent and prevp.parent.type in {
847                 syms.typedargslist,
848                 syms.varargslist,
849                 syms.parameters,
850                 syms.arglist,
851                 syms.dictsetmaker,
852             }:
853                 return NO
854
855         elif prevp.type == token.COLON:
856             if prevp.parent and prevp.parent.type == syms.subscript:
857                 return NO
858
859         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
860             return NO
861
862     elif prev.type in OPENING_BRACKETS:
863         return NO
864
865     if p.type in {syms.parameters, syms.arglist}:
866         # untyped function signatures or calls
867         if t == token.RPAR:
868             return NO
869
870         if not prev or prev.type != token.COMMA:
871             return NO
872
873     if p.type == syms.varargslist:
874         # lambdas
875         if t == token.RPAR:
876             return NO
877
878         if prev and prev.type != token.COMMA:
879             return NO
880
881     elif p.type == syms.typedargslist:
882         # typed function signatures
883         if not prev:
884             return NO
885
886         if t == token.EQUAL:
887             if prev.type != syms.tname:
888                 return NO
889
890         elif prev.type == token.EQUAL:
891             # A bit hacky: if the equal sign has whitespace, it means we
892             # previously found it's a typed argument.  So, we're using that, too.
893             return prev.prefix
894
895         elif prev.type != token.COMMA:
896             return NO
897
898     elif p.type == syms.tname:
899         # type names
900         if not prev:
901             prevp = preceding_leaf(p)
902             if not prevp or prevp.type != token.COMMA:
903                 return NO
904
905     elif p.type == syms.trailer:
906         # attributes and calls
907         if t == token.LPAR or t == token.RPAR:
908             return NO
909
910         if not prev:
911             if t == token.DOT:
912                 prevp = preceding_leaf(p)
913                 if not prevp or prevp.type != token.NUMBER:
914                     return NO
915
916             elif t == token.LSQB:
917                 return NO
918
919         elif prev.type != token.COMMA:
920             return NO
921
922     elif p.type == syms.argument:
923         # single argument
924         if t == token.EQUAL:
925             return NO
926
927         if not prev:
928             prevp = preceding_leaf(p)
929             if not prevp or prevp.type == token.LPAR:
930                 return NO
931
932         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
933             return NO
934
935     elif p.type == syms.decorator:
936         # decorators
937         return NO
938
939     elif p.type == syms.dotted_name:
940         if prev:
941             return NO
942
943         prevp = preceding_leaf(p)
944         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
945             return NO
946
947     elif p.type == syms.classdef:
948         if t == token.LPAR:
949             return NO
950
951         if prev and prev.type == token.LPAR:
952             return NO
953
954     elif p.type == syms.subscript:
955         # indexing
956         if not prev:
957             assert p.parent is not None, "subscripts are always parented"
958             if p.parent.type == syms.subscriptlist:
959                 return SPACE
960
961             return NO
962
963         elif prev.type == token.COLON:
964             return NO
965
966     elif p.type == syms.atom:
967         if prev and t == token.DOT:
968             # dots, but not the first one.
969             return NO
970
971     elif (
972         p.type == syms.listmaker or
973         p.type == syms.testlist_gexp or
974         p.type == syms.subscriptlist
975     ):
976         # list interior, including unpacking
977         if not prev:
978             return NO
979
980     elif p.type == syms.dictsetmaker:
981         # dict and set interior, including unpacking
982         if not prev:
983             return NO
984
985         if prev.type == token.DOUBLESTAR:
986             return NO
987
988     elif p.type in {syms.factor, syms.star_expr}:
989         # unary ops
990         if not prev:
991             prevp = preceding_leaf(p)
992             if not prevp or prevp.type in OPENING_BRACKETS:
993                 return NO
994
995             prevp_parent = prevp.parent
996             assert prevp_parent is not None
997             if prevp.type == token.COLON and prevp_parent.type in {
998                 syms.subscript, syms.sliceop
999             }:
1000                 return NO
1001
1002             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1003                 return NO
1004
1005         elif t == token.NAME or t == token.NUMBER:
1006             return NO
1007
1008     elif p.type == syms.import_from:
1009         if t == token.DOT:
1010             if prev and prev.type == token.DOT:
1011                 return NO
1012
1013         elif t == token.NAME:
1014             if v == 'import':
1015                 return SPACE
1016
1017             if prev and prev.type == token.DOT:
1018                 return NO
1019
1020     elif p.type == syms.sliceop:
1021         return NO
1022
1023     return SPACE
1024
1025
1026 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1027     """Returns the first leaf that precedes `node`, if any."""
1028     while node:
1029         res = node.prev_sibling
1030         if res:
1031             if isinstance(res, Leaf):
1032                 return res
1033
1034             try:
1035                 return list(res.leaves())[-1]
1036
1037             except IndexError:
1038                 return None
1039
1040         node = node.parent
1041     return None
1042
1043
1044 def is_delimiter(leaf: Leaf) -> int:
1045     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1046
1047     Higher numbers are higher priority.
1048     """
1049     if leaf.type == token.COMMA:
1050         return COMMA_PRIORITY
1051
1052     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1053         return LOGIC_PRIORITY
1054
1055     if leaf.type in COMPARATORS:
1056         return COMPARATOR_PRIORITY
1057
1058     if (
1059         leaf.type in MATH_OPERATORS and
1060         leaf.parent and
1061         leaf.parent.type not in {syms.factor, syms.star_expr}
1062     ):
1063         return MATH_PRIORITY
1064
1065     return 0
1066
1067
1068 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1069     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1070
1071     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1072     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1073     move because it does away with modifying the grammar to include all the
1074     possible places in which comments can be placed.
1075
1076     The sad consequence for us though is that comments don't "belong" anywhere.
1077     This is why this function generates simple parentless Leaf objects for
1078     comments.  We simply don't know what the correct parent should be.
1079
1080     No matter though, we can live without this.  We really only need to
1081     differentiate between inline and standalone comments.  The latter don't
1082     share the line with any code.
1083
1084     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1085     are emitted with a fake STANDALONE_COMMENT token identifier.
1086     """
1087     if not leaf.prefix:
1088         return
1089
1090     if '#' not in leaf.prefix:
1091         return
1092
1093     before_comment, content = leaf.prefix.split('#', 1)
1094     content = content.rstrip()
1095     if content and (content[0] not in {' ', '!', '#'}):
1096         content = ' ' + content
1097     is_standalone_comment = (
1098         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1099     )
1100     if not is_standalone_comment:
1101         # simple trailing comment
1102         yield Leaf(token.COMMENT, value='#' + content)
1103         return
1104
1105     for line in ('#' + content).split('\n'):
1106         line = line.lstrip()
1107         if not line.startswith('#'):
1108             continue
1109
1110         yield Leaf(STANDALONE_COMMENT, line)
1111
1112
1113 def split_line(
1114     line: Line, line_length: int, inner: bool = False, py36: bool = False
1115 ) -> Iterator[Line]:
1116     """Splits a `line` into potentially many lines.
1117
1118     They should fit in the allotted `line_length` but might not be able to.
1119     `inner` signifies that there were a pair of brackets somewhere around the
1120     current `line`, possibly transitively. This means we can fallback to splitting
1121     by delimiters if the LHS/RHS don't yield any results.
1122
1123     If `py36` is True, splitting may generate syntax that is only compatible
1124     with Python 3.6 and later.
1125     """
1126     line_str = str(line).strip('\n')
1127     if len(line_str) <= line_length and '\n' not in line_str:
1128         yield line
1129         return
1130
1131     if line.is_def:
1132         split_funcs = [left_hand_split]
1133     elif line.inside_brackets:
1134         split_funcs = [delimiter_split]
1135         if '\n' not in line_str:
1136             # Only attempt RHS if we don't have multiline strings or comments
1137             # on this line.
1138             split_funcs.append(right_hand_split)
1139     else:
1140         split_funcs = [right_hand_split]
1141     for split_func in split_funcs:
1142         # We are accumulating lines in `result` because we might want to abort
1143         # mission and return the original line in the end, or attempt a different
1144         # split altogether.
1145         result: List[Line] = []
1146         try:
1147             for l in split_func(line, py36=py36):
1148                 if str(l).strip('\n') == line_str:
1149                     raise CannotSplit("Split function returned an unchanged result")
1150
1151                 result.extend(
1152                     split_line(l, line_length=line_length, inner=True, py36=py36)
1153                 )
1154         except CannotSplit as cs:
1155             continue
1156
1157         else:
1158             yield from result
1159             break
1160
1161     else:
1162         yield line
1163
1164
1165 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1166     """Split line into many lines, starting with the first matching bracket pair.
1167
1168     Note: this usually looks weird, only use this for function definitions.
1169     Prefer RHS otherwise.
1170     """
1171     head = Line(depth=line.depth)
1172     body = Line(depth=line.depth + 1, inside_brackets=True)
1173     tail = Line(depth=line.depth)
1174     tail_leaves: List[Leaf] = []
1175     body_leaves: List[Leaf] = []
1176     head_leaves: List[Leaf] = []
1177     current_leaves = head_leaves
1178     matching_bracket = None
1179     for leaf in line.leaves:
1180         if (
1181             current_leaves is body_leaves and
1182             leaf.type in CLOSING_BRACKETS and
1183             leaf.opening_bracket is matching_bracket
1184         ):
1185             current_leaves = tail_leaves if body_leaves else head_leaves
1186         current_leaves.append(leaf)
1187         if current_leaves is head_leaves:
1188             if leaf.type in OPENING_BRACKETS:
1189                 matching_bracket = leaf
1190                 current_leaves = body_leaves
1191     # Since body is a new indent level, remove spurious leading whitespace.
1192     if body_leaves:
1193         normalize_prefix(body_leaves[0])
1194     # Build the new lines.
1195     for result, leaves in (
1196         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1197     ):
1198         for leaf in leaves:
1199             result.append(leaf, preformatted=True)
1200             comment_after = line.comments.get(id(leaf))
1201             if comment_after:
1202                 result.append(comment_after, preformatted=True)
1203     split_succeeded_or_raise(head, body, tail)
1204     for result in (head, body, tail):
1205         if result:
1206             yield result
1207
1208
1209 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1210     """Split line into many lines, starting with the last matching bracket pair."""
1211     head = Line(depth=line.depth)
1212     body = Line(depth=line.depth + 1, inside_brackets=True)
1213     tail = Line(depth=line.depth)
1214     tail_leaves: List[Leaf] = []
1215     body_leaves: List[Leaf] = []
1216     head_leaves: List[Leaf] = []
1217     current_leaves = tail_leaves
1218     opening_bracket = None
1219     for leaf in reversed(line.leaves):
1220         if current_leaves is body_leaves:
1221             if leaf is opening_bracket:
1222                 current_leaves = head_leaves if body_leaves else tail_leaves
1223         current_leaves.append(leaf)
1224         if current_leaves is tail_leaves:
1225             if leaf.type in CLOSING_BRACKETS:
1226                 opening_bracket = leaf.opening_bracket
1227                 current_leaves = body_leaves
1228     tail_leaves.reverse()
1229     body_leaves.reverse()
1230     head_leaves.reverse()
1231     # Since body is a new indent level, remove spurious leading whitespace.
1232     if body_leaves:
1233         normalize_prefix(body_leaves[0])
1234     # Build the new lines.
1235     for result, leaves in (
1236         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1237     ):
1238         for leaf in leaves:
1239             result.append(leaf, preformatted=True)
1240             comment_after = line.comments.get(id(leaf))
1241             if comment_after:
1242                 result.append(comment_after, preformatted=True)
1243     split_succeeded_or_raise(head, body, tail)
1244     for result in (head, body, tail):
1245         if result:
1246             yield result
1247
1248
1249 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1250     tail_len = len(str(tail).strip())
1251     if not body:
1252         if tail_len == 0:
1253             raise CannotSplit("Splitting brackets produced the same line")
1254
1255         elif tail_len < 3:
1256             raise CannotSplit(
1257                 f"Splitting brackets on an empty body to save "
1258                 f"{tail_len} characters is not worth it"
1259             )
1260
1261
1262 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1263     """Split according to delimiters of the highest priority.
1264
1265     This kind of split doesn't increase indentation.
1266     If `py36` is True, the split will add trailing commas also in function
1267     signatures that contain * and **.
1268     """
1269     try:
1270         last_leaf = line.leaves[-1]
1271     except IndexError:
1272         raise CannotSplit("Line empty")
1273
1274     delimiters = line.bracket_tracker.delimiters
1275     try:
1276         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1277     except ValueError:
1278         raise CannotSplit("No delimiters found")
1279
1280     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1281     lowest_depth = sys.maxsize
1282     trailing_comma_safe = True
1283     for leaf in line.leaves:
1284         current_line.append(leaf, preformatted=True)
1285         comment_after = line.comments.get(id(leaf))
1286         if comment_after:
1287             current_line.append(comment_after, preformatted=True)
1288         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1289         if (
1290             leaf.bracket_depth == lowest_depth and
1291             leaf.type == token.STAR or
1292             leaf.type == token.DOUBLESTAR
1293         ):
1294             trailing_comma_safe = trailing_comma_safe and py36
1295         leaf_priority = delimiters.get(id(leaf))
1296         if leaf_priority == delimiter_priority:
1297             normalize_prefix(current_line.leaves[0])
1298             yield current_line
1299
1300             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1301     if current_line:
1302         if (
1303             delimiter_priority == COMMA_PRIORITY and
1304             current_line.leaves[-1].type != token.COMMA and
1305             trailing_comma_safe
1306         ):
1307             current_line.append(Leaf(token.COMMA, ','))
1308         normalize_prefix(current_line.leaves[0])
1309         yield current_line
1310
1311
1312 def is_import(leaf: Leaf) -> bool:
1313     """Returns True if the given leaf starts an import statement."""
1314     p = leaf.parent
1315     t = leaf.type
1316     v = leaf.value
1317     return bool(
1318         t == token.NAME and
1319         (
1320             (v == 'import' and p and p.type == syms.import_name) or
1321             (v == 'from' and p and p.type == syms.import_from)
1322         )
1323     )
1324
1325
1326 def normalize_prefix(leaf: Leaf) -> None:
1327     """Leave existing extra newlines for imports.  Remove everything else."""
1328     if is_import(leaf):
1329         spl = leaf.prefix.split('#', 1)
1330         nl_count = spl[0].count('\n')
1331         if len(spl) > 1:
1332             # Skip one newline since it was for a standalone comment.
1333             nl_count -= 1
1334         leaf.prefix = '\n' * nl_count
1335         return
1336
1337     leaf.prefix = ''
1338
1339
1340 def is_python36(node: Node) -> bool:
1341     """Returns True if the current file is using Python 3.6+ features.
1342
1343     Currently looking for:
1344     - f-strings; and
1345     - trailing commas after * or ** in function signatures.
1346     """
1347     for n in node.pre_order():
1348         if n.type == token.STRING:
1349             value_head = n.value[:2]  # type: ignore
1350             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1351                 return True
1352
1353         elif (
1354             n.type == syms.typedargslist and
1355             n.children and
1356             n.children[-1].type == token.COMMA
1357         ):
1358             for ch in n.children:
1359                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1360                     return True
1361
1362     return False
1363
1364
1365 PYTHON_EXTENSIONS = {'.py'}
1366 BLACKLISTED_DIRECTORIES = {
1367     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1368 }
1369
1370
1371 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1372     for child in path.iterdir():
1373         if child.is_dir():
1374             if child.name in BLACKLISTED_DIRECTORIES:
1375                 continue
1376
1377             yield from gen_python_files_in_dir(child)
1378
1379         elif child.suffix in PYTHON_EXTENSIONS:
1380             yield child
1381
1382
1383 @dataclass
1384 class Report:
1385     """Provides a reformatting counter."""
1386     change_count: int = 0
1387     same_count: int = 0
1388     failure_count: int = 0
1389
1390     def done(self, src: Path, changed: bool) -> None:
1391         """Increment the counter for successful reformatting. Write out a message."""
1392         if changed:
1393             out(f'reformatted {src}')
1394             self.change_count += 1
1395         else:
1396             out(f'{src} already well formatted, good job.', bold=False)
1397             self.same_count += 1
1398
1399     def failed(self, src: Path, message: str) -> None:
1400         """Increment the counter for failed reformatting. Write out a message."""
1401         err(f'error: cannot format {src}: {message}')
1402         self.failure_count += 1
1403
1404     @property
1405     def return_code(self) -> int:
1406         """Which return code should the app use considering the current state."""
1407         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1408         # 126 we have special returncodes reserved by the shell.
1409         if self.failure_count:
1410             return 123
1411
1412         elif self.change_count:
1413             return 1
1414
1415         return 0
1416
1417     def __str__(self) -> str:
1418         """A color report of the current state.
1419
1420         Use `click.unstyle` to remove colors.
1421         """
1422         report = []
1423         if self.change_count:
1424             s = 's' if self.change_count > 1 else ''
1425             report.append(
1426                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1427             )
1428         if self.same_count:
1429             s = 's' if self.same_count > 1 else ''
1430             report.append(f'{self.same_count} file{s} left unchanged')
1431         if self.failure_count:
1432             s = 's' if self.failure_count > 1 else ''
1433             report.append(
1434                 click.style(
1435                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1436                 )
1437             )
1438         return ', '.join(report) + '.'
1439
1440
1441 def assert_equivalent(src: str, dst: str) -> None:
1442     """Raises AssertionError if `src` and `dst` aren't equivalent.
1443
1444     This is a temporary sanity check until Black becomes stable.
1445     """
1446
1447     import ast
1448     import traceback
1449
1450     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1451         """Simple visitor generating strings to compare ASTs by content."""
1452         yield f"{'  ' * depth}{node.__class__.__name__}("
1453
1454         for field in sorted(node._fields):
1455             try:
1456                 value = getattr(node, field)
1457             except AttributeError:
1458                 continue
1459
1460             yield f"{'  ' * (depth+1)}{field}="
1461
1462             if isinstance(value, list):
1463                 for item in value:
1464                     if isinstance(item, ast.AST):
1465                         yield from _v(item, depth + 2)
1466
1467             elif isinstance(value, ast.AST):
1468                 yield from _v(value, depth + 2)
1469
1470             else:
1471                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1472
1473         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1474
1475     try:
1476         src_ast = ast.parse(src)
1477     except Exception as exc:
1478         raise AssertionError(f"cannot parse source: {exc}") from None
1479
1480     try:
1481         dst_ast = ast.parse(dst)
1482     except Exception as exc:
1483         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1484         raise AssertionError(
1485             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1486             f"Please report a bug on https://github.com/ambv/black/issues.  "
1487             f"This invalid output might be helpful: {log}"
1488         ) from None
1489
1490     src_ast_str = '\n'.join(_v(src_ast))
1491     dst_ast_str = '\n'.join(_v(dst_ast))
1492     if src_ast_str != dst_ast_str:
1493         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1494         raise AssertionError(
1495             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1496             f"the source.  "
1497             f"Please report a bug on https://github.com/ambv/black/issues.  "
1498             f"This diff might be helpful: {log}"
1499         ) from None
1500
1501
1502 def assert_stable(src: str, dst: str, line_length: int) -> None:
1503     """Raises AssertionError if `dst` reformats differently the second time.
1504
1505     This is a temporary sanity check until Black becomes stable.
1506     """
1507     newdst = format_str(dst, line_length=line_length)
1508     if dst != newdst:
1509         log = dump_to_file(
1510             diff(src, dst, 'source', 'first pass'),
1511             diff(dst, newdst, 'first pass', 'second pass'),
1512         )
1513         raise AssertionError(
1514             f"INTERNAL ERROR: Black produced different code on the second pass "
1515             f"of the formatter.  "
1516             f"Please report a bug on https://github.com/ambv/black/issues.  "
1517             f"This diff might be helpful: {log}"
1518         ) from None
1519
1520
1521 def dump_to_file(*output: str) -> str:
1522     """Dumps `output` to a temporary file. Returns path to the file."""
1523     import tempfile
1524
1525     with tempfile.NamedTemporaryFile(
1526         mode='w', prefix='blk_', suffix='.log', delete=False
1527     ) as f:
1528         for lines in output:
1529             f.write(lines)
1530             f.write('\n')
1531     return f.name
1532
1533
1534 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1535     """Returns a udiff string between strings `a` and `b`."""
1536     import difflib
1537
1538     a_lines = [line + '\n' for line in a.split('\n')]
1539     b_lines = [line + '\n' for line in b.split('\n')]
1540     return ''.join(
1541         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1542     )
1543
1544
1545 if __name__ == '__main__':
1546     main()