]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

More support for numpy tuple indexing
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 import sys
11 from typing import (
12     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
13 )
14
15 from attr import dataclass, Factory
16 import click
17
18 # lib2to3 fork
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
23
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
26 # types
27 syms = pygram.python_symbols
28 FileContent = str
29 Encoding = str
30 Depth = int
31 NodeType = int
32 LeafID = int
33 Priority = int
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
37
38
39 class NothingChanged(UserWarning):
40     """Raised by `format_file` when the reformatted code is the same as source."""
41
42
43 class CannotSplit(Exception):
44     """A readable split that fits the allotted line length is impossible.
45
46     Raised by `left_hand_split()` and `right_hand_split()`.
47     """
48
49
50 @click.command()
51 @click.option(
52     '-l',
53     '--line-length',
54     type=int,
55     default=DEFAULT_LINE_LENGTH,
56     help='How many character per line to allow.',
57     show_default=True,
58 )
59 @click.option(
60     '--check',
61     is_flag=True,
62     help=(
63         "Don't write back the files, just return the status.  Return code 0 "
64         "means nothing changed.  Return code 1 means some files were "
65         "reformatted.  Return code 123 means there was an internal error."
66     ),
67 )
68 @click.option(
69     '--fast/--safe',
70     is_flag=True,
71     help='If --fast given, skip temporary sanity checks. [default: --safe]',
72 )
73 @click.version_option(version=__version__)
74 @click.argument(
75     'src',
76     nargs=-1,
77     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
78 )
79 @click.pass_context
80 def main(
81     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
82 ) -> None:
83     """The uncompromising code formatter."""
84     sources: List[Path] = []
85     for s in src:
86         p = Path(s)
87         if p.is_dir():
88             sources.extend(gen_python_files_in_dir(p))
89         elif p.is_file():
90             # if a file was explicitly given, we don't care about its extension
91             sources.append(p)
92         else:
93             err(f'invalid path: {s}')
94     if len(sources) == 0:
95         ctx.exit(0)
96     elif len(sources) == 1:
97         p = sources[0]
98         report = Report()
99         try:
100             changed = format_file_in_place(
101                 p, line_length=line_length, fast=fast, write_back=not check
102             )
103             report.done(p, changed)
104         except Exception as exc:
105             report.failed(p, str(exc))
106         ctx.exit(report.return_code)
107     else:
108         loop = asyncio.get_event_loop()
109         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
110         return_code = 1
111         try:
112             return_code = loop.run_until_complete(
113                 schedule_formatting(
114                     sources, line_length, not check, fast, loop, executor
115                 )
116             )
117         finally:
118             loop.close()
119             ctx.exit(return_code)
120
121
122 async def schedule_formatting(
123     sources: List[Path],
124     line_length: int,
125     write_back: bool,
126     fast: bool,
127     loop: BaseEventLoop,
128     executor: Executor,
129 ) -> int:
130     tasks = {
131         src: loop.run_in_executor(
132             executor, format_file_in_place, src, line_length, fast, write_back
133         )
134         for src in sources
135     }
136     await asyncio.wait(tasks.values())
137     cancelled = []
138     report = Report()
139     for src, task in tasks.items():
140         if not task.done():
141             report.failed(src, 'timed out, cancelling')
142             task.cancel()
143             cancelled.append(task)
144         elif task.exception():
145             report.failed(src, str(task.exception()))
146         else:
147             report.done(src, task.result())
148     if cancelled:
149         await asyncio.wait(cancelled, timeout=2)
150     out('All done! ✨ 🍰 ✨')
151     click.echo(str(report))
152     return report.return_code
153
154
155 def format_file_in_place(
156     src: Path, line_length: int, fast: bool, write_back: bool = False
157 ) -> bool:
158     """Format the file and rewrite if changed. Return True if changed."""
159     try:
160         contents, encoding = format_file(src, line_length=line_length, fast=fast)
161     except NothingChanged:
162         return False
163
164     if write_back:
165         with open(src, "w", encoding=encoding) as f:
166             f.write(contents)
167     return True
168
169
170 def format_file(
171     src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173     """Reformats a file and returns its contents and encoding."""
174     with tokenize.open(src) as src_buffer:
175         src_contents = src_buffer.read()
176     if src_contents.strip() == '':
177         raise NothingChanged(src)
178
179     dst_contents = format_str(src_contents, line_length=line_length)
180     if src_contents == dst_contents:
181         raise NothingChanged(src)
182
183     if not fast:
184         assert_equivalent(src_contents, dst_contents)
185         assert_stable(src_contents, dst_contents, line_length=line_length)
186     return dst_contents, src_buffer.encoding
187
188
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190     """Reformats a string and returns new contents."""
191     src_node = lib2to3_parse(src_contents)
192     dst_contents = ""
193     comments: List[Line] = []
194     lines = LineGenerator()
195     elt = EmptyLineTracker()
196     py36 = is_python36(src_node)
197     empty_line = Line()
198     after = 0
199     for current_line in lines.visit(src_node):
200         for _ in range(after):
201             dst_contents += str(empty_line)
202         before, after = elt.maybe_empty_lines(current_line)
203         for _ in range(before):
204             dst_contents += str(empty_line)
205         if not current_line.is_comment:
206             for comment in comments:
207                 dst_contents += str(comment)
208             comments = []
209             for line in split_line(current_line, line_length=line_length, py36=py36):
210                 dst_contents += str(line)
211         else:
212             comments.append(current_line)
213     if comments:
214         if elt.previous_defs:
215             # Separate postscriptum comments from the last module-level def.
216             dst_contents += str(empty_line)
217             dst_contents += str(empty_line)
218         for comment in comments:
219             dst_contents += str(comment)
220     return dst_contents
221
222
223 def lib2to3_parse(src_txt: str) -> Node:
224     """Given a string with source, return the lib2to3 Node."""
225     grammar = pygram.python_grammar_no_print_statement
226     drv = driver.Driver(grammar, pytree.convert)
227     if src_txt[-1] != '\n':
228         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
229         src_txt += nl
230     try:
231         result = drv.parse_string(src_txt, True)
232     except ParseError as pe:
233         lineno, column = pe.context[1]
234         lines = src_txt.splitlines()
235         try:
236             faulty_line = lines[lineno - 1]
237         except IndexError:
238             faulty_line = "<line number missing in source>"
239         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
240
241     if isinstance(result, Leaf):
242         result = Node(syms.file_input, [result])
243     return result
244
245
246 def lib2to3_unparse(node: Node) -> str:
247     """Given a lib2to3 node, return its string representation."""
248     code = str(node)
249     return code
250
251
252 T = TypeVar('T')
253
254
255 class Visitor(Generic[T]):
256     """Basic lib2to3 visitor that yields things on visiting."""
257
258     def visit(self, node: LN) -> Iterator[T]:
259         if node.type < 256:
260             name = token.tok_name[node.type]
261         else:
262             name = type_repr(node.type)
263         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
264
265     def visit_default(self, node: LN) -> Iterator[T]:
266         if isinstance(node, Node):
267             for child in node.children:
268                 yield from self.visit(child)
269
270
271 @dataclass
272 class DebugVisitor(Visitor[T]):
273     tree_depth: int = 0
274
275     def visit_default(self, node: LN) -> Iterator[T]:
276         indent = ' ' * (2 * self.tree_depth)
277         if isinstance(node, Node):
278             _type = type_repr(node.type)
279             out(f'{indent}{_type}', fg='yellow')
280             self.tree_depth += 1
281             for child in node.children:
282                 yield from self.visit(child)
283
284             self.tree_depth -= 1
285             out(f'{indent}/{_type}', fg='yellow', bold=False)
286         else:
287             _type = token.tok_name.get(node.type, str(node.type))
288             out(f'{indent}{_type}', fg='blue', nl=False)
289             if node.prefix:
290                 # We don't have to handle prefixes for `Node` objects since
291                 # that delegates to the first child anyway.
292                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
293             out(f' {node.value!r}', fg='blue', bold=False)
294
295
296 KEYWORDS = set(keyword.kwlist)
297 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
298 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
299 STATEMENT = {
300     syms.if_stmt,
301     syms.while_stmt,
302     syms.for_stmt,
303     syms.try_stmt,
304     syms.except_clause,
305     syms.with_stmt,
306     syms.funcdef,
307     syms.classdef,
308 }
309 STANDALONE_COMMENT = 153
310 LOGIC_OPERATORS = {'and', 'or'}
311 COMPARATORS = {
312     token.LESS,
313     token.GREATER,
314     token.EQEQUAL,
315     token.NOTEQUAL,
316     token.LESSEQUAL,
317     token.GREATEREQUAL,
318 }
319 MATH_OPERATORS = {
320     token.PLUS,
321     token.MINUS,
322     token.STAR,
323     token.SLASH,
324     token.VBAR,
325     token.AMPER,
326     token.PERCENT,
327     token.CIRCUMFLEX,
328     token.LEFTSHIFT,
329     token.RIGHTSHIFT,
330     token.DOUBLESTAR,
331     token.DOUBLESLASH,
332 }
333 COMPREHENSION_PRIORITY = 20
334 COMMA_PRIORITY = 10
335 LOGIC_PRIORITY = 5
336 STRING_PRIORITY = 4
337 COMPARATOR_PRIORITY = 3
338 MATH_PRIORITY = 1
339
340
341 @dataclass
342 class BracketTracker:
343     depth: int = 0
344     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
345     delimiters: Dict[LeafID, Priority] = Factory(dict)
346     previous: Optional[Leaf] = None
347
348     def mark(self, leaf: Leaf) -> None:
349         if leaf.type == token.COMMENT:
350             return
351
352         if leaf.type in CLOSING_BRACKETS:
353             self.depth -= 1
354             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
355             leaf.opening_bracket = opening_bracket
356         leaf.bracket_depth = self.depth
357         if self.depth == 0:
358             delim = is_delimiter(leaf)
359             if delim:
360                 self.delimiters[id(leaf)] = delim
361             elif self.previous is not None:
362                 if leaf.type == token.STRING and self.previous.type == token.STRING:
363                     self.delimiters[id(self.previous)] = STRING_PRIORITY
364                 elif (
365                     leaf.type == token.NAME
366                     and leaf.value == 'for'
367                     and leaf.parent
368                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
369                 ):
370                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
371                 elif (
372                     leaf.type == token.NAME
373                     and leaf.value == 'if'
374                     and leaf.parent
375                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
376                 ):
377                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
378                 elif (
379                     leaf.type == token.NAME
380                     and leaf.value in LOGIC_OPERATORS
381                     and leaf.parent
382                 ):
383                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
384         if leaf.type in OPENING_BRACKETS:
385             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
386             self.depth += 1
387         self.previous = leaf
388
389     def any_open_brackets(self) -> bool:
390         """Returns True if there is an yet unmatched open bracket on the line."""
391         return bool(self.bracket_match)
392
393     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
394         """Returns the highest priority of a delimiter found on the line.
395
396         Values are consistent with what `is_delimiter()` returns.
397         """
398         return max(v for k, v in self.delimiters.items() if k not in exclude)
399
400
401 @dataclass
402 class Line:
403     depth: int = 0
404     leaves: List[Leaf] = Factory(list)
405     comments: Dict[LeafID, Leaf] = Factory(dict)
406     bracket_tracker: BracketTracker = Factory(BracketTracker)
407     inside_brackets: bool = False
408     has_for: bool = False
409     _for_loop_variable: bool = False
410
411     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
412         has_value = leaf.value.strip()
413         if not has_value:
414             return
415
416         if self.leaves and not preformatted:
417             # Note: at this point leaf.prefix should be empty except for
418             # imports, for which we only preserve newlines.
419             leaf.prefix += whitespace(leaf)
420         if self.inside_brackets or not preformatted:
421             self.maybe_decrement_after_for_loop_variable(leaf)
422             self.bracket_tracker.mark(leaf)
423             self.maybe_remove_trailing_comma(leaf)
424             self.maybe_increment_for_loop_variable(leaf)
425             if self.maybe_adapt_standalone_comment(leaf):
426                 return
427
428         if not self.append_comment(leaf):
429             self.leaves.append(leaf)
430
431     @property
432     def is_comment(self) -> bool:
433         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
434
435     @property
436     def is_decorator(self) -> bool:
437         return bool(self) and self.leaves[0].type == token.AT
438
439     @property
440     def is_import(self) -> bool:
441         return bool(self) and is_import(self.leaves[0])
442
443     @property
444     def is_class(self) -> bool:
445         return (
446             bool(self)
447             and self.leaves[0].type == token.NAME
448             and self.leaves[0].value == 'class'
449         )
450
451     @property
452     def is_def(self) -> bool:
453         """Also returns True for async defs."""
454         try:
455             first_leaf = self.leaves[0]
456         except IndexError:
457             return False
458
459         try:
460             second_leaf: Optional[Leaf] = self.leaves[1]
461         except IndexError:
462             second_leaf = None
463         return (
464             (first_leaf.type == token.NAME and first_leaf.value == 'def')
465             or (
466                 first_leaf.type == token.NAME
467                 and first_leaf.value == 'async'
468                 and second_leaf is not None
469                 and second_leaf.type == token.NAME
470                 and second_leaf.value == 'def'
471             )
472         )
473
474     @property
475     def is_flow_control(self) -> bool:
476         return (
477             bool(self)
478             and self.leaves[0].type == token.NAME
479             and self.leaves[0].value in FLOW_CONTROL
480         )
481
482     @property
483     def is_yield(self) -> bool:
484         return (
485             bool(self)
486             and self.leaves[0].type == token.NAME
487             and self.leaves[0].value == 'yield'
488         )
489
490     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
491         if not (
492             self.leaves
493             and self.leaves[-1].type == token.COMMA
494             and closing.type in CLOSING_BRACKETS
495         ):
496             return False
497
498         if closing.type == token.RSQB or closing.type == token.RBRACE:
499             self.leaves.pop()
500             return True
501
502         # For parens let's check if it's safe to remove the comma.  If the
503         # trailing one is the only one, we might mistakenly change a tuple
504         # into a different type by removing the comma.
505         depth = closing.bracket_depth + 1
506         commas = 0
507         opening = closing.opening_bracket
508         for _opening_index, leaf in enumerate(self.leaves):
509             if leaf is opening:
510                 break
511
512         else:
513             return False
514
515         for leaf in self.leaves[_opening_index + 1:]:
516             if leaf is closing:
517                 break
518
519             bracket_depth = leaf.bracket_depth
520             if bracket_depth == depth and leaf.type == token.COMMA:
521                 commas += 1
522                 if leaf.parent and leaf.parent.type == syms.arglist:
523                     commas += 1
524                     break
525
526         if commas > 1:
527             self.leaves.pop()
528             return True
529
530         return False
531
532     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
533         """In a for loop, or comprehension, the variables are often unpacks.
534
535         To avoid splitting on the comma in this situation, we will increase
536         the depth of tokens between `for` and `in`.
537         """
538         if leaf.type == token.NAME and leaf.value == 'for':
539             self.has_for = True
540             self.bracket_tracker.depth += 1
541             self._for_loop_variable = True
542             return True
543
544         return False
545
546     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
547         # See `maybe_increment_for_loop_variable` above for explanation.
548         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
549             self.bracket_tracker.depth -= 1
550             self._for_loop_variable = False
551             return True
552
553         return False
554
555     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
556         """Hack a standalone comment to act as a trailing comment for line splitting.
557
558         If this line has brackets and a standalone `comment`, we need to adapt
559         it to be able to still reformat the line.
560
561         This is not perfect, the line to which the standalone comment gets
562         appended will appear "too long" when splitting.
563         """
564         if not (
565             comment.type == STANDALONE_COMMENT
566             and self.bracket_tracker.any_open_brackets()
567         ):
568             return False
569
570         comment.type = token.COMMENT
571         comment.prefix = '\n' + '    ' * (self.depth + 1)
572         return self.append_comment(comment)
573
574     def append_comment(self, comment: Leaf) -> bool:
575         if comment.type != token.COMMENT:
576             return False
577
578         try:
579             after = id(self.last_non_delimiter())
580         except LookupError:
581             comment.type = STANDALONE_COMMENT
582             comment.prefix = ''
583             return False
584
585         else:
586             if after in self.comments:
587                 self.comments[after].value += str(comment)
588             else:
589                 self.comments[after] = comment
590             return True
591
592     def last_non_delimiter(self) -> Leaf:
593         for i in range(len(self.leaves)):
594             last = self.leaves[-i - 1]
595             if not is_delimiter(last):
596                 return last
597
598         raise LookupError("No non-delimiters found")
599
600     def __str__(self) -> str:
601         if not self:
602             return '\n'
603
604         indent = '    ' * self.depth
605         leaves = iter(self.leaves)
606         first = next(leaves)
607         res = f'{first.prefix}{indent}{first.value}'
608         for leaf in leaves:
609             res += str(leaf)
610         for comment in self.comments.values():
611             res += str(comment)
612         return res + '\n'
613
614     def __bool__(self) -> bool:
615         return bool(self.leaves or self.comments)
616
617
618 @dataclass
619 class EmptyLineTracker:
620     """Provides a stateful method that returns the number of potential extra
621     empty lines needed before and after the currently processed line.
622
623     Note: this tracker works on lines that haven't been split yet.  It assumes
624     the prefix of the first leaf consists of optional newlines.  Those newlines
625     are consumed by `maybe_empty_lines()` and included in the computation.
626     """
627     previous_line: Optional[Line] = None
628     previous_after: int = 0
629     previous_defs: List[int] = Factory(list)
630
631     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
632         """Returns the number of extra empty lines before and after the `current_line`.
633
634         This is for separating `def`, `async def` and `class` with extra empty lines
635         (two on module-level), as well as providing an extra empty line after flow
636         control keywords to make them more prominent.
637         """
638         if current_line.is_comment:
639             # Don't count standalone comments towards previous empty lines.
640             return 0, 0
641
642         before, after = self._maybe_empty_lines(current_line)
643         before -= self.previous_after
644         self.previous_after = after
645         self.previous_line = current_line
646         return before, after
647
648     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
649         if current_line.leaves:
650             # Consume the first leaf's extra newlines.
651             first_leaf = current_line.leaves[0]
652             before = int('\n' in first_leaf.prefix)
653             first_leaf.prefix = ''
654         else:
655             before = 0
656         depth = current_line.depth
657         while self.previous_defs and self.previous_defs[-1] >= depth:
658             self.previous_defs.pop()
659             before = 1 if depth else 2
660         is_decorator = current_line.is_decorator
661         if is_decorator or current_line.is_def or current_line.is_class:
662             if not is_decorator:
663                 self.previous_defs.append(depth)
664             if self.previous_line is None:
665                 # Don't insert empty lines before the first line in the file.
666                 return 0, 0
667
668             if self.previous_line and self.previous_line.is_decorator:
669                 # Don't insert empty lines between decorators.
670                 return 0, 0
671
672             newlines = 2
673             if current_line.depth:
674                 newlines -= 1
675             return newlines, 0
676
677         if current_line.is_flow_control:
678             return before, 1
679
680         if (
681             self.previous_line
682             and self.previous_line.is_import
683             and not current_line.is_import
684             and depth == self.previous_line.depth
685         ):
686             return (before or 1), 0
687
688         if (
689             self.previous_line
690             and self.previous_line.is_yield
691             and (not current_line.is_yield or depth != self.previous_line.depth)
692         ):
693             return (before or 1), 0
694
695         return before, 0
696
697
698 @dataclass
699 class LineGenerator(Visitor[Line]):
700     """Generates reformatted Line objects.  Empty lines are not emitted.
701
702     Note: destroys the tree it's visiting by mutating prefixes of its leaves
703     in ways that will no longer stringify to valid Python code on the tree.
704     """
705     current_line: Line = Factory(Line)
706     standalone_comments: List[Leaf] = Factory(list)
707
708     def line(self, indent: int = 0) -> Iterator[Line]:
709         """Generate a line.
710
711         If the line is empty, only emit if it makes sense.
712         If the line is too long, split it first and then generate.
713
714         If any lines were generated, set up a new current_line.
715         """
716         if not self.current_line:
717             self.current_line.depth += indent
718             return  # Line is empty, don't emit. Creating a new one unnecessary.
719
720         complete_line = self.current_line
721         self.current_line = Line(depth=complete_line.depth + indent)
722         yield complete_line
723
724     def visit_default(self, node: LN) -> Iterator[Line]:
725         if isinstance(node, Leaf):
726             for comment in generate_comments(node):
727                 if self.current_line.bracket_tracker.any_open_brackets():
728                     # any comment within brackets is subject to splitting
729                     self.current_line.append(comment)
730                 elif comment.type == token.COMMENT:
731                     # regular trailing comment
732                     self.current_line.append(comment)
733                     yield from self.line()
734
735                 else:
736                     # regular standalone comment, to be processed later (see
737                     # docstring in `generate_comments()`
738                     self.standalone_comments.append(comment)
739             normalize_prefix(node)
740             if node.type not in WHITESPACE:
741                 for comment in self.standalone_comments:
742                     yield from self.line()
743
744                     self.current_line.append(comment)
745                     yield from self.line()
746
747                 self.standalone_comments = []
748                 self.current_line.append(node)
749         yield from super().visit_default(node)
750
751     def visit_suite(self, node: Node) -> Iterator[Line]:
752         """Body of a statement after a colon."""
753         children = iter(node.children)
754         # Process newline before indenting.  It might contain an inline
755         # comment that should go right after the colon.
756         newline = next(children)
757         yield from self.visit(newline)
758         yield from self.line(+1)
759
760         for child in children:
761             yield from self.visit(child)
762
763         yield from self.line(-1)
764
765     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
766         """Visit a statement.
767
768         The relevant Python language keywords for this statement are NAME leaves
769         within it.
770         """
771         for child in node.children:
772             if child.type == token.NAME and child.value in keywords:  # type: ignore
773                 yield from self.line()
774
775             yield from self.visit(child)
776
777     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
778         """A statement without nested statements."""
779         is_suite_like = node.parent and node.parent.type in STATEMENT
780         if is_suite_like:
781             yield from self.line(+1)
782             yield from self.visit_default(node)
783             yield from self.line(-1)
784
785         else:
786             yield from self.line()
787             yield from self.visit_default(node)
788
789     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
790         yield from self.line()
791
792         children = iter(node.children)
793         for child in children:
794             yield from self.visit(child)
795
796             if child.type == token.NAME and child.value == 'async':  # type: ignore
797                 break
798
799         internal_stmt = next(children)
800         for child in internal_stmt.children:
801             yield from self.visit(child)
802
803     def visit_decorators(self, node: Node) -> Iterator[Line]:
804         for child in node.children:
805             yield from self.line()
806             yield from self.visit(child)
807
808     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
809         yield from self.line()
810
811     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
812         yield from self.visit_default(leaf)
813         yield from self.line()
814
815     def __attrs_post_init__(self) -> None:
816         """You are in a twisty little maze of passages."""
817         v = self.visit_stmt
818         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
819         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
820         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
821         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
822         self.visit_except_clause = partial(v, keywords={'except'})
823         self.visit_funcdef = partial(v, keywords={'def'})
824         self.visit_with_stmt = partial(v, keywords={'with'})
825         self.visit_classdef = partial(v, keywords={'class'})
826         self.visit_async_funcdef = self.visit_async_stmt
827         self.visit_decorated = self.visit_decorators
828
829
830 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
831 OPENING_BRACKETS = set(BRACKET.keys())
832 CLOSING_BRACKETS = set(BRACKET.values())
833 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
834 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
835
836
837 def whitespace(leaf: Leaf) -> str:  # noqa C901
838     """Return whitespace prefix if needed for the given `leaf`."""
839     NO = ''
840     SPACE = ' '
841     DOUBLESPACE = '  '
842     t = leaf.type
843     p = leaf.parent
844     v = leaf.value
845     if t in ALWAYS_NO_SPACE:
846         return NO
847
848     if t == token.COMMENT:
849         return DOUBLESPACE
850
851     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
852     if t == token.COLON and p.type != syms.subscript:
853         return NO
854
855     prev = leaf.prev_sibling
856     if not prev:
857         prevp = preceding_leaf(p)
858         if not prevp or prevp.type in OPENING_BRACKETS:
859             return NO
860
861         if t == token.COLON:
862             return SPACE if prevp.type == token.COMMA else NO
863
864         if prevp.type == token.EQUAL:
865             if prevp.parent and prevp.parent.type in {
866                 syms.typedargslist,
867                 syms.varargslist,
868                 syms.parameters,
869                 syms.arglist,
870                 syms.argument,
871             }:
872                 return NO
873
874         elif prevp.type == token.DOUBLESTAR:
875             if prevp.parent and prevp.parent.type in {
876                 syms.typedargslist,
877                 syms.varargslist,
878                 syms.parameters,
879                 syms.arglist,
880                 syms.dictsetmaker,
881             }:
882                 return NO
883
884         elif prevp.type == token.COLON:
885             if prevp.parent and prevp.parent.type == syms.subscript:
886                 return NO
887
888         elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
889             return NO
890
891     elif prev.type in OPENING_BRACKETS:
892         return NO
893
894     if p.type in {syms.parameters, syms.arglist}:
895         # untyped function signatures or calls
896         if t == token.RPAR:
897             return NO
898
899         if not prev or prev.type != token.COMMA:
900             return NO
901
902     if p.type == syms.varargslist:
903         # lambdas
904         if t == token.RPAR:
905             return NO
906
907         if prev and prev.type != token.COMMA:
908             return NO
909
910     elif p.type == syms.typedargslist:
911         # typed function signatures
912         if not prev:
913             return NO
914
915         if t == token.EQUAL:
916             if prev.type != syms.tname:
917                 return NO
918
919         elif prev.type == token.EQUAL:
920             # A bit hacky: if the equal sign has whitespace, it means we
921             # previously found it's a typed argument.  So, we're using that, too.
922             return prev.prefix
923
924         elif prev.type != token.COMMA:
925             return NO
926
927     elif p.type == syms.tname:
928         # type names
929         if not prev:
930             prevp = preceding_leaf(p)
931             if not prevp or prevp.type != token.COMMA:
932                 return NO
933
934     elif p.type == syms.trailer:
935         # attributes and calls
936         if t == token.LPAR or t == token.RPAR:
937             return NO
938
939         if not prev:
940             if t == token.DOT:
941                 prevp = preceding_leaf(p)
942                 if not prevp or prevp.type != token.NUMBER:
943                     return NO
944
945             elif t == token.LSQB:
946                 return NO
947
948         elif prev.type != token.COMMA:
949             return NO
950
951     elif p.type == syms.argument:
952         # single argument
953         if t == token.EQUAL:
954             return NO
955
956         if not prev:
957             prevp = preceding_leaf(p)
958             if not prevp or prevp.type == token.LPAR:
959                 return NO
960
961         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
962             return NO
963
964     elif p.type == syms.decorator:
965         # decorators
966         return NO
967
968     elif p.type == syms.dotted_name:
969         if prev:
970             return NO
971
972         prevp = preceding_leaf(p)
973         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
974             return NO
975
976     elif p.type == syms.classdef:
977         if t == token.LPAR:
978             return NO
979
980         if prev and prev.type == token.LPAR:
981             return NO
982
983     elif p.type == syms.subscript:
984         # indexing
985         if not prev:
986             assert p.parent is not None, "subscripts are always parented"
987             if p.parent.type == syms.subscriptlist:
988                 return SPACE
989
990             return NO
991
992         else:
993             return NO
994
995     elif p.type == syms.atom:
996         if prev and t == token.DOT:
997             # dots, but not the first one.
998             return NO
999
1000     elif (
1001         p.type == syms.listmaker
1002         or p.type == syms.testlist_gexp
1003         or p.type == syms.subscriptlist
1004     ):
1005         # list interior, including unpacking
1006         if not prev:
1007             return NO
1008
1009     elif p.type == syms.dictsetmaker:
1010         # dict and set interior, including unpacking
1011         if not prev:
1012             return NO
1013
1014         if prev.type == token.DOUBLESTAR:
1015             return NO
1016
1017     elif p.type in {syms.factor, syms.star_expr}:
1018         # unary ops
1019         if not prev:
1020             prevp = preceding_leaf(p)
1021             if not prevp or prevp.type in OPENING_BRACKETS:
1022                 return NO
1023
1024             prevp_parent = prevp.parent
1025             assert prevp_parent is not None
1026             if prevp.type == token.COLON and prevp_parent.type in {
1027                 syms.subscript, syms.sliceop
1028             }:
1029                 return NO
1030
1031             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1032                 return NO
1033
1034         elif t == token.NAME or t == token.NUMBER:
1035             return NO
1036
1037     elif p.type == syms.import_from:
1038         if t == token.DOT:
1039             if prev and prev.type == token.DOT:
1040                 return NO
1041
1042         elif t == token.NAME:
1043             if v == 'import':
1044                 return SPACE
1045
1046             if prev and prev.type == token.DOT:
1047                 return NO
1048
1049     elif p.type == syms.sliceop:
1050         return NO
1051
1052     return SPACE
1053
1054
1055 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1056     """Returns the first leaf that precedes `node`, if any."""
1057     while node:
1058         res = node.prev_sibling
1059         if res:
1060             if isinstance(res, Leaf):
1061                 return res
1062
1063             try:
1064                 return list(res.leaves())[-1]
1065
1066             except IndexError:
1067                 return None
1068
1069         node = node.parent
1070     return None
1071
1072
1073 def is_delimiter(leaf: Leaf) -> int:
1074     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1075
1076     Higher numbers are higher priority.
1077     """
1078     if leaf.type == token.COMMA:
1079         return COMMA_PRIORITY
1080
1081     if leaf.type in COMPARATORS:
1082         return COMPARATOR_PRIORITY
1083
1084     if (
1085         leaf.type in MATH_OPERATORS
1086         and leaf.parent
1087         and leaf.parent.type not in {syms.factor, syms.star_expr}
1088     ):
1089         return MATH_PRIORITY
1090
1091     return 0
1092
1093
1094 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1095     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1096
1097     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1098     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1099     move because it does away with modifying the grammar to include all the
1100     possible places in which comments can be placed.
1101
1102     The sad consequence for us though is that comments don't "belong" anywhere.
1103     This is why this function generates simple parentless Leaf objects for
1104     comments.  We simply don't know what the correct parent should be.
1105
1106     No matter though, we can live without this.  We really only need to
1107     differentiate between inline and standalone comments.  The latter don't
1108     share the line with any code.
1109
1110     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1111     are emitted with a fake STANDALONE_COMMENT token identifier.
1112     """
1113     if not leaf.prefix:
1114         return
1115
1116     if '#' not in leaf.prefix:
1117         return
1118
1119     before_comment, content = leaf.prefix.split('#', 1)
1120     content = content.rstrip()
1121     if content and (content[0] not in {' ', '!', '#'}):
1122         content = ' ' + content
1123     is_standalone_comment = (
1124         '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER
1125     )
1126     if not is_standalone_comment:
1127         # simple trailing comment
1128         yield Leaf(token.COMMENT, value='#' + content)
1129         return
1130
1131     for line in ('#' + content).split('\n'):
1132         line = line.lstrip()
1133         if not line.startswith('#'):
1134             continue
1135
1136         yield Leaf(STANDALONE_COMMENT, line)
1137
1138
1139 def split_line(
1140     line: Line, line_length: int, inner: bool = False, py36: bool = False
1141 ) -> Iterator[Line]:
1142     """Splits a `line` into potentially many lines.
1143
1144     They should fit in the allotted `line_length` but might not be able to.
1145     `inner` signifies that there were a pair of brackets somewhere around the
1146     current `line`, possibly transitively. This means we can fallback to splitting
1147     by delimiters if the LHS/RHS don't yield any results.
1148
1149     If `py36` is True, splitting may generate syntax that is only compatible
1150     with Python 3.6 and later.
1151     """
1152     line_str = str(line).strip('\n')
1153     if len(line_str) <= line_length and '\n' not in line_str:
1154         yield line
1155         return
1156
1157     if line.is_def:
1158         split_funcs = [left_hand_split]
1159     elif line.inside_brackets:
1160         split_funcs = [delimiter_split]
1161         if '\n' not in line_str:
1162             # Only attempt RHS if we don't have multiline strings or comments
1163             # on this line.
1164             split_funcs.append(right_hand_split)
1165     else:
1166         split_funcs = [right_hand_split]
1167     for split_func in split_funcs:
1168         # We are accumulating lines in `result` because we might want to abort
1169         # mission and return the original line in the end, or attempt a different
1170         # split altogether.
1171         result: List[Line] = []
1172         try:
1173             for l in split_func(line, py36=py36):
1174                 if str(l).strip('\n') == line_str:
1175                     raise CannotSplit("Split function returned an unchanged result")
1176
1177                 result.extend(
1178                     split_line(l, line_length=line_length, inner=True, py36=py36)
1179                 )
1180         except CannotSplit as cs:
1181             continue
1182
1183         else:
1184             yield from result
1185             break
1186
1187     else:
1188         yield line
1189
1190
1191 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1192     """Split line into many lines, starting with the first matching bracket pair.
1193
1194     Note: this usually looks weird, only use this for function definitions.
1195     Prefer RHS otherwise.
1196     """
1197     head = Line(depth=line.depth)
1198     body = Line(depth=line.depth + 1, inside_brackets=True)
1199     tail = Line(depth=line.depth)
1200     tail_leaves: List[Leaf] = []
1201     body_leaves: List[Leaf] = []
1202     head_leaves: List[Leaf] = []
1203     current_leaves = head_leaves
1204     matching_bracket = None
1205     for leaf in line.leaves:
1206         if (
1207             current_leaves is body_leaves
1208             and leaf.type in CLOSING_BRACKETS
1209             and leaf.opening_bracket is matching_bracket
1210         ):
1211             current_leaves = tail_leaves if body_leaves else head_leaves
1212         current_leaves.append(leaf)
1213         if current_leaves is head_leaves:
1214             if leaf.type in OPENING_BRACKETS:
1215                 matching_bracket = leaf
1216                 current_leaves = body_leaves
1217     # Since body is a new indent level, remove spurious leading whitespace.
1218     if body_leaves:
1219         normalize_prefix(body_leaves[0])
1220     # Build the new lines.
1221     for result, leaves in (
1222         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1223     ):
1224         for leaf in leaves:
1225             result.append(leaf, preformatted=True)
1226             comment_after = line.comments.get(id(leaf))
1227             if comment_after:
1228                 result.append(comment_after, preformatted=True)
1229     split_succeeded_or_raise(head, body, tail)
1230     for result in (head, body, tail):
1231         if result:
1232             yield result
1233
1234
1235 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1236     """Split line into many lines, starting with the last matching bracket pair."""
1237     head = Line(depth=line.depth)
1238     body = Line(depth=line.depth + 1, inside_brackets=True)
1239     tail = Line(depth=line.depth)
1240     tail_leaves: List[Leaf] = []
1241     body_leaves: List[Leaf] = []
1242     head_leaves: List[Leaf] = []
1243     current_leaves = tail_leaves
1244     opening_bracket = None
1245     for leaf in reversed(line.leaves):
1246         if current_leaves is body_leaves:
1247             if leaf is opening_bracket:
1248                 current_leaves = head_leaves if body_leaves else tail_leaves
1249         current_leaves.append(leaf)
1250         if current_leaves is tail_leaves:
1251             if leaf.type in CLOSING_BRACKETS:
1252                 opening_bracket = leaf.opening_bracket
1253                 current_leaves = body_leaves
1254     tail_leaves.reverse()
1255     body_leaves.reverse()
1256     head_leaves.reverse()
1257     # Since body is a new indent level, remove spurious leading whitespace.
1258     if body_leaves:
1259         normalize_prefix(body_leaves[0])
1260     # Build the new lines.
1261     for result, leaves in (
1262         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1263     ):
1264         for leaf in leaves:
1265             result.append(leaf, preformatted=True)
1266             comment_after = line.comments.get(id(leaf))
1267             if comment_after:
1268                 result.append(comment_after, preformatted=True)
1269     split_succeeded_or_raise(head, body, tail)
1270     for result in (head, body, tail):
1271         if result:
1272             yield result
1273
1274
1275 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1276     tail_len = len(str(tail).strip())
1277     if not body:
1278         if tail_len == 0:
1279             raise CannotSplit("Splitting brackets produced the same line")
1280
1281         elif tail_len < 3:
1282             raise CannotSplit(
1283                 f"Splitting brackets on an empty body to save "
1284                 f"{tail_len} characters is not worth it"
1285             )
1286
1287
1288 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1289     """Split according to delimiters of the highest priority.
1290
1291     This kind of split doesn't increase indentation.
1292     If `py36` is True, the split will add trailing commas also in function
1293     signatures that contain * and **.
1294     """
1295     try:
1296         last_leaf = line.leaves[-1]
1297     except IndexError:
1298         raise CannotSplit("Line empty")
1299
1300     delimiters = line.bracket_tracker.delimiters
1301     try:
1302         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1303     except ValueError:
1304         raise CannotSplit("No delimiters found")
1305
1306     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1307     lowest_depth = sys.maxsize
1308     trailing_comma_safe = True
1309     for leaf in line.leaves:
1310         current_line.append(leaf, preformatted=True)
1311         comment_after = line.comments.get(id(leaf))
1312         if comment_after:
1313             current_line.append(comment_after, preformatted=True)
1314         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1315         if (
1316             leaf.bracket_depth == lowest_depth
1317             and leaf.type == token.STAR
1318             or leaf.type == token.DOUBLESTAR
1319         ):
1320             trailing_comma_safe = trailing_comma_safe and py36
1321         leaf_priority = delimiters.get(id(leaf))
1322         if leaf_priority == delimiter_priority:
1323             normalize_prefix(current_line.leaves[0])
1324             yield current_line
1325
1326             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1327     if current_line:
1328         if (
1329             delimiter_priority == COMMA_PRIORITY
1330             and current_line.leaves[-1].type != token.COMMA
1331             and trailing_comma_safe
1332         ):
1333             current_line.append(Leaf(token.COMMA, ','))
1334         normalize_prefix(current_line.leaves[0])
1335         yield current_line
1336
1337
1338 def is_import(leaf: Leaf) -> bool:
1339     """Returns True if the given leaf starts an import statement."""
1340     p = leaf.parent
1341     t = leaf.type
1342     v = leaf.value
1343     return bool(
1344         t == token.NAME
1345         and (
1346             (v == 'import' and p and p.type == syms.import_name)
1347             or (v == 'from' and p and p.type == syms.import_from)
1348         )
1349     )
1350
1351
1352 def normalize_prefix(leaf: Leaf) -> None:
1353     """Leave existing extra newlines for imports.  Remove everything else."""
1354     if is_import(leaf):
1355         spl = leaf.prefix.split('#', 1)
1356         nl_count = spl[0].count('\n')
1357         leaf.prefix = '\n' * nl_count
1358         return
1359
1360     leaf.prefix = ''
1361
1362
1363 def is_python36(node: Node) -> bool:
1364     """Returns True if the current file is using Python 3.6+ features.
1365
1366     Currently looking for:
1367     - f-strings; and
1368     - trailing commas after * or ** in function signatures.
1369     """
1370     for n in node.pre_order():
1371         if n.type == token.STRING:
1372             value_head = n.value[:2]  # type: ignore
1373             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1374                 return True
1375
1376         elif (
1377             n.type == syms.typedargslist
1378             and n.children
1379             and n.children[-1].type == token.COMMA
1380         ):
1381             for ch in n.children:
1382                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1383                     return True
1384
1385     return False
1386
1387
1388 PYTHON_EXTENSIONS = {'.py'}
1389 BLACKLISTED_DIRECTORIES = {
1390     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1391 }
1392
1393
1394 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1395     for child in path.iterdir():
1396         if child.is_dir():
1397             if child.name in BLACKLISTED_DIRECTORIES:
1398                 continue
1399
1400             yield from gen_python_files_in_dir(child)
1401
1402         elif child.suffix in PYTHON_EXTENSIONS:
1403             yield child
1404
1405
1406 @dataclass
1407 class Report:
1408     """Provides a reformatting counter."""
1409     change_count: int = 0
1410     same_count: int = 0
1411     failure_count: int = 0
1412
1413     def done(self, src: Path, changed: bool) -> None:
1414         """Increment the counter for successful reformatting. Write out a message."""
1415         if changed:
1416             out(f'reformatted {src}')
1417             self.change_count += 1
1418         else:
1419             out(f'{src} already well formatted, good job.', bold=False)
1420             self.same_count += 1
1421
1422     def failed(self, src: Path, message: str) -> None:
1423         """Increment the counter for failed reformatting. Write out a message."""
1424         err(f'error: cannot format {src}: {message}')
1425         self.failure_count += 1
1426
1427     @property
1428     def return_code(self) -> int:
1429         """Which return code should the app use considering the current state."""
1430         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1431         # 126 we have special returncodes reserved by the shell.
1432         if self.failure_count:
1433             return 123
1434
1435         elif self.change_count:
1436             return 1
1437
1438         return 0
1439
1440     def __str__(self) -> str:
1441         """A color report of the current state.
1442
1443         Use `click.unstyle` to remove colors.
1444         """
1445         report = []
1446         if self.change_count:
1447             s = 's' if self.change_count > 1 else ''
1448             report.append(
1449                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1450             )
1451         if self.same_count:
1452             s = 's' if self.same_count > 1 else ''
1453             report.append(f'{self.same_count} file{s} left unchanged')
1454         if self.failure_count:
1455             s = 's' if self.failure_count > 1 else ''
1456             report.append(
1457                 click.style(
1458                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1459                 )
1460             )
1461         return ', '.join(report) + '.'
1462
1463
1464 def assert_equivalent(src: str, dst: str) -> None:
1465     """Raises AssertionError if `src` and `dst` aren't equivalent.
1466
1467     This is a temporary sanity check until Black becomes stable.
1468     """
1469
1470     import ast
1471     import traceback
1472
1473     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1474         """Simple visitor generating strings to compare ASTs by content."""
1475         yield f"{'  ' * depth}{node.__class__.__name__}("
1476
1477         for field in sorted(node._fields):
1478             try:
1479                 value = getattr(node, field)
1480             except AttributeError:
1481                 continue
1482
1483             yield f"{'  ' * (depth+1)}{field}="
1484
1485             if isinstance(value, list):
1486                 for item in value:
1487                     if isinstance(item, ast.AST):
1488                         yield from _v(item, depth + 2)
1489
1490             elif isinstance(value, ast.AST):
1491                 yield from _v(value, depth + 2)
1492
1493             else:
1494                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1495
1496         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1497
1498     try:
1499         src_ast = ast.parse(src)
1500     except Exception as exc:
1501         raise AssertionError(f"cannot parse source: {exc}") from None
1502
1503     try:
1504         dst_ast = ast.parse(dst)
1505     except Exception as exc:
1506         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1507         raise AssertionError(
1508             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1509             f"Please report a bug on https://github.com/ambv/black/issues.  "
1510             f"This invalid output might be helpful: {log}"
1511         ) from None
1512
1513     src_ast_str = '\n'.join(_v(src_ast))
1514     dst_ast_str = '\n'.join(_v(dst_ast))
1515     if src_ast_str != dst_ast_str:
1516         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1517         raise AssertionError(
1518             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1519             f"the source.  "
1520             f"Please report a bug on https://github.com/ambv/black/issues.  "
1521             f"This diff might be helpful: {log}"
1522         ) from None
1523
1524
1525 def assert_stable(src: str, dst: str, line_length: int) -> None:
1526     """Raises AssertionError if `dst` reformats differently the second time.
1527
1528     This is a temporary sanity check until Black becomes stable.
1529     """
1530     newdst = format_str(dst, line_length=line_length)
1531     if dst != newdst:
1532         log = dump_to_file(
1533             diff(src, dst, 'source', 'first pass'),
1534             diff(dst, newdst, 'first pass', 'second pass'),
1535         )
1536         raise AssertionError(
1537             f"INTERNAL ERROR: Black produced different code on the second pass "
1538             f"of the formatter.  "
1539             f"Please report a bug on https://github.com/ambv/black/issues.  "
1540             f"This diff might be helpful: {log}"
1541         ) from None
1542
1543
1544 def dump_to_file(*output: str) -> str:
1545     """Dumps `output` to a temporary file. Returns path to the file."""
1546     import tempfile
1547
1548     with tempfile.NamedTemporaryFile(
1549         mode='w', prefix='blk_', suffix='.log', delete=False
1550     ) as f:
1551         for lines in output:
1552             f.write(lines)
1553             f.write('\n')
1554     return f.name
1555
1556
1557 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1558     """Returns a udiff string between strings `a` and `b`."""
1559     import difflib
1560
1561     a_lines = [line + '\n' for line in a.split('\n')]
1562     b_lines = [line + '\n' for line in b.split('\n')]
1563     return ''.join(
1564         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1565     )
1566
1567
1568 if __name__ == '__main__':
1569     main()