]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Describe fix for #21 in README
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
78 )
79 @click.pass_context
80 def main(
81     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
82 ) -> None:
83     """The uncompromising code formatter."""
84     sources: List[Path] = []
85     for s in src:
86         p = Path(s)
87         if p.is_dir():
88             sources.extend(gen_python_files_in_dir(p))
89         elif p.is_file():
90             # if a file was explicitly given, we don't care about its extension
91             sources.append(p)
92         else:
93             err(f'invalid path: {s}')
94     if len(sources) == 0:
95         ctx.exit(0)
96     elif len(sources) == 1:
97         p = sources[0]
98         report = Report()
99         try:
100             changed = format_file_in_place(
101                 p, line_length=line_length, fast=fast, write_back=not check
102             )
103             report.done(p, changed)
104         except Exception as exc:
105             report.failed(p, str(exc))
106         ctx.exit(report.return_code)
107     else:
108         loop = asyncio.get_event_loop()
109         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
110         return_code = 1
111         try:
112             return_code = loop.run_until_complete(
113                 schedule_formatting(
114                     sources, line_length, not check, fast, loop, executor
115                 )
116             )
117         finally:
118             loop.close()
119             ctx.exit(return_code)
120
121
122 async def schedule_formatting(
123     sources: List[Path],
124     line_length: int,
125     write_back: bool,
126     fast: bool,
127     loop: BaseEventLoop,
128     executor: Executor,
129 ) -> int:
130     tasks = {
131         src: loop.run_in_executor(
132             executor, format_file_in_place, src, line_length, fast, write_back
133         )
134         for src in sources
135     }
136     await asyncio.wait(tasks.values())
137     cancelled = []
138     report = Report()
139     for src, task in tasks.items():
140         if not task.done():
141             report.failed(src, 'timed out, cancelling')
142             task.cancel()
143             cancelled.append(task)
144         elif task.exception():
145             report.failed(src, str(task.exception()))
146         else:
147             report.done(src, task.result())
148     if cancelled:
149         await asyncio.wait(cancelled, timeout=2)
150     out('All done! ✨ 🍰 ✨')
151     click.echo(str(report))
152     return report.return_code
153
154
155 def format_file_in_place(
156     src: Path, line_length: int, fast: bool, write_back: bool = False
157 ) -> bool:
158     """Format the file and rewrite if changed. Return True if changed."""
159     try:
160         contents, encoding = format_file(src, line_length=line_length, fast=fast)
161     except NothingChanged:
162         return False
163
164     if write_back:
165         with open(src, "w", encoding=encoding) as f:
166             f.write(contents)
167     return True
168
169
170 def format_file(
171     src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173     """Reformats a file and returns its contents and encoding."""
174     with tokenize.open(src) as src_buffer:
175         src_contents = src_buffer.read()
176     if src_contents.strip() == '':
177         raise NothingChanged(src)
178
179     dst_contents = format_str(src_contents, line_length=line_length)
180     if src_contents == dst_contents:
181         raise NothingChanged(src)
182
183     if not fast:
184         assert_equivalent(src_contents, dst_contents)
185         assert_stable(src_contents, dst_contents, line_length=line_length)
186     return dst_contents, src_buffer.encoding
187
188
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190     """Reformats a string and returns new contents."""
191     src_node = lib2to3_parse(src_contents)
192     dst_contents = ""
193     comments: List[Line] = []
194     lines = LineGenerator()
195     elt = EmptyLineTracker()
196     py36 = is_python36(src_node)
197     empty_line = Line()
198     after = 0
199     for current_line in lines.visit(src_node):
200         for _ in range(after):
201             dst_contents += str(empty_line)
202         before, after = elt.maybe_empty_lines(current_line)
203         for _ in range(before):
204             dst_contents += str(empty_line)
205         if not current_line.is_comment:
206             for comment in comments:
207                 dst_contents += str(comment)
208             comments = []
209             for line in split_line(current_line, line_length=line_length, py36=py36):
210                 dst_contents += str(line)
211         else:
212             comments.append(current_line)
213     for comment in comments:
214         dst_contents += str(comment)
215     return dst_contents
216
217
218 def lib2to3_parse(src_txt: str) -> Node:
219     """Given a string with source, return the lib2to3 Node."""
220     grammar = pygram.python_grammar_no_print_statement
221     drv = driver.Driver(grammar, pytree.convert)
222     if src_txt[-1] != '\n':
223         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
224         src_txt += nl
225     try:
226         result = drv.parse_string(src_txt, True)
227     except ParseError as pe:
228         lineno, column = pe.context[1]
229         lines = src_txt.splitlines()
230         try:
231             faulty_line = lines[lineno - 1]
232         except IndexError:
233             faulty_line = "<line number missing in source>"
234         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
235
236     if isinstance(result, Leaf):
237         result = Node(syms.file_input, [result])
238     return result
239
240
241 def lib2to3_unparse(node: Node) -> str:
242     """Given a lib2to3 node, return its string representation."""
243     code = str(node)
244     return code
245
246
247 T = TypeVar('T')
248
249
250 class Visitor(Generic[T]):
251     """Basic lib2to3 visitor that yields things on visiting."""
252
253     def visit(self, node: LN) -> Iterator[T]:
254         if node.type < 256:
255             name = token.tok_name[node.type]
256         else:
257             name = type_repr(node.type)
258         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
259
260     def visit_default(self, node: LN) -> Iterator[T]:
261         if isinstance(node, Node):
262             for child in node.children:
263                 yield from self.visit(child)
264
265
266 @dataclass
267 class DebugVisitor(Visitor[T]):
268     tree_depth: int = 0
269
270     def visit_default(self, node: LN) -> Iterator[T]:
271         indent = ' ' * (2 * self.tree_depth)
272         if isinstance(node, Node):
273             _type = type_repr(node.type)
274             out(f'{indent}{_type}', fg='yellow')
275             self.tree_depth += 1
276             for child in node.children:
277                 yield from self.visit(child)
278
279             self.tree_depth -= 1
280             out(f'{indent}/{_type}', fg='yellow', bold=False)
281         else:
282             _type = token.tok_name.get(node.type, str(node.type))
283             out(f'{indent}{_type}', fg='blue', nl=False)
284             if node.prefix:
285                 # We don't have to handle prefixes for `Node` objects since
286                 # that delegates to the first child anyway.
287                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
288             out(f' {node.value!r}', fg='blue', bold=False)
289
290
291 KEYWORDS = set(keyword.kwlist)
292 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
293 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
294 STATEMENT = {
295     syms.if_stmt,
296     syms.while_stmt,
297     syms.for_stmt,
298     syms.try_stmt,
299     syms.except_clause,
300     syms.with_stmt,
301     syms.funcdef,
302     syms.classdef,
303 }
304 STANDALONE_COMMENT = 153
305 LOGIC_OPERATORS = {'and', 'or'}
306 COMPARATORS = {
307     token.LESS,
308     token.GREATER,
309     token.EQEQUAL,
310     token.NOTEQUAL,
311     token.LESSEQUAL,
312     token.GREATEREQUAL,
313 }
314 MATH_OPERATORS = {
315     token.PLUS,
316     token.MINUS,
317     token.STAR,
318     token.SLASH,
319     token.VBAR,
320     token.AMPER,
321     token.PERCENT,
322     token.CIRCUMFLEX,
323     token.LEFTSHIFT,
324     token.RIGHTSHIFT,
325     token.DOUBLESTAR,
326     token.DOUBLESLASH,
327 }
328 COMPREHENSION_PRIORITY = 20
329 COMMA_PRIORITY = 10
330 LOGIC_PRIORITY = 5
331 STRING_PRIORITY = 4
332 COMPARATOR_PRIORITY = 3
333 MATH_PRIORITY = 1
334
335
336 @dataclass
337 class BracketTracker:
338     depth: int = 0
339     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
340     delimiters: Dict[LeafID, Priority] = Factory(dict)
341     previous: Optional[Leaf] = None
342
343     def mark(self, leaf: Leaf) -> None:
344         if leaf.type == token.COMMENT:
345             return
346
347         if leaf.type in CLOSING_BRACKETS:
348             self.depth -= 1
349             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
350             leaf.opening_bracket = opening_bracket
351         leaf.bracket_depth = self.depth
352         if self.depth == 0:
353             delim = is_delimiter(leaf)
354             if delim:
355                 self.delimiters[id(leaf)] = delim
356             elif self.previous is not None:
357                 if leaf.type == token.STRING and self.previous.type == token.STRING:
358                     self.delimiters[id(self.previous)] = STRING_PRIORITY
359                 elif (
360                     leaf.type == token.NAME
361                     and leaf.value == 'for'
362                     and leaf.parent
363                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
364                 ):
365                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
366                 elif (
367                     leaf.type == token.NAME
368                     and leaf.value == 'if'
369                     and leaf.parent
370                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
371                 ):
372                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
373                 elif (
374                     leaf.type == token.NAME
375                     and leaf.value in LOGIC_OPERATORS
376                     and leaf.parent
377                 ):
378                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
379         if leaf.type in OPENING_BRACKETS:
380             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
381             self.depth += 1
382         self.previous = leaf
383
384     def any_open_brackets(self) -> bool:
385         """Returns True if there is an yet unmatched open bracket on the line."""
386         return bool(self.bracket_match)
387
388     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
389         """Returns the highest priority of a delimiter found on the line.
390
391         Values are consistent with what `is_delimiter()` returns.
392         """
393         return max(v for k, v in self.delimiters.items() if k not in exclude)
394
395
396 @dataclass
397 class Line:
398     depth: int = 0
399     leaves: List[Leaf] = Factory(list)
400     comments: Dict[LeafID, Leaf] = Factory(dict)
401     bracket_tracker: BracketTracker = Factory(BracketTracker)
402     inside_brackets: bool = False
403     has_for: bool = False
404     _for_loop_variable: bool = False
405
406     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
407         has_value = leaf.value.strip()
408         if not has_value:
409             return
410
411         if self.leaves and not preformatted:
412             # Note: at this point leaf.prefix should be empty except for
413             # imports, for which we only preserve newlines.
414             leaf.prefix += whitespace(leaf)
415         if self.inside_brackets or not preformatted:
416             self.maybe_decrement_after_for_loop_variable(leaf)
417             self.bracket_tracker.mark(leaf)
418             self.maybe_remove_trailing_comma(leaf)
419             self.maybe_increment_for_loop_variable(leaf)
420             if self.maybe_adapt_standalone_comment(leaf):
421                 return
422
423         if not self.append_comment(leaf):
424             self.leaves.append(leaf)
425
426     @property
427     def is_comment(self) -> bool:
428         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
429
430     @property
431     def is_decorator(self) -> bool:
432         return bool(self) and self.leaves[0].type == token.AT
433
434     @property
435     def is_import(self) -> bool:
436         return bool(self) and is_import(self.leaves[0])
437
438     @property
439     def is_class(self) -> bool:
440         return (
441             bool(self)
442             and self.leaves[0].type == token.NAME
443             and self.leaves[0].value == 'class'
444         )
445
446     @property
447     def is_def(self) -> bool:
448         """Also returns True for async defs."""
449         try:
450             first_leaf = self.leaves[0]
451         except IndexError:
452             return False
453
454         try:
455             second_leaf: Optional[Leaf] = self.leaves[1]
456         except IndexError:
457             second_leaf = None
458         return (
459             (first_leaf.type == token.NAME and first_leaf.value == 'def')
460             or (
461                 first_leaf.type == token.NAME
462                 and first_leaf.value == 'async'
463                 and second_leaf is not None
464                 and second_leaf.type == token.NAME
465                 and second_leaf.value == 'def'
466             )
467         )
468
469     @property
470     def is_flow_control(self) -> bool:
471         return (
472             bool(self)
473             and self.leaves[0].type == token.NAME
474             and self.leaves[0].value in FLOW_CONTROL
475         )
476
477     @property
478     def is_yield(self) -> bool:
479         return (
480             bool(self)
481             and self.leaves[0].type == token.NAME
482             and self.leaves[0].value == 'yield'
483         )
484
485     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
486         if not (
487             self.leaves
488             and self.leaves[-1].type == token.COMMA
489             and closing.type in CLOSING_BRACKETS
490         ):
491             return False
492
493         if closing.type == token.RSQB or closing.type == token.RBRACE:
494             self.leaves.pop()
495             return True
496
497         # For parens let's check if it's safe to remove the comma.  If the
498         # trailing one is the only one, we might mistakenly change a tuple
499         # into a different type by removing the comma.
500         depth = closing.bracket_depth + 1
501         commas = 0
502         opening = closing.opening_bracket
503         for _opening_index, leaf in enumerate(self.leaves):
504             if leaf is opening:
505                 break
506
507         else:
508             return False
509
510         for leaf in self.leaves[_opening_index + 1:]:
511             if leaf is closing:
512                 break
513
514             bracket_depth = leaf.bracket_depth
515             if bracket_depth == depth and leaf.type == token.COMMA:
516                 commas += 1
517                 if leaf.parent and leaf.parent.type == syms.arglist:
518                     commas += 1
519                     break
520
521         if commas > 1:
522             self.leaves.pop()
523             return True
524
525         return False
526
527     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
528         """In a for loop, or comprehension, the variables are often unpacks.
529
530         To avoid splitting on the comma in this situation, we will increase
531         the depth of tokens between `for` and `in`.
532         """
533         if leaf.type == token.NAME and leaf.value == 'for':
534             self.has_for = True
535             self.bracket_tracker.depth += 1
536             self._for_loop_variable = True
537             return True
538
539         return False
540
541     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
542         # See `maybe_increment_for_loop_variable` above for explanation.
543         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
544             self.bracket_tracker.depth -= 1
545             self._for_loop_variable = False
546             return True
547
548         return False
549
550     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
551         """Hack a standalone comment to act as a trailing comment for line splitting.
552
553         If this line has brackets and a standalone `comment`, we need to adapt
554         it to be able to still reformat the line.
555
556         This is not perfect, the line to which the standalone comment gets
557         appended will appear "too long" when splitting.
558         """
559         if not (
560             comment.type == STANDALONE_COMMENT
561             and self.bracket_tracker.any_open_brackets()
562         ):
563             return False
564
565         comment.type = token.COMMENT
566         comment.prefix = '\n' + '    ' * (self.depth + 1)
567         return self.append_comment(comment)
568
569     def append_comment(self, comment: Leaf) -> bool:
570         if comment.type != token.COMMENT:
571             return False
572
573         try:
574             after = id(self.last_non_delimiter())
575         except LookupError:
576             comment.type = STANDALONE_COMMENT
577             comment.prefix = ''
578             return False
579
580         else:
581             if after in self.comments:
582                 self.comments[after].value += str(comment)
583             else:
584                 self.comments[after] = comment
585             return True
586
587     def last_non_delimiter(self) -> Leaf:
588         for i in range(len(self.leaves)):
589             last = self.leaves[-i - 1]
590             if not is_delimiter(last):
591                 return last
592
593         raise LookupError("No non-delimiters found")
594
595     def __str__(self) -> str:
596         if not self:
597             return '\n'
598
599         indent = '    ' * self.depth
600         leaves = iter(self.leaves)
601         first = next(leaves)
602         res = f'{first.prefix}{indent}{first.value}'
603         for leaf in leaves:
604             res += str(leaf)
605         for comment in self.comments.values():
606             res += str(comment)
607         return res + '\n'
608
609     def __bool__(self) -> bool:
610         return bool(self.leaves or self.comments)
611
612
613 @dataclass
614 class EmptyLineTracker:
615     """Provides a stateful method that returns the number of potential extra
616     empty lines needed before and after the currently processed line.
617
618     Note: this tracker works on lines that haven't been split yet.
619     """
620     previous_line: Optional[Line] = None
621     previous_after: int = 0
622     previous_defs: List[int] = Factory(list)
623
624     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
625         """Returns the number of extra empty lines before and after the `current_line`.
626
627         This is for separating `def`, `async def` and `class` with extra empty lines
628         (two on module-level), as well as providing an extra empty line after flow
629         control keywords to make them more prominent.
630         """
631         before, after = self._maybe_empty_lines(current_line)
632         self.previous_after = after
633         self.previous_line = current_line
634         return before, after
635
636     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
637         before = 0
638         depth = current_line.depth
639         while self.previous_defs and self.previous_defs[-1] >= depth:
640             self.previous_defs.pop()
641             before = (1 if depth else 2) - self.previous_after
642         is_decorator = current_line.is_decorator
643         if is_decorator or current_line.is_def or current_line.is_class:
644             if not is_decorator:
645                 self.previous_defs.append(depth)
646             if self.previous_line is None:
647                 # Don't insert empty lines before the first line in the file.
648                 return 0, 0
649
650             if self.previous_line and self.previous_line.is_decorator:
651                 # Don't insert empty lines between decorators.
652                 return 0, 0
653
654             newlines = 2
655             if current_line.depth:
656                 newlines -= 1
657             newlines -= self.previous_after
658             return newlines, 0
659
660         if current_line.is_flow_control:
661             return before, 1
662
663         if (
664             self.previous_line
665             and self.previous_line.is_import
666             and not current_line.is_import
667             and depth == self.previous_line.depth
668         ):
669             return (before or 1), 0
670
671         if (
672             self.previous_line
673             and self.previous_line.is_yield
674             and (not current_line.is_yield or depth != self.previous_line.depth)
675         ):
676             return (before or 1), 0
677
678         return before, 0
679
680
681 @dataclass
682 class LineGenerator(Visitor[Line]):
683     """Generates reformatted Line objects.  Empty lines are not emitted.
684
685     Note: destroys the tree it's visiting by mutating prefixes of its leaves
686     in ways that will no longer stringify to valid Python code on the tree.
687     """
688     current_line: Line = Factory(Line)
689     standalone_comments: List[Leaf] = Factory(list)
690
691     def line(self, indent: int = 0) -> Iterator[Line]:
692         """Generate a line.
693
694         If the line is empty, only emit if it makes sense.
695         If the line is too long, split it first and then generate.
696
697         If any lines were generated, set up a new current_line.
698         """
699         if not self.current_line:
700             self.current_line.depth += indent
701             return  # Line is empty, don't emit. Creating a new one unnecessary.
702
703         complete_line = self.current_line
704         self.current_line = Line(depth=complete_line.depth + indent)
705         yield complete_line
706
707     def visit_default(self, node: LN) -> Iterator[Line]:
708         if isinstance(node, Leaf):
709             for comment in generate_comments(node):
710                 if self.current_line.bracket_tracker.any_open_brackets():
711                     # any comment within brackets is subject to splitting
712                     self.current_line.append(comment)
713                 elif comment.type == token.COMMENT:
714                     # regular trailing comment
715                     self.current_line.append(comment)
716                     yield from self.line()
717
718                 else:
719                     # regular standalone comment, to be processed later (see
720                     # docstring in `generate_comments()`
721                     self.standalone_comments.append(comment)
722             normalize_prefix(node)
723             if node.type not in WHITESPACE:
724                 for comment in self.standalone_comments:
725                     yield from self.line()
726
727                     self.current_line.append(comment)
728                     yield from self.line()
729
730                 self.standalone_comments = []
731                 self.current_line.append(node)
732         yield from super().visit_default(node)
733
734     def visit_suite(self, node: Node) -> Iterator[Line]:
735         """Body of a statement after a colon."""
736         children = iter(node.children)
737         # Process newline before indenting.  It might contain an inline
738         # comment that should go right after the colon.
739         newline = next(children)
740         yield from self.visit(newline)
741         yield from self.line(+1)
742
743         for child in children:
744             yield from self.visit(child)
745
746         yield from self.line(-1)
747
748     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
749         """Visit a statement.
750
751         The relevant Python language keywords for this statement are NAME leaves
752         within it.
753         """
754         for child in node.children:
755             if child.type == token.NAME and child.value in keywords:  # type: ignore
756                 yield from self.line()
757
758             yield from self.visit(child)
759
760     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
761         """A statement without nested statements."""
762         is_suite_like = node.parent and node.parent.type in STATEMENT
763         if is_suite_like:
764             yield from self.line(+1)
765             yield from self.visit_default(node)
766             yield from self.line(-1)
767
768         else:
769             yield from self.line()
770             yield from self.visit_default(node)
771
772     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
773         yield from self.line()
774
775         children = iter(node.children)
776         for child in children:
777             yield from self.visit(child)
778
779             if child.type == token.NAME and child.value == 'async':  # type: ignore
780                 break
781
782         internal_stmt = next(children)
783         for child in internal_stmt.children:
784             yield from self.visit(child)
785
786     def visit_decorators(self, node: Node) -> Iterator[Line]:
787         for child in node.children:
788             yield from self.line()
789             yield from self.visit(child)
790
791     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
792         yield from self.line()
793
794     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
795         yield from self.visit_default(leaf)
796         yield from self.line()
797
798     def __attrs_post_init__(self) -> None:
799         """You are in a twisty little maze of passages."""
800         v = self.visit_stmt
801         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
802         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
803         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
804         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
805         self.visit_except_clause = partial(v, keywords={'except'})
806         self.visit_funcdef = partial(v, keywords={'def'})
807         self.visit_with_stmt = partial(v, keywords={'with'})
808         self.visit_classdef = partial(v, keywords={'class'})
809         self.visit_async_funcdef = self.visit_async_stmt
810         self.visit_decorated = self.visit_decorators
811
812
813 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
814 OPENING_BRACKETS = set(BRACKET.keys())
815 CLOSING_BRACKETS = set(BRACKET.values())
816 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
817 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
818
819
820 def whitespace(leaf: Leaf) -> str:  # noqa C901
821     """Return whitespace prefix if needed for the given `leaf`."""
822     NO = ''
823     SPACE = ' '
824     DOUBLESPACE = '  '
825     t = leaf.type
826     p = leaf.parent
827     v = leaf.value
828     if t in ALWAYS_NO_SPACE:
829         return NO
830
831     if t == token.COMMENT:
832         return DOUBLESPACE
833
834     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
835     prev = leaf.prev_sibling
836     if not prev:
837         prevp = preceding_leaf(p)
838         if not prevp or prevp.type in OPENING_BRACKETS:
839             return NO
840
841         if prevp.type == token.EQUAL:
842             if prevp.parent and prevp.parent.type in {
843                 syms.typedargslist,
844                 syms.varargslist,
845                 syms.parameters,
846                 syms.arglist,
847                 syms.argument,
848             }:
849                 return NO
850
851         elif prevp.type == token.DOUBLESTAR:
852             if prevp.parent and prevp.parent.type in {
853                 syms.typedargslist,
854                 syms.varargslist,
855                 syms.parameters,
856                 syms.arglist,
857                 syms.dictsetmaker,
858             }:
859                 return NO
860
861         elif prevp.type == token.COLON:
862             if prevp.parent and prevp.parent.type == syms.subscript:
863                 return NO
864
865         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
866             return NO
867
868     elif prev.type in OPENING_BRACKETS:
869         return NO
870
871     if p.type in {syms.parameters, syms.arglist}:
872         # untyped function signatures or calls
873         if t == token.RPAR:
874             return NO
875
876         if not prev or prev.type != token.COMMA:
877             return NO
878
879     if p.type == syms.varargslist:
880         # lambdas
881         if t == token.RPAR:
882             return NO
883
884         if prev and prev.type != token.COMMA:
885             return NO
886
887     elif p.type == syms.typedargslist:
888         # typed function signatures
889         if not prev:
890             return NO
891
892         if t == token.EQUAL:
893             if prev.type != syms.tname:
894                 return NO
895
896         elif prev.type == token.EQUAL:
897             # A bit hacky: if the equal sign has whitespace, it means we
898             # previously found it's a typed argument.  So, we're using that, too.
899             return prev.prefix
900
901         elif prev.type != token.COMMA:
902             return NO
903
904     elif p.type == syms.tname:
905         # type names
906         if not prev:
907             prevp = preceding_leaf(p)
908             if not prevp or prevp.type != token.COMMA:
909                 return NO
910
911     elif p.type == syms.trailer:
912         # attributes and calls
913         if t == token.LPAR or t == token.RPAR:
914             return NO
915
916         if not prev:
917             if t == token.DOT:
918                 prevp = preceding_leaf(p)
919                 if not prevp or prevp.type != token.NUMBER:
920                     return NO
921
922             elif t == token.LSQB:
923                 return NO
924
925         elif prev.type != token.COMMA:
926             return NO
927
928     elif p.type == syms.argument:
929         # single argument
930         if t == token.EQUAL:
931             return NO
932
933         if not prev:
934             prevp = preceding_leaf(p)
935             if not prevp or prevp.type == token.LPAR:
936                 return NO
937
938         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
939             return NO
940
941     elif p.type == syms.decorator:
942         # decorators
943         return NO
944
945     elif p.type == syms.dotted_name:
946         if prev:
947             return NO
948
949         prevp = preceding_leaf(p)
950         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
951             return NO
952
953     elif p.type == syms.classdef:
954         if t == token.LPAR:
955             return NO
956
957         if prev and prev.type == token.LPAR:
958             return NO
959
960     elif p.type == syms.subscript:
961         # indexing
962         if not prev:
963             assert p.parent is not None, "subscripts are always parented"
964             if p.parent.type == syms.subscriptlist:
965                 return SPACE
966
967             return NO
968
969         elif prev.type == token.COLON:
970             return NO
971
972     elif p.type == syms.atom:
973         if prev and t == token.DOT:
974             # dots, but not the first one.
975             return NO
976
977     elif (
978         p.type == syms.listmaker
979         or p.type == syms.testlist_gexp
980         or p.type == syms.subscriptlist
981     ):
982         # list interior, including unpacking
983         if not prev:
984             return NO
985
986     elif p.type == syms.dictsetmaker:
987         # dict and set interior, including unpacking
988         if not prev:
989             return NO
990
991         if prev.type == token.DOUBLESTAR:
992             return NO
993
994     elif p.type in {syms.factor, syms.star_expr}:
995         # unary ops
996         if not prev:
997             prevp = preceding_leaf(p)
998             if not prevp or prevp.type in OPENING_BRACKETS:
999                 return NO
1000
1001             prevp_parent = prevp.parent
1002             assert prevp_parent is not None
1003             if prevp.type == token.COLON and prevp_parent.type in {
1004                 syms.subscript, syms.sliceop
1005             }:
1006                 return NO
1007
1008             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1009                 return NO
1010
1011         elif t == token.NAME or t == token.NUMBER:
1012             return NO
1013
1014     elif p.type == syms.import_from:
1015         if t == token.DOT:
1016             if prev and prev.type == token.DOT:
1017                 return NO
1018
1019         elif t == token.NAME:
1020             if v == 'import':
1021                 return SPACE
1022
1023             if prev and prev.type == token.DOT:
1024                 return NO
1025
1026     elif p.type == syms.sliceop:
1027         return NO
1028
1029     return SPACE
1030
1031
1032 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1033     """Returns the first leaf that precedes `node`, if any."""
1034     while node:
1035         res = node.prev_sibling
1036         if res:
1037             if isinstance(res, Leaf):
1038                 return res
1039
1040             try:
1041                 return list(res.leaves())[-1]
1042
1043             except IndexError:
1044                 return None
1045
1046         node = node.parent
1047     return None
1048
1049
1050 def is_delimiter(leaf: Leaf) -> int:
1051     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1052
1053     Higher numbers are higher priority.
1054     """
1055     if leaf.type == token.COMMA:
1056         return COMMA_PRIORITY
1057
1058     if leaf.type in COMPARATORS:
1059         return COMPARATOR_PRIORITY
1060
1061     if (
1062         leaf.type in MATH_OPERATORS
1063         and leaf.parent
1064         and leaf.parent.type not in {syms.factor, syms.star_expr}
1065     ):
1066         return MATH_PRIORITY
1067
1068     return 0
1069
1070
1071 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1072     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1073
1074     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1075     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1076     move because it does away with modifying the grammar to include all the
1077     possible places in which comments can be placed.
1078
1079     The sad consequence for us though is that comments don't "belong" anywhere.
1080     This is why this function generates simple parentless Leaf objects for
1081     comments.  We simply don't know what the correct parent should be.
1082
1083     No matter though, we can live without this.  We really only need to
1084     differentiate between inline and standalone comments.  The latter don't
1085     share the line with any code.
1086
1087     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1088     are emitted with a fake STANDALONE_COMMENT token identifier.
1089     """
1090     if not leaf.prefix:
1091         return
1092
1093     if '#' not in leaf.prefix:
1094         return
1095
1096     before_comment, content = leaf.prefix.split('#', 1)
1097     content = content.rstrip()
1098     if content and (content[0] not in {' ', '!', '#'}):
1099         content = ' ' + content
1100     is_standalone_comment = (
1101         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1102     )
1103     if not is_standalone_comment:
1104         # simple trailing comment
1105         yield Leaf(token.COMMENT, value='#' + content)
1106         return
1107
1108     for line in ('#' + content).split('\n'):
1109         line = line.lstrip()
1110         if not line.startswith('#'):
1111             continue
1112
1113         yield Leaf(STANDALONE_COMMENT, line)
1114
1115
1116 def split_line(
1117     line: Line, line_length: int, inner: bool = False, py36: bool = False
1118 ) -> Iterator[Line]:
1119     """Splits a `line` into potentially many lines.
1120
1121     They should fit in the allotted `line_length` but might not be able to.
1122     `inner` signifies that there were a pair of brackets somewhere around the
1123     current `line`, possibly transitively. This means we can fallback to splitting
1124     by delimiters if the LHS/RHS don't yield any results.
1125
1126     If `py36` is True, splitting may generate syntax that is only compatible
1127     with Python 3.6 and later.
1128     """
1129     line_str = str(line).strip('\n')
1130     if len(line_str) <= line_length and '\n' not in line_str:
1131         yield line
1132         return
1133
1134     if line.is_def:
1135         split_funcs = [left_hand_split]
1136     elif line.inside_brackets:
1137         split_funcs = [delimiter_split]
1138         if '\n' not in line_str:
1139             # Only attempt RHS if we don't have multiline strings or comments
1140             # on this line.
1141             split_funcs.append(right_hand_split)
1142     else:
1143         split_funcs = [right_hand_split]
1144     for split_func in split_funcs:
1145         # We are accumulating lines in `result` because we might want to abort
1146         # mission and return the original line in the end, or attempt a different
1147         # split altogether.
1148         result: List[Line] = []
1149         try:
1150             for l in split_func(line, py36=py36):
1151                 if str(l).strip('\n') == line_str:
1152                     raise CannotSplit("Split function returned an unchanged result")
1153
1154                 result.extend(
1155                     split_line(l, line_length=line_length, inner=True, py36=py36)
1156                 )
1157         except CannotSplit as cs:
1158             continue
1159
1160         else:
1161             yield from result
1162             break
1163
1164     else:
1165         yield line
1166
1167
1168 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1169     """Split line into many lines, starting with the first matching bracket pair.
1170
1171     Note: this usually looks weird, only use this for function definitions.
1172     Prefer RHS otherwise.
1173     """
1174     head = Line(depth=line.depth)
1175     body = Line(depth=line.depth + 1, inside_brackets=True)
1176     tail = Line(depth=line.depth)
1177     tail_leaves: List[Leaf] = []
1178     body_leaves: List[Leaf] = []
1179     head_leaves: List[Leaf] = []
1180     current_leaves = head_leaves
1181     matching_bracket = None
1182     for leaf in line.leaves:
1183         if (
1184             current_leaves is body_leaves
1185             and leaf.type in CLOSING_BRACKETS
1186             and leaf.opening_bracket is matching_bracket
1187         ):
1188             current_leaves = tail_leaves if body_leaves else head_leaves
1189         current_leaves.append(leaf)
1190         if current_leaves is head_leaves:
1191             if leaf.type in OPENING_BRACKETS:
1192                 matching_bracket = leaf
1193                 current_leaves = body_leaves
1194     # Since body is a new indent level, remove spurious leading whitespace.
1195     if body_leaves:
1196         normalize_prefix(body_leaves[0])
1197     # Build the new lines.
1198     for result, leaves in (
1199         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1200     ):
1201         for leaf in leaves:
1202             result.append(leaf, preformatted=True)
1203             comment_after = line.comments.get(id(leaf))
1204             if comment_after:
1205                 result.append(comment_after, preformatted=True)
1206     split_succeeded_or_raise(head, body, tail)
1207     for result in (head, body, tail):
1208         if result:
1209             yield result
1210
1211
1212 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1213     """Split line into many lines, starting with the last matching bracket pair."""
1214     head = Line(depth=line.depth)
1215     body = Line(depth=line.depth + 1, inside_brackets=True)
1216     tail = Line(depth=line.depth)
1217     tail_leaves: List[Leaf] = []
1218     body_leaves: List[Leaf] = []
1219     head_leaves: List[Leaf] = []
1220     current_leaves = tail_leaves
1221     opening_bracket = None
1222     for leaf in reversed(line.leaves):
1223         if current_leaves is body_leaves:
1224             if leaf is opening_bracket:
1225                 current_leaves = head_leaves if body_leaves else tail_leaves
1226         current_leaves.append(leaf)
1227         if current_leaves is tail_leaves:
1228             if leaf.type in CLOSING_BRACKETS:
1229                 opening_bracket = leaf.opening_bracket
1230                 current_leaves = body_leaves
1231     tail_leaves.reverse()
1232     body_leaves.reverse()
1233     head_leaves.reverse()
1234     # Since body is a new indent level, remove spurious leading whitespace.
1235     if body_leaves:
1236         normalize_prefix(body_leaves[0])
1237     # Build the new lines.
1238     for result, leaves in (
1239         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1240     ):
1241         for leaf in leaves:
1242             result.append(leaf, preformatted=True)
1243             comment_after = line.comments.get(id(leaf))
1244             if comment_after:
1245                 result.append(comment_after, preformatted=True)
1246     split_succeeded_or_raise(head, body, tail)
1247     for result in (head, body, tail):
1248         if result:
1249             yield result
1250
1251
1252 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1253     tail_len = len(str(tail).strip())
1254     if not body:
1255         if tail_len == 0:
1256             raise CannotSplit("Splitting brackets produced the same line")
1257
1258         elif tail_len < 3:
1259             raise CannotSplit(
1260                 f"Splitting brackets on an empty body to save "
1261                 f"{tail_len} characters is not worth it"
1262             )
1263
1264
1265 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1266     """Split according to delimiters of the highest priority.
1267
1268     This kind of split doesn't increase indentation.
1269     If `py36` is True, the split will add trailing commas also in function
1270     signatures that contain * and **.
1271     """
1272     try:
1273         last_leaf = line.leaves[-1]
1274     except IndexError:
1275         raise CannotSplit("Line empty")
1276
1277     delimiters = line.bracket_tracker.delimiters
1278     try:
1279         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1280     except ValueError:
1281         raise CannotSplit("No delimiters found")
1282
1283     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1284     lowest_depth = sys.maxsize
1285     trailing_comma_safe = True
1286     for leaf in line.leaves:
1287         current_line.append(leaf, preformatted=True)
1288         comment_after = line.comments.get(id(leaf))
1289         if comment_after:
1290             current_line.append(comment_after, preformatted=True)
1291         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1292         if (
1293             leaf.bracket_depth == lowest_depth
1294             and leaf.type == token.STAR
1295             or leaf.type == token.DOUBLESTAR
1296         ):
1297             trailing_comma_safe = trailing_comma_safe and py36
1298         leaf_priority = delimiters.get(id(leaf))
1299         if leaf_priority == delimiter_priority:
1300             normalize_prefix(current_line.leaves[0])
1301             yield current_line
1302
1303             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1304     if current_line:
1305         if (
1306             delimiter_priority == COMMA_PRIORITY
1307             and current_line.leaves[-1].type != token.COMMA
1308             and trailing_comma_safe
1309         ):
1310             current_line.append(Leaf(token.COMMA, ','))
1311         normalize_prefix(current_line.leaves[0])
1312         yield current_line
1313
1314
1315 def is_import(leaf: Leaf) -> bool:
1316     """Returns True if the given leaf starts an import statement."""
1317     p = leaf.parent
1318     t = leaf.type
1319     v = leaf.value
1320     return bool(
1321         t == token.NAME
1322         and (
1323             (v == 'import' and p and p.type == syms.import_name)
1324             or (v == 'from' and p and p.type == syms.import_from)
1325         )
1326     )
1327
1328
1329 def normalize_prefix(leaf: Leaf) -> None:
1330     """Leave existing extra newlines for imports.  Remove everything else."""
1331     if is_import(leaf):
1332         spl = leaf.prefix.split('#', 1)
1333         nl_count = spl[0].count('\n')
1334         if len(spl) > 1:
1335             # Skip one newline since it was for a standalone comment.
1336             nl_count -= 1
1337         leaf.prefix = '\n' * nl_count
1338         return
1339
1340     leaf.prefix = ''
1341
1342
1343 def is_python36(node: Node) -> bool:
1344     """Returns True if the current file is using Python 3.6+ features.
1345
1346     Currently looking for:
1347     - f-strings; and
1348     - trailing commas after * or ** in function signatures.
1349     """
1350     for n in node.pre_order():
1351         if n.type == token.STRING:
1352             value_head = n.value[:2]  # type: ignore
1353             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1354                 return True
1355
1356         elif (
1357             n.type == syms.typedargslist
1358             and n.children
1359             and n.children[-1].type == token.COMMA
1360         ):
1361             for ch in n.children:
1362                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1363                     return True
1364
1365     return False
1366
1367
1368 PYTHON_EXTENSIONS = {'.py'}
1369 BLACKLISTED_DIRECTORIES = {
1370     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1371 }
1372
1373
1374 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1375     for child in path.iterdir():
1376         if child.is_dir():
1377             if child.name in BLACKLISTED_DIRECTORIES:
1378                 continue
1379
1380             yield from gen_python_files_in_dir(child)
1381
1382         elif child.suffix in PYTHON_EXTENSIONS:
1383             yield child
1384
1385
1386 @dataclass
1387 class Report:
1388     """Provides a reformatting counter."""
1389     change_count: int = 0
1390     same_count: int = 0
1391     failure_count: int = 0
1392
1393     def done(self, src: Path, changed: bool) -> None:
1394         """Increment the counter for successful reformatting. Write out a message."""
1395         if changed:
1396             out(f'reformatted {src}')
1397             self.change_count += 1
1398         else:
1399             out(f'{src} already well formatted, good job.', bold=False)
1400             self.same_count += 1
1401
1402     def failed(self, src: Path, message: str) -> None:
1403         """Increment the counter for failed reformatting. Write out a message."""
1404         err(f'error: cannot format {src}: {message}')
1405         self.failure_count += 1
1406
1407     @property
1408     def return_code(self) -> int:
1409         """Which return code should the app use considering the current state."""
1410         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1411         # 126 we have special returncodes reserved by the shell.
1412         if self.failure_count:
1413             return 123
1414
1415         elif self.change_count:
1416             return 1
1417
1418         return 0
1419
1420     def __str__(self) -> str:
1421         """A color report of the current state.
1422
1423         Use `click.unstyle` to remove colors.
1424         """
1425         report = []
1426         if self.change_count:
1427             s = 's' if self.change_count > 1 else ''
1428             report.append(
1429                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1430             )
1431         if self.same_count:
1432             s = 's' if self.same_count > 1 else ''
1433             report.append(f'{self.same_count} file{s} left unchanged')
1434         if self.failure_count:
1435             s = 's' if self.failure_count > 1 else ''
1436             report.append(
1437                 click.style(
1438                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1439                 )
1440             )
1441         return ', '.join(report) + '.'
1442
1443
1444 def assert_equivalent(src: str, dst: str) -> None:
1445     """Raises AssertionError if `src` and `dst` aren't equivalent.
1446
1447     This is a temporary sanity check until Black becomes stable.
1448     """
1449
1450     import ast
1451     import traceback
1452
1453     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1454         """Simple visitor generating strings to compare ASTs by content."""
1455         yield f"{'  ' * depth}{node.__class__.__name__}("
1456
1457         for field in sorted(node._fields):
1458             try:
1459                 value = getattr(node, field)
1460             except AttributeError:
1461                 continue
1462
1463             yield f"{'  ' * (depth+1)}{field}="
1464
1465             if isinstance(value, list):
1466                 for item in value:
1467                     if isinstance(item, ast.AST):
1468                         yield from _v(item, depth + 2)
1469
1470             elif isinstance(value, ast.AST):
1471                 yield from _v(value, depth + 2)
1472
1473             else:
1474                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1475
1476         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1477
1478     try:
1479         src_ast = ast.parse(src)
1480     except Exception as exc:
1481         raise AssertionError(f"cannot parse source: {exc}") from None
1482
1483     try:
1484         dst_ast = ast.parse(dst)
1485     except Exception as exc:
1486         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1487         raise AssertionError(
1488             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1489             f"Please report a bug on https://github.com/ambv/black/issues.  "
1490             f"This invalid output might be helpful: {log}"
1491         ) from None
1492
1493     src_ast_str = '\n'.join(_v(src_ast))
1494     dst_ast_str = '\n'.join(_v(dst_ast))
1495     if src_ast_str != dst_ast_str:
1496         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1497         raise AssertionError(
1498             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1499             f"the source.  "
1500             f"Please report a bug on https://github.com/ambv/black/issues.  "
1501             f"This diff might be helpful: {log}"
1502         ) from None
1503
1504
1505 def assert_stable(src: str, dst: str, line_length: int) -> None:
1506     """Raises AssertionError if `dst` reformats differently the second time.
1507
1508     This is a temporary sanity check until Black becomes stable.
1509     """
1510     newdst = format_str(dst, line_length=line_length)
1511     if dst != newdst:
1512         log = dump_to_file(
1513             diff(src, dst, 'source', 'first pass'),
1514             diff(dst, newdst, 'first pass', 'second pass'),
1515         )
1516         raise AssertionError(
1517             f"INTERNAL ERROR: Black produced different code on the second pass "
1518             f"of the formatter.  "
1519             f"Please report a bug on https://github.com/ambv/black/issues.  "
1520             f"This diff might be helpful: {log}"
1521         ) from None
1522
1523
1524 def dump_to_file(*output: str) -> str:
1525     """Dumps `output` to a temporary file. Returns path to the file."""
1526     import tempfile
1527
1528     with tempfile.NamedTemporaryFile(
1529         mode='w', prefix='blk_', suffix='.log', delete=False
1530     ) as f:
1531         for lines in output:
1532             f.write(lines)
1533             f.write('\n')
1534     return f.name
1535
1536
1537 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1538     """Returns a udiff string between strings `a` and `b`."""
1539     import difflib
1540
1541     a_lines = [line + '\n' for line in a.split('\n')]
1542     b_lines = [line + '\n' for line in b.split('\n')]
1543     return ''.join(
1544         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1545     )
1546
1547
1548 if __name__ == '__main__':
1549     main()