]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

53d332addada506c5228406c0529b3ea58922ee6
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a3"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by `format_file` when the reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
48     """
49
50
51 @click.command()
52 @click.option(
53     '-l',
54     '--line-length',
55     type=int,
56     default=DEFAULT_LINE_LENGTH,
57     help='How many character per line to allow.',
58     show_default=True,
59 )
60 @click.option(
61     '--check',
62     is_flag=True,
63     help=(
64         "Don't write back the files, just return the status.  Return code 0 "
65         "means nothing would change.  Return code 1 means some files would be "
66         "reformatted.  Return code 123 means there was an internal error."
67     ),
68 )
69 @click.option(
70     '--fast/--safe',
71     is_flag=True,
72     help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 )
74 @click.version_option(version=__version__)
75 @click.argument(
76     'src',
77     nargs=-1,
78     type=click.Path(
79         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
80     ),
81 )
82 @click.pass_context
83 def main(
84     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 ) -> None:
86     """The uncompromising code formatter."""
87     sources: List[Path] = []
88     for s in src:
89         p = Path(s)
90         if p.is_dir():
91             sources.extend(gen_python_files_in_dir(p))
92         elif p.is_file():
93             # if a file was explicitly given, we don't care about its extension
94             sources.append(p)
95         elif s == '-':
96             sources.append(Path('-'))
97         else:
98             err(f'invalid path: {s}')
99     if len(sources) == 0:
100         ctx.exit(0)
101     elif len(sources) == 1:
102         p = sources[0]
103         report = Report(check=check)
104         try:
105             if not p.is_file() and str(p) == '-':
106                 changed = format_stdin_to_stdout(
107                     line_length=line_length, fast=fast, write_back=not check
108                 )
109             else:
110                 changed = format_file_in_place(
111                     p, line_length=line_length, fast=fast, write_back=not check
112                 )
113             report.done(p, changed)
114         except Exception as exc:
115             report.failed(p, str(exc))
116         ctx.exit(report.return_code)
117     else:
118         loop = asyncio.get_event_loop()
119         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
120         return_code = 1
121         try:
122             return_code = loop.run_until_complete(
123                 schedule_formatting(
124                     sources, line_length, not check, fast, loop, executor
125                 )
126             )
127         finally:
128             loop.close()
129             ctx.exit(return_code)
130
131
132 async def schedule_formatting(
133     sources: List[Path],
134     line_length: int,
135     write_back: bool,
136     fast: bool,
137     loop: BaseEventLoop,
138     executor: Executor,
139 ) -> int:
140     tasks = {
141         src: loop.run_in_executor(
142             executor, format_file_in_place, src, line_length, fast, write_back
143         )
144         for src in sources
145     }
146     await asyncio.wait(tasks.values())
147     cancelled = []
148     report = Report()
149     for src, task in tasks.items():
150         if not task.done():
151             report.failed(src, 'timed out, cancelling')
152             task.cancel()
153             cancelled.append(task)
154         elif task.exception():
155             report.failed(src, str(task.exception()))
156         else:
157             report.done(src, task.result())
158     if cancelled:
159         await asyncio.wait(cancelled, timeout=2)
160     out('All done! ✨ 🍰 ✨')
161     click.echo(str(report))
162     return report.return_code
163
164
165 def format_file_in_place(
166     src: Path, line_length: int, fast: bool, write_back: bool = False
167 ) -> bool:
168     """Format the file and rewrite if changed. Return True if changed."""
169     with tokenize.open(src) as src_buffer:
170         src_contents = src_buffer.read()
171     try:
172         contents = format_file_contents(
173             src_contents, line_length=line_length, fast=fast
174         )
175     except NothingChanged:
176         return False
177
178     if write_back:
179         with open(src, "w", encoding=src_buffer.encoding) as f:
180             f.write(contents)
181     return True
182
183
184 def format_stdin_to_stdout(
185     line_length: int, fast: bool, write_back: bool = False
186 ) -> bool:
187     """Format file on stdin and pipe output to stdout. Return True if changed."""
188     contents = sys.stdin.read()
189     try:
190         contents = format_file_contents(contents, line_length=line_length, fast=fast)
191         return True
192
193     except NothingChanged:
194         return False
195
196     finally:
197         if write_back:
198             sys.stdout.write(contents)
199
200
201 def format_file_contents(
202     src_contents: str, line_length: int, fast: bool
203 ) -> FileContent:
204     """Reformats a file and returns its contents and encoding."""
205     if src_contents.strip() == '':
206         raise NothingChanged
207
208     dst_contents = format_str(src_contents, line_length=line_length)
209     if src_contents == dst_contents:
210         raise NothingChanged
211
212     if not fast:
213         assert_equivalent(src_contents, dst_contents)
214         assert_stable(src_contents, dst_contents, line_length=line_length)
215     return dst_contents
216
217
218 def format_str(src_contents: str, line_length: int) -> FileContent:
219     """Reformats a string and returns new contents."""
220     src_node = lib2to3_parse(src_contents)
221     dst_contents = ""
222     lines = LineGenerator()
223     elt = EmptyLineTracker()
224     py36 = is_python36(src_node)
225     empty_line = Line()
226     after = 0
227     for current_line in lines.visit(src_node):
228         for _ in range(after):
229             dst_contents += str(empty_line)
230         before, after = elt.maybe_empty_lines(current_line)
231         for _ in range(before):
232             dst_contents += str(empty_line)
233         for line in split_line(current_line, line_length=line_length, py36=py36):
234             dst_contents += str(line)
235     return dst_contents
236
237
238 GRAMMARS = [
239     pygram.python_grammar_no_print_statement_no_exec_statement,
240     pygram.python_grammar_no_print_statement,
241     pygram.python_grammar_no_exec_statement,
242     pygram.python_grammar,
243 ]
244
245
246 def lib2to3_parse(src_txt: str) -> Node:
247     """Given a string with source, return the lib2to3 Node."""
248     grammar = pygram.python_grammar_no_print_statement
249     if src_txt[-1] != '\n':
250         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
251         src_txt += nl
252     for grammar in GRAMMARS:
253         drv = driver.Driver(grammar, pytree.convert)
254         try:
255             result = drv.parse_string(src_txt, True)
256             break
257
258         except ParseError as pe:
259             lineno, column = pe.context[1]
260             lines = src_txt.splitlines()
261             try:
262                 faulty_line = lines[lineno - 1]
263             except IndexError:
264                 faulty_line = "<line number missing in source>"
265             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
266     else:
267         raise exc from None
268
269     if isinstance(result, Leaf):
270         result = Node(syms.file_input, [result])
271     return result
272
273
274 def lib2to3_unparse(node: Node) -> str:
275     """Given a lib2to3 node, return its string representation."""
276     code = str(node)
277     return code
278
279
280 T = TypeVar('T')
281
282
283 class Visitor(Generic[T]):
284     """Basic lib2to3 visitor that yields things on visiting."""
285
286     def visit(self, node: LN) -> Iterator[T]:
287         if node.type < 256:
288             name = token.tok_name[node.type]
289         else:
290             name = type_repr(node.type)
291         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
292
293     def visit_default(self, node: LN) -> Iterator[T]:
294         if isinstance(node, Node):
295             for child in node.children:
296                 yield from self.visit(child)
297
298
299 @dataclass
300 class DebugVisitor(Visitor[T]):
301     tree_depth: int = 0
302
303     def visit_default(self, node: LN) -> Iterator[T]:
304         indent = ' ' * (2 * self.tree_depth)
305         if isinstance(node, Node):
306             _type = type_repr(node.type)
307             out(f'{indent}{_type}', fg='yellow')
308             self.tree_depth += 1
309             for child in node.children:
310                 yield from self.visit(child)
311
312             self.tree_depth -= 1
313             out(f'{indent}/{_type}', fg='yellow', bold=False)
314         else:
315             _type = token.tok_name.get(node.type, str(node.type))
316             out(f'{indent}{_type}', fg='blue', nl=False)
317             if node.prefix:
318                 # We don't have to handle prefixes for `Node` objects since
319                 # that delegates to the first child anyway.
320                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
321             out(f' {node.value!r}', fg='blue', bold=False)
322
323     @classmethod
324     def show(cls, code: str) -> None:
325         """Pretty-prints a given string of `code`.
326
327         Convenience method for debugging.
328         """
329         v: DebugVisitor[None] = DebugVisitor()
330         list(v.visit(lib2to3_parse(code)))
331
332
333 KEYWORDS = set(keyword.kwlist)
334 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
335 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
336 STATEMENT = {
337     syms.if_stmt,
338     syms.while_stmt,
339     syms.for_stmt,
340     syms.try_stmt,
341     syms.except_clause,
342     syms.with_stmt,
343     syms.funcdef,
344     syms.classdef,
345 }
346 STANDALONE_COMMENT = 153
347 LOGIC_OPERATORS = {'and', 'or'}
348 COMPARATORS = {
349     token.LESS,
350     token.GREATER,
351     token.EQEQUAL,
352     token.NOTEQUAL,
353     token.LESSEQUAL,
354     token.GREATEREQUAL,
355 }
356 MATH_OPERATORS = {
357     token.PLUS,
358     token.MINUS,
359     token.STAR,
360     token.SLASH,
361     token.VBAR,
362     token.AMPER,
363     token.PERCENT,
364     token.CIRCUMFLEX,
365     token.TILDE,
366     token.LEFTSHIFT,
367     token.RIGHTSHIFT,
368     token.DOUBLESTAR,
369     token.DOUBLESLASH,
370 }
371 COMPREHENSION_PRIORITY = 20
372 COMMA_PRIORITY = 10
373 LOGIC_PRIORITY = 5
374 STRING_PRIORITY = 4
375 COMPARATOR_PRIORITY = 3
376 MATH_PRIORITY = 1
377
378
379 @dataclass
380 class BracketTracker:
381     depth: int = 0
382     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
383     delimiters: Dict[LeafID, Priority] = Factory(dict)
384     previous: Optional[Leaf] = None
385
386     def mark(self, leaf: Leaf) -> None:
387         if leaf.type == token.COMMENT:
388             return
389
390         if leaf.type in CLOSING_BRACKETS:
391             self.depth -= 1
392             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
393             leaf.opening_bracket = opening_bracket
394         leaf.bracket_depth = self.depth
395         if self.depth == 0:
396             delim = is_delimiter(leaf)
397             if delim:
398                 self.delimiters[id(leaf)] = delim
399             elif self.previous is not None:
400                 if leaf.type == token.STRING and self.previous.type == token.STRING:
401                     self.delimiters[id(self.previous)] = STRING_PRIORITY
402                 elif (
403                     leaf.type == token.NAME
404                     and leaf.value == 'for'
405                     and leaf.parent
406                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
407                 ):
408                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
409                 elif (
410                     leaf.type == token.NAME
411                     and leaf.value == 'if'
412                     and leaf.parent
413                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
414                 ):
415                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
416                 elif (
417                     leaf.type == token.NAME
418                     and leaf.value in LOGIC_OPERATORS
419                     and leaf.parent
420                 ):
421                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
422         if leaf.type in OPENING_BRACKETS:
423             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
424             self.depth += 1
425         self.previous = leaf
426
427     def any_open_brackets(self) -> bool:
428         """Returns True if there is an yet unmatched open bracket on the line."""
429         return bool(self.bracket_match)
430
431     def max_priority(self, exclude: Iterable[LeafID] = ()) -> int:
432         """Returns the highest priority of a delimiter found on the line.
433
434         Values are consistent with what `is_delimiter()` returns.
435         """
436         return max(v for k, v in self.delimiters.items() if k not in exclude)
437
438
439 @dataclass
440 class Line:
441     depth: int = 0
442     leaves: List[Leaf] = Factory(list)
443     comments: Dict[LeafID, Leaf] = Factory(dict)
444     bracket_tracker: BracketTracker = Factory(BracketTracker)
445     inside_brackets: bool = False
446     has_for: bool = False
447     _for_loop_variable: bool = False
448
449     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
450         has_value = leaf.value.strip()
451         if not has_value:
452             return
453
454         if self.leaves and not preformatted:
455             # Note: at this point leaf.prefix should be empty except for
456             # imports, for which we only preserve newlines.
457             leaf.prefix += whitespace(leaf)
458         if self.inside_brackets or not preformatted:
459             self.maybe_decrement_after_for_loop_variable(leaf)
460             self.bracket_tracker.mark(leaf)
461             self.maybe_remove_trailing_comma(leaf)
462             self.maybe_increment_for_loop_variable(leaf)
463             if self.maybe_adapt_standalone_comment(leaf):
464                 return
465
466         if not self.append_comment(leaf):
467             self.leaves.append(leaf)
468
469     @property
470     def is_comment(self) -> bool:
471         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
472
473     @property
474     def is_decorator(self) -> bool:
475         return bool(self) and self.leaves[0].type == token.AT
476
477     @property
478     def is_import(self) -> bool:
479         return bool(self) and is_import(self.leaves[0])
480
481     @property
482     def is_class(self) -> bool:
483         return (
484             bool(self)
485             and self.leaves[0].type == token.NAME
486             and self.leaves[0].value == 'class'
487         )
488
489     @property
490     def is_def(self) -> bool:
491         """Also returns True for async defs."""
492         try:
493             first_leaf = self.leaves[0]
494         except IndexError:
495             return False
496
497         try:
498             second_leaf: Optional[Leaf] = self.leaves[1]
499         except IndexError:
500             second_leaf = None
501         return (
502             (first_leaf.type == token.NAME and first_leaf.value == 'def')
503             or (
504                 first_leaf.type == token.ASYNC
505                 and second_leaf is not None
506                 and second_leaf.type == token.NAME
507                 and second_leaf.value == 'def'
508             )
509         )
510
511     @property
512     def is_flow_control(self) -> bool:
513         return (
514             bool(self)
515             and self.leaves[0].type == token.NAME
516             and self.leaves[0].value in FLOW_CONTROL
517         )
518
519     @property
520     def is_yield(self) -> bool:
521         return (
522             bool(self)
523             and self.leaves[0].type == token.NAME
524             and self.leaves[0].value == 'yield'
525         )
526
527     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
528         if not (
529             self.leaves
530             and self.leaves[-1].type == token.COMMA
531             and closing.type in CLOSING_BRACKETS
532         ):
533             return False
534
535         if closing.type == token.RBRACE:
536             self.leaves.pop()
537             return True
538
539         if closing.type == token.RSQB:
540             comma = self.leaves[-1]
541             if comma.parent and comma.parent.type == syms.listmaker:
542                 self.leaves.pop()
543                 return True
544
545         # For parens let's check if it's safe to remove the comma.  If the
546         # trailing one is the only one, we might mistakenly change a tuple
547         # into a different type by removing the comma.
548         depth = closing.bracket_depth + 1
549         commas = 0
550         opening = closing.opening_bracket
551         for _opening_index, leaf in enumerate(self.leaves):
552             if leaf is opening:
553                 break
554
555         else:
556             return False
557
558         for leaf in self.leaves[_opening_index + 1:]:
559             if leaf is closing:
560                 break
561
562             bracket_depth = leaf.bracket_depth
563             if bracket_depth == depth and leaf.type == token.COMMA:
564                 commas += 1
565                 if leaf.parent and leaf.parent.type == syms.arglist:
566                     commas += 1
567                     break
568
569         if commas > 1:
570             self.leaves.pop()
571             return True
572
573         return False
574
575     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
576         """In a for loop, or comprehension, the variables are often unpacks.
577
578         To avoid splitting on the comma in this situation, we will increase
579         the depth of tokens between `for` and `in`.
580         """
581         if leaf.type == token.NAME and leaf.value == 'for':
582             self.has_for = True
583             self.bracket_tracker.depth += 1
584             self._for_loop_variable = True
585             return True
586
587         return False
588
589     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
590         # See `maybe_increment_for_loop_variable` above for explanation.
591         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
592             self.bracket_tracker.depth -= 1
593             self._for_loop_variable = False
594             return True
595
596         return False
597
598     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
599         """Hack a standalone comment to act as a trailing comment for line splitting.
600
601         If this line has brackets and a standalone `comment`, we need to adapt
602         it to be able to still reformat the line.
603
604         This is not perfect, the line to which the standalone comment gets
605         appended will appear "too long" when splitting.
606         """
607         if not (
608             comment.type == STANDALONE_COMMENT
609             and self.bracket_tracker.any_open_brackets()
610         ):
611             return False
612
613         comment.type = token.COMMENT
614         comment.prefix = '\n' + '    ' * (self.depth + 1)
615         return self.append_comment(comment)
616
617     def append_comment(self, comment: Leaf) -> bool:
618         if comment.type != token.COMMENT:
619             return False
620
621         try:
622             after = id(self.last_non_delimiter())
623         except LookupError:
624             comment.type = STANDALONE_COMMENT
625             comment.prefix = ''
626             return False
627
628         else:
629             if after in self.comments:
630                 self.comments[after].value += str(comment)
631             else:
632                 self.comments[after] = comment
633             return True
634
635     def last_non_delimiter(self) -> Leaf:
636         for i in range(len(self.leaves)):
637             last = self.leaves[-i - 1]
638             if not is_delimiter(last):
639                 return last
640
641         raise LookupError("No non-delimiters found")
642
643     def __str__(self) -> str:
644         if not self:
645             return '\n'
646
647         indent = '    ' * self.depth
648         leaves = iter(self.leaves)
649         first = next(leaves)
650         res = f'{first.prefix}{indent}{first.value}'
651         for leaf in leaves:
652             res += str(leaf)
653         for comment in self.comments.values():
654             res += str(comment)
655         return res + '\n'
656
657     def __bool__(self) -> bool:
658         return bool(self.leaves or self.comments)
659
660
661 @dataclass
662 class EmptyLineTracker:
663     """Provides a stateful method that returns the number of potential extra
664     empty lines needed before and after the currently processed line.
665
666     Note: this tracker works on lines that haven't been split yet.  It assumes
667     the prefix of the first leaf consists of optional newlines.  Those newlines
668     are consumed by `maybe_empty_lines()` and included in the computation.
669     """
670     previous_line: Optional[Line] = None
671     previous_after: int = 0
672     previous_defs: List[int] = Factory(list)
673
674     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
675         """Returns the number of extra empty lines before and after the `current_line`.
676
677         This is for separating `def`, `async def` and `class` with extra empty lines
678         (two on module-level), as well as providing an extra empty line after flow
679         control keywords to make them more prominent.
680         """
681         before, after = self._maybe_empty_lines(current_line)
682         before -= self.previous_after
683         self.previous_after = after
684         self.previous_line = current_line
685         return before, after
686
687     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
688         max_allowed = 1
689         if current_line.is_comment and current_line.depth == 0:
690             max_allowed = 2
691         if current_line.leaves:
692             # Consume the first leaf's extra newlines.
693             first_leaf = current_line.leaves[0]
694             before = first_leaf.prefix.count('\n')
695             before = min(before, max(before, max_allowed))
696             first_leaf.prefix = ''
697         else:
698             before = 0
699         depth = current_line.depth
700         while self.previous_defs and self.previous_defs[-1] >= depth:
701             self.previous_defs.pop()
702             before = 1 if depth else 2
703         is_decorator = current_line.is_decorator
704         if is_decorator or current_line.is_def or current_line.is_class:
705             if not is_decorator:
706                 self.previous_defs.append(depth)
707             if self.previous_line is None:
708                 # Don't insert empty lines before the first line in the file.
709                 return 0, 0
710
711             if self.previous_line and self.previous_line.is_decorator:
712                 # Don't insert empty lines between decorators.
713                 return 0, 0
714
715             newlines = 2
716             if current_line.depth:
717                 newlines -= 1
718             return newlines, 0
719
720         if current_line.is_flow_control:
721             return before, 1
722
723         if (
724             self.previous_line
725             and self.previous_line.is_import
726             and not current_line.is_import
727             and depth == self.previous_line.depth
728         ):
729             return (before or 1), 0
730
731         if (
732             self.previous_line
733             and self.previous_line.is_yield
734             and (not current_line.is_yield or depth != self.previous_line.depth)
735         ):
736             return (before or 1), 0
737
738         return before, 0
739
740
741 @dataclass
742 class LineGenerator(Visitor[Line]):
743     """Generates reformatted Line objects.  Empty lines are not emitted.
744
745     Note: destroys the tree it's visiting by mutating prefixes of its leaves
746     in ways that will no longer stringify to valid Python code on the tree.
747     """
748     current_line: Line = Factory(Line)
749
750     def line(self, indent: int = 0) -> Iterator[Line]:
751         """Generate a line.
752
753         If the line is empty, only emit if it makes sense.
754         If the line is too long, split it first and then generate.
755
756         If any lines were generated, set up a new current_line.
757         """
758         if not self.current_line:
759             self.current_line.depth += indent
760             return  # Line is empty, don't emit. Creating a new one unnecessary.
761
762         complete_line = self.current_line
763         self.current_line = Line(depth=complete_line.depth + indent)
764         yield complete_line
765
766     def visit_default(self, node: LN) -> Iterator[Line]:
767         if isinstance(node, Leaf):
768             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
769             for comment in generate_comments(node):
770                 if any_open_brackets:
771                     # any comment within brackets is subject to splitting
772                     self.current_line.append(comment)
773                 elif comment.type == token.COMMENT:
774                     # regular trailing comment
775                     self.current_line.append(comment)
776                     yield from self.line()
777
778                 else:
779                     # regular standalone comment
780                     yield from self.line()
781
782                     self.current_line.append(comment)
783                     yield from self.line()
784
785             normalize_prefix(node, inside_brackets=any_open_brackets)
786             if node.type not in WHITESPACE:
787                 self.current_line.append(node)
788         yield from super().visit_default(node)
789
790     def visit_INDENT(self, node: Node) -> Iterator[Line]:
791         yield from self.line(+1)
792         yield from self.visit_default(node)
793
794     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
795         yield from self.line(-1)
796
797     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
798         """Visit a statement.
799
800         The relevant Python language keywords for this statement are NAME leaves
801         within it.
802         """
803         for child in node.children:
804             if child.type == token.NAME and child.value in keywords:  # type: ignore
805                 yield from self.line()
806
807             yield from self.visit(child)
808
809     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
810         """A statement without nested statements."""
811         is_suite_like = node.parent and node.parent.type in STATEMENT
812         if is_suite_like:
813             yield from self.line(+1)
814             yield from self.visit_default(node)
815             yield from self.line(-1)
816
817         else:
818             yield from self.line()
819             yield from self.visit_default(node)
820
821     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
822         yield from self.line()
823
824         children = iter(node.children)
825         for child in children:
826             yield from self.visit(child)
827
828             if child.type == token.ASYNC:
829                 break
830
831         internal_stmt = next(children)
832         for child in internal_stmt.children:
833             yield from self.visit(child)
834
835     def visit_decorators(self, node: Node) -> Iterator[Line]:
836         for child in node.children:
837             yield from self.line()
838             yield from self.visit(child)
839
840     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
841         yield from self.line()
842
843     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
844         yield from self.visit_default(leaf)
845         yield from self.line()
846
847     def __attrs_post_init__(self) -> None:
848         """You are in a twisty little maze of passages."""
849         v = self.visit_stmt
850         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
851         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
852         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
853         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
854         self.visit_except_clause = partial(v, keywords={'except'})
855         self.visit_funcdef = partial(v, keywords={'def'})
856         self.visit_with_stmt = partial(v, keywords={'with'})
857         self.visit_classdef = partial(v, keywords={'class'})
858         self.visit_async_funcdef = self.visit_async_stmt
859         self.visit_decorated = self.visit_decorators
860
861
862 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
863 OPENING_BRACKETS = set(BRACKET.keys())
864 CLOSING_BRACKETS = set(BRACKET.values())
865 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
866 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
867
868
869 def whitespace(leaf: Leaf) -> str:  # noqa C901
870     """Return whitespace prefix if needed for the given `leaf`."""
871     NO = ''
872     SPACE = ' '
873     DOUBLESPACE = '  '
874     t = leaf.type
875     p = leaf.parent
876     v = leaf.value
877     if t in ALWAYS_NO_SPACE:
878         return NO
879
880     if t == token.COMMENT:
881         return DOUBLESPACE
882
883     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
884     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
885         return NO
886
887     prev = leaf.prev_sibling
888     if not prev:
889         prevp = preceding_leaf(p)
890         if not prevp or prevp.type in OPENING_BRACKETS:
891             return NO
892
893         if t == token.COLON:
894             return SPACE if prevp.type == token.COMMA else NO
895
896         if prevp.type == token.EQUAL:
897             if prevp.parent:
898                 if prevp.parent.type in {
899                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
900                 }:
901                     return NO
902
903                 elif prevp.parent.type == syms.typedargslist:
904                     # A bit hacky: if the equal sign has whitespace, it means we
905                     # previously found it's a typed argument.  So, we're using
906                     # that, too.
907                     return prevp.prefix
908
909         elif prevp.type == token.DOUBLESTAR:
910             if prevp.parent and prevp.parent.type in {
911                 syms.arglist,
912                 syms.argument,
913                 syms.dictsetmaker,
914                 syms.parameters,
915                 syms.typedargslist,
916                 syms.varargslist,
917             }:
918                 return NO
919
920         elif prevp.type == token.COLON:
921             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
922                 return NO
923
924         elif (
925             prevp.parent
926             and prevp.parent.type in {syms.factor, syms.star_expr}
927             and prevp.type in MATH_OPERATORS
928         ):
929             return NO
930
931         elif (
932             prevp.type == token.RIGHTSHIFT
933             and prevp.parent
934             and prevp.parent.type == syms.shift_expr
935             and prevp.prev_sibling
936             and prevp.prev_sibling.type == token.NAME
937             and prevp.prev_sibling.value == 'print'  # type: ignore
938         ):
939             # Python 2 print chevron
940             return NO
941
942     elif prev.type in OPENING_BRACKETS:
943         return NO
944
945     if p.type in {syms.parameters, syms.arglist}:
946         # untyped function signatures or calls
947         if t == token.RPAR:
948             return NO
949
950         if not prev or prev.type != token.COMMA:
951             return NO
952
953     elif p.type == syms.varargslist:
954         # lambdas
955         if t == token.RPAR:
956             return NO
957
958         if prev and prev.type != token.COMMA:
959             return NO
960
961     elif p.type == syms.typedargslist:
962         # typed function signatures
963         if not prev:
964             return NO
965
966         if t == token.EQUAL:
967             if prev.type != syms.tname:
968                 return NO
969
970         elif prev.type == token.EQUAL:
971             # A bit hacky: if the equal sign has whitespace, it means we
972             # previously found it's a typed argument.  So, we're using that, too.
973             return prev.prefix
974
975         elif prev.type != token.COMMA:
976             return NO
977
978     elif p.type == syms.tname:
979         # type names
980         if not prev:
981             prevp = preceding_leaf(p)
982             if not prevp or prevp.type != token.COMMA:
983                 return NO
984
985     elif p.type == syms.trailer:
986         # attributes and calls
987         if t == token.LPAR or t == token.RPAR:
988             return NO
989
990         if not prev:
991             if t == token.DOT:
992                 prevp = preceding_leaf(p)
993                 if not prevp or prevp.type != token.NUMBER:
994                     return NO
995
996             elif t == token.LSQB:
997                 return NO
998
999         elif prev.type != token.COMMA:
1000             return NO
1001
1002     elif p.type == syms.argument:
1003         # single argument
1004         if t == token.EQUAL:
1005             return NO
1006
1007         if not prev:
1008             prevp = preceding_leaf(p)
1009             if not prevp or prevp.type == token.LPAR:
1010                 return NO
1011
1012         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1013             return NO
1014
1015     elif p.type == syms.decorator:
1016         # decorators
1017         return NO
1018
1019     elif p.type == syms.dotted_name:
1020         if prev:
1021             return NO
1022
1023         prevp = preceding_leaf(p)
1024         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1025             return NO
1026
1027     elif p.type == syms.classdef:
1028         if t == token.LPAR:
1029             return NO
1030
1031         if prev and prev.type == token.LPAR:
1032             return NO
1033
1034     elif p.type == syms.subscript:
1035         # indexing
1036         if not prev:
1037             assert p.parent is not None, "subscripts are always parented"
1038             if p.parent.type == syms.subscriptlist:
1039                 return SPACE
1040
1041             return NO
1042
1043         else:
1044             return NO
1045
1046     elif p.type == syms.atom:
1047         if prev and t == token.DOT:
1048             # dots, but not the first one.
1049             return NO
1050
1051     elif (
1052         p.type == syms.listmaker
1053         or p.type == syms.testlist_gexp
1054         or p.type == syms.subscriptlist
1055     ):
1056         # list interior, including unpacking
1057         if not prev:
1058             return NO
1059
1060     elif p.type == syms.dictsetmaker:
1061         # dict and set interior, including unpacking
1062         if not prev:
1063             return NO
1064
1065         if prev.type == token.DOUBLESTAR:
1066             return NO
1067
1068     elif p.type in {syms.factor, syms.star_expr}:
1069         # unary ops
1070         if not prev:
1071             prevp = preceding_leaf(p)
1072             if not prevp or prevp.type in OPENING_BRACKETS:
1073                 return NO
1074
1075             prevp_parent = prevp.parent
1076             assert prevp_parent is not None
1077             if prevp.type == token.COLON and prevp_parent.type in {
1078                 syms.subscript, syms.sliceop
1079             }:
1080                 return NO
1081
1082             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1083                 return NO
1084
1085         elif t == token.NAME or t == token.NUMBER:
1086             return NO
1087
1088     elif p.type == syms.import_from:
1089         if t == token.DOT:
1090             if prev and prev.type == token.DOT:
1091                 return NO
1092
1093         elif t == token.NAME:
1094             if v == 'import':
1095                 return SPACE
1096
1097             if prev and prev.type == token.DOT:
1098                 return NO
1099
1100     elif p.type == syms.sliceop:
1101         return NO
1102
1103     return SPACE
1104
1105
1106 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1107     """Returns the first leaf that precedes `node`, if any."""
1108     while node:
1109         res = node.prev_sibling
1110         if res:
1111             if isinstance(res, Leaf):
1112                 return res
1113
1114             try:
1115                 return list(res.leaves())[-1]
1116
1117             except IndexError:
1118                 return None
1119
1120         node = node.parent
1121     return None
1122
1123
1124 def is_delimiter(leaf: Leaf) -> int:
1125     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1126
1127     Higher numbers are higher priority.
1128     """
1129     if leaf.type == token.COMMA:
1130         return COMMA_PRIORITY
1131
1132     if leaf.type in COMPARATORS:
1133         return COMPARATOR_PRIORITY
1134
1135     if (
1136         leaf.type in MATH_OPERATORS
1137         and leaf.parent
1138         and leaf.parent.type not in {syms.factor, syms.star_expr}
1139     ):
1140         return MATH_PRIORITY
1141
1142     return 0
1143
1144
1145 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1146     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1147
1148     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1149     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1150     move because it does away with modifying the grammar to include all the
1151     possible places in which comments can be placed.
1152
1153     The sad consequence for us though is that comments don't "belong" anywhere.
1154     This is why this function generates simple parentless Leaf objects for
1155     comments.  We simply don't know what the correct parent should be.
1156
1157     No matter though, we can live without this.  We really only need to
1158     differentiate between inline and standalone comments.  The latter don't
1159     share the line with any code.
1160
1161     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1162     are emitted with a fake STANDALONE_COMMENT token identifier.
1163     """
1164     p = leaf.prefix
1165     if not p:
1166         return
1167
1168     if '#' not in p:
1169         return
1170
1171     nlines = 0
1172     for index, line in enumerate(p.split('\n')):
1173         line = line.lstrip()
1174         if not line:
1175             nlines += 1
1176         if not line.startswith('#'):
1177             continue
1178
1179         if index == 0 and leaf.type != token.ENDMARKER:
1180             comment_type = token.COMMENT  # simple trailing comment
1181         else:
1182             comment_type = STANDALONE_COMMENT
1183         yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
1184
1185         nlines = 0
1186
1187
1188 def make_comment(content: str) -> str:
1189     content = content.rstrip()
1190     if not content:
1191         return '#'
1192
1193     if content[0] == '#':
1194         content = content[1:]
1195     if content and content[0] not in ' !:#':
1196         content = ' ' + content
1197     return '#' + content
1198
1199
1200 def split_line(
1201     line: Line, line_length: int, inner: bool = False, py36: bool = False
1202 ) -> Iterator[Line]:
1203     """Splits a `line` into potentially many lines.
1204
1205     They should fit in the allotted `line_length` but might not be able to.
1206     `inner` signifies that there were a pair of brackets somewhere around the
1207     current `line`, possibly transitively. This means we can fallback to splitting
1208     by delimiters if the LHS/RHS don't yield any results.
1209
1210     If `py36` is True, splitting may generate syntax that is only compatible
1211     with Python 3.6 and later.
1212     """
1213     line_str = str(line).strip('\n')
1214     if len(line_str) <= line_length and '\n' not in line_str:
1215         yield line
1216         return
1217
1218     if line.is_def:
1219         split_funcs = [left_hand_split]
1220     elif line.inside_brackets:
1221         split_funcs = [delimiter_split]
1222         if '\n' not in line_str:
1223             # Only attempt RHS if we don't have multiline strings or comments
1224             # on this line.
1225             split_funcs.append(right_hand_split)
1226     else:
1227         split_funcs = [right_hand_split]
1228     for split_func in split_funcs:
1229         # We are accumulating lines in `result` because we might want to abort
1230         # mission and return the original line in the end, or attempt a different
1231         # split altogether.
1232         result: List[Line] = []
1233         try:
1234             for l in split_func(line, py36=py36):
1235                 if str(l).strip('\n') == line_str:
1236                     raise CannotSplit("Split function returned an unchanged result")
1237
1238                 result.extend(
1239                     split_line(l, line_length=line_length, inner=True, py36=py36)
1240                 )
1241         except CannotSplit as cs:
1242             continue
1243
1244         else:
1245             yield from result
1246             break
1247
1248     else:
1249         yield line
1250
1251
1252 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1253     """Split line into many lines, starting with the first matching bracket pair.
1254
1255     Note: this usually looks weird, only use this for function definitions.
1256     Prefer RHS otherwise.
1257     """
1258     head = Line(depth=line.depth)
1259     body = Line(depth=line.depth + 1, inside_brackets=True)
1260     tail = Line(depth=line.depth)
1261     tail_leaves: List[Leaf] = []
1262     body_leaves: List[Leaf] = []
1263     head_leaves: List[Leaf] = []
1264     current_leaves = head_leaves
1265     matching_bracket = None
1266     for leaf in line.leaves:
1267         if (
1268             current_leaves is body_leaves
1269             and leaf.type in CLOSING_BRACKETS
1270             and leaf.opening_bracket is matching_bracket
1271         ):
1272             current_leaves = tail_leaves if body_leaves else head_leaves
1273         current_leaves.append(leaf)
1274         if current_leaves is head_leaves:
1275             if leaf.type in OPENING_BRACKETS:
1276                 matching_bracket = leaf
1277                 current_leaves = body_leaves
1278     # Since body is a new indent level, remove spurious leading whitespace.
1279     if body_leaves:
1280         normalize_prefix(body_leaves[0], inside_brackets=True)
1281     # Build the new lines.
1282     for result, leaves in (
1283         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1284     ):
1285         for leaf in leaves:
1286             result.append(leaf, preformatted=True)
1287             comment_after = line.comments.get(id(leaf))
1288             if comment_after:
1289                 result.append(comment_after, preformatted=True)
1290     split_succeeded_or_raise(head, body, tail)
1291     for result in (head, body, tail):
1292         if result:
1293             yield result
1294
1295
1296 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1297     """Split line into many lines, starting with the last matching bracket pair."""
1298     head = Line(depth=line.depth)
1299     body = Line(depth=line.depth + 1, inside_brackets=True)
1300     tail = Line(depth=line.depth)
1301     tail_leaves: List[Leaf] = []
1302     body_leaves: List[Leaf] = []
1303     head_leaves: List[Leaf] = []
1304     current_leaves = tail_leaves
1305     opening_bracket = None
1306     for leaf in reversed(line.leaves):
1307         if current_leaves is body_leaves:
1308             if leaf is opening_bracket:
1309                 current_leaves = head_leaves if body_leaves else tail_leaves
1310         current_leaves.append(leaf)
1311         if current_leaves is tail_leaves:
1312             if leaf.type in CLOSING_BRACKETS:
1313                 opening_bracket = leaf.opening_bracket
1314                 current_leaves = body_leaves
1315     tail_leaves.reverse()
1316     body_leaves.reverse()
1317     head_leaves.reverse()
1318     # Since body is a new indent level, remove spurious leading whitespace.
1319     if body_leaves:
1320         normalize_prefix(body_leaves[0], inside_brackets=True)
1321     # Build the new lines.
1322     for result, leaves in (
1323         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1324     ):
1325         for leaf in leaves:
1326             result.append(leaf, preformatted=True)
1327             comment_after = line.comments.get(id(leaf))
1328             if comment_after:
1329                 result.append(comment_after, preformatted=True)
1330     split_succeeded_or_raise(head, body, tail)
1331     for result in (head, body, tail):
1332         if result:
1333             yield result
1334
1335
1336 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1337     tail_len = len(str(tail).strip())
1338     if not body:
1339         if tail_len == 0:
1340             raise CannotSplit("Splitting brackets produced the same line")
1341
1342         elif tail_len < 3:
1343             raise CannotSplit(
1344                 f"Splitting brackets on an empty body to save "
1345                 f"{tail_len} characters is not worth it"
1346             )
1347
1348
1349 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1350     """Split according to delimiters of the highest priority.
1351
1352     This kind of split doesn't increase indentation.
1353     If `py36` is True, the split will add trailing commas also in function
1354     signatures that contain * and **.
1355     """
1356     try:
1357         last_leaf = line.leaves[-1]
1358     except IndexError:
1359         raise CannotSplit("Line empty")
1360
1361     delimiters = line.bracket_tracker.delimiters
1362     try:
1363         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1364     except ValueError:
1365         raise CannotSplit("No delimiters found")
1366
1367     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1368     lowest_depth = sys.maxsize
1369     trailing_comma_safe = True
1370     for leaf in line.leaves:
1371         current_line.append(leaf, preformatted=True)
1372         comment_after = line.comments.get(id(leaf))
1373         if comment_after:
1374             current_line.append(comment_after, preformatted=True)
1375         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1376         if (
1377             leaf.bracket_depth == lowest_depth
1378             and leaf.type == token.STAR
1379             or leaf.type == token.DOUBLESTAR
1380         ):
1381             trailing_comma_safe = trailing_comma_safe and py36
1382         leaf_priority = delimiters.get(id(leaf))
1383         if leaf_priority == delimiter_priority:
1384             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1385             yield current_line
1386
1387             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1388     if current_line:
1389         if (
1390             delimiter_priority == COMMA_PRIORITY
1391             and current_line.leaves[-1].type != token.COMMA
1392             and trailing_comma_safe
1393         ):
1394             current_line.append(Leaf(token.COMMA, ','))
1395         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1396         yield current_line
1397
1398
1399 def is_import(leaf: Leaf) -> bool:
1400     """Returns True if the given leaf starts an import statement."""
1401     p = leaf.parent
1402     t = leaf.type
1403     v = leaf.value
1404     return bool(
1405         t == token.NAME
1406         and (
1407             (v == 'import' and p and p.type == syms.import_name)
1408             or (v == 'from' and p and p.type == syms.import_from)
1409         )
1410     )
1411
1412
1413 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1414     """Leave existing extra newlines if not `inside_brackets`.
1415
1416     Remove everything else.  Note: don't use backslashes for formatting or
1417     you'll lose your voting rights.
1418     """
1419     if not inside_brackets:
1420         spl = leaf.prefix.split('#')
1421         if '\\' not in spl[0]:
1422             nl_count = spl[-1].count('\n')
1423             if len(spl) > 1:
1424                 nl_count -= 1
1425             leaf.prefix = '\n' * nl_count
1426             return
1427
1428     leaf.prefix = ''
1429
1430
1431 def is_python36(node: Node) -> bool:
1432     """Returns True if the current file is using Python 3.6+ features.
1433
1434     Currently looking for:
1435     - f-strings; and
1436     - trailing commas after * or ** in function signatures.
1437     """
1438     for n in node.pre_order():
1439         if n.type == token.STRING:
1440             value_head = n.value[:2]  # type: ignore
1441             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1442                 return True
1443
1444         elif (
1445             n.type == syms.typedargslist
1446             and n.children
1447             and n.children[-1].type == token.COMMA
1448         ):
1449             for ch in n.children:
1450                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1451                     return True
1452
1453     return False
1454
1455
1456 PYTHON_EXTENSIONS = {'.py'}
1457 BLACKLISTED_DIRECTORIES = {
1458     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1459 }
1460
1461
1462 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1463     for child in path.iterdir():
1464         if child.is_dir():
1465             if child.name in BLACKLISTED_DIRECTORIES:
1466                 continue
1467
1468             yield from gen_python_files_in_dir(child)
1469
1470         elif child.suffix in PYTHON_EXTENSIONS:
1471             yield child
1472
1473
1474 @dataclass
1475 class Report:
1476     """Provides a reformatting counter."""
1477     check: bool = False
1478     change_count: int = 0
1479     same_count: int = 0
1480     failure_count: int = 0
1481
1482     def done(self, src: Path, changed: bool) -> None:
1483         """Increment the counter for successful reformatting. Write out a message."""
1484         if changed:
1485             reformatted = 'would reformat' if self.check else 'reformatted'
1486             out(f'{reformatted} {src}')
1487             self.change_count += 1
1488         else:
1489             out(f'{src} already well formatted, good job.', bold=False)
1490             self.same_count += 1
1491
1492     def failed(self, src: Path, message: str) -> None:
1493         """Increment the counter for failed reformatting. Write out a message."""
1494         err(f'error: cannot format {src}: {message}')
1495         self.failure_count += 1
1496
1497     @property
1498     def return_code(self) -> int:
1499         """Which return code should the app use considering the current state."""
1500         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1501         # 126 we have special returncodes reserved by the shell.
1502         if self.failure_count:
1503             return 123
1504
1505         elif self.change_count and self.check:
1506             return 1
1507
1508         return 0
1509
1510     def __str__(self) -> str:
1511         """A color report of the current state.
1512
1513         Use `click.unstyle` to remove colors.
1514         """
1515         if self.check:
1516             reformatted = "would be reformatted"
1517             unchanged = "would be left unchanged"
1518             failed = "would fail to reformat"
1519         else:
1520             reformatted = "reformatted"
1521             unchanged = "left unchanged"
1522             failed = "failed to reformat"
1523         report = []
1524         if self.change_count:
1525             s = 's' if self.change_count > 1 else ''
1526             report.append(
1527                 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1528             )
1529         if self.same_count:
1530             s = 's' if self.same_count > 1 else ''
1531             report.append(f'{self.same_count} file{s} {unchanged}')
1532         if self.failure_count:
1533             s = 's' if self.failure_count > 1 else ''
1534             report.append(
1535                 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1536             )
1537         return ', '.join(report) + '.'
1538
1539
1540 def assert_equivalent(src: str, dst: str) -> None:
1541     """Raises AssertionError if `src` and `dst` aren't equivalent.
1542
1543     This is a temporary sanity check until Black becomes stable.
1544     """
1545
1546     import ast
1547     import traceback
1548
1549     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1550         """Simple visitor generating strings to compare ASTs by content."""
1551         yield f"{'  ' * depth}{node.__class__.__name__}("
1552
1553         for field in sorted(node._fields):
1554             try:
1555                 value = getattr(node, field)
1556             except AttributeError:
1557                 continue
1558
1559             yield f"{'  ' * (depth+1)}{field}="
1560
1561             if isinstance(value, list):
1562                 for item in value:
1563                     if isinstance(item, ast.AST):
1564                         yield from _v(item, depth + 2)
1565
1566             elif isinstance(value, ast.AST):
1567                 yield from _v(value, depth + 2)
1568
1569             else:
1570                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1571
1572         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1573
1574     try:
1575         src_ast = ast.parse(src)
1576     except Exception as exc:
1577         major, minor = sys.version_info[:2]
1578         raise AssertionError(
1579             f"cannot use --safe with this file; failed to parse source file "
1580             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1581             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1582         )
1583
1584     try:
1585         dst_ast = ast.parse(dst)
1586     except Exception as exc:
1587         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1588         raise AssertionError(
1589             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1590             f"Please report a bug on https://github.com/ambv/black/issues.  "
1591             f"This invalid output might be helpful: {log}"
1592         ) from None
1593
1594     src_ast_str = '\n'.join(_v(src_ast))
1595     dst_ast_str = '\n'.join(_v(dst_ast))
1596     if src_ast_str != dst_ast_str:
1597         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1598         raise AssertionError(
1599             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1600             f"the source.  "
1601             f"Please report a bug on https://github.com/ambv/black/issues.  "
1602             f"This diff might be helpful: {log}"
1603         ) from None
1604
1605
1606 def assert_stable(src: str, dst: str, line_length: int) -> None:
1607     """Raises AssertionError if `dst` reformats differently the second time.
1608
1609     This is a temporary sanity check until Black becomes stable.
1610     """
1611     newdst = format_str(dst, line_length=line_length)
1612     if dst != newdst:
1613         log = dump_to_file(
1614             diff(src, dst, 'source', 'first pass'),
1615             diff(dst, newdst, 'first pass', 'second pass'),
1616         )
1617         raise AssertionError(
1618             f"INTERNAL ERROR: Black produced different code on the second pass "
1619             f"of the formatter.  "
1620             f"Please report a bug on https://github.com/ambv/black/issues.  "
1621             f"This diff might be helpful: {log}"
1622         ) from None
1623
1624
1625 def dump_to_file(*output: str) -> str:
1626     """Dumps `output` to a temporary file. Returns path to the file."""
1627     import tempfile
1628
1629     with tempfile.NamedTemporaryFile(
1630         mode='w', prefix='blk_', suffix='.log', delete=False
1631     ) as f:
1632         for lines in output:
1633             f.write(lines)
1634             f.write('\n')
1635     return f.name
1636
1637
1638 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1639     """Returns a udiff string between strings `a` and `b`."""
1640     import difflib
1641
1642     a_lines = [line + '\n' for line in a.split('\n')]
1643     b_lines = [line + '\n' for line in b.split('\n')]
1644     return ''.join(
1645         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1646     )
1647
1648
1649 if __name__ == '__main__':
1650     main()