]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

89155f60103a32e2f89cd86ac1e6c92aa1bed22f
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
78 )
79 @click.pass_context
80 def main(
81     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
82 ) -> None:
83     """The uncompromising code formatter."""
84     sources: List[Path] = []
85     for s in src:
86         p = Path(s)
87         if p.is_dir():
88             sources.extend(gen_python_files_in_dir(p))
89         elif p.is_file():
90             # if a file was explicitly given, we don't care about its extension
91             sources.append(p)
92         else:
93             err(f'invalid path: {s}')
94     if len(sources) == 0:
95         ctx.exit(0)
96     elif len(sources) == 1:
97         p = sources[0]
98         report = Report()
99         try:
100             changed = format_file_in_place(
101                 p, line_length=line_length, fast=fast, write_back=not check
102             )
103             report.done(p, changed)
104         except Exception as exc:
105             report.failed(p, str(exc))
106         ctx.exit(report.return_code)
107     else:
108         loop = asyncio.get_event_loop()
109         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
110         return_code = 1
111         try:
112             return_code = loop.run_until_complete(
113                 schedule_formatting(
114                     sources, line_length, not check, fast, loop, executor
115                 )
116             )
117         finally:
118             loop.close()
119             ctx.exit(return_code)
120
121
122 async def schedule_formatting(
123     sources: List[Path],
124     line_length: int,
125     write_back: bool,
126     fast: bool,
127     loop: BaseEventLoop,
128     executor: Executor,
129 ) -> int:
130     tasks = {
131         src: loop.run_in_executor(
132             executor, format_file_in_place, src, line_length, fast, write_back
133         )
134         for src in sources
135     }
136     await asyncio.wait(tasks.values())
137     cancelled = []
138     report = Report()
139     for src, task in tasks.items():
140         if not task.done():
141             report.failed(src, 'timed out, cancelling')
142             task.cancel()
143             cancelled.append(task)
144         elif task.exception():
145             report.failed(src, str(task.exception()))
146         else:
147             report.done(src, task.result())
148     if cancelled:
149         await asyncio.wait(cancelled, timeout=2)
150     out('All done! ✨ 🍰 ✨')
151     click.echo(str(report))
152     return report.return_code
153
154
155 def format_file_in_place(
156     src: Path, line_length: int, fast: bool, write_back: bool = False
157 ) -> bool:
158     """Format the file and rewrite if changed. Return True if changed."""
159     try:
160         contents, encoding = format_file(src, line_length=line_length, fast=fast)
161     except NothingChanged:
162         return False
163
164     if write_back:
165         with open(src, "w", encoding=encoding) as f:
166             f.write(contents)
167     return True
168
169
170 def format_file(
171     src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173     """Reformats a file and returns its contents and encoding."""
174     with tokenize.open(src) as src_buffer:
175         src_contents = src_buffer.read()
176     if src_contents.strip() == '':
177         raise NothingChanged(src)
178
179     dst_contents = format_str(src_contents, line_length=line_length)
180     if src_contents == dst_contents:
181         raise NothingChanged(src)
182
183     if not fast:
184         assert_equivalent(src_contents, dst_contents)
185         assert_stable(src_contents, dst_contents, line_length=line_length)
186     return dst_contents, src_buffer.encoding
187
188
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190     """Reformats a string and returns new contents."""
191     src_node = lib2to3_parse(src_contents)
192     dst_contents = ""
193     comments: List[Line] = []
194     lines = LineGenerator()
195     elt = EmptyLineTracker()
196     py36 = is_python36(src_node)
197     empty_line = Line()
198     after = 0
199     for current_line in lines.visit(src_node):
200         for _ in range(after):
201             dst_contents += str(empty_line)
202         before, after = elt.maybe_empty_lines(current_line)
203         for _ in range(before):
204             dst_contents += str(empty_line)
205         if not current_line.is_comment:
206             for comment in comments:
207                 dst_contents += str(comment)
208             comments = []
209             for line in split_line(current_line, line_length=line_length, py36=py36):
210                 dst_contents += str(line)
211         else:
212             comments.append(current_line)
213     if comments:
214         if elt.previous_defs:
215             # Separate postscriptum comments from the last module-level def.
216             dst_contents += str(empty_line)
217             dst_contents += str(empty_line)
218         for comment in comments:
219             dst_contents += str(comment)
220     return dst_contents
221
222
223 def lib2to3_parse(src_txt: str) -> Node:
224     """Given a string with source, return the lib2to3 Node."""
225     grammar = pygram.python_grammar_no_print_statement
226     drv = driver.Driver(grammar, pytree.convert)
227     if src_txt[-1] != '\n':
228         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
229         src_txt += nl
230     try:
231         result = drv.parse_string(src_txt, True)
232     except ParseError as pe:
233         lineno, column = pe.context[1]
234         lines = src_txt.splitlines()
235         try:
236             faulty_line = lines[lineno - 1]
237         except IndexError:
238             faulty_line = "<line number missing in source>"
239         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
240
241     if isinstance(result, Leaf):
242         result = Node(syms.file_input, [result])
243     return result
244
245
246 def lib2to3_unparse(node: Node) -> str:
247     """Given a lib2to3 node, return its string representation."""
248     code = str(node)
249     return code
250
251
252 T = TypeVar('T')
253
254
255 class Visitor(Generic[T]):
256     """Basic lib2to3 visitor that yields things on visiting."""
257
258     def visit(self, node: LN) -> Iterator[T]:
259         if node.type < 256:
260             name = token.tok_name[node.type]
261         else:
262             name = type_repr(node.type)
263         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
264
265     def visit_default(self, node: LN) -> Iterator[T]:
266         if isinstance(node, Node):
267             for child in node.children:
268                 yield from self.visit(child)
269
270
271 @dataclass
272 class DebugVisitor(Visitor[T]):
273     tree_depth: int = 0
274
275     def visit_default(self, node: LN) -> Iterator[T]:
276         indent = ' ' * (2 * self.tree_depth)
277         if isinstance(node, Node):
278             _type = type_repr(node.type)
279             out(f'{indent}{_type}', fg='yellow')
280             self.tree_depth += 1
281             for child in node.children:
282                 yield from self.visit(child)
283
284             self.tree_depth -= 1
285             out(f'{indent}/{_type}', fg='yellow', bold=False)
286         else:
287             _type = token.tok_name.get(node.type, str(node.type))
288             out(f'{indent}{_type}', fg='blue', nl=False)
289             if node.prefix:
290                 # We don't have to handle prefixes for `Node` objects since
291                 # that delegates to the first child anyway.
292                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
293             out(f' {node.value!r}', fg='blue', bold=False)
294
295
296 KEYWORDS = set(keyword.kwlist)
297 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
298 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
299 STATEMENT = {
300     syms.if_stmt,
301     syms.while_stmt,
302     syms.for_stmt,
303     syms.try_stmt,
304     syms.except_clause,
305     syms.with_stmt,
306     syms.funcdef,
307     syms.classdef,
308 }
309 STANDALONE_COMMENT = 153
310 LOGIC_OPERATORS = {'and', 'or'}
311 COMPARATORS = {
312     token.LESS,
313     token.GREATER,
314     token.EQEQUAL,
315     token.NOTEQUAL,
316     token.LESSEQUAL,
317     token.GREATEREQUAL,
318 }
319 MATH_OPERATORS = {
320     token.PLUS,
321     token.MINUS,
322     token.STAR,
323     token.SLASH,
324     token.VBAR,
325     token.AMPER,
326     token.PERCENT,
327     token.CIRCUMFLEX,
328     token.LEFTSHIFT,
329     token.RIGHTSHIFT,
330     token.DOUBLESTAR,
331     token.DOUBLESLASH,
332 }
333 COMPREHENSION_PRIORITY = 20
334 COMMA_PRIORITY = 10
335 LOGIC_PRIORITY = 5
336 STRING_PRIORITY = 4
337 COMPARATOR_PRIORITY = 3
338 MATH_PRIORITY = 1
339
340
341 @dataclass
342 class BracketTracker:
343     depth: int = 0
344     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
345     delimiters: Dict[LeafID, Priority] = Factory(dict)
346     previous: Optional[Leaf] = None
347
348     def mark(self, leaf: Leaf) -> None:
349         if leaf.type == token.COMMENT:
350             return
351
352         if leaf.type in CLOSING_BRACKETS:
353             self.depth -= 1
354             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
355             leaf.opening_bracket = opening_bracket
356         leaf.bracket_depth = self.depth
357         if self.depth == 0:
358             delim = is_delimiter(leaf)
359             if delim:
360                 self.delimiters[id(leaf)] = delim
361             elif self.previous is not None:
362                 if leaf.type == token.STRING and self.previous.type == token.STRING:
363                     self.delimiters[id(self.previous)] = STRING_PRIORITY
364                 elif (
365                     leaf.type == token.NAME
366                     and leaf.value == 'for'
367                     and leaf.parent
368                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
369                 ):
370                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
371                 elif (
372                     leaf.type == token.NAME
373                     and leaf.value == 'if'
374                     and leaf.parent
375                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
376                 ):
377                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
378                 elif (
379                     leaf.type == token.NAME
380                     and leaf.value in LOGIC_OPERATORS
381                     and leaf.parent
382                 ):
383                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
384         if leaf.type in OPENING_BRACKETS:
385             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
386             self.depth += 1
387         self.previous = leaf
388
389     def any_open_brackets(self) -> bool:
390         """Returns True if there is an yet unmatched open bracket on the line."""
391         return bool(self.bracket_match)
392
393     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
394         """Returns the highest priority of a delimiter found on the line.
395
396         Values are consistent with what `is_delimiter()` returns.
397         """
398         return max(v for k, v in self.delimiters.items() if k not in exclude)
399
400
401 @dataclass
402 class Line:
403     depth: int = 0
404     leaves: List[Leaf] = Factory(list)
405     comments: Dict[LeafID, Leaf] = Factory(dict)
406     bracket_tracker: BracketTracker = Factory(BracketTracker)
407     inside_brackets: bool = False
408     has_for: bool = False
409     _for_loop_variable: bool = False
410
411     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
412         has_value = leaf.value.strip()
413         if not has_value:
414             return
415
416         if self.leaves and not preformatted:
417             # Note: at this point leaf.prefix should be empty except for
418             # imports, for which we only preserve newlines.
419             leaf.prefix += whitespace(leaf)
420         if self.inside_brackets or not preformatted:
421             self.maybe_decrement_after_for_loop_variable(leaf)
422             self.bracket_tracker.mark(leaf)
423             self.maybe_remove_trailing_comma(leaf)
424             self.maybe_increment_for_loop_variable(leaf)
425             if self.maybe_adapt_standalone_comment(leaf):
426                 return
427
428         if not self.append_comment(leaf):
429             self.leaves.append(leaf)
430
431     @property
432     def is_comment(self) -> bool:
433         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
434
435     @property
436     def is_decorator(self) -> bool:
437         return bool(self) and self.leaves[0].type == token.AT
438
439     @property
440     def is_import(self) -> bool:
441         return bool(self) and is_import(self.leaves[0])
442
443     @property
444     def is_class(self) -> bool:
445         return (
446             bool(self)
447             and self.leaves[0].type == token.NAME
448             and self.leaves[0].value == 'class'
449         )
450
451     @property
452     def is_def(self) -> bool:
453         """Also returns True for async defs."""
454         try:
455             first_leaf = self.leaves[0]
456         except IndexError:
457             return False
458
459         try:
460             second_leaf: Optional[Leaf] = self.leaves[1]
461         except IndexError:
462             second_leaf = None
463         return (
464             (first_leaf.type == token.NAME and first_leaf.value == 'def')
465             or (
466                 first_leaf.type == token.NAME
467                 and first_leaf.value == 'async'
468                 and second_leaf is not None
469                 and second_leaf.type == token.NAME
470                 and second_leaf.value == 'def'
471             )
472         )
473
474     @property
475     def is_flow_control(self) -> bool:
476         return (
477             bool(self)
478             and self.leaves[0].type == token.NAME
479             and self.leaves[0].value in FLOW_CONTROL
480         )
481
482     @property
483     def is_yield(self) -> bool:
484         return (
485             bool(self)
486             and self.leaves[0].type == token.NAME
487             and self.leaves[0].value == 'yield'
488         )
489
490     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
491         if not (
492             self.leaves
493             and self.leaves[-1].type == token.COMMA
494             and closing.type in CLOSING_BRACKETS
495         ):
496             return False
497
498         if closing.type == token.RSQB or closing.type == token.RBRACE:
499             self.leaves.pop()
500             return True
501
502         # For parens let's check if it's safe to remove the comma.  If the
503         # trailing one is the only one, we might mistakenly change a tuple
504         # into a different type by removing the comma.
505         depth = closing.bracket_depth + 1
506         commas = 0
507         opening = closing.opening_bracket
508         for _opening_index, leaf in enumerate(self.leaves):
509             if leaf is opening:
510                 break
511
512         else:
513             return False
514
515         for leaf in self.leaves[_opening_index + 1:]:
516             if leaf is closing:
517                 break
518
519             bracket_depth = leaf.bracket_depth
520             if bracket_depth == depth and leaf.type == token.COMMA:
521                 commas += 1
522                 if leaf.parent and leaf.parent.type == syms.arglist:
523                     commas += 1
524                     break
525
526         if commas > 1:
527             self.leaves.pop()
528             return True
529
530         return False
531
532     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
533         """In a for loop, or comprehension, the variables are often unpacks.
534
535         To avoid splitting on the comma in this situation, we will increase
536         the depth of tokens between `for` and `in`.
537         """
538         if leaf.type == token.NAME and leaf.value == 'for':
539             self.has_for = True
540             self.bracket_tracker.depth += 1
541             self._for_loop_variable = True
542             return True
543
544         return False
545
546     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
547         # See `maybe_increment_for_loop_variable` above for explanation.
548         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
549             self.bracket_tracker.depth -= 1
550             self._for_loop_variable = False
551             return True
552
553         return False
554
555     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
556         """Hack a standalone comment to act as a trailing comment for line splitting.
557
558         If this line has brackets and a standalone `comment`, we need to adapt
559         it to be able to still reformat the line.
560
561         This is not perfect, the line to which the standalone comment gets
562         appended will appear "too long" when splitting.
563         """
564         if not (
565             comment.type == STANDALONE_COMMENT
566             and self.bracket_tracker.any_open_brackets()
567         ):
568             return False
569
570         comment.type = token.COMMENT
571         comment.prefix = '\n' + '    ' * (self.depth + 1)
572         return self.append_comment(comment)
573
574     def append_comment(self, comment: Leaf) -> bool:
575         if comment.type != token.COMMENT:
576             return False
577
578         try:
579             after = id(self.last_non_delimiter())
580         except LookupError:
581             comment.type = STANDALONE_COMMENT
582             comment.prefix = ''
583             return False
584
585         else:
586             if after in self.comments:
587                 self.comments[after].value += str(comment)
588             else:
589                 self.comments[after] = comment
590             return True
591
592     def last_non_delimiter(self) -> Leaf:
593         for i in range(len(self.leaves)):
594             last = self.leaves[-i - 1]
595             if not is_delimiter(last):
596                 return last
597
598         raise LookupError("No non-delimiters found")
599
600     def __str__(self) -> str:
601         if not self:
602             return '\n'
603
604         indent = '    ' * self.depth
605         leaves = iter(self.leaves)
606         first = next(leaves)
607         res = f'{first.prefix}{indent}{first.value}'
608         for leaf in leaves:
609             res += str(leaf)
610         for comment in self.comments.values():
611             res += str(comment)
612         return res + '\n'
613
614     def __bool__(self) -> bool:
615         return bool(self.leaves or self.comments)
616
617
618 @dataclass
619 class EmptyLineTracker:
620     """Provides a stateful method that returns the number of potential extra
621     empty lines needed before and after the currently processed line.
622
623     Note: this tracker works on lines that haven't been split yet.  It assumes
624     the prefix of the first leaf consists of optional newlines.  Those newlines
625     are consumed by `maybe_empty_lines()` and included in the computation.
626     """
627     previous_line: Optional[Line] = None
628     previous_after: int = 0
629     previous_defs: List[int] = Factory(list)
630
631     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
632         """Returns the number of extra empty lines before and after the `current_line`.
633
634         This is for separating `def`, `async def` and `class` with extra empty lines
635         (two on module-level), as well as providing an extra empty line after flow
636         control keywords to make them more prominent.
637         """
638         if current_line.is_comment:
639             # Don't count standalone comments towards previous empty lines.
640             return 0, 0
641
642         before, after = self._maybe_empty_lines(current_line)
643         before -= self.previous_after
644         self.previous_after = after
645         self.previous_line = current_line
646         return before, after
647
648     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
649         if current_line.leaves:
650             # Consume the first leaf's extra newlines.
651             first_leaf = current_line.leaves[0]
652             before = int('\n' in first_leaf.prefix)
653             first_leaf.prefix = ''
654         else:
655             before = 0
656         depth = current_line.depth
657         while self.previous_defs and self.previous_defs[-1] >= depth:
658             self.previous_defs.pop()
659             before = 1 if depth else 2
660         is_decorator = current_line.is_decorator
661         if is_decorator or current_line.is_def or current_line.is_class:
662             if not is_decorator:
663                 self.previous_defs.append(depth)
664             if self.previous_line is None:
665                 # Don't insert empty lines before the first line in the file.
666                 return 0, 0
667
668             if self.previous_line and self.previous_line.is_decorator:
669                 # Don't insert empty lines between decorators.
670                 return 0, 0
671
672             newlines = 2
673             if current_line.depth:
674                 newlines -= 1
675             return newlines, 0
676
677         if current_line.is_flow_control:
678             return before, 1
679
680         if (
681             self.previous_line
682             and self.previous_line.is_import
683             and not current_line.is_import
684             and depth == self.previous_line.depth
685         ):
686             return (before or 1), 0
687
688         if (
689             self.previous_line
690             and self.previous_line.is_yield
691             and (not current_line.is_yield or depth != self.previous_line.depth)
692         ):
693             return (before or 1), 0
694
695         return before, 0
696
697
698 @dataclass
699 class LineGenerator(Visitor[Line]):
700     """Generates reformatted Line objects.  Empty lines are not emitted.
701
702     Note: destroys the tree it's visiting by mutating prefixes of its leaves
703     in ways that will no longer stringify to valid Python code on the tree.
704     """
705     current_line: Line = Factory(Line)
706     standalone_comments: List[Leaf] = Factory(list)
707
708     def line(self, indent: int = 0) -> Iterator[Line]:
709         """Generate a line.
710
711         If the line is empty, only emit if it makes sense.
712         If the line is too long, split it first and then generate.
713
714         If any lines were generated, set up a new current_line.
715         """
716         if not self.current_line:
717             self.current_line.depth += indent
718             return  # Line is empty, don't emit. Creating a new one unnecessary.
719
720         complete_line = self.current_line
721         self.current_line = Line(depth=complete_line.depth + indent)
722         yield complete_line
723
724     def visit_default(self, node: LN) -> Iterator[Line]:
725         if isinstance(node, Leaf):
726             for comment in generate_comments(node):
727                 if self.current_line.bracket_tracker.any_open_brackets():
728                     # any comment within brackets is subject to splitting
729                     self.current_line.append(comment)
730                 elif comment.type == token.COMMENT:
731                     # regular trailing comment
732                     self.current_line.append(comment)
733                     yield from self.line()
734
735                 else:
736                     # regular standalone comment, to be processed later (see
737                     # docstring in `generate_comments()`
738                     self.standalone_comments.append(comment)
739             normalize_prefix(node)
740             if node.type not in WHITESPACE:
741                 for comment in self.standalone_comments:
742                     yield from self.line()
743
744                     self.current_line.append(comment)
745                     yield from self.line()
746
747                 self.standalone_comments = []
748                 self.current_line.append(node)
749         yield from super().visit_default(node)
750
751     def visit_suite(self, node: Node) -> Iterator[Line]:
752         """Body of a statement after a colon."""
753         children = iter(node.children)
754         # Process newline before indenting.  It might contain an inline
755         # comment that should go right after the colon.
756         newline = next(children)
757         yield from self.visit(newline)
758         yield from self.line(+1)
759
760         for child in children:
761             yield from self.visit(child)
762
763         yield from self.line(-1)
764
765     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
766         """Visit a statement.
767
768         The relevant Python language keywords for this statement are NAME leaves
769         within it.
770         """
771         for child in node.children:
772             if child.type == token.NAME and child.value in keywords:  # type: ignore
773                 yield from self.line()
774
775             yield from self.visit(child)
776
777     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
778         """A statement without nested statements."""
779         is_suite_like = node.parent and node.parent.type in STATEMENT
780         if is_suite_like:
781             yield from self.line(+1)
782             yield from self.visit_default(node)
783             yield from self.line(-1)
784
785         else:
786             yield from self.line()
787             yield from self.visit_default(node)
788
789     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
790         yield from self.line()
791
792         children = iter(node.children)
793         for child in children:
794             yield from self.visit(child)
795
796             if child.type == token.NAME and child.value == 'async':  # type: ignore
797                 break
798
799         internal_stmt = next(children)
800         for child in internal_stmt.children:
801             yield from self.visit(child)
802
803     def visit_decorators(self, node: Node) -> Iterator[Line]:
804         for child in node.children:
805             yield from self.line()
806             yield from self.visit(child)
807
808     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
809         yield from self.line()
810
811     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
812         yield from self.visit_default(leaf)
813         yield from self.line()
814
815     def __attrs_post_init__(self) -> None:
816         """You are in a twisty little maze of passages."""
817         v = self.visit_stmt
818         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
819         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
820         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
821         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
822         self.visit_except_clause = partial(v, keywords={'except'})
823         self.visit_funcdef = partial(v, keywords={'def'})
824         self.visit_with_stmt = partial(v, keywords={'with'})
825         self.visit_classdef = partial(v, keywords={'class'})
826         self.visit_async_funcdef = self.visit_async_stmt
827         self.visit_decorated = self.visit_decorators
828
829
830 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
831 OPENING_BRACKETS = set(BRACKET.keys())
832 CLOSING_BRACKETS = set(BRACKET.values())
833 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
834 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
835
836
837 def whitespace(leaf: Leaf) -> str:  # noqa C901
838     """Return whitespace prefix if needed for the given `leaf`."""
839     NO = ''
840     SPACE = ' '
841     DOUBLESPACE = '  '
842     t = leaf.type
843     p = leaf.parent
844     v = leaf.value
845     if t in ALWAYS_NO_SPACE:
846         return NO
847
848     if t == token.COMMENT:
849         return DOUBLESPACE
850
851     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
852     prev = leaf.prev_sibling
853     if not prev:
854         prevp = preceding_leaf(p)
855         if not prevp or prevp.type in OPENING_BRACKETS:
856             return NO
857
858         if prevp.type == token.EQUAL:
859             if prevp.parent and prevp.parent.type in {
860                 syms.typedargslist,
861                 syms.varargslist,
862                 syms.parameters,
863                 syms.arglist,
864                 syms.argument,
865             }:
866                 return NO
867
868         elif prevp.type == token.DOUBLESTAR:
869             if prevp.parent and prevp.parent.type in {
870                 syms.typedargslist,
871                 syms.varargslist,
872                 syms.parameters,
873                 syms.arglist,
874                 syms.dictsetmaker,
875             }:
876                 return NO
877
878         elif prevp.type == token.COLON:
879             if prevp.parent and prevp.parent.type == syms.subscript:
880                 return NO
881
882         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
883             return NO
884
885     elif prev.type in OPENING_BRACKETS:
886         return NO
887
888     if p.type in {syms.parameters, syms.arglist}:
889         # untyped function signatures or calls
890         if t == token.RPAR:
891             return NO
892
893         if not prev or prev.type != token.COMMA:
894             return NO
895
896     if p.type == syms.varargslist:
897         # lambdas
898         if t == token.RPAR:
899             return NO
900
901         if prev and prev.type != token.COMMA:
902             return NO
903
904     elif p.type == syms.typedargslist:
905         # typed function signatures
906         if not prev:
907             return NO
908
909         if t == token.EQUAL:
910             if prev.type != syms.tname:
911                 return NO
912
913         elif prev.type == token.EQUAL:
914             # A bit hacky: if the equal sign has whitespace, it means we
915             # previously found it's a typed argument.  So, we're using that, too.
916             return prev.prefix
917
918         elif prev.type != token.COMMA:
919             return NO
920
921     elif p.type == syms.tname:
922         # type names
923         if not prev:
924             prevp = preceding_leaf(p)
925             if not prevp or prevp.type != token.COMMA:
926                 return NO
927
928     elif p.type == syms.trailer:
929         # attributes and calls
930         if t == token.LPAR or t == token.RPAR:
931             return NO
932
933         if not prev:
934             if t == token.DOT:
935                 prevp = preceding_leaf(p)
936                 if not prevp or prevp.type != token.NUMBER:
937                     return NO
938
939             elif t == token.LSQB:
940                 return NO
941
942         elif prev.type != token.COMMA:
943             return NO
944
945     elif p.type == syms.argument:
946         # single argument
947         if t == token.EQUAL:
948             return NO
949
950         if not prev:
951             prevp = preceding_leaf(p)
952             if not prevp or prevp.type == token.LPAR:
953                 return NO
954
955         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
956             return NO
957
958     elif p.type == syms.decorator:
959         # decorators
960         return NO
961
962     elif p.type == syms.dotted_name:
963         if prev:
964             return NO
965
966         prevp = preceding_leaf(p)
967         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
968             return NO
969
970     elif p.type == syms.classdef:
971         if t == token.LPAR:
972             return NO
973
974         if prev and prev.type == token.LPAR:
975             return NO
976
977     elif p.type == syms.subscript:
978         # indexing
979         if not prev:
980             assert p.parent is not None, "subscripts are always parented"
981             if p.parent.type == syms.subscriptlist:
982                 return SPACE
983
984             return NO
985
986         elif prev.type == token.COLON:
987             return NO
988
989     elif p.type == syms.atom:
990         if prev and t == token.DOT:
991             # dots, but not the first one.
992             return NO
993
994     elif (
995         p.type == syms.listmaker
996         or p.type == syms.testlist_gexp
997         or p.type == syms.subscriptlist
998     ):
999         # list interior, including unpacking
1000         if not prev:
1001             return NO
1002
1003     elif p.type == syms.dictsetmaker:
1004         # dict and set interior, including unpacking
1005         if not prev:
1006             return NO
1007
1008         if prev.type == token.DOUBLESTAR:
1009             return NO
1010
1011     elif p.type in {syms.factor, syms.star_expr}:
1012         # unary ops
1013         if not prev:
1014             prevp = preceding_leaf(p)
1015             if not prevp or prevp.type in OPENING_BRACKETS:
1016                 return NO
1017
1018             prevp_parent = prevp.parent
1019             assert prevp_parent is not None
1020             if prevp.type == token.COLON and prevp_parent.type in {
1021                 syms.subscript, syms.sliceop
1022             }:
1023                 return NO
1024
1025             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1026                 return NO
1027
1028         elif t == token.NAME or t == token.NUMBER:
1029             return NO
1030
1031     elif p.type == syms.import_from:
1032         if t == token.DOT:
1033             if prev and prev.type == token.DOT:
1034                 return NO
1035
1036         elif t == token.NAME:
1037             if v == 'import':
1038                 return SPACE
1039
1040             if prev and prev.type == token.DOT:
1041                 return NO
1042
1043     elif p.type == syms.sliceop:
1044         return NO
1045
1046     return SPACE
1047
1048
1049 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1050     """Returns the first leaf that precedes `node`, if any."""
1051     while node:
1052         res = node.prev_sibling
1053         if res:
1054             if isinstance(res, Leaf):
1055                 return res
1056
1057             try:
1058                 return list(res.leaves())[-1]
1059
1060             except IndexError:
1061                 return None
1062
1063         node = node.parent
1064     return None
1065
1066
1067 def is_delimiter(leaf: Leaf) -> int:
1068     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1069
1070     Higher numbers are higher priority.
1071     """
1072     if leaf.type == token.COMMA:
1073         return COMMA_PRIORITY
1074
1075     if leaf.type in COMPARATORS:
1076         return COMPARATOR_PRIORITY
1077
1078     if (
1079         leaf.type in MATH_OPERATORS
1080         and leaf.parent
1081         and leaf.parent.type not in {syms.factor, syms.star_expr}
1082     ):
1083         return MATH_PRIORITY
1084
1085     return 0
1086
1087
1088 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1089     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1090
1091     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1092     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1093     move because it does away with modifying the grammar to include all the
1094     possible places in which comments can be placed.
1095
1096     The sad consequence for us though is that comments don't "belong" anywhere.
1097     This is why this function generates simple parentless Leaf objects for
1098     comments.  We simply don't know what the correct parent should be.
1099
1100     No matter though, we can live without this.  We really only need to
1101     differentiate between inline and standalone comments.  The latter don't
1102     share the line with any code.
1103
1104     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1105     are emitted with a fake STANDALONE_COMMENT token identifier.
1106     """
1107     if not leaf.prefix:
1108         return
1109
1110     if '#' not in leaf.prefix:
1111         return
1112
1113     before_comment, content = leaf.prefix.split('#', 1)
1114     content = content.rstrip()
1115     if content and (content[0] not in {' ', '!', '#'}):
1116         content = ' ' + content
1117     is_standalone_comment = (
1118         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1119     )
1120     if not is_standalone_comment:
1121         # simple trailing comment
1122         yield Leaf(token.COMMENT, value='#' + content)
1123         return
1124
1125     for line in ('#' + content).split('\n'):
1126         line = line.lstrip()
1127         if not line.startswith('#'):
1128             continue
1129
1130         yield Leaf(STANDALONE_COMMENT, line)
1131
1132
1133 def split_line(
1134     line: Line, line_length: int, inner: bool = False, py36: bool = False
1135 ) -> Iterator[Line]:
1136     """Splits a `line` into potentially many lines.
1137
1138     They should fit in the allotted `line_length` but might not be able to.
1139     `inner` signifies that there were a pair of brackets somewhere around the
1140     current `line`, possibly transitively. This means we can fallback to splitting
1141     by delimiters if the LHS/RHS don't yield any results.
1142
1143     If `py36` is True, splitting may generate syntax that is only compatible
1144     with Python 3.6 and later.
1145     """
1146     line_str = str(line).strip('\n')
1147     if len(line_str) <= line_length and '\n' not in line_str:
1148         yield line
1149         return
1150
1151     if line.is_def:
1152         split_funcs = [left_hand_split]
1153     elif line.inside_brackets:
1154         split_funcs = [delimiter_split]
1155         if '\n' not in line_str:
1156             # Only attempt RHS if we don't have multiline strings or comments
1157             # on this line.
1158             split_funcs.append(right_hand_split)
1159     else:
1160         split_funcs = [right_hand_split]
1161     for split_func in split_funcs:
1162         # We are accumulating lines in `result` because we might want to abort
1163         # mission and return the original line in the end, or attempt a different
1164         # split altogether.
1165         result: List[Line] = []
1166         try:
1167             for l in split_func(line, py36=py36):
1168                 if str(l).strip('\n') == line_str:
1169                     raise CannotSplit("Split function returned an unchanged result")
1170
1171                 result.extend(
1172                     split_line(l, line_length=line_length, inner=True, py36=py36)
1173                 )
1174         except CannotSplit as cs:
1175             continue
1176
1177         else:
1178             yield from result
1179             break
1180
1181     else:
1182         yield line
1183
1184
1185 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1186     """Split line into many lines, starting with the first matching bracket pair.
1187
1188     Note: this usually looks weird, only use this for function definitions.
1189     Prefer RHS otherwise.
1190     """
1191     head = Line(depth=line.depth)
1192     body = Line(depth=line.depth + 1, inside_brackets=True)
1193     tail = Line(depth=line.depth)
1194     tail_leaves: List[Leaf] = []
1195     body_leaves: List[Leaf] = []
1196     head_leaves: List[Leaf] = []
1197     current_leaves = head_leaves
1198     matching_bracket = None
1199     for leaf in line.leaves:
1200         if (
1201             current_leaves is body_leaves
1202             and leaf.type in CLOSING_BRACKETS
1203             and leaf.opening_bracket is matching_bracket
1204         ):
1205             current_leaves = tail_leaves if body_leaves else head_leaves
1206         current_leaves.append(leaf)
1207         if current_leaves is head_leaves:
1208             if leaf.type in OPENING_BRACKETS:
1209                 matching_bracket = leaf
1210                 current_leaves = body_leaves
1211     # Since body is a new indent level, remove spurious leading whitespace.
1212     if body_leaves:
1213         normalize_prefix(body_leaves[0])
1214     # Build the new lines.
1215     for result, leaves in (
1216         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1217     ):
1218         for leaf in leaves:
1219             result.append(leaf, preformatted=True)
1220             comment_after = line.comments.get(id(leaf))
1221             if comment_after:
1222                 result.append(comment_after, preformatted=True)
1223     split_succeeded_or_raise(head, body, tail)
1224     for result in (head, body, tail):
1225         if result:
1226             yield result
1227
1228
1229 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1230     """Split line into many lines, starting with the last matching bracket pair."""
1231     head = Line(depth=line.depth)
1232     body = Line(depth=line.depth + 1, inside_brackets=True)
1233     tail = Line(depth=line.depth)
1234     tail_leaves: List[Leaf] = []
1235     body_leaves: List[Leaf] = []
1236     head_leaves: List[Leaf] = []
1237     current_leaves = tail_leaves
1238     opening_bracket = None
1239     for leaf in reversed(line.leaves):
1240         if current_leaves is body_leaves:
1241             if leaf is opening_bracket:
1242                 current_leaves = head_leaves if body_leaves else tail_leaves
1243         current_leaves.append(leaf)
1244         if current_leaves is tail_leaves:
1245             if leaf.type in CLOSING_BRACKETS:
1246                 opening_bracket = leaf.opening_bracket
1247                 current_leaves = body_leaves
1248     tail_leaves.reverse()
1249     body_leaves.reverse()
1250     head_leaves.reverse()
1251     # Since body is a new indent level, remove spurious leading whitespace.
1252     if body_leaves:
1253         normalize_prefix(body_leaves[0])
1254     # Build the new lines.
1255     for result, leaves in (
1256         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1257     ):
1258         for leaf in leaves:
1259             result.append(leaf, preformatted=True)
1260             comment_after = line.comments.get(id(leaf))
1261             if comment_after:
1262                 result.append(comment_after, preformatted=True)
1263     split_succeeded_or_raise(head, body, tail)
1264     for result in (head, body, tail):
1265         if result:
1266             yield result
1267
1268
1269 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1270     tail_len = len(str(tail).strip())
1271     if not body:
1272         if tail_len == 0:
1273             raise CannotSplit("Splitting brackets produced the same line")
1274
1275         elif tail_len < 3:
1276             raise CannotSplit(
1277                 f"Splitting brackets on an empty body to save "
1278                 f"{tail_len} characters is not worth it"
1279             )
1280
1281
1282 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1283     """Split according to delimiters of the highest priority.
1284
1285     This kind of split doesn't increase indentation.
1286     If `py36` is True, the split will add trailing commas also in function
1287     signatures that contain * and **.
1288     """
1289     try:
1290         last_leaf = line.leaves[-1]
1291     except IndexError:
1292         raise CannotSplit("Line empty")
1293
1294     delimiters = line.bracket_tracker.delimiters
1295     try:
1296         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1297     except ValueError:
1298         raise CannotSplit("No delimiters found")
1299
1300     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1301     lowest_depth = sys.maxsize
1302     trailing_comma_safe = True
1303     for leaf in line.leaves:
1304         current_line.append(leaf, preformatted=True)
1305         comment_after = line.comments.get(id(leaf))
1306         if comment_after:
1307             current_line.append(comment_after, preformatted=True)
1308         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1309         if (
1310             leaf.bracket_depth == lowest_depth
1311             and leaf.type == token.STAR
1312             or leaf.type == token.DOUBLESTAR
1313         ):
1314             trailing_comma_safe = trailing_comma_safe and py36
1315         leaf_priority = delimiters.get(id(leaf))
1316         if leaf_priority == delimiter_priority:
1317             normalize_prefix(current_line.leaves[0])
1318             yield current_line
1319
1320             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1321     if current_line:
1322         if (
1323             delimiter_priority == COMMA_PRIORITY
1324             and current_line.leaves[-1].type != token.COMMA
1325             and trailing_comma_safe
1326         ):
1327             current_line.append(Leaf(token.COMMA, ','))
1328         normalize_prefix(current_line.leaves[0])
1329         yield current_line
1330
1331
1332 def is_import(leaf: Leaf) -> bool:
1333     """Returns True if the given leaf starts an import statement."""
1334     p = leaf.parent
1335     t = leaf.type
1336     v = leaf.value
1337     return bool(
1338         t == token.NAME
1339         and (
1340             (v == 'import' and p and p.type == syms.import_name)
1341             or (v == 'from' and p and p.type == syms.import_from)
1342         )
1343     )
1344
1345
1346 def normalize_prefix(leaf: Leaf) -> None:
1347     """Leave existing extra newlines for imports.  Remove everything else."""
1348     if is_import(leaf):
1349         spl = leaf.prefix.split('#', 1)
1350         nl_count = spl[0].count('\n')
1351         leaf.prefix = '\n' * nl_count
1352         return
1353
1354     leaf.prefix = ''
1355
1356
1357 def is_python36(node: Node) -> bool:
1358     """Returns True if the current file is using Python 3.6+ features.
1359
1360     Currently looking for:
1361     - f-strings; and
1362     - trailing commas after * or ** in function signatures.
1363     """
1364     for n in node.pre_order():
1365         if n.type == token.STRING:
1366             value_head = n.value[:2]  # type: ignore
1367             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1368                 return True
1369
1370         elif (
1371             n.type == syms.typedargslist
1372             and n.children
1373             and n.children[-1].type == token.COMMA
1374         ):
1375             for ch in n.children:
1376                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1377                     return True
1378
1379     return False
1380
1381
1382 PYTHON_EXTENSIONS = {'.py'}
1383 BLACKLISTED_DIRECTORIES = {
1384     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1385 }
1386
1387
1388 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1389     for child in path.iterdir():
1390         if child.is_dir():
1391             if child.name in BLACKLISTED_DIRECTORIES:
1392                 continue
1393
1394             yield from gen_python_files_in_dir(child)
1395
1396         elif child.suffix in PYTHON_EXTENSIONS:
1397             yield child
1398
1399
1400 @dataclass
1401 class Report:
1402     """Provides a reformatting counter."""
1403     change_count: int = 0
1404     same_count: int = 0
1405     failure_count: int = 0
1406
1407     def done(self, src: Path, changed: bool) -> None:
1408         """Increment the counter for successful reformatting. Write out a message."""
1409         if changed:
1410             out(f'reformatted {src}')
1411             self.change_count += 1
1412         else:
1413             out(f'{src} already well formatted, good job.', bold=False)
1414             self.same_count += 1
1415
1416     def failed(self, src: Path, message: str) -> None:
1417         """Increment the counter for failed reformatting. Write out a message."""
1418         err(f'error: cannot format {src}: {message}')
1419         self.failure_count += 1
1420
1421     @property
1422     def return_code(self) -> int:
1423         """Which return code should the app use considering the current state."""
1424         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1425         # 126 we have special returncodes reserved by the shell.
1426         if self.failure_count:
1427             return 123
1428
1429         elif self.change_count:
1430             return 1
1431
1432         return 0
1433
1434     def __str__(self) -> str:
1435         """A color report of the current state.
1436
1437         Use `click.unstyle` to remove colors.
1438         """
1439         report = []
1440         if self.change_count:
1441             s = 's' if self.change_count > 1 else ''
1442             report.append(
1443                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1444             )
1445         if self.same_count:
1446             s = 's' if self.same_count > 1 else ''
1447             report.append(f'{self.same_count} file{s} left unchanged')
1448         if self.failure_count:
1449             s = 's' if self.failure_count > 1 else ''
1450             report.append(
1451                 click.style(
1452                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1453                 )
1454             )
1455         return ', '.join(report) + '.'
1456
1457
1458 def assert_equivalent(src: str, dst: str) -> None:
1459     """Raises AssertionError if `src` and `dst` aren't equivalent.
1460
1461     This is a temporary sanity check until Black becomes stable.
1462     """
1463
1464     import ast
1465     import traceback
1466
1467     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1468         """Simple visitor generating strings to compare ASTs by content."""
1469         yield f"{'  ' * depth}{node.__class__.__name__}("
1470
1471         for field in sorted(node._fields):
1472             try:
1473                 value = getattr(node, field)
1474             except AttributeError:
1475                 continue
1476
1477             yield f"{'  ' * (depth+1)}{field}="
1478
1479             if isinstance(value, list):
1480                 for item in value:
1481                     if isinstance(item, ast.AST):
1482                         yield from _v(item, depth + 2)
1483
1484             elif isinstance(value, ast.AST):
1485                 yield from _v(value, depth + 2)
1486
1487             else:
1488                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1489
1490         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1491
1492     try:
1493         src_ast = ast.parse(src)
1494     except Exception as exc:
1495         raise AssertionError(f"cannot parse source: {exc}") from None
1496
1497     try:
1498         dst_ast = ast.parse(dst)
1499     except Exception as exc:
1500         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1501         raise AssertionError(
1502             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1503             f"Please report a bug on https://github.com/ambv/black/issues.  "
1504             f"This invalid output might be helpful: {log}"
1505         ) from None
1506
1507     src_ast_str = '\n'.join(_v(src_ast))
1508     dst_ast_str = '\n'.join(_v(dst_ast))
1509     if src_ast_str != dst_ast_str:
1510         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1511         raise AssertionError(
1512             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1513             f"the source.  "
1514             f"Please report a bug on https://github.com/ambv/black/issues.  "
1515             f"This diff might be helpful: {log}"
1516         ) from None
1517
1518
1519 def assert_stable(src: str, dst: str, line_length: int) -> None:
1520     """Raises AssertionError if `dst` reformats differently the second time.
1521
1522     This is a temporary sanity check until Black becomes stable.
1523     """
1524     newdst = format_str(dst, line_length=line_length)
1525     if dst != newdst:
1526         log = dump_to_file(
1527             diff(src, dst, 'source', 'first pass'),
1528             diff(dst, newdst, 'first pass', 'second pass'),
1529         )
1530         raise AssertionError(
1531             f"INTERNAL ERROR: Black produced different code on the second pass "
1532             f"of the formatter.  "
1533             f"Please report a bug on https://github.com/ambv/black/issues.  "
1534             f"This diff might be helpful: {log}"
1535         ) from None
1536
1537
1538 def dump_to_file(*output: str) -> str:
1539     """Dumps `output` to a temporary file. Returns path to the file."""
1540     import tempfile
1541
1542     with tempfile.NamedTemporaryFile(
1543         mode='w', prefix='blk_', suffix='.log', delete=False
1544     ) as f:
1545         for lines in output:
1546             f.write(lines)
1547             f.write('\n')
1548     return f.name
1549
1550
1551 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1552     """Returns a udiff string between strings `a` and `b`."""
1553     import difflib
1554
1555     a_lines = [line + '\n' for line in a.split('\n')]
1556     b_lines = [line + '\n' for line in b.split('\n')]
1557     return ''.join(
1558         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1559     )
1560
1561
1562 if __name__ == '__main__':
1563     main()