]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

0fe8fef255be119d3569071404544b153d5ce89f
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
78 )
79 @click.pass_context
80 def main(
81     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
82 ) -> None:
83     """The uncompromising code formatter."""
84     sources: List[Path] = []
85     for s in src:
86         p = Path(s)
87         if p.is_dir():
88             sources.extend(gen_python_files_in_dir(p))
89         elif p.is_file():
90             # if a file was explicitly given, we don't care about its extension
91             sources.append(p)
92         else:
93             err(f'invalid path: {s}')
94     if len(sources) == 0:
95         ctx.exit(0)
96     elif len(sources) == 1:
97         p = sources[0]
98         report = Report()
99         try:
100             changed = format_file_in_place(
101                 p, line_length=line_length, fast=fast, write_back=not check
102             )
103             report.done(p, changed)
104         except Exception as exc:
105             report.failed(p, str(exc))
106         ctx.exit(report.return_code)
107     else:
108         loop = asyncio.get_event_loop()
109         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
110         return_code = 1
111         try:
112             return_code = loop.run_until_complete(
113                 schedule_formatting(
114                     sources, line_length, not check, fast, loop, executor
115                 )
116             )
117         finally:
118             loop.close()
119             ctx.exit(return_code)
120
121
122 async def schedule_formatting(
123     sources: List[Path],
124     line_length: int,
125     write_back: bool,
126     fast: bool,
127     loop: BaseEventLoop,
128     executor: Executor,
129 ) -> int:
130     tasks = {
131         src: loop.run_in_executor(
132             executor, format_file_in_place, src, line_length, fast, write_back
133         )
134         for src in sources
135     }
136     await asyncio.wait(tasks.values())
137     cancelled = []
138     report = Report()
139     for src, task in tasks.items():
140         if not task.done():
141             report.failed(src, 'timed out, cancelling')
142             task.cancel()
143             cancelled.append(task)
144         elif task.exception():
145             report.failed(src, str(task.exception()))
146         else:
147             report.done(src, task.result())
148     if cancelled:
149         await asyncio.wait(cancelled, timeout=2)
150     out('All done! ✨ 🍰 ✨')
151     click.echo(str(report))
152     return report.return_code
153
154
155 def format_file_in_place(
156     src: Path, line_length: int, fast: bool, write_back: bool = False
157 ) -> bool:
158     """Format the file and rewrite if changed. Return True if changed."""
159     try:
160         contents, encoding = format_file(src, line_length=line_length, fast=fast)
161     except NothingChanged:
162         return False
163
164     if write_back:
165         with open(src, "w", encoding=encoding) as f:
166             f.write(contents)
167     return True
168
169
170 def format_file(
171     src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173     """Reformats a file and returns its contents and encoding."""
174     with tokenize.open(src) as src_buffer:
175         src_contents = src_buffer.read()
176     if src_contents.strip() == '':
177         raise NothingChanged(src)
178
179     dst_contents = format_str(src_contents, line_length=line_length)
180     if src_contents == dst_contents:
181         raise NothingChanged(src)
182
183     if not fast:
184         assert_equivalent(src_contents, dst_contents)
185         assert_stable(src_contents, dst_contents, line_length=line_length)
186     return dst_contents, src_buffer.encoding
187
188
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190     """Reformats a string and returns new contents."""
191     src_node = lib2to3_parse(src_contents)
192     dst_contents = ""
193     comments: List[Line] = []
194     lines = LineGenerator()
195     elt = EmptyLineTracker()
196     py36 = is_python36(src_node)
197     empty_line = Line()
198     after = 0
199     for current_line in lines.visit(src_node):
200         for _ in range(after):
201             dst_contents += str(empty_line)
202         before, after = elt.maybe_empty_lines(current_line)
203         for _ in range(before):
204             dst_contents += str(empty_line)
205         if not current_line.is_comment:
206             for comment in comments:
207                 dst_contents += str(comment)
208             comments = []
209             for line in split_line(current_line, line_length=line_length, py36=py36):
210                 dst_contents += str(line)
211         else:
212             comments.append(current_line)
213     for comment in comments:
214         dst_contents += str(comment)
215     return dst_contents
216
217
218 def lib2to3_parse(src_txt: str) -> Node:
219     """Given a string with source, return the lib2to3 Node."""
220     grammar = pygram.python_grammar_no_print_statement
221     drv = driver.Driver(grammar, pytree.convert)
222     if src_txt[-1] != '\n':
223         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
224         src_txt += nl
225     try:
226         result = drv.parse_string(src_txt, True)
227     except ParseError as pe:
228         lineno, column = pe.context[1]
229         lines = src_txt.splitlines()
230         try:
231             faulty_line = lines[lineno - 1]
232         except IndexError:
233             faulty_line = "<line number missing in source>"
234         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
235
236     if isinstance(result, Leaf):
237         result = Node(syms.file_input, [result])
238     return result
239
240
241 def lib2to3_unparse(node: Node) -> str:
242     """Given a lib2to3 node, return its string representation."""
243     code = str(node)
244     return code
245
246
247 T = TypeVar('T')
248
249
250 class Visitor(Generic[T]):
251     """Basic lib2to3 visitor that yields things on visiting."""
252
253     def visit(self, node: LN) -> Iterator[T]:
254         if node.type < 256:
255             name = token.tok_name[node.type]
256         else:
257             name = type_repr(node.type)
258         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
259
260     def visit_default(self, node: LN) -> Iterator[T]:
261         if isinstance(node, Node):
262             for child in node.children:
263                 yield from self.visit(child)
264
265
266 @dataclass
267 class DebugVisitor(Visitor[T]):
268     tree_depth: int = 0
269
270     def visit_default(self, node: LN) -> Iterator[T]:
271         indent = ' ' * (2 * self.tree_depth)
272         if isinstance(node, Node):
273             _type = type_repr(node.type)
274             out(f'{indent}{_type}', fg='yellow')
275             self.tree_depth += 1
276             for child in node.children:
277                 yield from self.visit(child)
278
279             self.tree_depth -= 1
280             out(f'{indent}/{_type}', fg='yellow', bold=False)
281         else:
282             _type = token.tok_name.get(node.type, str(node.type))
283             out(f'{indent}{_type}', fg='blue', nl=False)
284             if node.prefix:
285                 # We don't have to handle prefixes for `Node` objects since
286                 # that delegates to the first child anyway.
287                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
288             out(f' {node.value!r}', fg='blue', bold=False)
289
290
291 KEYWORDS = set(keyword.kwlist)
292 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
293 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
294 STATEMENT = {
295     syms.if_stmt,
296     syms.while_stmt,
297     syms.for_stmt,
298     syms.try_stmt,
299     syms.except_clause,
300     syms.with_stmt,
301     syms.funcdef,
302     syms.classdef,
303 }
304 STANDALONE_COMMENT = 153
305 LOGIC_OPERATORS = {'and', 'or'}
306 COMPARATORS = {
307     token.LESS,
308     token.GREATER,
309     token.EQEQUAL,
310     token.NOTEQUAL,
311     token.LESSEQUAL,
312     token.GREATEREQUAL,
313 }
314 MATH_OPERATORS = {
315     token.PLUS,
316     token.MINUS,
317     token.STAR,
318     token.SLASH,
319     token.VBAR,
320     token.AMPER,
321     token.PERCENT,
322     token.CIRCUMFLEX,
323     token.LEFTSHIFT,
324     token.RIGHTSHIFT,
325     token.DOUBLESTAR,
326     token.DOUBLESLASH,
327 }
328 COMPREHENSION_PRIORITY = 20
329 COMMA_PRIORITY = 10
330 LOGIC_PRIORITY = 5
331 STRING_PRIORITY = 4
332 COMPARATOR_PRIORITY = 3
333 MATH_PRIORITY = 1
334
335
336 @dataclass
337 class BracketTracker:
338     depth: int = 0
339     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
340     delimiters: Dict[LeafID, Priority] = Factory(dict)
341     previous: Optional[Leaf] = None
342
343     def mark(self, leaf: Leaf) -> None:
344         if leaf.type == token.COMMENT:
345             return
346
347         if leaf.type in CLOSING_BRACKETS:
348             self.depth -= 1
349             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
350             leaf.opening_bracket = opening_bracket
351         leaf.bracket_depth = self.depth
352         if self.depth == 0:
353             delim = is_delimiter(leaf)
354             if delim:
355                 self.delimiters[id(leaf)] = delim
356             elif self.previous is not None:
357                 if leaf.type == token.STRING and self.previous.type == token.STRING:
358                     self.delimiters[id(self.previous)] = STRING_PRIORITY
359                 elif (
360                     leaf.type == token.NAME
361                     and leaf.value == 'for'
362                     and leaf.parent
363                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
364                 ):
365                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
366                 elif (
367                     leaf.type == token.NAME
368                     and leaf.value == 'if'
369                     and leaf.parent
370                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
371                 ):
372                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
373                 elif (
374                     leaf.type == token.NAME
375                     and leaf.value in LOGIC_OPERATORS
376                     and leaf.parent
377                 ):
378                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
379         if leaf.type in OPENING_BRACKETS:
380             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
381             self.depth += 1
382         self.previous = leaf
383
384     def any_open_brackets(self) -> bool:
385         """Returns True if there is an yet unmatched open bracket on the line."""
386         return bool(self.bracket_match)
387
388     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
389         """Returns the highest priority of a delimiter found on the line.
390
391         Values are consistent with what `is_delimiter()` returns.
392         """
393         return max(v for k, v in self.delimiters.items() if k not in exclude)
394
395
396 @dataclass
397 class Line:
398     depth: int = 0
399     leaves: List[Leaf] = Factory(list)
400     comments: Dict[LeafID, Leaf] = Factory(dict)
401     bracket_tracker: BracketTracker = Factory(BracketTracker)
402     inside_brackets: bool = False
403     has_for: bool = False
404     _for_loop_variable: bool = False
405
406     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
407         has_value = leaf.value.strip()
408         if not has_value:
409             return
410
411         if self.leaves and not preformatted:
412             # Note: at this point leaf.prefix should be empty except for
413             # imports, for which we only preserve newlines.
414             leaf.prefix += whitespace(leaf)
415         if self.inside_brackets or not preformatted:
416             self.maybe_decrement_after_for_loop_variable(leaf)
417             self.bracket_tracker.mark(leaf)
418             self.maybe_remove_trailing_comma(leaf)
419             self.maybe_increment_for_loop_variable(leaf)
420             if self.maybe_adapt_standalone_comment(leaf):
421                 return
422
423         if not self.append_comment(leaf):
424             self.leaves.append(leaf)
425
426     @property
427     def is_comment(self) -> bool:
428         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
429
430     @property
431     def is_decorator(self) -> bool:
432         return bool(self) and self.leaves[0].type == token.AT
433
434     @property
435     def is_import(self) -> bool:
436         return bool(self) and is_import(self.leaves[0])
437
438     @property
439     def is_class(self) -> bool:
440         return (
441             bool(self)
442             and self.leaves[0].type == token.NAME
443             and self.leaves[0].value == 'class'
444         )
445
446     @property
447     def is_def(self) -> bool:
448         """Also returns True for async defs."""
449         try:
450             first_leaf = self.leaves[0]
451         except IndexError:
452             return False
453
454         try:
455             second_leaf: Optional[Leaf] = self.leaves[1]
456         except IndexError:
457             second_leaf = None
458         return (
459             (first_leaf.type == token.NAME and first_leaf.value == 'def')
460             or (
461                 first_leaf.type == token.NAME
462                 and first_leaf.value == 'async'
463                 and second_leaf is not None
464                 and second_leaf.type == token.NAME
465                 and second_leaf.value == 'def'
466             )
467         )
468
469     @property
470     def is_flow_control(self) -> bool:
471         return (
472             bool(self)
473             and self.leaves[0].type == token.NAME
474             and self.leaves[0].value in FLOW_CONTROL
475         )
476
477     @property
478     def is_yield(self) -> bool:
479         return (
480             bool(self)
481             and self.leaves[0].type == token.NAME
482             and self.leaves[0].value == 'yield'
483         )
484
485     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
486         if not (
487             self.leaves
488             and self.leaves[-1].type == token.COMMA
489             and closing.type in CLOSING_BRACKETS
490         ):
491             return False
492
493         if closing.type == token.RSQB or closing.type == token.RBRACE:
494             self.leaves.pop()
495             return True
496
497         # For parens let's check if it's safe to remove the comma.  If the
498         # trailing one is the only one, we might mistakenly change a tuple
499         # into a different type by removing the comma.
500         depth = closing.bracket_depth + 1
501         commas = 0
502         opening = closing.opening_bracket
503         for _opening_index, leaf in enumerate(self.leaves):
504             if leaf is opening:
505                 break
506
507         else:
508             return False
509
510         for leaf in self.leaves[_opening_index + 1:]:
511             if leaf is closing:
512                 break
513
514             bracket_depth = leaf.bracket_depth
515             if bracket_depth == depth and leaf.type == token.COMMA:
516                 commas += 1
517                 if leaf.parent and leaf.parent.type == syms.arglist:
518                     commas += 1
519                     break
520
521         if commas > 1:
522             self.leaves.pop()
523             return True
524
525         return False
526
527     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
528         """In a for loop, or comprehension, the variables are often unpacks.
529
530         To avoid splitting on the comma in this situation, we will increase
531         the depth of tokens between `for` and `in`.
532         """
533         if leaf.type == token.NAME and leaf.value == 'for':
534             self.has_for = True
535             self.bracket_tracker.depth += 1
536             self._for_loop_variable = True
537             return True
538
539         return False
540
541     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
542         # See `maybe_increment_for_loop_variable` above for explanation.
543         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
544             self.bracket_tracker.depth -= 1
545             self._for_loop_variable = False
546             return True
547
548         return False
549
550     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
551         """Hack a standalone comment to act as a trailing comment for line splitting.
552
553         If this line has brackets and a standalone `comment`, we need to adapt
554         it to be able to still reformat the line.
555
556         This is not perfect, the line to which the standalone comment gets
557         appended will appear "too long" when splitting.
558         """
559         if not (
560             comment.type == STANDALONE_COMMENT
561             and self.bracket_tracker.any_open_brackets()
562         ):
563             return False
564
565         comment.type = token.COMMENT
566         comment.prefix = '\n' + '    ' * (self.depth + 1)
567         return self.append_comment(comment)
568
569     def append_comment(self, comment: Leaf) -> bool:
570         if comment.type != token.COMMENT:
571             return False
572
573         try:
574             after = id(self.last_non_delimiter())
575         except LookupError:
576             comment.type = STANDALONE_COMMENT
577             comment.prefix = ''
578             return False
579
580         else:
581             if after in self.comments:
582                 self.comments[after].value += str(comment)
583             else:
584                 self.comments[after] = comment
585             return True
586
587     def last_non_delimiter(self) -> Leaf:
588         for i in range(len(self.leaves)):
589             last = self.leaves[-i - 1]
590             if not is_delimiter(last):
591                 return last
592
593         raise LookupError("No non-delimiters found")
594
595     def __str__(self) -> str:
596         if not self:
597             return '\n'
598
599         indent = '    ' * self.depth
600         leaves = iter(self.leaves)
601         first = next(leaves)
602         res = f'{first.prefix}{indent}{first.value}'
603         for leaf in leaves:
604             res += str(leaf)
605         for comment in self.comments.values():
606             res += str(comment)
607         return res + '\n'
608
609     def __bool__(self) -> bool:
610         return bool(self.leaves or self.comments)
611
612
613 @dataclass
614 class EmptyLineTracker:
615     """Provides a stateful method that returns the number of potential extra
616     empty lines needed before and after the currently processed line.
617
618     Note: this tracker works on lines that haven't been split yet.
619     """
620     previous_line: Optional[Line] = None
621     previous_after: int = 0
622     previous_defs: List[int] = Factory(list)
623
624     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
625         """Returns the number of extra empty lines before and after the `current_line`.
626
627         This is for separating `def`, `async def` and `class` with extra empty lines
628         (two on module-level), as well as providing an extra empty line after flow
629         control keywords to make them more prominent.
630         """
631         if current_line.is_comment:
632             # Don't count standalone comments towards previous empty lines.
633             return 0, 0
634
635         before, after = self._maybe_empty_lines(current_line)
636         self.previous_after = after
637         self.previous_line = current_line
638         return before, after
639
640     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
641         before = 0
642         depth = current_line.depth
643         while self.previous_defs and self.previous_defs[-1] >= depth:
644             self.previous_defs.pop()
645             before = (1 if depth else 2) - self.previous_after
646         is_decorator = current_line.is_decorator
647         if is_decorator or current_line.is_def or current_line.is_class:
648             if not is_decorator:
649                 self.previous_defs.append(depth)
650             if self.previous_line is None:
651                 # Don't insert empty lines before the first line in the file.
652                 return 0, 0
653
654             if self.previous_line and self.previous_line.is_decorator:
655                 # Don't insert empty lines between decorators.
656                 return 0, 0
657
658             newlines = 2
659             if current_line.depth:
660                 newlines -= 1
661             newlines -= self.previous_after
662             return newlines, 0
663
664         if current_line.is_flow_control:
665             return before, 1
666
667         if (
668             self.previous_line
669             and self.previous_line.is_import
670             and not current_line.is_import
671             and depth == self.previous_line.depth
672         ):
673             return (before or 1), 0
674
675         if (
676             self.previous_line
677             and self.previous_line.is_yield
678             and (not current_line.is_yield or depth != self.previous_line.depth)
679         ):
680             return (before or 1), 0
681
682         return before, 0
683
684
685 @dataclass
686 class LineGenerator(Visitor[Line]):
687     """Generates reformatted Line objects.  Empty lines are not emitted.
688
689     Note: destroys the tree it's visiting by mutating prefixes of its leaves
690     in ways that will no longer stringify to valid Python code on the tree.
691     """
692     current_line: Line = Factory(Line)
693     standalone_comments: List[Leaf] = Factory(list)
694
695     def line(self, indent: int = 0) -> Iterator[Line]:
696         """Generate a line.
697
698         If the line is empty, only emit if it makes sense.
699         If the line is too long, split it first and then generate.
700
701         If any lines were generated, set up a new current_line.
702         """
703         if not self.current_line:
704             self.current_line.depth += indent
705             return  # Line is empty, don't emit. Creating a new one unnecessary.
706
707         complete_line = self.current_line
708         self.current_line = Line(depth=complete_line.depth + indent)
709         yield complete_line
710
711     def visit_default(self, node: LN) -> Iterator[Line]:
712         if isinstance(node, Leaf):
713             for comment in generate_comments(node):
714                 if self.current_line.bracket_tracker.any_open_brackets():
715                     # any comment within brackets is subject to splitting
716                     self.current_line.append(comment)
717                 elif comment.type == token.COMMENT:
718                     # regular trailing comment
719                     self.current_line.append(comment)
720                     yield from self.line()
721
722                 else:
723                     # regular standalone comment, to be processed later (see
724                     # docstring in `generate_comments()`
725                     self.standalone_comments.append(comment)
726             normalize_prefix(node)
727             if node.type not in WHITESPACE:
728                 for comment in self.standalone_comments:
729                     yield from self.line()
730
731                     self.current_line.append(comment)
732                     yield from self.line()
733
734                 self.standalone_comments = []
735                 self.current_line.append(node)
736         yield from super().visit_default(node)
737
738     def visit_suite(self, node: Node) -> Iterator[Line]:
739         """Body of a statement after a colon."""
740         children = iter(node.children)
741         # Process newline before indenting.  It might contain an inline
742         # comment that should go right after the colon.
743         newline = next(children)
744         yield from self.visit(newline)
745         yield from self.line(+1)
746
747         for child in children:
748             yield from self.visit(child)
749
750         yield from self.line(-1)
751
752     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
753         """Visit a statement.
754
755         The relevant Python language keywords for this statement are NAME leaves
756         within it.
757         """
758         for child in node.children:
759             if child.type == token.NAME and child.value in keywords:  # type: ignore
760                 yield from self.line()
761
762             yield from self.visit(child)
763
764     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
765         """A statement without nested statements."""
766         is_suite_like = node.parent and node.parent.type in STATEMENT
767         if is_suite_like:
768             yield from self.line(+1)
769             yield from self.visit_default(node)
770             yield from self.line(-1)
771
772         else:
773             yield from self.line()
774             yield from self.visit_default(node)
775
776     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
777         yield from self.line()
778
779         children = iter(node.children)
780         for child in children:
781             yield from self.visit(child)
782
783             if child.type == token.NAME and child.value == 'async':  # type: ignore
784                 break
785
786         internal_stmt = next(children)
787         for child in internal_stmt.children:
788             yield from self.visit(child)
789
790     def visit_decorators(self, node: Node) -> Iterator[Line]:
791         for child in node.children:
792             yield from self.line()
793             yield from self.visit(child)
794
795     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
796         yield from self.line()
797
798     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
799         yield from self.visit_default(leaf)
800         yield from self.line()
801
802     def __attrs_post_init__(self) -> None:
803         """You are in a twisty little maze of passages."""
804         v = self.visit_stmt
805         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
806         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
807         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
808         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
809         self.visit_except_clause = partial(v, keywords={'except'})
810         self.visit_funcdef = partial(v, keywords={'def'})
811         self.visit_with_stmt = partial(v, keywords={'with'})
812         self.visit_classdef = partial(v, keywords={'class'})
813         self.visit_async_funcdef = self.visit_async_stmt
814         self.visit_decorated = self.visit_decorators
815
816
817 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
818 OPENING_BRACKETS = set(BRACKET.keys())
819 CLOSING_BRACKETS = set(BRACKET.values())
820 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
821 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
822
823
824 def whitespace(leaf: Leaf) -> str:  # noqa C901
825     """Return whitespace prefix if needed for the given `leaf`."""
826     NO = ''
827     SPACE = ' '
828     DOUBLESPACE = '  '
829     t = leaf.type
830     p = leaf.parent
831     v = leaf.value
832     if t in ALWAYS_NO_SPACE:
833         return NO
834
835     if t == token.COMMENT:
836         return DOUBLESPACE
837
838     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
839     prev = leaf.prev_sibling
840     if not prev:
841         prevp = preceding_leaf(p)
842         if not prevp or prevp.type in OPENING_BRACKETS:
843             return NO
844
845         if prevp.type == token.EQUAL:
846             if prevp.parent and prevp.parent.type in {
847                 syms.typedargslist,
848                 syms.varargslist,
849                 syms.parameters,
850                 syms.arglist,
851                 syms.argument,
852             }:
853                 return NO
854
855         elif prevp.type == token.DOUBLESTAR:
856             if prevp.parent and prevp.parent.type in {
857                 syms.typedargslist,
858                 syms.varargslist,
859                 syms.parameters,
860                 syms.arglist,
861                 syms.dictsetmaker,
862             }:
863                 return NO
864
865         elif prevp.type == token.COLON:
866             if prevp.parent and prevp.parent.type == syms.subscript:
867                 return NO
868
869         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
870             return NO
871
872     elif prev.type in OPENING_BRACKETS:
873         return NO
874
875     if p.type in {syms.parameters, syms.arglist}:
876         # untyped function signatures or calls
877         if t == token.RPAR:
878             return NO
879
880         if not prev or prev.type != token.COMMA:
881             return NO
882
883     if p.type == syms.varargslist:
884         # lambdas
885         if t == token.RPAR:
886             return NO
887
888         if prev and prev.type != token.COMMA:
889             return NO
890
891     elif p.type == syms.typedargslist:
892         # typed function signatures
893         if not prev:
894             return NO
895
896         if t == token.EQUAL:
897             if prev.type != syms.tname:
898                 return NO
899
900         elif prev.type == token.EQUAL:
901             # A bit hacky: if the equal sign has whitespace, it means we
902             # previously found it's a typed argument.  So, we're using that, too.
903             return prev.prefix
904
905         elif prev.type != token.COMMA:
906             return NO
907
908     elif p.type == syms.tname:
909         # type names
910         if not prev:
911             prevp = preceding_leaf(p)
912             if not prevp or prevp.type != token.COMMA:
913                 return NO
914
915     elif p.type == syms.trailer:
916         # attributes and calls
917         if t == token.LPAR or t == token.RPAR:
918             return NO
919
920         if not prev:
921             if t == token.DOT:
922                 prevp = preceding_leaf(p)
923                 if not prevp or prevp.type != token.NUMBER:
924                     return NO
925
926             elif t == token.LSQB:
927                 return NO
928
929         elif prev.type != token.COMMA:
930             return NO
931
932     elif p.type == syms.argument:
933         # single argument
934         if t == token.EQUAL:
935             return NO
936
937         if not prev:
938             prevp = preceding_leaf(p)
939             if not prevp or prevp.type == token.LPAR:
940                 return NO
941
942         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
943             return NO
944
945     elif p.type == syms.decorator:
946         # decorators
947         return NO
948
949     elif p.type == syms.dotted_name:
950         if prev:
951             return NO
952
953         prevp = preceding_leaf(p)
954         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
955             return NO
956
957     elif p.type == syms.classdef:
958         if t == token.LPAR:
959             return NO
960
961         if prev and prev.type == token.LPAR:
962             return NO
963
964     elif p.type == syms.subscript:
965         # indexing
966         if not prev:
967             assert p.parent is not None, "subscripts are always parented"
968             if p.parent.type == syms.subscriptlist:
969                 return SPACE
970
971             return NO
972
973         elif prev.type == token.COLON:
974             return NO
975
976     elif p.type == syms.atom:
977         if prev and t == token.DOT:
978             # dots, but not the first one.
979             return NO
980
981     elif (
982         p.type == syms.listmaker
983         or p.type == syms.testlist_gexp
984         or p.type == syms.subscriptlist
985     ):
986         # list interior, including unpacking
987         if not prev:
988             return NO
989
990     elif p.type == syms.dictsetmaker:
991         # dict and set interior, including unpacking
992         if not prev:
993             return NO
994
995         if prev.type == token.DOUBLESTAR:
996             return NO
997
998     elif p.type in {syms.factor, syms.star_expr}:
999         # unary ops
1000         if not prev:
1001             prevp = preceding_leaf(p)
1002             if not prevp or prevp.type in OPENING_BRACKETS:
1003                 return NO
1004
1005             prevp_parent = prevp.parent
1006             assert prevp_parent is not None
1007             if prevp.type == token.COLON and prevp_parent.type in {
1008                 syms.subscript, syms.sliceop
1009             }:
1010                 return NO
1011
1012             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1013                 return NO
1014
1015         elif t == token.NAME or t == token.NUMBER:
1016             return NO
1017
1018     elif p.type == syms.import_from:
1019         if t == token.DOT:
1020             if prev and prev.type == token.DOT:
1021                 return NO
1022
1023         elif t == token.NAME:
1024             if v == 'import':
1025                 return SPACE
1026
1027             if prev and prev.type == token.DOT:
1028                 return NO
1029
1030     elif p.type == syms.sliceop:
1031         return NO
1032
1033     return SPACE
1034
1035
1036 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1037     """Returns the first leaf that precedes `node`, if any."""
1038     while node:
1039         res = node.prev_sibling
1040         if res:
1041             if isinstance(res, Leaf):
1042                 return res
1043
1044             try:
1045                 return list(res.leaves())[-1]
1046
1047             except IndexError:
1048                 return None
1049
1050         node = node.parent
1051     return None
1052
1053
1054 def is_delimiter(leaf: Leaf) -> int:
1055     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1056
1057     Higher numbers are higher priority.
1058     """
1059     if leaf.type == token.COMMA:
1060         return COMMA_PRIORITY
1061
1062     if leaf.type in COMPARATORS:
1063         return COMPARATOR_PRIORITY
1064
1065     if (
1066         leaf.type in MATH_OPERATORS
1067         and leaf.parent
1068         and leaf.parent.type not in {syms.factor, syms.star_expr}
1069     ):
1070         return MATH_PRIORITY
1071
1072     return 0
1073
1074
1075 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1076     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1077
1078     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1079     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1080     move because it does away with modifying the grammar to include all the
1081     possible places in which comments can be placed.
1082
1083     The sad consequence for us though is that comments don't "belong" anywhere.
1084     This is why this function generates simple parentless Leaf objects for
1085     comments.  We simply don't know what the correct parent should be.
1086
1087     No matter though, we can live without this.  We really only need to
1088     differentiate between inline and standalone comments.  The latter don't
1089     share the line with any code.
1090
1091     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1092     are emitted with a fake STANDALONE_COMMENT token identifier.
1093     """
1094     if not leaf.prefix:
1095         return
1096
1097     if '#' not in leaf.prefix:
1098         return
1099
1100     before_comment, content = leaf.prefix.split('#', 1)
1101     content = content.rstrip()
1102     if content and (content[0] not in {' ', '!', '#'}):
1103         content = ' ' + content
1104     is_standalone_comment = (
1105         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1106     )
1107     if not is_standalone_comment:
1108         # simple trailing comment
1109         yield Leaf(token.COMMENT, value='#' + content)
1110         return
1111
1112     for line in ('#' + content).split('\n'):
1113         line = line.lstrip()
1114         if not line.startswith('#'):
1115             continue
1116
1117         yield Leaf(STANDALONE_COMMENT, line)
1118
1119
1120 def split_line(
1121     line: Line, line_length: int, inner: bool = False, py36: bool = False
1122 ) -> Iterator[Line]:
1123     """Splits a `line` into potentially many lines.
1124
1125     They should fit in the allotted `line_length` but might not be able to.
1126     `inner` signifies that there were a pair of brackets somewhere around the
1127     current `line`, possibly transitively. This means we can fallback to splitting
1128     by delimiters if the LHS/RHS don't yield any results.
1129
1130     If `py36` is True, splitting may generate syntax that is only compatible
1131     with Python 3.6 and later.
1132     """
1133     line_str = str(line).strip('\n')
1134     if len(line_str) <= line_length and '\n' not in line_str:
1135         yield line
1136         return
1137
1138     if line.is_def:
1139         split_funcs = [left_hand_split]
1140     elif line.inside_brackets:
1141         split_funcs = [delimiter_split]
1142         if '\n' not in line_str:
1143             # Only attempt RHS if we don't have multiline strings or comments
1144             # on this line.
1145             split_funcs.append(right_hand_split)
1146     else:
1147         split_funcs = [right_hand_split]
1148     for split_func in split_funcs:
1149         # We are accumulating lines in `result` because we might want to abort
1150         # mission and return the original line in the end, or attempt a different
1151         # split altogether.
1152         result: List[Line] = []
1153         try:
1154             for l in split_func(line, py36=py36):
1155                 if str(l).strip('\n') == line_str:
1156                     raise CannotSplit("Split function returned an unchanged result")
1157
1158                 result.extend(
1159                     split_line(l, line_length=line_length, inner=True, py36=py36)
1160                 )
1161         except CannotSplit as cs:
1162             continue
1163
1164         else:
1165             yield from result
1166             break
1167
1168     else:
1169         yield line
1170
1171
1172 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1173     """Split line into many lines, starting with the first matching bracket pair.
1174
1175     Note: this usually looks weird, only use this for function definitions.
1176     Prefer RHS otherwise.
1177     """
1178     head = Line(depth=line.depth)
1179     body = Line(depth=line.depth + 1, inside_brackets=True)
1180     tail = Line(depth=line.depth)
1181     tail_leaves: List[Leaf] = []
1182     body_leaves: List[Leaf] = []
1183     head_leaves: List[Leaf] = []
1184     current_leaves = head_leaves
1185     matching_bracket = None
1186     for leaf in line.leaves:
1187         if (
1188             current_leaves is body_leaves
1189             and leaf.type in CLOSING_BRACKETS
1190             and leaf.opening_bracket is matching_bracket
1191         ):
1192             current_leaves = tail_leaves if body_leaves else head_leaves
1193         current_leaves.append(leaf)
1194         if current_leaves is head_leaves:
1195             if leaf.type in OPENING_BRACKETS:
1196                 matching_bracket = leaf
1197                 current_leaves = body_leaves
1198     # Since body is a new indent level, remove spurious leading whitespace.
1199     if body_leaves:
1200         normalize_prefix(body_leaves[0])
1201     # Build the new lines.
1202     for result, leaves in (
1203         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1204     ):
1205         for leaf in leaves:
1206             result.append(leaf, preformatted=True)
1207             comment_after = line.comments.get(id(leaf))
1208             if comment_after:
1209                 result.append(comment_after, preformatted=True)
1210     split_succeeded_or_raise(head, body, tail)
1211     for result in (head, body, tail):
1212         if result:
1213             yield result
1214
1215
1216 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1217     """Split line into many lines, starting with the last matching bracket pair."""
1218     head = Line(depth=line.depth)
1219     body = Line(depth=line.depth + 1, inside_brackets=True)
1220     tail = Line(depth=line.depth)
1221     tail_leaves: List[Leaf] = []
1222     body_leaves: List[Leaf] = []
1223     head_leaves: List[Leaf] = []
1224     current_leaves = tail_leaves
1225     opening_bracket = None
1226     for leaf in reversed(line.leaves):
1227         if current_leaves is body_leaves:
1228             if leaf is opening_bracket:
1229                 current_leaves = head_leaves if body_leaves else tail_leaves
1230         current_leaves.append(leaf)
1231         if current_leaves is tail_leaves:
1232             if leaf.type in CLOSING_BRACKETS:
1233                 opening_bracket = leaf.opening_bracket
1234                 current_leaves = body_leaves
1235     tail_leaves.reverse()
1236     body_leaves.reverse()
1237     head_leaves.reverse()
1238     # Since body is a new indent level, remove spurious leading whitespace.
1239     if body_leaves:
1240         normalize_prefix(body_leaves[0])
1241     # Build the new lines.
1242     for result, leaves in (
1243         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1244     ):
1245         for leaf in leaves:
1246             result.append(leaf, preformatted=True)
1247             comment_after = line.comments.get(id(leaf))
1248             if comment_after:
1249                 result.append(comment_after, preformatted=True)
1250     split_succeeded_or_raise(head, body, tail)
1251     for result in (head, body, tail):
1252         if result:
1253             yield result
1254
1255
1256 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1257     tail_len = len(str(tail).strip())
1258     if not body:
1259         if tail_len == 0:
1260             raise CannotSplit("Splitting brackets produced the same line")
1261
1262         elif tail_len < 3:
1263             raise CannotSplit(
1264                 f"Splitting brackets on an empty body to save "
1265                 f"{tail_len} characters is not worth it"
1266             )
1267
1268
1269 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1270     """Split according to delimiters of the highest priority.
1271
1272     This kind of split doesn't increase indentation.
1273     If `py36` is True, the split will add trailing commas also in function
1274     signatures that contain * and **.
1275     """
1276     try:
1277         last_leaf = line.leaves[-1]
1278     except IndexError:
1279         raise CannotSplit("Line empty")
1280
1281     delimiters = line.bracket_tracker.delimiters
1282     try:
1283         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1284     except ValueError:
1285         raise CannotSplit("No delimiters found")
1286
1287     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1288     lowest_depth = sys.maxsize
1289     trailing_comma_safe = True
1290     for leaf in line.leaves:
1291         current_line.append(leaf, preformatted=True)
1292         comment_after = line.comments.get(id(leaf))
1293         if comment_after:
1294             current_line.append(comment_after, preformatted=True)
1295         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1296         if (
1297             leaf.bracket_depth == lowest_depth
1298             and leaf.type == token.STAR
1299             or leaf.type == token.DOUBLESTAR
1300         ):
1301             trailing_comma_safe = trailing_comma_safe and py36
1302         leaf_priority = delimiters.get(id(leaf))
1303         if leaf_priority == delimiter_priority:
1304             normalize_prefix(current_line.leaves[0])
1305             yield current_line
1306
1307             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1308     if current_line:
1309         if (
1310             delimiter_priority == COMMA_PRIORITY
1311             and current_line.leaves[-1].type != token.COMMA
1312             and trailing_comma_safe
1313         ):
1314             current_line.append(Leaf(token.COMMA, ','))
1315         normalize_prefix(current_line.leaves[0])
1316         yield current_line
1317
1318
1319 def is_import(leaf: Leaf) -> bool:
1320     """Returns True if the given leaf starts an import statement."""
1321     p = leaf.parent
1322     t = leaf.type
1323     v = leaf.value
1324     return bool(
1325         t == token.NAME
1326         and (
1327             (v == 'import' and p and p.type == syms.import_name)
1328             or (v == 'from' and p and p.type == syms.import_from)
1329         )
1330     )
1331
1332
1333 def normalize_prefix(leaf: Leaf) -> None:
1334     """Leave existing extra newlines for imports.  Remove everything else."""
1335     if is_import(leaf):
1336         spl = leaf.prefix.split('#', 1)
1337         nl_count = spl[0].count('\n')
1338         if len(spl) > 1:
1339             # Skip one newline since it was for a standalone comment.
1340             nl_count -= 1
1341         leaf.prefix = '\n' * nl_count
1342         return
1343
1344     leaf.prefix = ''
1345
1346
1347 def is_python36(node: Node) -> bool:
1348     """Returns True if the current file is using Python 3.6+ features.
1349
1350     Currently looking for:
1351     - f-strings; and
1352     - trailing commas after * or ** in function signatures.
1353     """
1354     for n in node.pre_order():
1355         if n.type == token.STRING:
1356             value_head = n.value[:2]  # type: ignore
1357             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1358                 return True
1359
1360         elif (
1361             n.type == syms.typedargslist
1362             and n.children
1363             and n.children[-1].type == token.COMMA
1364         ):
1365             for ch in n.children:
1366                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1367                     return True
1368
1369     return False
1370
1371
1372 PYTHON_EXTENSIONS = {'.py'}
1373 BLACKLISTED_DIRECTORIES = {
1374     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1375 }
1376
1377
1378 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1379     for child in path.iterdir():
1380         if child.is_dir():
1381             if child.name in BLACKLISTED_DIRECTORIES:
1382                 continue
1383
1384             yield from gen_python_files_in_dir(child)
1385
1386         elif child.suffix in PYTHON_EXTENSIONS:
1387             yield child
1388
1389
1390 @dataclass
1391 class Report:
1392     """Provides a reformatting counter."""
1393     change_count: int = 0
1394     same_count: int = 0
1395     failure_count: int = 0
1396
1397     def done(self, src: Path, changed: bool) -> None:
1398         """Increment the counter for successful reformatting. Write out a message."""
1399         if changed:
1400             out(f'reformatted {src}')
1401             self.change_count += 1
1402         else:
1403             out(f'{src} already well formatted, good job.', bold=False)
1404             self.same_count += 1
1405
1406     def failed(self, src: Path, message: str) -> None:
1407         """Increment the counter for failed reformatting. Write out a message."""
1408         err(f'error: cannot format {src}: {message}')
1409         self.failure_count += 1
1410
1411     @property
1412     def return_code(self) -> int:
1413         """Which return code should the app use considering the current state."""
1414         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1415         # 126 we have special returncodes reserved by the shell.
1416         if self.failure_count:
1417             return 123
1418
1419         elif self.change_count:
1420             return 1
1421
1422         return 0
1423
1424     def __str__(self) -> str:
1425         """A color report of the current state.
1426
1427         Use `click.unstyle` to remove colors.
1428         """
1429         report = []
1430         if self.change_count:
1431             s = 's' if self.change_count > 1 else ''
1432             report.append(
1433                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1434             )
1435         if self.same_count:
1436             s = 's' if self.same_count > 1 else ''
1437             report.append(f'{self.same_count} file{s} left unchanged')
1438         if self.failure_count:
1439             s = 's' if self.failure_count > 1 else ''
1440             report.append(
1441                 click.style(
1442                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1443                 )
1444             )
1445         return ', '.join(report) + '.'
1446
1447
1448 def assert_equivalent(src: str, dst: str) -> None:
1449     """Raises AssertionError if `src` and `dst` aren't equivalent.
1450
1451     This is a temporary sanity check until Black becomes stable.
1452     """
1453
1454     import ast
1455     import traceback
1456
1457     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1458         """Simple visitor generating strings to compare ASTs by content."""
1459         yield f"{'  ' * depth}{node.__class__.__name__}("
1460
1461         for field in sorted(node._fields):
1462             try:
1463                 value = getattr(node, field)
1464             except AttributeError:
1465                 continue
1466
1467             yield f"{'  ' * (depth+1)}{field}="
1468
1469             if isinstance(value, list):
1470                 for item in value:
1471                     if isinstance(item, ast.AST):
1472                         yield from _v(item, depth + 2)
1473
1474             elif isinstance(value, ast.AST):
1475                 yield from _v(value, depth + 2)
1476
1477             else:
1478                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1479
1480         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1481
1482     try:
1483         src_ast = ast.parse(src)
1484     except Exception as exc:
1485         raise AssertionError(f"cannot parse source: {exc}") from None
1486
1487     try:
1488         dst_ast = ast.parse(dst)
1489     except Exception as exc:
1490         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1491         raise AssertionError(
1492             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1493             f"Please report a bug on https://github.com/ambv/black/issues.  "
1494             f"This invalid output might be helpful: {log}"
1495         ) from None
1496
1497     src_ast_str = '\n'.join(_v(src_ast))
1498     dst_ast_str = '\n'.join(_v(dst_ast))
1499     if src_ast_str != dst_ast_str:
1500         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1501         raise AssertionError(
1502             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1503             f"the source.  "
1504             f"Please report a bug on https://github.com/ambv/black/issues.  "
1505             f"This diff might be helpful: {log}"
1506         ) from None
1507
1508
1509 def assert_stable(src: str, dst: str, line_length: int) -> None:
1510     """Raises AssertionError if `dst` reformats differently the second time.
1511
1512     This is a temporary sanity check until Black becomes stable.
1513     """
1514     newdst = format_str(dst, line_length=line_length)
1515     if dst != newdst:
1516         log = dump_to_file(
1517             diff(src, dst, 'source', 'first pass'),
1518             diff(dst, newdst, 'first pass', 'second pass'),
1519         )
1520         raise AssertionError(
1521             f"INTERNAL ERROR: Black produced different code on the second pass "
1522             f"of the formatter.  "
1523             f"Please report a bug on https://github.com/ambv/black/issues.  "
1524             f"This diff might be helpful: {log}"
1525         ) from None
1526
1527
1528 def dump_to_file(*output: str) -> str:
1529     """Dumps `output` to a temporary file. Returns path to the file."""
1530     import tempfile
1531
1532     with tempfile.NamedTemporaryFile(
1533         mode='w', prefix='blk_', suffix='.log', delete=False
1534     ) as f:
1535         for lines in output:
1536             f.write(lines)
1537             f.write('\n')
1538     return f.name
1539
1540
1541 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1542     """Returns a udiff string between strings `a` and `b`."""
1543     import difflib
1544
1545     a_lines = [line + '\n' for line in a.split('\n')]
1546     b_lines = [line + '\n' for line in b.split('\n')]
1547     return ''.join(
1548         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1549     )
1550
1551
1552 if __name__ == '__main__':
1553     main()