]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Only use trailing commas in function signatures when it's safe
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import attrib, dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a1"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
78 )
79 @click.pass_context
80 def main(
81     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
82 ) -> None:
83     """The uncompromising code formatter."""
84     sources: List[Path] = []
85     for s in src:
86         p = Path(s)
87         if p.is_dir():
88             sources.extend(gen_python_files_in_dir(p))
89         elif p.is_file():
90             # if a file was explicitly given, we don't care about its extension
91             sources.append(p)
92         else:
93             err(f'invalid path: {s}')
94     if len(sources) == 0:
95         ctx.exit(0)
96     elif len(sources) == 1:
97         p = sources[0]
98         report = Report()
99         try:
100             changed = format_file_in_place(
101                 p, line_length=line_length, fast=fast, write_back=not check
102             )
103             report.done(p, changed)
104         except Exception as exc:
105             report.failed(p, str(exc))
106         ctx.exit(report.return_code)
107     else:
108         loop = asyncio.get_event_loop()
109         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
110         return_code = 1
111         try:
112             return_code = loop.run_until_complete(
113                 schedule_formatting(
114                     sources, line_length, not check, fast, loop, executor
115                 )
116             )
117         finally:
118             loop.close()
119             ctx.exit(return_code)
120
121
122 async def schedule_formatting(
123     sources: List[Path],
124     line_length: int,
125     write_back: bool,
126     fast: bool,
127     loop: BaseEventLoop,
128     executor: Executor,
129 ) -> int:
130     tasks = {
131         src: loop.run_in_executor(
132             executor, format_file_in_place, src, line_length, fast, write_back
133         )
134         for src in sources
135     }
136     await asyncio.wait(tasks.values())
137     cancelled = []
138     report = Report()
139     for src, task in tasks.items():
140         if not task.done():
141             report.failed(src, 'timed out, cancelling')
142             task.cancel()
143             cancelled.append(task)
144         elif task.exception():
145             report.failed(src, str(task.exception()))
146         else:
147             report.done(src, task.result())
148     if cancelled:
149         await asyncio.wait(cancelled, timeout=2)
150     out('All done! ✨ 🍰 ✨')
151     click.echo(str(report))
152     return report.return_code
153
154
155 def format_file_in_place(
156     src: Path, line_length: int, fast: bool, write_back: bool = False
157 ) -> bool:
158     """Format the file and rewrite if changed. Return True if changed."""
159     try:
160         contents, encoding = format_file(src, line_length=line_length, fast=fast)
161     except NothingChanged:
162         return False
163
164     if write_back:
165         with open(src, "w", encoding=encoding) as f:
166             f.write(contents)
167     return True
168
169
170 def format_file(
171     src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173     """Reformats a file and returns its contents and encoding."""
174     with tokenize.open(src) as src_buffer:
175         src_contents = src_buffer.read()
176     if src_contents.strip() == '':
177         raise NothingChanged(src)
178
179     dst_contents = format_str(src_contents, line_length=line_length)
180     if src_contents == dst_contents:
181         raise NothingChanged(src)
182
183     if not fast:
184         assert_equivalent(src_contents, dst_contents)
185         assert_stable(src_contents, dst_contents, line_length=line_length)
186     return dst_contents, src_buffer.encoding
187
188
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190     """Reformats a string and returns new contents."""
191     src_node = lib2to3_parse(src_contents)
192     dst_contents = ""
193     comments: List[Line] = []
194     lines = LineGenerator()
195     elt = EmptyLineTracker()
196     py36 = is_python36(src_node)
197     empty_line = Line()
198     after = 0
199     for current_line in lines.visit(src_node):
200         for _ in range(after):
201             dst_contents += str(empty_line)
202         before, after = elt.maybe_empty_lines(current_line)
203         for _ in range(before):
204             dst_contents += str(empty_line)
205         if not current_line.is_comment:
206             for comment in comments:
207                 dst_contents += str(comment)
208             comments = []
209             for line in split_line(current_line, line_length=line_length, py36=py36):
210                 dst_contents += str(line)
211         else:
212             comments.append(current_line)
213     for comment in comments:
214         dst_contents += str(comment)
215     return dst_contents
216
217
218 def lib2to3_parse(src_txt: str) -> Node:
219     """Given a string with source, return the lib2to3 Node."""
220     grammar = pygram.python_grammar_no_print_statement
221     drv = driver.Driver(grammar, pytree.convert)
222     if src_txt[-1] != '\n':
223         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
224         src_txt += nl
225     try:
226         result = drv.parse_string(src_txt, True)
227     except ParseError as pe:
228         lineno, column = pe.context[1]
229         lines = src_txt.splitlines()
230         try:
231             faulty_line = lines[lineno - 1]
232         except IndexError:
233             faulty_line = "<line number missing in source>"
234         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
235
236     if isinstance(result, Leaf):
237         result = Node(syms.file_input, [result])
238     return result
239
240
241 def lib2to3_unparse(node: Node) -> str:
242     """Given a lib2to3 node, return its string representation."""
243     code = str(node)
244     return code
245
246
247 T = TypeVar('T')
248
249
250 class Visitor(Generic[T]):
251     """Basic lib2to3 visitor that yields things on visiting."""
252
253     def visit(self, node: LN) -> Iterator[T]:
254         if node.type < 256:
255             name = token.tok_name[node.type]
256         else:
257             name = type_repr(node.type)
258         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
259
260     def visit_default(self, node: LN) -> Iterator[T]:
261         if isinstance(node, Node):
262             for child in node.children:
263                 yield from self.visit(child)
264
265
266 @dataclass
267 class DebugVisitor(Visitor[T]):
268     tree_depth: int = attrib(default=0)
269
270     def visit_default(self, node: LN) -> Iterator[T]:
271         indent = ' ' * (2 * self.tree_depth)
272         if isinstance(node, Node):
273             _type = type_repr(node.type)
274             out(f'{indent}{_type}', fg='yellow')
275             self.tree_depth += 1
276             for child in node.children:
277                 yield from self.visit(child)
278
279             self.tree_depth -= 1
280             out(f'{indent}/{_type}', fg='yellow', bold=False)
281         else:
282             _type = token.tok_name.get(node.type, str(node.type))
283             out(f'{indent}{_type}', fg='blue', nl=False)
284             if node.prefix:
285                 # We don't have to handle prefixes for `Node` objects since
286                 # that delegates to the first child anyway.
287                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
288             out(f' {node.value!r}', fg='blue', bold=False)
289
290
291 KEYWORDS = set(keyword.kwlist)
292 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
293 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
294 STATEMENT = {
295     syms.if_stmt,
296     syms.while_stmt,
297     syms.for_stmt,
298     syms.try_stmt,
299     syms.except_clause,
300     syms.with_stmt,
301     syms.funcdef,
302     syms.classdef,
303 }
304 STANDALONE_COMMENT = 153
305 LOGIC_OPERATORS = {'and', 'or'}
306 COMPARATORS = {
307     token.LESS,
308     token.GREATER,
309     token.EQEQUAL,
310     token.NOTEQUAL,
311     token.LESSEQUAL,
312     token.GREATEREQUAL,
313 }
314 MATH_OPERATORS = {
315     token.PLUS,
316     token.MINUS,
317     token.STAR,
318     token.SLASH,
319     token.VBAR,
320     token.AMPER,
321     token.PERCENT,
322     token.CIRCUMFLEX,
323     token.LEFTSHIFT,
324     token.RIGHTSHIFT,
325     token.DOUBLESTAR,
326     token.DOUBLESLASH,
327 }
328 COMPREHENSION_PRIORITY = 20
329 COMMA_PRIORITY = 10
330 LOGIC_PRIORITY = 5
331 STRING_PRIORITY = 4
332 COMPARATOR_PRIORITY = 3
333 MATH_PRIORITY = 1
334
335
336 @dataclass
337 class BracketTracker:
338     depth: int = attrib(default=0)
339     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
340     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
341     previous: Optional[Leaf] = attrib(default=None)
342
343     def mark(self, leaf: Leaf) -> None:
344         if leaf.type == token.COMMENT:
345             return
346
347         if leaf.type in CLOSING_BRACKETS:
348             self.depth -= 1
349             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
350             leaf.opening_bracket = opening_bracket  # type: ignore
351         leaf.bracket_depth = self.depth  # type: ignore
352         if self.depth == 0:
353             delim = is_delimiter(leaf)
354             if delim:
355                 self.delimiters[id(leaf)] = delim
356             elif self.previous is not None:
357                 if leaf.type == token.STRING and self.previous.type == token.STRING:
358                     self.delimiters[id(self.previous)] = STRING_PRIORITY
359                 elif (
360                     leaf.type == token.NAME and
361                     leaf.value == 'for' and
362                     leaf.parent and
363                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
364                 ):
365                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
366                 elif (
367                     leaf.type == token.NAME and
368                     leaf.value == 'if' and
369                     leaf.parent and
370                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
371                 ):
372                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
373         if leaf.type in OPENING_BRACKETS:
374             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
375             self.depth += 1
376         self.previous = leaf
377
378     def any_open_brackets(self) -> bool:
379         """Returns True if there is an yet unmatched open bracket on the line."""
380         return bool(self.bracket_match)
381
382     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
383         """Returns the highest priority of a delimiter found on the line.
384
385         Values are consistent with what `is_delimiter()` returns.
386         """
387         return max(v for k, v in self.delimiters.items() if k not in exclude)
388
389
390 @dataclass
391 class Line:
392     depth: int = attrib(default=0)
393     leaves: List[Leaf] = attrib(default=Factory(list))
394     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
395     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
396     inside_brackets: bool = attrib(default=False)
397     has_for: bool = attrib(default=False)
398     _for_loop_variable: bool = attrib(default=False, init=False)
399
400     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
401         has_value = leaf.value.strip()
402         if not has_value:
403             return
404
405         if self.leaves and not preformatted:
406             # Note: at this point leaf.prefix should be empty except for
407             # imports, for which we only preserve newlines.
408             leaf.prefix += whitespace(leaf)
409         if self.inside_brackets or not preformatted:
410             self.maybe_decrement_after_for_loop_variable(leaf)
411             self.bracket_tracker.mark(leaf)
412             self.maybe_remove_trailing_comma(leaf)
413             self.maybe_increment_for_loop_variable(leaf)
414             if self.maybe_adapt_standalone_comment(leaf):
415                 return
416
417         if not self.append_comment(leaf):
418             self.leaves.append(leaf)
419
420     @property
421     def is_comment(self) -> bool:
422         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
423
424     @property
425     def is_decorator(self) -> bool:
426         return bool(self) and self.leaves[0].type == token.AT
427
428     @property
429     def is_import(self) -> bool:
430         return bool(self) and is_import(self.leaves[0])
431
432     @property
433     def is_class(self) -> bool:
434         return (
435             bool(self) and
436             self.leaves[0].type == token.NAME and
437             self.leaves[0].value == 'class'
438         )
439
440     @property
441     def is_def(self) -> bool:
442         """Also returns True for async defs."""
443         try:
444             first_leaf = self.leaves[0]
445         except IndexError:
446             return False
447
448         try:
449             second_leaf: Optional[Leaf] = self.leaves[1]
450         except IndexError:
451             second_leaf = None
452         return (
453             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
454             (
455                 first_leaf.type == token.NAME and
456                 first_leaf.value == 'async' and
457                 second_leaf is not None and
458                 second_leaf.type == token.NAME and
459                 second_leaf.value == 'def'
460             )
461         )
462
463     @property
464     def is_flow_control(self) -> bool:
465         return (
466             bool(self) and
467             self.leaves[0].type == token.NAME and
468             self.leaves[0].value in FLOW_CONTROL
469         )
470
471     @property
472     def is_yield(self) -> bool:
473         return (
474             bool(self) and
475             self.leaves[0].type == token.NAME and
476             self.leaves[0].value == 'yield'
477         )
478
479     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
480         if not (
481             self.leaves and
482             self.leaves[-1].type == token.COMMA and
483             closing.type in CLOSING_BRACKETS
484         ):
485             return False
486
487         if closing.type == token.RSQB or closing.type == token.RBRACE:
488             self.leaves.pop()
489             return True
490
491         # For parens let's check if it's safe to remove the comma.  If the
492         # trailing one is the only one, we might mistakenly change a tuple
493         # into a different type by removing the comma.
494         depth = closing.bracket_depth + 1  # type: ignore
495         commas = 0
496         opening = closing.opening_bracket  # type: ignore
497         for _opening_index, leaf in enumerate(self.leaves):
498             if leaf is opening:
499                 break
500
501         else:
502             return False
503
504         for leaf in self.leaves[_opening_index + 1:]:
505             if leaf is closing:
506                 break
507
508             bracket_depth = leaf.bracket_depth  # type: ignore
509             if bracket_depth == depth and leaf.type == token.COMMA:
510                 commas += 1
511         if commas > 1:
512             self.leaves.pop()
513             return True
514
515         return False
516
517     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
518         """In a for loop, or comprehension, the variables are often unpacks.
519
520         To avoid splitting on the comma in this situation, we will increase
521         the depth of tokens between `for` and `in`.
522         """
523         if leaf.type == token.NAME and leaf.value == 'for':
524             self.has_for = True
525             self.bracket_tracker.depth += 1
526             self._for_loop_variable = True
527             return True
528
529         return False
530
531     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
532         # See `maybe_increment_for_loop_variable` above for explanation.
533         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
534             self.bracket_tracker.depth -= 1
535             self._for_loop_variable = False
536             return True
537
538         return False
539
540     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
541         """Hack a standalone comment to act as a trailing comment for line splitting.
542
543         If this line has brackets and a standalone `comment`, we need to adapt
544         it to be able to still reformat the line.
545
546         This is not perfect, the line to which the standalone comment gets
547         appended will appear "too long" when splitting.
548         """
549         if not (
550             comment.type == STANDALONE_COMMENT and
551             self.bracket_tracker.any_open_brackets()
552         ):
553             return False
554
555         comment.type = token.COMMENT
556         comment.prefix = '\n' + '    ' * (self.depth + 1)
557         return self.append_comment(comment)
558
559     def append_comment(self, comment: Leaf) -> bool:
560         if comment.type != token.COMMENT:
561             return False
562
563         try:
564             after = id(self.last_non_delimiter())
565         except LookupError:
566             comment.type = STANDALONE_COMMENT
567             comment.prefix = ''
568             return False
569
570         else:
571             if after in self.comments:
572                 self.comments[after].value += str(comment)
573             else:
574                 self.comments[after] = comment
575             return True
576
577     def last_non_delimiter(self) -> Leaf:
578         for i in range(len(self.leaves)):
579             last = self.leaves[-i - 1]
580             if not is_delimiter(last):
581                 return last
582
583         raise LookupError("No non-delimiters found")
584
585     def __str__(self) -> str:
586         if not self:
587             return '\n'
588
589         indent = '    ' * self.depth
590         leaves = iter(self.leaves)
591         first = next(leaves)
592         res = f'{first.prefix}{indent}{first.value}'
593         for leaf in leaves:
594             res += str(leaf)
595         for comment in self.comments.values():
596             res += str(comment)
597         return res + '\n'
598
599     def __bool__(self) -> bool:
600         return bool(self.leaves or self.comments)
601
602
603 @dataclass
604 class EmptyLineTracker:
605     """Provides a stateful method that returns the number of potential extra
606     empty lines needed before and after the currently processed line.
607
608     Note: this tracker works on lines that haven't been split yet.
609     """
610     previous_line: Optional[Line] = attrib(default=None)
611     previous_after: int = attrib(default=0)
612     previous_defs: List[int] = attrib(default=Factory(list))
613
614     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
615         """Returns the number of extra empty lines before and after the `current_line`.
616
617         This is for separating `def`, `async def` and `class` with extra empty lines
618         (two on module-level), as well as providing an extra empty line after flow
619         control keywords to make them more prominent.
620         """
621         before, after = self._maybe_empty_lines(current_line)
622         self.previous_after = after
623         self.previous_line = current_line
624         return before, after
625
626     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
627         before = 0
628         depth = current_line.depth
629         while self.previous_defs and self.previous_defs[-1] >= depth:
630             self.previous_defs.pop()
631             before = (1 if depth else 2) - self.previous_after
632         is_decorator = current_line.is_decorator
633         if is_decorator or current_line.is_def or current_line.is_class:
634             if not is_decorator:
635                 self.previous_defs.append(depth)
636             if self.previous_line is None:
637                 # Don't insert empty lines before the first line in the file.
638                 return 0, 0
639
640             if self.previous_line and self.previous_line.is_decorator:
641                 # Don't insert empty lines between decorators.
642                 return 0, 0
643
644             newlines = 2
645             if current_line.depth:
646                 newlines -= 1
647             newlines -= self.previous_after
648             return newlines, 0
649
650         if current_line.is_flow_control:
651             return before, 1
652
653         if (
654             self.previous_line and
655             self.previous_line.is_import and
656             not current_line.is_import and
657             depth == self.previous_line.depth
658         ):
659             return (before or 1), 0
660
661         if (
662             self.previous_line and
663             self.previous_line.is_yield and
664             (not current_line.is_yield or depth != self.previous_line.depth)
665         ):
666             return (before or 1), 0
667
668         return before, 0
669
670
671 @dataclass
672 class LineGenerator(Visitor[Line]):
673     """Generates reformatted Line objects.  Empty lines are not emitted.
674
675     Note: destroys the tree it's visiting by mutating prefixes of its leaves
676     in ways that will no longer stringify to valid Python code on the tree.
677     """
678     current_line: Line = attrib(default=Factory(Line))
679     standalone_comments: List[Leaf] = attrib(default=Factory(list))
680
681     def line(self, indent: int = 0) -> Iterator[Line]:
682         """Generate a line.
683
684         If the line is empty, only emit if it makes sense.
685         If the line is too long, split it first and then generate.
686
687         If any lines were generated, set up a new current_line.
688         """
689         if not self.current_line:
690             self.current_line.depth += indent
691             return  # Line is empty, don't emit. Creating a new one unnecessary.
692
693         complete_line = self.current_line
694         self.current_line = Line(depth=complete_line.depth + indent)
695         yield complete_line
696
697     def visit_default(self, node: LN) -> Iterator[Line]:
698         if isinstance(node, Leaf):
699             for comment in generate_comments(node):
700                 if self.current_line.bracket_tracker.any_open_brackets():
701                     # any comment within brackets is subject to splitting
702                     self.current_line.append(comment)
703                 elif comment.type == token.COMMENT:
704                     # regular trailing comment
705                     self.current_line.append(comment)
706                     yield from self.line()
707
708                 else:
709                     # regular standalone comment, to be processed later (see
710                     # docstring in `generate_comments()`
711                     self.standalone_comments.append(comment)
712             normalize_prefix(node)
713             if node.type not in WHITESPACE:
714                 for comment in self.standalone_comments:
715                     yield from self.line()
716
717                     self.current_line.append(comment)
718                     yield from self.line()
719
720                 self.standalone_comments = []
721                 self.current_line.append(node)
722         yield from super().visit_default(node)
723
724     def visit_suite(self, node: Node) -> Iterator[Line]:
725         """Body of a statement after a colon."""
726         children = iter(node.children)
727         # Process newline before indenting.  It might contain an inline
728         # comment that should go right after the colon.
729         newline = next(children)
730         yield from self.visit(newline)
731         yield from self.line(+1)
732
733         for child in children:
734             yield from self.visit(child)
735
736         yield from self.line(-1)
737
738     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
739         """Visit a statement.
740
741         The relevant Python language keywords for this statement are NAME leaves
742         within it.
743         """
744         for child in node.children:
745             if child.type == token.NAME and child.value in keywords:  # type: ignore
746                 yield from self.line()
747
748             yield from self.visit(child)
749
750     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
751         """A statement without nested statements."""
752         is_suite_like = node.parent and node.parent.type in STATEMENT
753         if is_suite_like:
754             yield from self.line(+1)
755             yield from self.visit_default(node)
756             yield from self.line(-1)
757
758         else:
759             yield from self.line()
760             yield from self.visit_default(node)
761
762     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
763         yield from self.line()
764
765         children = iter(node.children)
766         for child in children:
767             yield from self.visit(child)
768
769             if child.type == token.NAME and child.value == 'async':  # type: ignore
770                 break
771
772         internal_stmt = next(children)
773         for child in internal_stmt.children:
774             yield from self.visit(child)
775
776     def visit_decorators(self, node: Node) -> Iterator[Line]:
777         for child in node.children:
778             yield from self.line()
779             yield from self.visit(child)
780
781     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
782         yield from self.line()
783
784     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
785         yield from self.visit_default(leaf)
786         yield from self.line()
787
788     def __attrs_post_init__(self) -> None:
789         """You are in a twisty little maze of passages."""
790         v = self.visit_stmt
791         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
792         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
793         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
794         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
795         self.visit_except_clause = partial(v, keywords={'except'})
796         self.visit_funcdef = partial(v, keywords={'def'})
797         self.visit_with_stmt = partial(v, keywords={'with'})
798         self.visit_classdef = partial(v, keywords={'class'})
799         self.visit_async_funcdef = self.visit_async_stmt
800         self.visit_decorated = self.visit_decorators
801
802
803 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
804 OPENING_BRACKETS = set(BRACKET.keys())
805 CLOSING_BRACKETS = set(BRACKET.values())
806 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
807
808
809 def whitespace(leaf: Leaf) -> str:
810     """Return whitespace prefix if needed for the given `leaf`."""
811     NO = ''
812     SPACE = ' '
813     DOUBLESPACE = '  '
814     t = leaf.type
815     p = leaf.parent
816     v = leaf.value
817     if t == token.COLON:
818         return NO
819
820     if t == token.COMMA:
821         return NO
822
823     if t == token.RPAR:
824         return NO
825
826     if t == token.COMMENT:
827         return DOUBLESPACE
828
829     if t == STANDALONE_COMMENT:
830         return NO
831
832     if t in CLOSING_BRACKETS:
833         return NO
834
835     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
836     prev = leaf.prev_sibling
837     if not prev:
838         prevp = preceding_leaf(p)
839         if not prevp or prevp.type in OPENING_BRACKETS:
840             return NO
841
842         if prevp.type == token.EQUAL:
843             if prevp.parent and prevp.parent.type in {
844                 syms.typedargslist,
845                 syms.varargslist,
846                 syms.parameters,
847                 syms.arglist,
848                 syms.argument,
849             }:
850                 return NO
851
852         elif prevp.type == token.DOUBLESTAR:
853             if prevp.parent and prevp.parent.type in {
854                 syms.typedargslist,
855                 syms.varargslist,
856                 syms.parameters,
857                 syms.arglist,
858                 syms.dictsetmaker,
859             }:
860                 return NO
861
862         elif prevp.type == token.COLON:
863             if prevp.parent and prevp.parent.type == syms.subscript:
864                 return NO
865
866         elif prevp.parent and prevp.parent.type == syms.factor:
867             return NO
868
869     elif prev.type in OPENING_BRACKETS:
870         return NO
871
872     if p.type in {syms.parameters, syms.arglist}:
873         # untyped function signatures or calls
874         if t == token.RPAR:
875             return NO
876
877         if not prev or prev.type != token.COMMA:
878             return NO
879
880     if p.type == syms.varargslist:
881         # lambdas
882         if t == token.RPAR:
883             return NO
884
885         if prev and prev.type != token.COMMA:
886             return NO
887
888     elif p.type == syms.typedargslist:
889         # typed function signatures
890         if not prev:
891             return NO
892
893         if t == token.EQUAL:
894             if prev.type != syms.tname:
895                 return NO
896
897         elif prev.type == token.EQUAL:
898             # A bit hacky: if the equal sign has whitespace, it means we
899             # previously found it's a typed argument.  So, we're using that, too.
900             return prev.prefix
901
902         elif prev.type != token.COMMA:
903             return NO
904
905     elif p.type == syms.tname:
906         # type names
907         if not prev:
908             prevp = preceding_leaf(p)
909             if not prevp or prevp.type != token.COMMA:
910                 return NO
911
912     elif p.type == syms.trailer:
913         # attributes and calls
914         if t == token.LPAR or t == token.RPAR:
915             return NO
916
917         if not prev:
918             if t == token.DOT:
919                 prevp = preceding_leaf(p)
920                 if not prevp or prevp.type != token.NUMBER:
921                     return NO
922
923             elif t == token.LSQB:
924                 return NO
925
926         elif prev.type != token.COMMA:
927             return NO
928
929     elif p.type == syms.argument:
930         # single argument
931         if t == token.EQUAL:
932             return NO
933
934         if not prev:
935             prevp = preceding_leaf(p)
936             if not prevp or prevp.type == token.LPAR:
937                 return NO
938
939         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
940             return NO
941
942     elif p.type == syms.decorator:
943         # decorators
944         return NO
945
946     elif p.type == syms.dotted_name:
947         if prev:
948             return NO
949
950         prevp = preceding_leaf(p)
951         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
952             return NO
953
954     elif p.type == syms.classdef:
955         if t == token.LPAR:
956             return NO
957
958         if prev and prev.type == token.LPAR:
959             return NO
960
961     elif p.type == syms.subscript:
962         # indexing
963         if not prev or prev.type == token.COLON:
964             return NO
965
966     elif p.type == syms.atom:
967         if prev and t == token.DOT:
968             # dots, but not the first one.
969             return NO
970
971     elif (
972         p.type == syms.listmaker or
973         p.type == syms.testlist_gexp or
974         p.type == syms.subscriptlist
975     ):
976         # list interior, including unpacking
977         if not prev:
978             return NO
979
980     elif p.type == syms.dictsetmaker:
981         # dict and set interior, including unpacking
982         if not prev:
983             return NO
984
985         if prev.type == token.DOUBLESTAR:
986             return NO
987
988     elif p.type == syms.factor or p.type == syms.star_expr:
989         # unary ops
990         if not prev:
991             prevp = preceding_leaf(p)
992             if not prevp or prevp.type in OPENING_BRACKETS:
993                 return NO
994
995             prevp_parent = prevp.parent
996             assert prevp_parent is not None
997             if prevp.type == token.COLON and prevp_parent.type in {
998                 syms.subscript, syms.sliceop
999             }:
1000                 return NO
1001
1002             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1003                 return NO
1004
1005         elif t == token.NAME or t == token.NUMBER:
1006             return NO
1007
1008     elif p.type == syms.import_from:
1009         if t == token.DOT:
1010             if prev and prev.type == token.DOT:
1011                 return NO
1012
1013         elif t == token.NAME:
1014             if v == 'import':
1015                 return SPACE
1016
1017             if prev and prev.type == token.DOT:
1018                 return NO
1019
1020     elif p.type == syms.sliceop:
1021         return NO
1022
1023     return SPACE
1024
1025
1026 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1027     """Returns the first leaf that precedes `node`, if any."""
1028     while node:
1029         res = node.prev_sibling
1030         if res:
1031             if isinstance(res, Leaf):
1032                 return res
1033
1034             try:
1035                 return list(res.leaves())[-1]
1036
1037             except IndexError:
1038                 return None
1039
1040         node = node.parent
1041     return None
1042
1043
1044 def is_delimiter(leaf: Leaf) -> int:
1045     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1046
1047     Higher numbers are higher priority.
1048     """
1049     if leaf.type == token.COMMA:
1050         return COMMA_PRIORITY
1051
1052     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1053         return LOGIC_PRIORITY
1054
1055     if leaf.type in COMPARATORS:
1056         return COMPARATOR_PRIORITY
1057
1058     if (
1059         leaf.type in MATH_OPERATORS and
1060         leaf.parent and
1061         leaf.parent.type not in {syms.factor, syms.star_expr}
1062     ):
1063         return MATH_PRIORITY
1064
1065     return 0
1066
1067
1068 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1069     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1070
1071     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1072     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1073     move because it does away with modifying the grammar to include all the
1074     possible places in which comments can be placed.
1075
1076     The sad consequence for us though is that comments don't "belong" anywhere.
1077     This is why this function generates simple parentless Leaf objects for
1078     comments.  We simply don't know what the correct parent should be.
1079
1080     No matter though, we can live without this.  We really only need to
1081     differentiate between inline and standalone comments.  The latter don't
1082     share the line with any code.
1083
1084     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1085     are emitted with a fake STANDALONE_COMMENT token identifier.
1086     """
1087     if not leaf.prefix:
1088         return
1089
1090     if '#' not in leaf.prefix:
1091         return
1092
1093     before_comment, content = leaf.prefix.split('#', 1)
1094     content = content.rstrip()
1095     if content and (content[0] not in {' ', '!', '#'}):
1096         content = ' ' + content
1097     is_standalone_comment = (
1098         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1099     )
1100     if not is_standalone_comment:
1101         # simple trailing comment
1102         yield Leaf(token.COMMENT, value='#' + content)
1103         return
1104
1105     for line in ('#' + content).split('\n'):
1106         line = line.lstrip()
1107         if not line.startswith('#'):
1108             continue
1109
1110         yield Leaf(STANDALONE_COMMENT, line)
1111
1112
1113 def split_line(
1114     line: Line, line_length: int, inner: bool = False, py36: bool = False
1115 ) -> Iterator[Line]:
1116     """Splits a `line` into potentially many lines.
1117
1118     They should fit in the allotted `line_length` but might not be able to.
1119     `inner` signifies that there were a pair of brackets somewhere around the
1120     current `line`, possibly transitively. This means we can fallback to splitting
1121     by delimiters if the LHS/RHS don't yield any results.
1122
1123     If `py36` is True, splitting may generate syntax that is only compatible
1124     with Python 3.6 and later.
1125     """
1126     line_str = str(line).strip('\n')
1127     if len(line_str) <= line_length and '\n' not in line_str:
1128         yield line
1129         return
1130
1131     if line.is_def:
1132         split_funcs = [left_hand_split]
1133     elif line.inside_brackets:
1134         split_funcs = [delimiter_split]
1135         if '\n' not in line_str:
1136             # Only attempt RHS if we don't have multiline strings or comments
1137             # on this line.
1138             split_funcs.append(right_hand_split)
1139     else:
1140         split_funcs = [right_hand_split]
1141     for split_func in split_funcs:
1142         # We are accumulating lines in `result` because we might want to abort
1143         # mission and return the original line in the end, or attempt a different
1144         # split altogether.
1145         result: List[Line] = []
1146         try:
1147             for l in split_func(line, py36=py36):
1148                 if str(l).strip('\n') == line_str:
1149                     raise CannotSplit("Split function returned an unchanged result")
1150
1151                 result.extend(
1152                     split_line(l, line_length=line_length, inner=True, py36=py36)
1153                 )
1154         except CannotSplit as cs:
1155             continue
1156
1157         else:
1158             yield from result
1159             break
1160
1161     else:
1162         yield line
1163
1164
1165 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1166     """Split line into many lines, starting with the first matching bracket pair.
1167
1168     Note: this usually looks weird, only use this for function definitions.
1169     Prefer RHS otherwise.
1170     """
1171     head = Line(depth=line.depth)
1172     body = Line(depth=line.depth + 1, inside_brackets=True)
1173     tail = Line(depth=line.depth)
1174     tail_leaves: List[Leaf] = []
1175     body_leaves: List[Leaf] = []
1176     head_leaves: List[Leaf] = []
1177     current_leaves = head_leaves
1178     matching_bracket = None
1179     for leaf in line.leaves:
1180         if (
1181             current_leaves is body_leaves and
1182             leaf.type in CLOSING_BRACKETS and
1183             leaf.opening_bracket is matching_bracket  # type: ignore
1184         ):
1185             current_leaves = tail_leaves
1186         current_leaves.append(leaf)
1187         if current_leaves is head_leaves:
1188             if leaf.type in OPENING_BRACKETS:
1189                 matching_bracket = leaf
1190                 current_leaves = body_leaves
1191     # Since body is a new indent level, remove spurious leading whitespace.
1192     if body_leaves:
1193         normalize_prefix(body_leaves[0])
1194     # Build the new lines.
1195     for result, leaves in (
1196         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1197     ):
1198         for leaf in leaves:
1199             result.append(leaf, preformatted=True)
1200             comment_after = line.comments.get(id(leaf))
1201             if comment_after:
1202                 result.append(comment_after, preformatted=True)
1203     # Check if the split succeeded.
1204     tail_len = len(str(tail))
1205     if not body:
1206         if tail_len == 0:
1207             raise CannotSplit("Splitting brackets produced the same line")
1208
1209         elif tail_len < 3:
1210             raise CannotSplit(
1211                 f"Splitting brackets on an empty body to save "
1212                 f"{tail_len} characters is not worth it"
1213             )
1214
1215     for result in (head, body, tail):
1216         if result:
1217             yield result
1218
1219
1220 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1221     """Split line into many lines, starting with the last matching bracket pair."""
1222     head = Line(depth=line.depth)
1223     body = Line(depth=line.depth + 1, inside_brackets=True)
1224     tail = Line(depth=line.depth)
1225     tail_leaves: List[Leaf] = []
1226     body_leaves: List[Leaf] = []
1227     head_leaves: List[Leaf] = []
1228     current_leaves = tail_leaves
1229     opening_bracket = None
1230     for leaf in reversed(line.leaves):
1231         if current_leaves is body_leaves:
1232             if leaf is opening_bracket:
1233                 current_leaves = head_leaves
1234         current_leaves.append(leaf)
1235         if current_leaves is tail_leaves:
1236             if leaf.type in CLOSING_BRACKETS:
1237                 opening_bracket = leaf.opening_bracket  # type: ignore
1238                 current_leaves = body_leaves
1239     tail_leaves.reverse()
1240     body_leaves.reverse()
1241     head_leaves.reverse()
1242     # Since body is a new indent level, remove spurious leading whitespace.
1243     if body_leaves:
1244         normalize_prefix(body_leaves[0])
1245     # Build the new lines.
1246     for result, leaves in (
1247         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1248     ):
1249         for leaf in leaves:
1250             result.append(leaf, preformatted=True)
1251             comment_after = line.comments.get(id(leaf))
1252             if comment_after:
1253                 result.append(comment_after, preformatted=True)
1254     # Check if the split succeeded.
1255     tail_len = len(str(tail).strip('\n'))
1256     if not body:
1257         if tail_len == 0:
1258             raise CannotSplit("Splitting brackets produced the same line")
1259
1260         elif tail_len < 3:
1261             raise CannotSplit(
1262                 f"Splitting brackets on an empty body to save "
1263                 f"{tail_len} characters is not worth it"
1264             )
1265
1266     for result in (head, body, tail):
1267         if result:
1268             yield result
1269
1270
1271 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1272     """Split according to delimiters of the highest priority.
1273
1274     This kind of split doesn't increase indentation.
1275     If `py36` is True, the split will add trailing commas also in function
1276     signatures that contain * and **.
1277     """
1278     try:
1279         last_leaf = line.leaves[-1]
1280     except IndexError:
1281         raise CannotSplit("Line empty")
1282
1283     delimiters = line.bracket_tracker.delimiters
1284     try:
1285         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1286     except ValueError:
1287         raise CannotSplit("No delimiters found")
1288
1289     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1290     lowest_depth = sys.maxsize
1291     trailing_comma_safe = True
1292     for leaf in line.leaves:
1293         current_line.append(leaf, preformatted=True)
1294         comment_after = line.comments.get(id(leaf))
1295         if comment_after:
1296             current_line.append(comment_after, preformatted=True)
1297         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1298         if (
1299             leaf.bracket_depth == lowest_depth and  # type: ignore
1300             leaf.type == token.STAR or
1301             leaf.type == token.DOUBLESTAR
1302         ):
1303             trailing_comma_safe = trailing_comma_safe and py36
1304         leaf_priority = delimiters.get(id(leaf))
1305         if leaf_priority == delimiter_priority:
1306             normalize_prefix(current_line.leaves[0])
1307             yield current_line
1308
1309             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1310     if current_line:
1311         if (
1312             delimiter_priority == COMMA_PRIORITY and
1313             current_line.leaves[-1].type != token.COMMA and
1314             trailing_comma_safe
1315         ):
1316             current_line.append(Leaf(token.COMMA, ','))
1317         normalize_prefix(current_line.leaves[0])
1318         yield current_line
1319
1320
1321 def is_import(leaf: Leaf) -> bool:
1322     """Returns True if the given leaf starts an import statement."""
1323     p = leaf.parent
1324     t = leaf.type
1325     v = leaf.value
1326     return bool(
1327         t == token.NAME and
1328         (
1329             (v == 'import' and p and p.type == syms.import_name) or
1330             (v == 'from' and p and p.type == syms.import_from)
1331         )
1332     )
1333
1334
1335 def normalize_prefix(leaf: Leaf) -> None:
1336     """Leave existing extra newlines for imports.  Remove everything else."""
1337     if is_import(leaf):
1338         spl = leaf.prefix.split('#', 1)
1339         nl_count = spl[0].count('\n')
1340         if len(spl) > 1:
1341             # Skip one newline since it was for a standalone comment.
1342             nl_count -= 1
1343         leaf.prefix = '\n' * nl_count
1344         return
1345
1346     leaf.prefix = ''
1347
1348
1349 def is_python36(node: Node) -> bool:
1350     """Returns True if the current file is using Python 3.6+ features.
1351
1352     Currently looking for:
1353     - f-strings; and
1354     - trailing commas after * or ** in function signatures.
1355     """
1356     for n in node.pre_order():
1357         if n.type == token.STRING:
1358             assert isinstance(n, Leaf)
1359             if n.value[:2] in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1360                 return True
1361
1362         elif (
1363             n.type == syms.typedargslist and
1364             n.children and
1365             n.children[-1].type == token.COMMA
1366         ):
1367             for ch in n.children:
1368                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1369                     return True
1370
1371     return False
1372
1373
1374 PYTHON_EXTENSIONS = {'.py'}
1375 BLACKLISTED_DIRECTORIES = {
1376     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1377 }
1378
1379
1380 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1381     for child in path.iterdir():
1382         if child.is_dir():
1383             if child.name in BLACKLISTED_DIRECTORIES:
1384                 continue
1385
1386             yield from gen_python_files_in_dir(child)
1387
1388         elif child.suffix in PYTHON_EXTENSIONS:
1389             yield child
1390
1391
1392 @dataclass
1393 class Report:
1394     """Provides a reformatting counter."""
1395     change_count: int = attrib(default=0)
1396     same_count: int = attrib(default=0)
1397     failure_count: int = attrib(default=0)
1398
1399     def done(self, src: Path, changed: bool) -> None:
1400         """Increment the counter for successful reformatting. Write out a message."""
1401         if changed:
1402             out(f'reformatted {src}')
1403             self.change_count += 1
1404         else:
1405             out(f'{src} already well formatted, good job.', bold=False)
1406             self.same_count += 1
1407
1408     def failed(self, src: Path, message: str) -> None:
1409         """Increment the counter for failed reformatting. Write out a message."""
1410         err(f'error: cannot format {src}: {message}')
1411         self.failure_count += 1
1412
1413     @property
1414     def return_code(self) -> int:
1415         """Which return code should the app use considering the current state."""
1416         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1417         # 126 we have special returncodes reserved by the shell.
1418         if self.failure_count:
1419             return 123
1420
1421         elif self.change_count:
1422             return 1
1423
1424         return 0
1425
1426     def __str__(self) -> str:
1427         """A color report of the current state.
1428
1429         Use `click.unstyle` to remove colors.
1430         """
1431         report = []
1432         if self.change_count:
1433             s = 's' if self.change_count > 1 else ''
1434             report.append(
1435                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1436             )
1437         if self.same_count:
1438             s = 's' if self.same_count > 1 else ''
1439             report.append(f'{self.same_count} file{s} left unchanged')
1440         if self.failure_count:
1441             s = 's' if self.failure_count > 1 else ''
1442             report.append(
1443                 click.style(
1444                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1445                 )
1446             )
1447         return ', '.join(report) + '.'
1448
1449
1450 def assert_equivalent(src: str, dst: str) -> None:
1451     """Raises AssertionError if `src` and `dst` aren't equivalent.
1452
1453     This is a temporary sanity check until Black becomes stable.
1454     """
1455
1456     import ast
1457     import traceback
1458
1459     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1460         """Simple visitor generating strings to compare ASTs by content."""
1461         yield f"{'  ' * depth}{node.__class__.__name__}("
1462
1463         for field in sorted(node._fields):
1464             try:
1465                 value = getattr(node, field)
1466             except AttributeError:
1467                 continue
1468
1469             yield f"{'  ' * (depth+1)}{field}="
1470
1471             if isinstance(value, list):
1472                 for item in value:
1473                     if isinstance(item, ast.AST):
1474                         yield from _v(item, depth + 2)
1475
1476             elif isinstance(value, ast.AST):
1477                 yield from _v(value, depth + 2)
1478
1479             else:
1480                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1481
1482         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1483
1484     try:
1485         src_ast = ast.parse(src)
1486     except Exception as exc:
1487         raise AssertionError(f"cannot parse source: {exc}") from None
1488
1489     try:
1490         dst_ast = ast.parse(dst)
1491     except Exception as exc:
1492         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1493         raise AssertionError(
1494             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1495             f"Please report a bug on https://github.com/ambv/black/issues.  "
1496             f"This invalid output might be helpful: {log}",
1497         ) from None
1498
1499     src_ast_str = '\n'.join(_v(src_ast))
1500     dst_ast_str = '\n'.join(_v(dst_ast))
1501     if src_ast_str != dst_ast_str:
1502         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1503         raise AssertionError(
1504             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1505             f"the source.  "
1506             f"Please report a bug on https://github.com/ambv/black/issues.  "
1507             f"This diff might be helpful: {log}",
1508         ) from None
1509
1510
1511 def assert_stable(src: str, dst: str, line_length: int) -> None:
1512     """Raises AssertionError if `dst` reformats differently the second time.
1513
1514     This is a temporary sanity check until Black becomes stable.
1515     """
1516     newdst = format_str(dst, line_length=line_length)
1517     if dst != newdst:
1518         log = dump_to_file(
1519             diff(src, dst, 'source', 'first pass'),
1520             diff(dst, newdst, 'first pass', 'second pass'),
1521         )
1522         raise AssertionError(
1523             f"INTERNAL ERROR: Black produced different code on the second pass "
1524             f"of the formatter.  "
1525             f"Please report a bug on https://github.com/ambv/black/issues.  "
1526             f"This diff might be helpful: {log}",
1527         ) from None
1528
1529
1530 def dump_to_file(*output: str) -> str:
1531     """Dumps `output` to a temporary file. Returns path to the file."""
1532     import tempfile
1533
1534     with tempfile.NamedTemporaryFile(
1535         mode='w', prefix='blk_', suffix='.log', delete=False
1536     ) as f:
1537         for lines in output:
1538             f.write(lines)
1539             f.write('\n')
1540     return f.name
1541
1542
1543 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1544     """Returns a udiff string between strings `a` and `b`."""
1545     import difflib
1546
1547     a_lines = [line + '\n' for line in a.split('\n')]
1548     b_lines = [line + '\n' for line in b.split('\n')]
1549     return ''.join(
1550         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1551     )
1552
1553
1554 if __name__ == '__main__':
1555     main()