]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

add sphinx docs skeleton (#71)
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a3"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by `format_file` when the reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
48     """
49
50
51 @click.command()
52 @click.option(
53     '-l',
54     '--line-length',
55     type=int,
56     default=DEFAULT_LINE_LENGTH,
57     help='How many character per line to allow.',
58     show_default=True,
59 )
60 @click.option(
61     '--check',
62     is_flag=True,
63     help=(
64         "Don't write back the files, just return the status.  Return code 0 "
65         "means nothing would change.  Return code 1 means some files would be "
66         "reformatted.  Return code 123 means there was an internal error."
67     ),
68 )
69 @click.option(
70     '--fast/--safe',
71     is_flag=True,
72     help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 )
74 @click.version_option(version=__version__)
75 @click.argument(
76     'src',
77     nargs=-1,
78     type=click.Path(
79         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
80     ),
81 )
82 @click.pass_context
83 def main(
84     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 ) -> None:
86     """The uncompromising code formatter."""
87     sources: List[Path] = []
88     for s in src:
89         p = Path(s)
90         if p.is_dir():
91             sources.extend(gen_python_files_in_dir(p))
92         elif p.is_file():
93             # if a file was explicitly given, we don't care about its extension
94             sources.append(p)
95         elif s == '-':
96             sources.append(Path('-'))
97         else:
98             err(f'invalid path: {s}')
99     if len(sources) == 0:
100         ctx.exit(0)
101     elif len(sources) == 1:
102         p = sources[0]
103         report = Report(check=check)
104         try:
105             if not p.is_file() and str(p) == '-':
106                 changed = format_stdin_to_stdout(
107                     line_length=line_length, fast=fast, write_back=not check
108                 )
109             else:
110                 changed = format_file_in_place(
111                     p, line_length=line_length, fast=fast, write_back=not check
112                 )
113             report.done(p, changed)
114         except Exception as exc:
115             report.failed(p, str(exc))
116         ctx.exit(report.return_code)
117     else:
118         loop = asyncio.get_event_loop()
119         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
120         return_code = 1
121         try:
122             return_code = loop.run_until_complete(
123                 schedule_formatting(
124                     sources, line_length, not check, fast, loop, executor
125                 )
126             )
127         finally:
128             loop.close()
129             ctx.exit(return_code)
130
131
132 async def schedule_formatting(
133     sources: List[Path],
134     line_length: int,
135     write_back: bool,
136     fast: bool,
137     loop: BaseEventLoop,
138     executor: Executor,
139 ) -> int:
140     tasks = {
141         src: loop.run_in_executor(
142             executor, format_file_in_place, src, line_length, fast, write_back
143         )
144         for src in sources
145     }
146     await asyncio.wait(tasks.values())
147     cancelled = []
148     report = Report()
149     for src, task in tasks.items():
150         if not task.done():
151             report.failed(src, 'timed out, cancelling')
152             task.cancel()
153             cancelled.append(task)
154         elif task.exception():
155             report.failed(src, str(task.exception()))
156         else:
157             report.done(src, task.result())
158     if cancelled:
159         await asyncio.wait(cancelled, timeout=2)
160     out('All done! ✨ 🍰 ✨')
161     click.echo(str(report))
162     return report.return_code
163
164
165 def format_file_in_place(
166     src: Path, line_length: int, fast: bool, write_back: bool = False
167 ) -> bool:
168     """Format the file and rewrite if changed. Return True if changed."""
169     with tokenize.open(src) as src_buffer:
170         src_contents = src_buffer.read()
171     try:
172         contents = format_file_contents(
173             src_contents, line_length=line_length, fast=fast
174         )
175     except NothingChanged:
176         return False
177
178     if write_back:
179         with open(src, "w", encoding=src_buffer.encoding) as f:
180             f.write(contents)
181     return True
182
183
184 def format_stdin_to_stdout(
185     line_length: int, fast: bool, write_back: bool = False
186 ) -> bool:
187     """Format file on stdin and pipe output to stdout. Return True if changed."""
188     contents = sys.stdin.read()
189     try:
190         contents = format_file_contents(contents, line_length=line_length, fast=fast)
191         return True
192
193     except NothingChanged:
194         return False
195
196     finally:
197         if write_back:
198             sys.stdout.write(contents)
199
200
201 def format_file_contents(
202     src_contents: str, line_length: int, fast: bool
203 ) -> FileContent:
204     """Reformats a file and returns its contents and encoding."""
205     if src_contents.strip() == '':
206         raise NothingChanged
207
208     dst_contents = format_str(src_contents, line_length=line_length)
209     if src_contents == dst_contents:
210         raise NothingChanged
211
212     if not fast:
213         assert_equivalent(src_contents, dst_contents)
214         assert_stable(src_contents, dst_contents, line_length=line_length)
215     return dst_contents
216
217
218 def format_str(src_contents: str, line_length: int) -> FileContent:
219     """Reformats a string and returns new contents."""
220     src_node = lib2to3_parse(src_contents)
221     dst_contents = ""
222     lines = LineGenerator()
223     elt = EmptyLineTracker()
224     py36 = is_python36(src_node)
225     empty_line = Line()
226     after = 0
227     for current_line in lines.visit(src_node):
228         for _ in range(after):
229             dst_contents += str(empty_line)
230         before, after = elt.maybe_empty_lines(current_line)
231         for _ in range(before):
232             dst_contents += str(empty_line)
233         for line in split_line(current_line, line_length=line_length, py36=py36):
234             dst_contents += str(line)
235     return dst_contents
236
237
238 GRAMMARS = [
239     pygram.python_grammar_no_print_statement_no_exec_statement,
240     pygram.python_grammar_no_print_statement,
241     pygram.python_grammar_no_exec_statement,
242     pygram.python_grammar,
243 ]
244
245
246 def lib2to3_parse(src_txt: str) -> Node:
247     """Given a string with source, return the lib2to3 Node."""
248     grammar = pygram.python_grammar_no_print_statement
249     if src_txt[-1] != '\n':
250         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
251         src_txt += nl
252     for grammar in GRAMMARS:
253         drv = driver.Driver(grammar, pytree.convert)
254         try:
255             result = drv.parse_string(src_txt, True)
256             break
257
258         except ParseError as pe:
259             lineno, column = pe.context[1]
260             lines = src_txt.splitlines()
261             try:
262                 faulty_line = lines[lineno - 1]
263             except IndexError:
264                 faulty_line = "<line number missing in source>"
265             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
266     else:
267         raise exc from None
268
269     if isinstance(result, Leaf):
270         result = Node(syms.file_input, [result])
271     return result
272
273
274 def lib2to3_unparse(node: Node) -> str:
275     """Given a lib2to3 node, return its string representation."""
276     code = str(node)
277     return code
278
279
280 T = TypeVar('T')
281
282
283 class Visitor(Generic[T]):
284     """Basic lib2to3 visitor that yields things on visiting."""
285
286     def visit(self, node: LN) -> Iterator[T]:
287         if node.type < 256:
288             name = token.tok_name[node.type]
289         else:
290             name = type_repr(node.type)
291         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
292
293     def visit_default(self, node: LN) -> Iterator[T]:
294         if isinstance(node, Node):
295             for child in node.children:
296                 yield from self.visit(child)
297
298
299 @dataclass
300 class DebugVisitor(Visitor[T]):
301     tree_depth: int = 0
302
303     def visit_default(self, node: LN) -> Iterator[T]:
304         indent = ' ' * (2 * self.tree_depth)
305         if isinstance(node, Node):
306             _type = type_repr(node.type)
307             out(f'{indent}{_type}', fg='yellow')
308             self.tree_depth += 1
309             for child in node.children:
310                 yield from self.visit(child)
311
312             self.tree_depth -= 1
313             out(f'{indent}/{_type}', fg='yellow', bold=False)
314         else:
315             _type = token.tok_name.get(node.type, str(node.type))
316             out(f'{indent}{_type}', fg='blue', nl=False)
317             if node.prefix:
318                 # We don't have to handle prefixes for `Node` objects since
319                 # that delegates to the first child anyway.
320                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
321             out(f' {node.value!r}', fg='blue', bold=False)
322
323
324 KEYWORDS = set(keyword.kwlist)
325 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
326 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
327 STATEMENT = {
328     syms.if_stmt,
329     syms.while_stmt,
330     syms.for_stmt,
331     syms.try_stmt,
332     syms.except_clause,
333     syms.with_stmt,
334     syms.funcdef,
335     syms.classdef,
336 }
337 STANDALONE_COMMENT = 153
338 LOGIC_OPERATORS = {'and', 'or'}
339 COMPARATORS = {
340     token.LESS,
341     token.GREATER,
342     token.EQEQUAL,
343     token.NOTEQUAL,
344     token.LESSEQUAL,
345     token.GREATEREQUAL,
346 }
347 MATH_OPERATORS = {
348     token.PLUS,
349     token.MINUS,
350     token.STAR,
351     token.SLASH,
352     token.VBAR,
353     token.AMPER,
354     token.PERCENT,
355     token.CIRCUMFLEX,
356     token.TILDE,
357     token.LEFTSHIFT,
358     token.RIGHTSHIFT,
359     token.DOUBLESTAR,
360     token.DOUBLESLASH,
361 }
362 COMPREHENSION_PRIORITY = 20
363 COMMA_PRIORITY = 10
364 LOGIC_PRIORITY = 5
365 STRING_PRIORITY = 4
366 COMPARATOR_PRIORITY = 3
367 MATH_PRIORITY = 1
368
369
370 @dataclass
371 class BracketTracker:
372     depth: int = 0
373     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
374     delimiters: Dict[LeafID, Priority] = Factory(dict)
375     previous: Optional[Leaf] = None
376
377     def mark(self, leaf: Leaf) -> None:
378         if leaf.type == token.COMMENT:
379             return
380
381         if leaf.type in CLOSING_BRACKETS:
382             self.depth -= 1
383             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
384             leaf.opening_bracket = opening_bracket
385         leaf.bracket_depth = self.depth
386         if self.depth == 0:
387             delim = is_delimiter(leaf)
388             if delim:
389                 self.delimiters[id(leaf)] = delim
390             elif self.previous is not None:
391                 if leaf.type == token.STRING and self.previous.type == token.STRING:
392                     self.delimiters[id(self.previous)] = STRING_PRIORITY
393                 elif (
394                     leaf.type == token.NAME
395                     and leaf.value == 'for'
396                     and leaf.parent
397                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
398                 ):
399                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
400                 elif (
401                     leaf.type == token.NAME
402                     and leaf.value == 'if'
403                     and leaf.parent
404                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
405                 ):
406                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
407                 elif (
408                     leaf.type == token.NAME
409                     and leaf.value in LOGIC_OPERATORS
410                     and leaf.parent
411                 ):
412                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
413         if leaf.type in OPENING_BRACKETS:
414             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
415             self.depth += 1
416         self.previous = leaf
417
418     def any_open_brackets(self) -> bool:
419         """Returns True if there is an yet unmatched open bracket on the line."""
420         return bool(self.bracket_match)
421
422     def max_priority(self, exclude: Iterable[LeafID] = ()) -> int:
423         """Returns the highest priority of a delimiter found on the line.
424
425         Values are consistent with what `is_delimiter()` returns.
426         """
427         return max(v for k, v in self.delimiters.items() if k not in exclude)
428
429
430 @dataclass
431 class Line:
432     depth: int = 0
433     leaves: List[Leaf] = Factory(list)
434     comments: Dict[LeafID, Leaf] = Factory(dict)
435     bracket_tracker: BracketTracker = Factory(BracketTracker)
436     inside_brackets: bool = False
437     has_for: bool = False
438     _for_loop_variable: bool = False
439
440     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
441         has_value = leaf.value.strip()
442         if not has_value:
443             return
444
445         if self.leaves and not preformatted:
446             # Note: at this point leaf.prefix should be empty except for
447             # imports, for which we only preserve newlines.
448             leaf.prefix += whitespace(leaf)
449         if self.inside_brackets or not preformatted:
450             self.maybe_decrement_after_for_loop_variable(leaf)
451             self.bracket_tracker.mark(leaf)
452             self.maybe_remove_trailing_comma(leaf)
453             self.maybe_increment_for_loop_variable(leaf)
454             if self.maybe_adapt_standalone_comment(leaf):
455                 return
456
457         if not self.append_comment(leaf):
458             self.leaves.append(leaf)
459
460     @property
461     def is_comment(self) -> bool:
462         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
463
464     @property
465     def is_decorator(self) -> bool:
466         return bool(self) and self.leaves[0].type == token.AT
467
468     @property
469     def is_import(self) -> bool:
470         return bool(self) and is_import(self.leaves[0])
471
472     @property
473     def is_class(self) -> bool:
474         return (
475             bool(self)
476             and self.leaves[0].type == token.NAME
477             and self.leaves[0].value == 'class'
478         )
479
480     @property
481     def is_def(self) -> bool:
482         """Also returns True for async defs."""
483         try:
484             first_leaf = self.leaves[0]
485         except IndexError:
486             return False
487
488         try:
489             second_leaf: Optional[Leaf] = self.leaves[1]
490         except IndexError:
491             second_leaf = None
492         return (
493             (first_leaf.type == token.NAME and first_leaf.value == 'def')
494             or (
495                 first_leaf.type == token.ASYNC
496                 and second_leaf is not None
497                 and second_leaf.type == token.NAME
498                 and second_leaf.value == 'def'
499             )
500         )
501
502     @property
503     def is_flow_control(self) -> bool:
504         return (
505             bool(self)
506             and self.leaves[0].type == token.NAME
507             and self.leaves[0].value in FLOW_CONTROL
508         )
509
510     @property
511     def is_yield(self) -> bool:
512         return (
513             bool(self)
514             and self.leaves[0].type == token.NAME
515             and self.leaves[0].value == 'yield'
516         )
517
518     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
519         if not (
520             self.leaves
521             and self.leaves[-1].type == token.COMMA
522             and closing.type in CLOSING_BRACKETS
523         ):
524             return False
525
526         if closing.type == token.RBRACE:
527             self.leaves.pop()
528             return True
529
530         if closing.type == token.RSQB:
531             comma = self.leaves[-1]
532             if comma.parent and comma.parent.type == syms.listmaker:
533                 self.leaves.pop()
534                 return True
535
536         # For parens let's check if it's safe to remove the comma.  If the
537         # trailing one is the only one, we might mistakenly change a tuple
538         # into a different type by removing the comma.
539         depth = closing.bracket_depth + 1
540         commas = 0
541         opening = closing.opening_bracket
542         for _opening_index, leaf in enumerate(self.leaves):
543             if leaf is opening:
544                 break
545
546         else:
547             return False
548
549         for leaf in self.leaves[_opening_index + 1:]:
550             if leaf is closing:
551                 break
552
553             bracket_depth = leaf.bracket_depth
554             if bracket_depth == depth and leaf.type == token.COMMA:
555                 commas += 1
556                 if leaf.parent and leaf.parent.type == syms.arglist:
557                     commas += 1
558                     break
559
560         if commas > 1:
561             self.leaves.pop()
562             return True
563
564         return False
565
566     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
567         """In a for loop, or comprehension, the variables are often unpacks.
568
569         To avoid splitting on the comma in this situation, we will increase
570         the depth of tokens between `for` and `in`.
571         """
572         if leaf.type == token.NAME and leaf.value == 'for':
573             self.has_for = True
574             self.bracket_tracker.depth += 1
575             self._for_loop_variable = True
576             return True
577
578         return False
579
580     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
581         # See `maybe_increment_for_loop_variable` above for explanation.
582         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
583             self.bracket_tracker.depth -= 1
584             self._for_loop_variable = False
585             return True
586
587         return False
588
589     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
590         """Hack a standalone comment to act as a trailing comment for line splitting.
591
592         If this line has brackets and a standalone `comment`, we need to adapt
593         it to be able to still reformat the line.
594
595         This is not perfect, the line to which the standalone comment gets
596         appended will appear "too long" when splitting.
597         """
598         if not (
599             comment.type == STANDALONE_COMMENT
600             and self.bracket_tracker.any_open_brackets()
601         ):
602             return False
603
604         comment.type = token.COMMENT
605         comment.prefix = '\n' + '    ' * (self.depth + 1)
606         return self.append_comment(comment)
607
608     def append_comment(self, comment: Leaf) -> bool:
609         if comment.type != token.COMMENT:
610             return False
611
612         try:
613             after = id(self.last_non_delimiter())
614         except LookupError:
615             comment.type = STANDALONE_COMMENT
616             comment.prefix = ''
617             return False
618
619         else:
620             if after in self.comments:
621                 self.comments[after].value += str(comment)
622             else:
623                 self.comments[after] = comment
624             return True
625
626     def last_non_delimiter(self) -> Leaf:
627         for i in range(len(self.leaves)):
628             last = self.leaves[-i - 1]
629             if not is_delimiter(last):
630                 return last
631
632         raise LookupError("No non-delimiters found")
633
634     def __str__(self) -> str:
635         if not self:
636             return '\n'
637
638         indent = '    ' * self.depth
639         leaves = iter(self.leaves)
640         first = next(leaves)
641         res = f'{first.prefix}{indent}{first.value}'
642         for leaf in leaves:
643             res += str(leaf)
644         for comment in self.comments.values():
645             res += str(comment)
646         return res + '\n'
647
648     def __bool__(self) -> bool:
649         return bool(self.leaves or self.comments)
650
651
652 @dataclass
653 class EmptyLineTracker:
654     """Provides a stateful method that returns the number of potential extra
655     empty lines needed before and after the currently processed line.
656
657     Note: this tracker works on lines that haven't been split yet.  It assumes
658     the prefix of the first leaf consists of optional newlines.  Those newlines
659     are consumed by `maybe_empty_lines()` and included in the computation.
660     """
661     previous_line: Optional[Line] = None
662     previous_after: int = 0
663     previous_defs: List[int] = Factory(list)
664
665     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
666         """Returns the number of extra empty lines before and after the `current_line`.
667
668         This is for separating `def`, `async def` and `class` with extra empty lines
669         (two on module-level), as well as providing an extra empty line after flow
670         control keywords to make them more prominent.
671         """
672         before, after = self._maybe_empty_lines(current_line)
673         before -= self.previous_after
674         self.previous_after = after
675         self.previous_line = current_line
676         return before, after
677
678     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
679         max_allowed = 1
680         if current_line.is_comment and current_line.depth == 0:
681             max_allowed = 2
682         if current_line.leaves:
683             # Consume the first leaf's extra newlines.
684             first_leaf = current_line.leaves[0]
685             before = first_leaf.prefix.count('\n')
686             before = min(before, max(before, max_allowed))
687             first_leaf.prefix = ''
688         else:
689             before = 0
690         depth = current_line.depth
691         while self.previous_defs and self.previous_defs[-1] >= depth:
692             self.previous_defs.pop()
693             before = 1 if depth else 2
694         is_decorator = current_line.is_decorator
695         if is_decorator or current_line.is_def or current_line.is_class:
696             if not is_decorator:
697                 self.previous_defs.append(depth)
698             if self.previous_line is None:
699                 # Don't insert empty lines before the first line in the file.
700                 return 0, 0
701
702             if self.previous_line and self.previous_line.is_decorator:
703                 # Don't insert empty lines between decorators.
704                 return 0, 0
705
706             newlines = 2
707             if current_line.depth:
708                 newlines -= 1
709             return newlines, 0
710
711         if current_line.is_flow_control:
712             return before, 1
713
714         if (
715             self.previous_line
716             and self.previous_line.is_import
717             and not current_line.is_import
718             and depth == self.previous_line.depth
719         ):
720             return (before or 1), 0
721
722         if (
723             self.previous_line
724             and self.previous_line.is_yield
725             and (not current_line.is_yield or depth != self.previous_line.depth)
726         ):
727             return (before or 1), 0
728
729         return before, 0
730
731
732 @dataclass
733 class LineGenerator(Visitor[Line]):
734     """Generates reformatted Line objects.  Empty lines are not emitted.
735
736     Note: destroys the tree it's visiting by mutating prefixes of its leaves
737     in ways that will no longer stringify to valid Python code on the tree.
738     """
739     current_line: Line = Factory(Line)
740
741     def line(self, indent: int = 0) -> Iterator[Line]:
742         """Generate a line.
743
744         If the line is empty, only emit if it makes sense.
745         If the line is too long, split it first and then generate.
746
747         If any lines were generated, set up a new current_line.
748         """
749         if not self.current_line:
750             self.current_line.depth += indent
751             return  # Line is empty, don't emit. Creating a new one unnecessary.
752
753         complete_line = self.current_line
754         self.current_line = Line(depth=complete_line.depth + indent)
755         yield complete_line
756
757     def visit_default(self, node: LN) -> Iterator[Line]:
758         if isinstance(node, Leaf):
759             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
760             for comment in generate_comments(node):
761                 if any_open_brackets:
762                     # any comment within brackets is subject to splitting
763                     self.current_line.append(comment)
764                 elif comment.type == token.COMMENT:
765                     # regular trailing comment
766                     self.current_line.append(comment)
767                     yield from self.line()
768
769                 else:
770                     # regular standalone comment
771                     yield from self.line()
772
773                     self.current_line.append(comment)
774                     yield from self.line()
775
776             normalize_prefix(node, inside_brackets=any_open_brackets)
777             if node.type not in WHITESPACE:
778                 self.current_line.append(node)
779         yield from super().visit_default(node)
780
781     def visit_INDENT(self, node: Node) -> Iterator[Line]:
782         yield from self.line(+1)
783         yield from self.visit_default(node)
784
785     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
786         yield from self.line(-1)
787
788     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
789         """Visit a statement.
790
791         The relevant Python language keywords for this statement are NAME leaves
792         within it.
793         """
794         for child in node.children:
795             if child.type == token.NAME and child.value in keywords:  # type: ignore
796                 yield from self.line()
797
798             yield from self.visit(child)
799
800     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
801         """A statement without nested statements."""
802         is_suite_like = node.parent and node.parent.type in STATEMENT
803         if is_suite_like:
804             yield from self.line(+1)
805             yield from self.visit_default(node)
806             yield from self.line(-1)
807
808         else:
809             yield from self.line()
810             yield from self.visit_default(node)
811
812     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
813         yield from self.line()
814
815         children = iter(node.children)
816         for child in children:
817             yield from self.visit(child)
818
819             if child.type == token.ASYNC:
820                 break
821
822         internal_stmt = next(children)
823         for child in internal_stmt.children:
824             yield from self.visit(child)
825
826     def visit_decorators(self, node: Node) -> Iterator[Line]:
827         for child in node.children:
828             yield from self.line()
829             yield from self.visit(child)
830
831     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
832         yield from self.line()
833
834     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
835         yield from self.visit_default(leaf)
836         yield from self.line()
837
838     def __attrs_post_init__(self) -> None:
839         """You are in a twisty little maze of passages."""
840         v = self.visit_stmt
841         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
842         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
843         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
844         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
845         self.visit_except_clause = partial(v, keywords={'except'})
846         self.visit_funcdef = partial(v, keywords={'def'})
847         self.visit_with_stmt = partial(v, keywords={'with'})
848         self.visit_classdef = partial(v, keywords={'class'})
849         self.visit_async_funcdef = self.visit_async_stmt
850         self.visit_decorated = self.visit_decorators
851
852
853 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
854 OPENING_BRACKETS = set(BRACKET.keys())
855 CLOSING_BRACKETS = set(BRACKET.values())
856 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
857 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
858
859
860 def whitespace(leaf: Leaf) -> str:  # noqa C901
861     """Return whitespace prefix if needed for the given `leaf`."""
862     NO = ''
863     SPACE = ' '
864     DOUBLESPACE = '  '
865     t = leaf.type
866     p = leaf.parent
867     v = leaf.value
868     if t in ALWAYS_NO_SPACE:
869         return NO
870
871     if t == token.COMMENT:
872         return DOUBLESPACE
873
874     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
875     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
876         return NO
877
878     prev = leaf.prev_sibling
879     if not prev:
880         prevp = preceding_leaf(p)
881         if not prevp or prevp.type in OPENING_BRACKETS:
882             return NO
883
884         if t == token.COLON:
885             return SPACE if prevp.type == token.COMMA else NO
886
887         if prevp.type == token.EQUAL:
888             if prevp.parent:
889                 if prevp.parent.type in {
890                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
891                 }:
892                     return NO
893
894                 elif prevp.parent.type == syms.typedargslist:
895                     # A bit hacky: if the equal sign has whitespace, it means we
896                     # previously found it's a typed argument.  So, we're using
897                     # that, too.
898                     return prevp.prefix
899
900         elif prevp.type == token.DOUBLESTAR:
901             if prevp.parent and prevp.parent.type in {
902                 syms.arglist,
903                 syms.argument,
904                 syms.dictsetmaker,
905                 syms.parameters,
906                 syms.typedargslist,
907                 syms.varargslist,
908             }:
909                 return NO
910
911         elif prevp.type == token.COLON:
912             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
913                 return NO
914
915         elif (
916             prevp.parent
917             and prevp.parent.type in {syms.factor, syms.star_expr}
918             and prevp.type in MATH_OPERATORS
919         ):
920             return NO
921
922         elif (
923             prevp.type == token.RIGHTSHIFT
924             and prevp.parent
925             and prevp.parent.type == syms.shift_expr
926             and prevp.prev_sibling
927             and prevp.prev_sibling.type == token.NAME
928             and prevp.prev_sibling.value == 'print'  # type: ignore
929         ):
930             # Python 2 print chevron
931             return NO
932
933     elif prev.type in OPENING_BRACKETS:
934         return NO
935
936     if p.type in {syms.parameters, syms.arglist}:
937         # untyped function signatures or calls
938         if t == token.RPAR:
939             return NO
940
941         if not prev or prev.type != token.COMMA:
942             return NO
943
944     elif p.type == syms.varargslist:
945         # lambdas
946         if t == token.RPAR:
947             return NO
948
949         if prev and prev.type != token.COMMA:
950             return NO
951
952     elif p.type == syms.typedargslist:
953         # typed function signatures
954         if not prev:
955             return NO
956
957         if t == token.EQUAL:
958             if prev.type != syms.tname:
959                 return NO
960
961         elif prev.type == token.EQUAL:
962             # A bit hacky: if the equal sign has whitespace, it means we
963             # previously found it's a typed argument.  So, we're using that, too.
964             return prev.prefix
965
966         elif prev.type != token.COMMA:
967             return NO
968
969     elif p.type == syms.tname:
970         # type names
971         if not prev:
972             prevp = preceding_leaf(p)
973             if not prevp or prevp.type != token.COMMA:
974                 return NO
975
976     elif p.type == syms.trailer:
977         # attributes and calls
978         if t == token.LPAR or t == token.RPAR:
979             return NO
980
981         if not prev:
982             if t == token.DOT:
983                 prevp = preceding_leaf(p)
984                 if not prevp or prevp.type != token.NUMBER:
985                     return NO
986
987             elif t == token.LSQB:
988                 return NO
989
990         elif prev.type != token.COMMA:
991             return NO
992
993     elif p.type == syms.argument:
994         # single argument
995         if t == token.EQUAL:
996             return NO
997
998         if not prev:
999             prevp = preceding_leaf(p)
1000             if not prevp or prevp.type == token.LPAR:
1001                 return NO
1002
1003         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1004             return NO
1005
1006     elif p.type == syms.decorator:
1007         # decorators
1008         return NO
1009
1010     elif p.type == syms.dotted_name:
1011         if prev:
1012             return NO
1013
1014         prevp = preceding_leaf(p)
1015         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1016             return NO
1017
1018     elif p.type == syms.classdef:
1019         if t == token.LPAR:
1020             return NO
1021
1022         if prev and prev.type == token.LPAR:
1023             return NO
1024
1025     elif p.type == syms.subscript:
1026         # indexing
1027         if not prev:
1028             assert p.parent is not None, "subscripts are always parented"
1029             if p.parent.type == syms.subscriptlist:
1030                 return SPACE
1031
1032             return NO
1033
1034         else:
1035             return NO
1036
1037     elif p.type == syms.atom:
1038         if prev and t == token.DOT:
1039             # dots, but not the first one.
1040             return NO
1041
1042     elif (
1043         p.type == syms.listmaker
1044         or p.type == syms.testlist_gexp
1045         or p.type == syms.subscriptlist
1046     ):
1047         # list interior, including unpacking
1048         if not prev:
1049             return NO
1050
1051     elif p.type == syms.dictsetmaker:
1052         # dict and set interior, including unpacking
1053         if not prev:
1054             return NO
1055
1056         if prev.type == token.DOUBLESTAR:
1057             return NO
1058
1059     elif p.type in {syms.factor, syms.star_expr}:
1060         # unary ops
1061         if not prev:
1062             prevp = preceding_leaf(p)
1063             if not prevp or prevp.type in OPENING_BRACKETS:
1064                 return NO
1065
1066             prevp_parent = prevp.parent
1067             assert prevp_parent is not None
1068             if prevp.type == token.COLON and prevp_parent.type in {
1069                 syms.subscript, syms.sliceop
1070             }:
1071                 return NO
1072
1073             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1074                 return NO
1075
1076         elif t == token.NAME or t == token.NUMBER:
1077             return NO
1078
1079     elif p.type == syms.import_from:
1080         if t == token.DOT:
1081             if prev and prev.type == token.DOT:
1082                 return NO
1083
1084         elif t == token.NAME:
1085             if v == 'import':
1086                 return SPACE
1087
1088             if prev and prev.type == token.DOT:
1089                 return NO
1090
1091     elif p.type == syms.sliceop:
1092         return NO
1093
1094     return SPACE
1095
1096
1097 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1098     """Returns the first leaf that precedes `node`, if any."""
1099     while node:
1100         res = node.prev_sibling
1101         if res:
1102             if isinstance(res, Leaf):
1103                 return res
1104
1105             try:
1106                 return list(res.leaves())[-1]
1107
1108             except IndexError:
1109                 return None
1110
1111         node = node.parent
1112     return None
1113
1114
1115 def is_delimiter(leaf: Leaf) -> int:
1116     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1117
1118     Higher numbers are higher priority.
1119     """
1120     if leaf.type == token.COMMA:
1121         return COMMA_PRIORITY
1122
1123     if leaf.type in COMPARATORS:
1124         return COMPARATOR_PRIORITY
1125
1126     if (
1127         leaf.type in MATH_OPERATORS
1128         and leaf.parent
1129         and leaf.parent.type not in {syms.factor, syms.star_expr}
1130     ):
1131         return MATH_PRIORITY
1132
1133     return 0
1134
1135
1136 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1137     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1138
1139     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1140     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1141     move because it does away with modifying the grammar to include all the
1142     possible places in which comments can be placed.
1143
1144     The sad consequence for us though is that comments don't "belong" anywhere.
1145     This is why this function generates simple parentless Leaf objects for
1146     comments.  We simply don't know what the correct parent should be.
1147
1148     No matter though, we can live without this.  We really only need to
1149     differentiate between inline and standalone comments.  The latter don't
1150     share the line with any code.
1151
1152     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1153     are emitted with a fake STANDALONE_COMMENT token identifier.
1154     """
1155     p = leaf.prefix
1156     if not p:
1157         return
1158
1159     if '#' not in p:
1160         return
1161
1162     nlines = 0
1163     for index, line in enumerate(p.split('\n')):
1164         line = line.lstrip()
1165         if not line:
1166             nlines += 1
1167         if not line.startswith('#'):
1168             continue
1169
1170         if index == 0 and leaf.type != token.ENDMARKER:
1171             comment_type = token.COMMENT  # simple trailing comment
1172         else:
1173             comment_type = STANDALONE_COMMENT
1174         yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
1175
1176         nlines = 0
1177
1178
1179 def make_comment(content: str) -> str:
1180     content = content.rstrip()
1181     if not content:
1182         return '#'
1183
1184     if content[0] == '#':
1185         content = content[1:]
1186     if content and content[0] not in ' !:#':
1187         content = ' ' + content
1188     return '#' + content
1189
1190
1191 def split_line(
1192     line: Line, line_length: int, inner: bool = False, py36: bool = False
1193 ) -> Iterator[Line]:
1194     """Splits a `line` into potentially many lines.
1195
1196     They should fit in the allotted `line_length` but might not be able to.
1197     `inner` signifies that there were a pair of brackets somewhere around the
1198     current `line`, possibly transitively. This means we can fallback to splitting
1199     by delimiters if the LHS/RHS don't yield any results.
1200
1201     If `py36` is True, splitting may generate syntax that is only compatible
1202     with Python 3.6 and later.
1203     """
1204     line_str = str(line).strip('\n')
1205     if len(line_str) <= line_length and '\n' not in line_str:
1206         yield line
1207         return
1208
1209     if line.is_def:
1210         split_funcs = [left_hand_split]
1211     elif line.inside_brackets:
1212         split_funcs = [delimiter_split]
1213         if '\n' not in line_str:
1214             # Only attempt RHS if we don't have multiline strings or comments
1215             # on this line.
1216             split_funcs.append(right_hand_split)
1217     else:
1218         split_funcs = [right_hand_split]
1219     for split_func in split_funcs:
1220         # We are accumulating lines in `result` because we might want to abort
1221         # mission and return the original line in the end, or attempt a different
1222         # split altogether.
1223         result: List[Line] = []
1224         try:
1225             for l in split_func(line, py36=py36):
1226                 if str(l).strip('\n') == line_str:
1227                     raise CannotSplit("Split function returned an unchanged result")
1228
1229                 result.extend(
1230                     split_line(l, line_length=line_length, inner=True, py36=py36)
1231                 )
1232         except CannotSplit as cs:
1233             continue
1234
1235         else:
1236             yield from result
1237             break
1238
1239     else:
1240         yield line
1241
1242
1243 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1244     """Split line into many lines, starting with the first matching bracket pair.
1245
1246     Note: this usually looks weird, only use this for function definitions.
1247     Prefer RHS otherwise.
1248     """
1249     head = Line(depth=line.depth)
1250     body = Line(depth=line.depth + 1, inside_brackets=True)
1251     tail = Line(depth=line.depth)
1252     tail_leaves: List[Leaf] = []
1253     body_leaves: List[Leaf] = []
1254     head_leaves: List[Leaf] = []
1255     current_leaves = head_leaves
1256     matching_bracket = None
1257     for leaf in line.leaves:
1258         if (
1259             current_leaves is body_leaves
1260             and leaf.type in CLOSING_BRACKETS
1261             and leaf.opening_bracket is matching_bracket
1262         ):
1263             current_leaves = tail_leaves if body_leaves else head_leaves
1264         current_leaves.append(leaf)
1265         if current_leaves is head_leaves:
1266             if leaf.type in OPENING_BRACKETS:
1267                 matching_bracket = leaf
1268                 current_leaves = body_leaves
1269     # Since body is a new indent level, remove spurious leading whitespace.
1270     if body_leaves:
1271         normalize_prefix(body_leaves[0], inside_brackets=True)
1272     # Build the new lines.
1273     for result, leaves in (
1274         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1275     ):
1276         for leaf in leaves:
1277             result.append(leaf, preformatted=True)
1278             comment_after = line.comments.get(id(leaf))
1279             if comment_after:
1280                 result.append(comment_after, preformatted=True)
1281     split_succeeded_or_raise(head, body, tail)
1282     for result in (head, body, tail):
1283         if result:
1284             yield result
1285
1286
1287 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1288     """Split line into many lines, starting with the last matching bracket pair."""
1289     head = Line(depth=line.depth)
1290     body = Line(depth=line.depth + 1, inside_brackets=True)
1291     tail = Line(depth=line.depth)
1292     tail_leaves: List[Leaf] = []
1293     body_leaves: List[Leaf] = []
1294     head_leaves: List[Leaf] = []
1295     current_leaves = tail_leaves
1296     opening_bracket = None
1297     for leaf in reversed(line.leaves):
1298         if current_leaves is body_leaves:
1299             if leaf is opening_bracket:
1300                 current_leaves = head_leaves if body_leaves else tail_leaves
1301         current_leaves.append(leaf)
1302         if current_leaves is tail_leaves:
1303             if leaf.type in CLOSING_BRACKETS:
1304                 opening_bracket = leaf.opening_bracket
1305                 current_leaves = body_leaves
1306     tail_leaves.reverse()
1307     body_leaves.reverse()
1308     head_leaves.reverse()
1309     # Since body is a new indent level, remove spurious leading whitespace.
1310     if body_leaves:
1311         normalize_prefix(body_leaves[0], inside_brackets=True)
1312     # Build the new lines.
1313     for result, leaves in (
1314         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1315     ):
1316         for leaf in leaves:
1317             result.append(leaf, preformatted=True)
1318             comment_after = line.comments.get(id(leaf))
1319             if comment_after:
1320                 result.append(comment_after, preformatted=True)
1321     split_succeeded_or_raise(head, body, tail)
1322     for result in (head, body, tail):
1323         if result:
1324             yield result
1325
1326
1327 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1328     tail_len = len(str(tail).strip())
1329     if not body:
1330         if tail_len == 0:
1331             raise CannotSplit("Splitting brackets produced the same line")
1332
1333         elif tail_len < 3:
1334             raise CannotSplit(
1335                 f"Splitting brackets on an empty body to save "
1336                 f"{tail_len} characters is not worth it"
1337             )
1338
1339
1340 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1341     """Split according to delimiters of the highest priority.
1342
1343     This kind of split doesn't increase indentation.
1344     If `py36` is True, the split will add trailing commas also in function
1345     signatures that contain * and **.
1346     """
1347     try:
1348         last_leaf = line.leaves[-1]
1349     except IndexError:
1350         raise CannotSplit("Line empty")
1351
1352     delimiters = line.bracket_tracker.delimiters
1353     try:
1354         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1355     except ValueError:
1356         raise CannotSplit("No delimiters found")
1357
1358     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1359     lowest_depth = sys.maxsize
1360     trailing_comma_safe = True
1361     for leaf in line.leaves:
1362         current_line.append(leaf, preformatted=True)
1363         comment_after = line.comments.get(id(leaf))
1364         if comment_after:
1365             current_line.append(comment_after, preformatted=True)
1366         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1367         if (
1368             leaf.bracket_depth == lowest_depth
1369             and leaf.type == token.STAR
1370             or leaf.type == token.DOUBLESTAR
1371         ):
1372             trailing_comma_safe = trailing_comma_safe and py36
1373         leaf_priority = delimiters.get(id(leaf))
1374         if leaf_priority == delimiter_priority:
1375             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1376             yield current_line
1377
1378             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1379     if current_line:
1380         if (
1381             delimiter_priority == COMMA_PRIORITY
1382             and current_line.leaves[-1].type != token.COMMA
1383             and trailing_comma_safe
1384         ):
1385             current_line.append(Leaf(token.COMMA, ','))
1386         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1387         yield current_line
1388
1389
1390 def is_import(leaf: Leaf) -> bool:
1391     """Returns True if the given leaf starts an import statement."""
1392     p = leaf.parent
1393     t = leaf.type
1394     v = leaf.value
1395     return bool(
1396         t == token.NAME
1397         and (
1398             (v == 'import' and p and p.type == syms.import_name)
1399             or (v == 'from' and p and p.type == syms.import_from)
1400         )
1401     )
1402
1403
1404 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1405     """Leave existing extra newlines if not `inside_brackets`.
1406
1407     Remove everything else.  Note: don't use backslashes for formatting or
1408     you'll lose your voting rights.
1409     """
1410     if not inside_brackets:
1411         spl = leaf.prefix.split('#')
1412         if '\\' not in spl[0]:
1413             nl_count = spl[-1].count('\n')
1414             if len(spl) > 1:
1415                 nl_count -= 1
1416             leaf.prefix = '\n' * nl_count
1417             return
1418
1419     leaf.prefix = ''
1420
1421
1422 def is_python36(node: Node) -> bool:
1423     """Returns True if the current file is using Python 3.6+ features.
1424
1425     Currently looking for:
1426     - f-strings; and
1427     - trailing commas after * or ** in function signatures.
1428     """
1429     for n in node.pre_order():
1430         if n.type == token.STRING:
1431             value_head = n.value[:2]  # type: ignore
1432             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1433                 return True
1434
1435         elif (
1436             n.type == syms.typedargslist
1437             and n.children
1438             and n.children[-1].type == token.COMMA
1439         ):
1440             for ch in n.children:
1441                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1442                     return True
1443
1444     return False
1445
1446
1447 PYTHON_EXTENSIONS = {'.py'}
1448 BLACKLISTED_DIRECTORIES = {
1449     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1450 }
1451
1452
1453 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1454     for child in path.iterdir():
1455         if child.is_dir():
1456             if child.name in BLACKLISTED_DIRECTORIES:
1457                 continue
1458
1459             yield from gen_python_files_in_dir(child)
1460
1461         elif child.suffix in PYTHON_EXTENSIONS:
1462             yield child
1463
1464
1465 @dataclass
1466 class Report:
1467     """Provides a reformatting counter."""
1468     check: bool = False
1469     change_count: int = 0
1470     same_count: int = 0
1471     failure_count: int = 0
1472
1473     def done(self, src: Path, changed: bool) -> None:
1474         """Increment the counter for successful reformatting. Write out a message."""
1475         if changed:
1476             reformatted = 'would reformat' if self.check else 'reformatted'
1477             out(f'{reformatted} {src}')
1478             self.change_count += 1
1479         else:
1480             out(f'{src} already well formatted, good job.', bold=False)
1481             self.same_count += 1
1482
1483     def failed(self, src: Path, message: str) -> None:
1484         """Increment the counter for failed reformatting. Write out a message."""
1485         err(f'error: cannot format {src}: {message}')
1486         self.failure_count += 1
1487
1488     @property
1489     def return_code(self) -> int:
1490         """Which return code should the app use considering the current state."""
1491         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1492         # 126 we have special returncodes reserved by the shell.
1493         if self.failure_count:
1494             return 123
1495
1496         elif self.change_count and self.check:
1497             return 1
1498
1499         return 0
1500
1501     def __str__(self) -> str:
1502         """A color report of the current state.
1503
1504         Use `click.unstyle` to remove colors.
1505         """
1506         if self.check:
1507             reformatted = "would be reformatted"
1508             unchanged = "would be left unchanged"
1509             failed = "would fail to reformat"
1510         else:
1511             reformatted = "reformatted"
1512             unchanged = "left unchanged"
1513             failed = "failed to reformat"
1514         report = []
1515         if self.change_count:
1516             s = 's' if self.change_count > 1 else ''
1517             report.append(
1518                 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1519             )
1520         if self.same_count:
1521             s = 's' if self.same_count > 1 else ''
1522             report.append(f'{self.same_count} file{s} {unchanged}')
1523         if self.failure_count:
1524             s = 's' if self.failure_count > 1 else ''
1525             report.append(
1526                 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1527             )
1528         return ', '.join(report) + '.'
1529
1530
1531 def assert_equivalent(src: str, dst: str) -> None:
1532     """Raises AssertionError if `src` and `dst` aren't equivalent.
1533
1534     This is a temporary sanity check until Black becomes stable.
1535     """
1536
1537     import ast
1538     import traceback
1539
1540     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1541         """Simple visitor generating strings to compare ASTs by content."""
1542         yield f"{'  ' * depth}{node.__class__.__name__}("
1543
1544         for field in sorted(node._fields):
1545             try:
1546                 value = getattr(node, field)
1547             except AttributeError:
1548                 continue
1549
1550             yield f"{'  ' * (depth+1)}{field}="
1551
1552             if isinstance(value, list):
1553                 for item in value:
1554                     if isinstance(item, ast.AST):
1555                         yield from _v(item, depth + 2)
1556
1557             elif isinstance(value, ast.AST):
1558                 yield from _v(value, depth + 2)
1559
1560             else:
1561                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1562
1563         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1564
1565     try:
1566         src_ast = ast.parse(src)
1567     except Exception as exc:
1568         major, minor = sys.version_info[:2]
1569         raise AssertionError(
1570             f"cannot use --safe with this file; failed to parse source file "
1571             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1572             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1573         )
1574
1575     try:
1576         dst_ast = ast.parse(dst)
1577     except Exception as exc:
1578         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1579         raise AssertionError(
1580             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1581             f"Please report a bug on https://github.com/ambv/black/issues.  "
1582             f"This invalid output might be helpful: {log}"
1583         ) from None
1584
1585     src_ast_str = '\n'.join(_v(src_ast))
1586     dst_ast_str = '\n'.join(_v(dst_ast))
1587     if src_ast_str != dst_ast_str:
1588         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1589         raise AssertionError(
1590             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1591             f"the source.  "
1592             f"Please report a bug on https://github.com/ambv/black/issues.  "
1593             f"This diff might be helpful: {log}"
1594         ) from None
1595
1596
1597 def assert_stable(src: str, dst: str, line_length: int) -> None:
1598     """Raises AssertionError if `dst` reformats differently the second time.
1599
1600     This is a temporary sanity check until Black becomes stable.
1601     """
1602     newdst = format_str(dst, line_length=line_length)
1603     if dst != newdst:
1604         log = dump_to_file(
1605             diff(src, dst, 'source', 'first pass'),
1606             diff(dst, newdst, 'first pass', 'second pass'),
1607         )
1608         raise AssertionError(
1609             f"INTERNAL ERROR: Black produced different code on the second pass "
1610             f"of the formatter.  "
1611             f"Please report a bug on https://github.com/ambv/black/issues.  "
1612             f"This diff might be helpful: {log}"
1613         ) from None
1614
1615
1616 def dump_to_file(*output: str) -> str:
1617     """Dumps `output` to a temporary file. Returns path to the file."""
1618     import tempfile
1619
1620     with tempfile.NamedTemporaryFile(
1621         mode='w', prefix='blk_', suffix='.log', delete=False
1622     ) as f:
1623         for lines in output:
1624             f.write(lines)
1625             f.write('\n')
1626     return f.name
1627
1628
1629 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1630     """Returns a udiff string between strings `a` and `b`."""
1631     import difflib
1632
1633     a_lines = [line + '\n' for line in a.split('\n')]
1634     b_lines = [line + '\n' for line in b.split('\n')]
1635     return ''.join(
1636         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1637     )
1638
1639
1640 if __name__ == '__main__':
1641     main()