]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

24c57ca4aee1d54ae862e8e879a4d27f9577f029
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 from typing import (
11     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
12 )
13
14 from attr import attrib, dataclass, Factory
15 import click
16
17 # lib2to3 fork
18 from blib2to3.pytree import Node, Leaf, type_repr
19 from blib2to3 import pygram, pytree
20 from blib2to3.pgen2 import driver, token
21 from blib2to3.pgen2.parse import ParseError
22
23 __version__ = "18.3a0"
24 DEFAULT_LINE_LENGTH = 88
25 # types
26 syms = pygram.python_symbols
27 FileContent = str
28 Encoding = str
29 Depth = int
30 NodeType = int
31 LeafID = int
32 Priority = int
33 LN = Union[Leaf, Node]
34 out = partial(click.secho, bold=True, err=True)
35 err = partial(click.secho, fg='red', err=True)
36
37
38 class NothingChanged(UserWarning):
39     """Raised by `format_file` when the reformatted code is the same as source."""
40
41
42 class CannotSplit(Exception):
43     """A readable split that fits the allotted line length is impossible.
44
45     Raised by `left_hand_split()` and `right_hand_split()`.
46     """
47
48
49 @click.command()
50 @click.option(
51     '-l',
52     '--line-length',
53     type=int,
54     default=DEFAULT_LINE_LENGTH,
55     help='How many character per line to allow.',
56     show_default=True,
57 )
58 @click.option(
59     '--fast/--safe',
60     is_flag=True,
61     help='If --fast given, skip temporary sanity checks. [default: --safe]',
62 )
63 @click.version_option(version=__version__)
64 @click.argument(
65     'src',
66     nargs=-1,
67     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
68 )
69 @click.pass_context
70 def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> None:
71     """The uncompromising code formatter."""
72     sources: List[Path] = []
73     for s in src:
74         p = Path(s)
75         if p.is_dir():
76             sources.extend(gen_python_files_in_dir(p))
77         elif p.is_file():
78             # if a file was explicitly given, we don't care about its extension
79             sources.append(p)
80         else:
81             err(f'invalid path: {s}')
82     if len(sources) == 0:
83         ctx.exit(0)
84     elif len(sources) == 1:
85         p = sources[0]
86         report = Report()
87         try:
88             changed = format_file_in_place(p, line_length=line_length, fast=fast)
89             report.done(p, changed)
90         except Exception as exc:
91             report.failed(p, str(exc))
92         ctx.exit(report.return_code)
93     else:
94         loop = asyncio.get_event_loop()
95         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
96         return_code = 1
97         try:
98             return_code = loop.run_until_complete(
99                 schedule_formatting(sources, line_length, fast, loop, executor)
100             )
101         finally:
102             loop.close()
103             ctx.exit(return_code)
104
105
106 async def schedule_formatting(
107     sources: List[Path],
108     line_length: int,
109     fast: bool,
110     loop: BaseEventLoop,
111     executor: Executor,
112 ) -> int:
113     tasks = {
114         src: loop.run_in_executor(
115             executor, format_file_in_place, src, line_length, fast
116         )
117         for src in sources
118     }
119     await asyncio.wait(tasks.values())
120     cancelled = []
121     report = Report()
122     for src, task in tasks.items():
123         if not task.done():
124             report.failed(src, 'timed out, cancelling')
125             task.cancel()
126             cancelled.append(task)
127         elif task.exception():
128             report.failed(src, str(task.exception()))
129         else:
130             report.done(src, task.result())
131     if cancelled:
132         await asyncio.wait(cancelled, timeout=2)
133     out('All done! ✨ 🍰 ✨')
134     click.echo(str(report))
135     return report.return_code
136
137
138 def format_file_in_place(src: Path, line_length: int, fast: bool) -> bool:
139     """Format the file and rewrite if changed. Return True if changed."""
140     try:
141         contents, encoding = format_file(src, line_length=line_length, fast=fast)
142     except NothingChanged:
143         return False
144
145     with open(src, "w", encoding=encoding) as f:
146         f.write(contents)
147     return True
148
149
150 def format_file(
151     src: Path, line_length: int, fast: bool
152 ) -> Tuple[FileContent, Encoding]:
153     """Reformats a file and returns its contents and encoding."""
154     with tokenize.open(src) as src_buffer:
155         src_contents = src_buffer.read()
156     if src_contents.strip() == '':
157         raise NothingChanged(src)
158
159     dst_contents = format_str(src_contents, line_length=line_length)
160     if src_contents == dst_contents:
161         raise NothingChanged(src)
162
163     if not fast:
164         assert_equivalent(src_contents, dst_contents)
165         assert_stable(src_contents, dst_contents, line_length=line_length)
166     return dst_contents, src_buffer.encoding
167
168
169 def format_str(src_contents: str, line_length: int) -> FileContent:
170     """Reformats a string and returns new contents."""
171     src_node = lib2to3_parse(src_contents)
172     dst_contents = ""
173     comments: List[Line] = []
174     lines = LineGenerator()
175     elt = EmptyLineTracker()
176     empty_line = Line()
177     after = 0
178     for current_line in lines.visit(src_node):
179         for _ in range(after):
180             dst_contents += str(empty_line)
181         before, after = elt.maybe_empty_lines(current_line)
182         for _ in range(before):
183             dst_contents += str(empty_line)
184         if not current_line.is_comment:
185             for comment in comments:
186                 dst_contents += str(comment)
187             comments = []
188             for line in split_line(current_line, line_length=line_length):
189                 dst_contents += str(line)
190         else:
191             comments.append(current_line)
192     for comment in comments:
193         dst_contents += str(comment)
194     return dst_contents
195
196
197 def lib2to3_parse(src_txt: str) -> Node:
198     """Given a string with source, return the lib2to3 Node."""
199     grammar = pygram.python_grammar_no_print_statement
200     drv = driver.Driver(grammar, pytree.convert)
201     if src_txt[-1] != '\n':
202         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
203         src_txt += nl
204     try:
205         result = drv.parse_string(src_txt, True)
206     except ParseError as pe:
207         lineno, column = pe.context[1]
208         lines = src_txt.splitlines()
209         try:
210             faulty_line = lines[lineno - 1]
211         except IndexError:
212             faulty_line = "<line number missing in source>"
213         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
214
215     if isinstance(result, Leaf):
216         result = Node(syms.file_input, [result])
217     return result
218
219
220 def lib2to3_unparse(node: Node) -> str:
221     """Given a lib2to3 node, return its string representation."""
222     code = str(node)
223     return code
224
225
226 T = TypeVar('T')
227
228
229 class Visitor(Generic[T]):
230     """Basic lib2to3 visitor that yields things on visiting."""
231
232     def visit(self, node: LN) -> Iterator[T]:
233         if node.type < 256:
234             name = token.tok_name[node.type]
235         else:
236             name = type_repr(node.type)
237         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
238
239     def visit_default(self, node: LN) -> Iterator[T]:
240         if isinstance(node, Node):
241             for child in node.children:
242                 yield from self.visit(child)
243
244
245 @dataclass
246 class DebugVisitor(Visitor[T]):
247     tree_depth: int = attrib(default=0)
248
249     def visit_default(self, node: LN) -> Iterator[T]:
250         indent = ' ' * (2 * self.tree_depth)
251         if isinstance(node, Node):
252             _type = type_repr(node.type)
253             out(f'{indent}{_type}', fg='yellow')
254             self.tree_depth += 1
255             for child in node.children:
256                 yield from self.visit(child)
257
258             self.tree_depth -= 1
259             out(f'{indent}/{_type}', fg='yellow', bold=False)
260         else:
261             _type = token.tok_name.get(node.type, str(node.type))
262             out(f'{indent}{_type}', fg='blue', nl=False)
263             if node.prefix:
264                 # We don't have to handle prefixes for `Node` objects since
265                 # that delegates to the first child anyway.
266                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
267             out(f' {node.value!r}', fg='blue', bold=False)
268
269
270 KEYWORDS = set(keyword.kwlist)
271 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
272 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
273 STATEMENT = {
274     syms.if_stmt,
275     syms.while_stmt,
276     syms.for_stmt,
277     syms.try_stmt,
278     syms.except_clause,
279     syms.with_stmt,
280     syms.funcdef,
281     syms.classdef,
282 }
283 STANDALONE_COMMENT = 153
284 LOGIC_OPERATORS = {'and', 'or'}
285 COMPARATORS = {
286     token.LESS,
287     token.GREATER,
288     token.EQEQUAL,
289     token.NOTEQUAL,
290     token.LESSEQUAL,
291     token.GREATEREQUAL,
292 }
293 MATH_OPERATORS = {
294     token.PLUS,
295     token.MINUS,
296     token.STAR,
297     token.SLASH,
298     token.VBAR,
299     token.AMPER,
300     token.PERCENT,
301     token.CIRCUMFLEX,
302     token.LEFTSHIFT,
303     token.RIGHTSHIFT,
304     token.DOUBLESTAR,
305     token.DOUBLESLASH,
306 }
307 COMPREHENSION_PRIORITY = 20
308 COMMA_PRIORITY = 10
309 LOGIC_PRIORITY = 5
310 STRING_PRIORITY = 4
311 COMPARATOR_PRIORITY = 3
312 MATH_PRIORITY = 1
313
314
315 @dataclass
316 class BracketTracker:
317     depth: int = attrib(default=0)
318     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
319     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
320     previous: Optional[Leaf] = attrib(default=None)
321
322     def mark(self, leaf: Leaf) -> None:
323         if leaf.type == token.COMMENT:
324             return
325
326         if leaf.type in CLOSING_BRACKETS:
327             self.depth -= 1
328             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
329             leaf.opening_bracket = opening_bracket  # type: ignore
330         leaf.bracket_depth = self.depth  # type: ignore
331         if self.depth == 0:
332             delim = is_delimiter(leaf)
333             if delim:
334                 self.delimiters[id(leaf)] = delim
335             elif self.previous is not None:
336                 if leaf.type == token.STRING and self.previous.type == token.STRING:
337                     self.delimiters[id(self.previous)] = STRING_PRIORITY
338                 elif (
339                     leaf.type == token.NAME and
340                     leaf.value == 'for' and
341                     leaf.parent and
342                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
343                 ):
344                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
345                 elif (
346                     leaf.type == token.NAME and
347                     leaf.value == 'if' and
348                     leaf.parent and
349                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
350                 ):
351                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
352         if leaf.type in OPENING_BRACKETS:
353             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
354             self.depth += 1
355         self.previous = leaf
356
357     def any_open_brackets(self) -> bool:
358         """Returns True if there is an yet unmatched open bracket on the line."""
359         return bool(self.bracket_match)
360
361     def max_priority(self, exclude: Iterable[LeafID] = ()) -> int:
362         """Returns the highest priority of a delimiter found on the line.
363
364         Values are consistent with what `is_delimiter()` returns.
365         """
366         return max(v for k, v in self.delimiters.items() if k not in exclude)
367
368
369 @dataclass
370 class Line:
371     depth: int = attrib(default=0)
372     leaves: List[Leaf] = attrib(default=Factory(list))
373     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
374     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
375     inside_brackets: bool = attrib(default=False)
376
377     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
378         has_value = leaf.value.strip()
379         if not has_value:
380             return
381
382         if self.leaves and not preformatted:
383             # Note: at this point leaf.prefix should be empty except for
384             # imports, for which we only preserve newlines.
385             leaf.prefix += whitespace(leaf)
386         if self.inside_brackets or not preformatted:
387             self.bracket_tracker.mark(leaf)
388             self.maybe_remove_trailing_comma(leaf)
389             if self.maybe_adapt_standalone_comment(leaf):
390                 return
391
392         if not self.append_comment(leaf):
393             self.leaves.append(leaf)
394
395     @property
396     def is_comment(self) -> bool:
397         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
398
399     @property
400     def is_decorator(self) -> bool:
401         return bool(self) and self.leaves[0].type == token.AT
402
403     @property
404     def is_import(self) -> bool:
405         return bool(self) and is_import(self.leaves[0])
406
407     @property
408     def is_class(self) -> bool:
409         return (
410             bool(self) and
411             self.leaves[0].type == token.NAME and
412             self.leaves[0].value == 'class'
413         )
414
415     @property
416     def is_def(self) -> bool:
417         """Also returns True for async defs."""
418         try:
419             first_leaf = self.leaves[0]
420         except IndexError:
421             return False
422
423         try:
424             second_leaf: Optional[Leaf] = self.leaves[1]
425         except IndexError:
426             second_leaf = None
427         return (
428             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
429             (
430                 first_leaf.type == token.NAME and
431                 first_leaf.value == 'async' and
432                 second_leaf is not None and
433                 second_leaf.type == token.NAME and
434                 second_leaf.value == 'def'
435             )
436         )
437
438     @property
439     def is_flow_control(self) -> bool:
440         return (
441             bool(self) and
442             self.leaves[0].type == token.NAME and
443             self.leaves[0].value in FLOW_CONTROL
444         )
445
446     @property
447     def is_yield(self) -> bool:
448         return (
449             bool(self) and
450             self.leaves[0].type == token.NAME and
451             self.leaves[0].value == 'yield'
452         )
453
454     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
455         if not (
456             self.leaves and
457             self.leaves[-1].type == token.COMMA and
458             closing.type in CLOSING_BRACKETS
459         ):
460             return False
461
462         if closing.type == token.RSQB or closing.type == token.RBRACE:
463             self.leaves.pop()
464             return True
465
466         # For parens let's check if it's safe to remove the comma.  If the
467         # trailing one is the only one, we might mistakenly change a tuple
468         # into a different type by removing the comma.
469         depth = closing.bracket_depth + 1  # type: ignore
470         commas = 0
471         opening = closing.opening_bracket  # type: ignore
472         for _opening_index, leaf in enumerate(self.leaves):
473             if leaf is opening:
474                 break
475
476         else:
477             return False
478
479         for leaf in self.leaves[_opening_index + 1:]:
480             if leaf is closing:
481                 break
482
483             bracket_depth = leaf.bracket_depth  # type: ignore
484             if bracket_depth == depth and leaf.type == token.COMMA:
485                 commas += 1
486         if commas > 1:
487             self.leaves.pop()
488             return True
489
490         return False
491
492     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
493         """Hack a standalone comment to act as a trailing comment for line splitting.
494
495         If this line has brackets and a standalone `comment`, we need to adapt
496         it to be able to still reformat the line.
497
498         This is not perfect, the line to which the standalone comment gets
499         appended will appear "too long" when splitting.
500         """
501         if not (
502             comment.type == STANDALONE_COMMENT and
503             self.bracket_tracker.any_open_brackets()
504         ):
505             return False
506
507         comment.type = token.COMMENT
508         comment.prefix = '\n' + '    ' * (self.depth + 1)
509         return self.append_comment(comment)
510
511     def append_comment(self, comment: Leaf) -> bool:
512         if comment.type != token.COMMENT:
513             return False
514
515         try:
516             after = id(self.last_non_delimiter())
517         except LookupError:
518             comment.type = STANDALONE_COMMENT
519             comment.prefix = ''
520             return False
521
522         else:
523             if after in self.comments:
524                 self.comments[after].value += str(comment)
525             else:
526                 self.comments[after] = comment
527             return True
528
529     def last_non_delimiter(self) -> Leaf:
530         for i in range(len(self.leaves)):
531             last = self.leaves[-i - 1]
532             if not is_delimiter(last):
533                 return last
534
535         raise LookupError("No non-delimiters found")
536
537     def __str__(self) -> str:
538         if not self:
539             return '\n'
540
541         indent = '    ' * self.depth
542         leaves = iter(self.leaves)
543         first = next(leaves)
544         res = f'{first.prefix}{indent}{first.value}'
545         for leaf in leaves:
546             res += str(leaf)
547         for comment in self.comments.values():
548             res += str(comment)
549         return res + '\n'
550
551     def __bool__(self) -> bool:
552         return bool(self.leaves or self.comments)
553
554
555 @dataclass
556 class EmptyLineTracker:
557     """Provides a stateful method that returns the number of potential extra
558     empty lines needed before and after the currently processed line.
559
560     Note: this tracker works on lines that haven't been split yet.
561     """
562     previous_line: Optional[Line] = attrib(default=None)
563     previous_after: int = attrib(default=0)
564     previous_defs: List[int] = attrib(default=Factory(list))
565
566     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
567         """Returns the number of extra empty lines before and after the `current_line`.
568
569         This is for separating `def`, `async def` and `class` with extra empty lines
570         (two on module-level), as well as providing an extra empty line after flow
571         control keywords to make them more prominent.
572         """
573         before, after = self._maybe_empty_lines(current_line)
574         self.previous_after = after
575         self.previous_line = current_line
576         return before, after
577
578     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
579         before = 0
580         depth = current_line.depth
581         while self.previous_defs and self.previous_defs[-1] >= depth:
582             self.previous_defs.pop()
583             before = (1 if depth else 2) - self.previous_after
584         is_decorator = current_line.is_decorator
585         if is_decorator or current_line.is_def or current_line.is_class:
586             if not is_decorator:
587                 self.previous_defs.append(depth)
588             if self.previous_line is None:
589                 # Don't insert empty lines before the first line in the file.
590                 return 0, 0
591
592             if self.previous_line and self.previous_line.is_decorator:
593                 # Don't insert empty lines between decorators.
594                 return 0, 0
595
596             newlines = 2
597             if current_line.depth:
598                 newlines -= 1
599             newlines -= self.previous_after
600             return newlines, 0
601
602         if current_line.is_flow_control:
603             return before, 1
604
605         if (
606             self.previous_line and
607             self.previous_line.is_import and
608             not current_line.is_import and
609             depth == self.previous_line.depth
610         ):
611             return (before or 1), 0
612
613         if (
614             self.previous_line and
615             self.previous_line.is_yield and
616             (not current_line.is_yield or depth != self.previous_line.depth)
617         ):
618             return (before or 1), 0
619
620         return before, 0
621
622
623 @dataclass
624 class LineGenerator(Visitor[Line]):
625     """Generates reformatted Line objects.  Empty lines are not emitted.
626
627     Note: destroys the tree it's visiting by mutating prefixes of its leaves
628     in ways that will no longer stringify to valid Python code on the tree.
629     """
630     current_line: Line = attrib(default=Factory(Line))
631     standalone_comments: List[Leaf] = attrib(default=Factory(list))
632
633     def line(self, indent: int = 0) -> Iterator[Line]:
634         """Generate a line.
635
636         If the line is empty, only emit if it makes sense.
637         If the line is too long, split it first and then generate.
638
639         If any lines were generated, set up a new current_line.
640         """
641         if not self.current_line:
642             self.current_line.depth += indent
643             return  # Line is empty, don't emit. Creating a new one unnecessary.
644
645         complete_line = self.current_line
646         self.current_line = Line(depth=complete_line.depth + indent)
647         yield complete_line
648
649     def visit_default(self, node: LN) -> Iterator[Line]:
650         if isinstance(node, Leaf):
651             for comment in generate_comments(node):
652                 if self.current_line.bracket_tracker.any_open_brackets():
653                     # any comment within brackets is subject to splitting
654                     self.current_line.append(comment)
655                 elif comment.type == token.COMMENT:
656                     # regular trailing comment
657                     self.current_line.append(comment)
658                     yield from self.line()
659
660                 else:
661                     # regular standalone comment, to be processed later (see
662                     # docstring in `generate_comments()`
663                     self.standalone_comments.append(comment)
664             normalize_prefix(node)
665             if node.type not in WHITESPACE:
666                 for comment in self.standalone_comments:
667                     yield from self.line()
668
669                     self.current_line.append(comment)
670                     yield from self.line()
671
672                 self.standalone_comments = []
673                 self.current_line.append(node)
674         yield from super().visit_default(node)
675
676     def visit_suite(self, node: Node) -> Iterator[Line]:
677         """Body of a statement after a colon."""
678         children = iter(node.children)
679         # Process newline before indenting.  It might contain an inline
680         # comment that should go right after the colon.
681         newline = next(children)
682         yield from self.visit(newline)
683         yield from self.line(+1)
684
685         for child in children:
686             yield from self.visit(child)
687
688         yield from self.line(-1)
689
690     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
691         """Visit a statement.
692
693         The relevant Python language keywords for this statement are NAME leaves
694         within it.
695         """
696         for child in node.children:
697             if child.type == token.NAME and child.value in keywords:  # type: ignore
698                 yield from self.line()
699
700             yield from self.visit(child)
701
702     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
703         """A statement without nested statements."""
704         is_suite_like = node.parent and node.parent.type in STATEMENT
705         if is_suite_like:
706             yield from self.line(+1)
707             yield from self.visit_default(node)
708             yield from self.line(-1)
709
710         else:
711             yield from self.line()
712             yield from self.visit_default(node)
713
714     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
715         yield from self.line()
716
717         children = iter(node.children)
718         for child in children:
719             yield from self.visit(child)
720
721             if child.type == token.NAME and child.value == 'async':  # type: ignore
722                 break
723
724         internal_stmt = next(children)
725         for child in internal_stmt.children:
726             yield from self.visit(child)
727
728     def visit_decorators(self, node: Node) -> Iterator[Line]:
729         for child in node.children:
730             yield from self.line()
731             yield from self.visit(child)
732
733     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
734         yield from self.line()
735
736     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
737         yield from self.visit_default(leaf)
738         yield from self.line()
739
740     def __attrs_post_init__(self) -> None:
741         """You are in a twisty little maze of passages."""
742         v = self.visit_stmt
743         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
744         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
745         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
746         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
747         self.visit_except_clause = partial(v, keywords={'except'})
748         self.visit_funcdef = partial(v, keywords={'def'})
749         self.visit_with_stmt = partial(v, keywords={'with'})
750         self.visit_classdef = partial(v, keywords={'class'})
751         self.visit_async_funcdef = self.visit_async_stmt
752         self.visit_decorated = self.visit_decorators
753
754
755 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
756 OPENING_BRACKETS = set(BRACKET.keys())
757 CLOSING_BRACKETS = set(BRACKET.values())
758 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
759
760
761 def whitespace(leaf: Leaf) -> str:
762     """Return whitespace prefix if needed for the given `leaf`."""
763     NO = ''
764     SPACE = ' '
765     DOUBLESPACE = '  '
766     t = leaf.type
767     p = leaf.parent
768     if t == token.COLON:
769         return NO
770
771     if t == token.COMMA:
772         return NO
773
774     if t == token.RPAR:
775         return NO
776
777     if t == token.COMMENT:
778         return DOUBLESPACE
779
780     if t == STANDALONE_COMMENT:
781         return NO
782
783     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
784     if p.type in {syms.parameters, syms.arglist}:
785         # untyped function signatures or calls
786         if t == token.RPAR:
787             return NO
788
789         prev = leaf.prev_sibling
790         if not prev or prev.type != token.COMMA:
791             return NO
792
793     if p.type == syms.varargslist:
794         # lambdas
795         if t == token.RPAR:
796             return NO
797
798         prev = leaf.prev_sibling
799         if prev and prev.type != token.COMMA:
800             return NO
801
802     elif p.type == syms.typedargslist:
803         # typed function signatures
804         prev = leaf.prev_sibling
805         if not prev:
806             return NO
807
808         if t == token.EQUAL:
809             if prev.type != syms.tname:
810                 return NO
811
812         elif prev.type == token.EQUAL:
813             # A bit hacky: if the equal sign has whitespace, it means we
814             # previously found it's a typed argument.  So, we're using that, too.
815             return prev.prefix
816
817         elif prev.type != token.COMMA:
818             return NO
819
820     elif p.type == syms.tname:
821         # type names
822         prev = leaf.prev_sibling
823         if not prev:
824             prevp = preceding_leaf(p)
825             if not prevp or prevp.type != token.COMMA:
826                 return NO
827
828     elif p.type == syms.trailer:
829         # attributes and calls
830         if t == token.LPAR or t == token.RPAR:
831             return NO
832
833         prev = leaf.prev_sibling
834         if not prev:
835             if t == token.DOT:
836                 prevp = preceding_leaf(p)
837                 if not prevp or prevp.type != token.NUMBER:
838                     return NO
839
840             elif t == token.LSQB:
841                 return NO
842
843         elif prev.type != token.COMMA:
844             return NO
845
846     elif p.type == syms.argument:
847         # single argument
848         if t == token.EQUAL:
849             return NO
850
851         prev = leaf.prev_sibling
852         if not prev:
853             prevp = preceding_leaf(p)
854             if not prevp or prevp.type == token.LPAR:
855                 return NO
856
857         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
858             return NO
859
860     elif p.type == syms.decorator:
861         # decorators
862         return NO
863
864     elif p.type == syms.dotted_name:
865         prev = leaf.prev_sibling
866         if prev:
867             return NO
868
869         prevp = preceding_leaf(p)
870         if not prevp or prevp.type == token.AT:
871             return NO
872
873     elif p.type == syms.classdef:
874         if t == token.LPAR:
875             return NO
876
877         prev = leaf.prev_sibling
878         if prev and prev.type == token.LPAR:
879             return NO
880
881     elif p.type == syms.subscript:
882         # indexing
883         if t == token.COLON:
884             return NO
885
886         prev = leaf.prev_sibling
887         if not prev or prev.type == token.COLON:
888             return NO
889
890     elif p.type in {
891         syms.test,
892         syms.not_test,
893         syms.xor_expr,
894         syms.or_test,
895         syms.and_test,
896         syms.arith_expr,
897         syms.shift_expr,
898         syms.yield_expr,
899         syms.term,
900         syms.power,
901         syms.comparison,
902     }:
903         # various arithmetic and logic expressions
904         prev = leaf.prev_sibling
905         if not prev:
906             prevp = preceding_leaf(p)
907             if not prevp or prevp.type in OPENING_BRACKETS:
908                 return NO
909
910             if prevp.type == token.EQUAL:
911                 if prevp.parent and prevp.parent.type in {
912                     syms.varargslist, syms.parameters, syms.arglist, syms.argument
913                 }:
914                     return NO
915
916         return SPACE
917
918     elif p.type == syms.atom:
919         if t in CLOSING_BRACKETS:
920             return NO
921
922         prev = leaf.prev_sibling
923         if not prev:
924             prevp = preceding_leaf(p)
925             if not prevp:
926                 return NO
927
928             if prevp.type in OPENING_BRACKETS:
929                 return NO
930
931             if prevp.type == token.EQUAL:
932                 if prevp.parent and prevp.parent.type in {
933                     syms.varargslist, syms.parameters, syms.arglist, syms.argument
934                 }:
935                     return NO
936
937             if prevp.type == token.DOUBLESTAR:
938                 if prevp.parent and prevp.parent.type in {
939                     syms.varargslist, syms.parameters, syms.arglist, syms.dictsetmaker
940                 }:
941                     return NO
942
943         elif prev.type in OPENING_BRACKETS:
944             return NO
945
946         elif t == token.DOT:
947             # dots, but not the first one.
948             return NO
949
950     elif (
951         p.type == syms.listmaker or
952         p.type == syms.testlist_gexp or
953         p.type == syms.subscriptlist
954     ):
955         # list interior, including unpacking
956         prev = leaf.prev_sibling
957         if not prev:
958             return NO
959
960     elif p.type == syms.dictsetmaker:
961         # dict and set interior, including unpacking
962         prev = leaf.prev_sibling
963         if not prev:
964             return NO
965
966         if prev.type == token.DOUBLESTAR:
967             return NO
968
969     elif p.type == syms.factor or p.type == syms.star_expr:
970         # unary ops
971         prev = leaf.prev_sibling
972         if not prev:
973             prevp = preceding_leaf(p)
974             if not prevp or prevp.type in OPENING_BRACKETS:
975                 return NO
976
977             prevp_parent = prevp.parent
978             assert prevp_parent is not None
979             if prevp.type == token.COLON and prevp_parent.type in {
980                 syms.subscript, syms.sliceop
981             }:
982                 return NO
983
984             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
985                 return NO
986
987         elif t == token.NAME or t == token.NUMBER:
988             return NO
989
990     elif p.type == syms.import_from and t == token.NAME:
991         prev = leaf.prev_sibling
992         if prev and prev.type == token.DOT:
993             return NO
994
995     elif p.type == syms.sliceop:
996         return NO
997
998     return SPACE
999
1000
1001 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1002     """Returns the first leaf that precedes `node`, if any."""
1003     while node:
1004         res = node.prev_sibling
1005         if res:
1006             if isinstance(res, Leaf):
1007                 return res
1008
1009             try:
1010                 return list(res.leaves())[-1]
1011
1012             except IndexError:
1013                 return None
1014
1015         node = node.parent
1016     return None
1017
1018
1019 def is_delimiter(leaf: Leaf) -> int:
1020     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1021
1022     Higher numbers are higher priority.
1023     """
1024     if leaf.type == token.COMMA:
1025         return COMMA_PRIORITY
1026
1027     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1028         return LOGIC_PRIORITY
1029
1030     if leaf.type in COMPARATORS:
1031         return COMPARATOR_PRIORITY
1032
1033     if (
1034         leaf.type in MATH_OPERATORS and
1035         leaf.parent and
1036         leaf.parent.type not in {syms.factor, syms.star_expr}
1037     ):
1038         return MATH_PRIORITY
1039
1040     return 0
1041
1042
1043 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1044     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1045
1046     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1047     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1048     move because it does away with modifying the grammar to include all the
1049     possible places in which comments can be placed.
1050
1051     The sad consequence for us though is that comments don't "belong" anywhere.
1052     This is why this function generates simple parentless Leaf objects for
1053     comments.  We simply don't know what the correct parent should be.
1054
1055     No matter though, we can live without this.  We really only need to
1056     differentiate between inline and standalone comments.  The latter don't
1057     share the line with any code.
1058
1059     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1060     are emitted with a fake STANDALONE_COMMENT token identifier.
1061     """
1062     if not leaf.prefix:
1063         return
1064
1065     if '#' not in leaf.prefix:
1066         return
1067
1068     before_comment, content = leaf.prefix.split('#', 1)
1069     content = content.rstrip()
1070     if content and (content[0] not in {' ', '!', '#'}):
1071         content = ' ' + content
1072     is_standalone_comment = (
1073         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1074     )
1075     if not is_standalone_comment:
1076         # simple trailing comment
1077         yield Leaf(token.COMMENT, value='#' + content)
1078         return
1079
1080     for line in ('#' + content).split('\n'):
1081         line = line.lstrip()
1082         if not line.startswith('#'):
1083             continue
1084
1085         yield Leaf(STANDALONE_COMMENT, line)
1086
1087
1088 def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]:
1089     """Splits a `line` into potentially many lines.
1090
1091     They should fit in the allotted `line_length` but might not be able to.
1092     `inner` signifies that there were a pair of brackets somewhere around the
1093     current `line`, possibly transitively. This means we can fallback to splitting
1094     by delimiters if the LHS/RHS don't yield any results.
1095     """
1096     line_str = str(line).strip('\n')
1097     if len(line_str) <= line_length and '\n' not in line_str:
1098         yield line
1099         return
1100
1101     if line.is_def:
1102         split_funcs = [left_hand_split]
1103     elif line.inside_brackets:
1104         split_funcs = [delimiter_split]
1105         if '\n' not in line_str:
1106             # Only attempt RHS if we don't have multiline strings or comments
1107             # on this line.
1108             split_funcs.append(right_hand_split)
1109     else:
1110         split_funcs = [right_hand_split]
1111     for split_func in split_funcs:
1112         # We are accumulating lines in `result` because we might want to abort
1113         # mission and return the original line in the end, or attempt a different
1114         # split altogether.
1115         result: List[Line] = []
1116         try:
1117             for l in split_func(line):
1118                 if str(l).strip('\n') == line_str:
1119                     raise CannotSplit("Split function returned an unchanged result")
1120
1121                 result.extend(split_line(l, line_length=line_length, inner=True))
1122         except CannotSplit as cs:
1123             continue
1124
1125         else:
1126             yield from result
1127             break
1128
1129     else:
1130         yield line
1131
1132
1133 def left_hand_split(line: Line) -> Iterator[Line]:
1134     """Split line into many lines, starting with the first matching bracket pair.
1135
1136     Note: this usually looks weird, only use this for function definitions.
1137     Prefer RHS otherwise.
1138     """
1139     head = Line(depth=line.depth)
1140     body = Line(depth=line.depth + 1, inside_brackets=True)
1141     tail = Line(depth=line.depth)
1142     tail_leaves: List[Leaf] = []
1143     body_leaves: List[Leaf] = []
1144     head_leaves: List[Leaf] = []
1145     current_leaves = head_leaves
1146     matching_bracket = None
1147     for leaf in line.leaves:
1148         if (
1149             current_leaves is body_leaves and
1150             leaf.type in CLOSING_BRACKETS and
1151             leaf.opening_bracket is matching_bracket  # type: ignore
1152         ):
1153             current_leaves = tail_leaves
1154         current_leaves.append(leaf)
1155         if current_leaves is head_leaves:
1156             if leaf.type in OPENING_BRACKETS:
1157                 matching_bracket = leaf
1158                 current_leaves = body_leaves
1159     # Since body is a new indent level, remove spurious leading whitespace.
1160     if body_leaves:
1161         normalize_prefix(body_leaves[0])
1162     # Build the new lines.
1163     for result, leaves in (
1164         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1165     ):
1166         for leaf in leaves:
1167             result.append(leaf, preformatted=True)
1168             comment_after = line.comments.get(id(leaf))
1169             if comment_after:
1170                 result.append(comment_after, preformatted=True)
1171     # Check if the split succeeded.
1172     tail_len = len(str(tail))
1173     if not body:
1174         if tail_len == 0:
1175             raise CannotSplit("Splitting brackets produced the same line")
1176
1177         elif tail_len < 3:
1178             raise CannotSplit(
1179                 f"Splitting brackets on an empty body to save "
1180                 f"{tail_len} characters is not worth it"
1181             )
1182
1183     for result in (head, body, tail):
1184         if result:
1185             yield result
1186
1187
1188 def right_hand_split(line: Line) -> Iterator[Line]:
1189     """Split line into many lines, starting with the last matching bracket pair."""
1190     head = Line(depth=line.depth)
1191     body = Line(depth=line.depth + 1, inside_brackets=True)
1192     tail = Line(depth=line.depth)
1193     tail_leaves: List[Leaf] = []
1194     body_leaves: List[Leaf] = []
1195     head_leaves: List[Leaf] = []
1196     current_leaves = tail_leaves
1197     opening_bracket = None
1198     for leaf in reversed(line.leaves):
1199         if current_leaves is body_leaves:
1200             if leaf is opening_bracket:
1201                 current_leaves = head_leaves
1202         current_leaves.append(leaf)
1203         if current_leaves is tail_leaves:
1204             if leaf.type in CLOSING_BRACKETS:
1205                 opening_bracket = leaf.opening_bracket  # type: ignore
1206                 current_leaves = body_leaves
1207     tail_leaves.reverse()
1208     body_leaves.reverse()
1209     head_leaves.reverse()
1210     # Since body is a new indent level, remove spurious leading whitespace.
1211     if body_leaves:
1212         normalize_prefix(body_leaves[0])
1213     # Build the new lines.
1214     for result, leaves in (
1215         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1216     ):
1217         for leaf in leaves:
1218             result.append(leaf, preformatted=True)
1219             comment_after = line.comments.get(id(leaf))
1220             if comment_after:
1221                 result.append(comment_after, preformatted=True)
1222     # Check if the split succeeded.
1223     tail_len = len(str(tail).strip('\n'))
1224     if not body:
1225         if tail_len == 0:
1226             raise CannotSplit("Splitting brackets produced the same line")
1227
1228         elif tail_len < 3:
1229             raise CannotSplit(
1230                 f"Splitting brackets on an empty body to save "
1231                 f"{tail_len} characters is not worth it"
1232             )
1233
1234     for result in (head, body, tail):
1235         if result:
1236             yield result
1237
1238
1239 def delimiter_split(line: Line) -> Iterator[Line]:
1240     """Split according to delimiters of the highest priority.
1241
1242     This kind of split doesn't increase indentation.
1243     """
1244     try:
1245         last_leaf = line.leaves[-1]
1246     except IndexError:
1247         raise CannotSplit("Line empty")
1248
1249     delimiters = line.bracket_tracker.delimiters
1250     try:
1251         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1252     except ValueError:
1253         raise CannotSplit("No delimiters found")
1254
1255     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1256     for leaf in line.leaves:
1257         current_line.append(leaf, preformatted=True)
1258         comment_after = line.comments.get(id(leaf))
1259         if comment_after:
1260             current_line.append(comment_after, preformatted=True)
1261         leaf_priority = delimiters.get(id(leaf))
1262         if leaf_priority == delimiter_priority:
1263             normalize_prefix(current_line.leaves[0])
1264             yield current_line
1265
1266             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1267     if current_line:
1268         if (
1269             delimiter_priority == COMMA_PRIORITY and
1270             current_line.leaves[-1].type != token.COMMA
1271         ):
1272             current_line.append(Leaf(token.COMMA, ','))
1273         normalize_prefix(current_line.leaves[0])
1274         yield current_line
1275
1276
1277 def is_import(leaf: Leaf) -> bool:
1278     """Returns True if the given leaf starts an import statement."""
1279     p = leaf.parent
1280     t = leaf.type
1281     v = leaf.value
1282     return bool(
1283         t == token.NAME and
1284         (
1285             (v == 'import' and p and p.type == syms.import_name) or
1286             (v == 'from' and p and p.type == syms.import_from)
1287         )
1288     )
1289
1290
1291 def normalize_prefix(leaf: Leaf) -> None:
1292     """Leave existing extra newlines for imports.  Remove everything else."""
1293     if is_import(leaf):
1294         spl = leaf.prefix.split('#', 1)
1295         nl_count = spl[0].count('\n')
1296         if len(spl) > 1:
1297             # Skip one newline since it was for a standalone comment.
1298             nl_count -= 1
1299         leaf.prefix = '\n' * nl_count
1300         return
1301
1302     leaf.prefix = ''
1303
1304
1305 PYTHON_EXTENSIONS = {'.py'}
1306 BLACKLISTED_DIRECTORIES = {
1307     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1308 }
1309
1310
1311 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1312     for child in path.iterdir():
1313         if child.is_dir():
1314             if child.name in BLACKLISTED_DIRECTORIES:
1315                 continue
1316
1317             yield from gen_python_files_in_dir(child)
1318
1319         elif child.suffix in PYTHON_EXTENSIONS:
1320             yield child
1321
1322
1323 @dataclass
1324 class Report:
1325     """Provides a reformatting counter."""
1326     change_count: int = attrib(default=0)
1327     same_count: int = attrib(default=0)
1328     failure_count: int = attrib(default=0)
1329
1330     def done(self, src: Path, changed: bool) -> None:
1331         """Increment the counter for successful reformatting. Write out a message."""
1332         if changed:
1333             out(f'reformatted {src}')
1334             self.change_count += 1
1335         else:
1336             out(f'{src} already well formatted, good job.', bold=False)
1337             self.same_count += 1
1338
1339     def failed(self, src: Path, message: str) -> None:
1340         """Increment the counter for failed reformatting. Write out a message."""
1341         err(f'error: cannot format {src}: {message}')
1342         self.failure_count += 1
1343
1344     @property
1345     def return_code(self) -> int:
1346         """Which return code should the app use considering the current state."""
1347         return 1 if self.failure_count else 0
1348
1349     def __str__(self) -> str:
1350         """A color report of the current state.
1351
1352         Use `click.unstyle` to remove colors.
1353         """
1354         report = []
1355         if self.change_count:
1356             s = 's' if self.change_count > 1 else ''
1357             report.append(
1358                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1359             )
1360         if self.same_count:
1361             s = 's' if self.same_count > 1 else ''
1362             report.append(f'{self.same_count} file{s} left unchanged')
1363         if self.failure_count:
1364             s = 's' if self.failure_count > 1 else ''
1365             report.append(
1366                 click.style(
1367                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1368                 )
1369             )
1370         return ', '.join(report) + '.'
1371
1372
1373 def assert_equivalent(src: str, dst: str) -> None:
1374     """Raises AssertionError if `src` and `dst` aren't equivalent.
1375
1376     This is a temporary sanity check until Black becomes stable.
1377     """
1378
1379     import ast
1380     import traceback
1381
1382     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1383         """Simple visitor generating strings to compare ASTs by content."""
1384         yield f"{'  ' * depth}{node.__class__.__name__}("
1385
1386         for field in sorted(node._fields):
1387             try:
1388                 value = getattr(node, field)
1389             except AttributeError:
1390                 continue
1391
1392             yield f"{'  ' * (depth+1)}{field}="
1393
1394             if isinstance(value, list):
1395                 for item in value:
1396                     if isinstance(item, ast.AST):
1397                         yield from _v(item, depth + 2)
1398
1399             elif isinstance(value, ast.AST):
1400                 yield from _v(value, depth + 2)
1401
1402             else:
1403                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1404
1405         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1406
1407     try:
1408         src_ast = ast.parse(src)
1409     except Exception as exc:
1410         raise AssertionError(f"cannot parse source: {exc}") from None
1411
1412     try:
1413         dst_ast = ast.parse(dst)
1414     except Exception as exc:
1415         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1416         raise AssertionError(
1417             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1418             f"Please report a bug on https://github.com/ambv/black/issues.  "
1419             f"This invalid output might be helpful: {log}",
1420         ) from None
1421
1422     src_ast_str = '\n'.join(_v(src_ast))
1423     dst_ast_str = '\n'.join(_v(dst_ast))
1424     if src_ast_str != dst_ast_str:
1425         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1426         raise AssertionError(
1427             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1428             f"the source.  "
1429             f"Please report a bug on https://github.com/ambv/black/issues.  "
1430             f"This diff might be helpful: {log}",
1431         ) from None
1432
1433
1434 def assert_stable(src: str, dst: str, line_length: int) -> None:
1435     """Raises AssertionError if `dst` reformats differently the second time.
1436
1437     This is a temporary sanity check until Black becomes stable.
1438     """
1439     newdst = format_str(dst, line_length=line_length)
1440     if dst != newdst:
1441         log = dump_to_file(
1442             diff(src, dst, 'source', 'first pass'),
1443             diff(dst, newdst, 'first pass', 'second pass'),
1444         )
1445         raise AssertionError(
1446             f"INTERNAL ERROR: Black produced different code on the second pass "
1447             f"of the formatter.  "
1448             f"Please report a bug on https://github.com/ambv/black/issues.  "
1449             f"This diff might be helpful: {log}",
1450         ) from None
1451
1452
1453 def dump_to_file(*output: str) -> str:
1454     """Dumps `output` to a temporary file. Returns path to the file."""
1455     import tempfile
1456
1457     with tempfile.NamedTemporaryFile(
1458         mode='w', prefix='blk_', suffix='.log', delete=False
1459     ) as f:
1460         for lines in output:
1461             f.write(lines)
1462             f.write('\n')
1463     return f.name
1464
1465
1466 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1467     """Returns a udiff string between strings `a` and `b`."""
1468     import difflib
1469
1470     a_lines = [line + '\n' for line in a.split('\n')]
1471     b_lines = [line + '\n' for line in b.split('\n')]
1472     return ''.join(
1473         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1474     )
1475
1476
1477 if __name__ == '__main__':
1478     main()