]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

d48fddbe6526d6dc2753a71c482e53afe8c2f9ee
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 from typing import (
11     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
12 )
13
14 from attr import attrib, dataclass, Factory
15 import click
16
17 # lib2to3 fork
18 from blib2to3.pytree import Node, Leaf, type_repr
19 from blib2to3 import pygram, pytree
20 from blib2to3.pgen2 import driver, token
21 from blib2to3.pgen2.parse import ParseError
22
23 __version__ = "18.3a1"
24 DEFAULT_LINE_LENGTH = 88
25 # types
26 syms = pygram.python_symbols
27 FileContent = str
28 Encoding = str
29 Depth = int
30 NodeType = int
31 LeafID = int
32 Priority = int
33 LN = Union[Leaf, Node]
34 out = partial(click.secho, bold=True, err=True)
35 err = partial(click.secho, fg='red', err=True)
36
37
38 class NothingChanged(UserWarning):
39     """Raised by `format_file` when the reformatted code is the same as source."""
40
41
42 class CannotSplit(Exception):
43     """A readable split that fits the allotted line length is impossible.
44
45     Raised by `left_hand_split()` and `right_hand_split()`.
46     """
47
48
49 @click.command()
50 @click.option(
51     '-l',
52     '--line-length',
53     type=int,
54     default=DEFAULT_LINE_LENGTH,
55     help='How many character per line to allow.',
56     show_default=True,
57 )
58 @click.option(
59     '--fast/--safe',
60     is_flag=True,
61     help='If --fast given, skip temporary sanity checks. [default: --safe]',
62 )
63 @click.version_option(version=__version__)
64 @click.argument(
65     'src',
66     nargs=-1,
67     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
68 )
69 @click.pass_context
70 def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> None:
71     """The uncompromising code formatter."""
72     sources: List[Path] = []
73     for s in src:
74         p = Path(s)
75         if p.is_dir():
76             sources.extend(gen_python_files_in_dir(p))
77         elif p.is_file():
78             # if a file was explicitly given, we don't care about its extension
79             sources.append(p)
80         else:
81             err(f'invalid path: {s}')
82     if len(sources) == 0:
83         ctx.exit(0)
84     elif len(sources) == 1:
85         p = sources[0]
86         report = Report()
87         try:
88             changed = format_file_in_place(p, line_length=line_length, fast=fast)
89             report.done(p, changed)
90         except Exception as exc:
91             report.failed(p, str(exc))
92         ctx.exit(report.return_code)
93     else:
94         loop = asyncio.get_event_loop()
95         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
96         return_code = 1
97         try:
98             return_code = loop.run_until_complete(
99                 schedule_formatting(sources, line_length, fast, loop, executor)
100             )
101         finally:
102             loop.close()
103             ctx.exit(return_code)
104
105
106 async def schedule_formatting(
107     sources: List[Path],
108     line_length: int,
109     fast: bool,
110     loop: BaseEventLoop,
111     executor: Executor,
112 ) -> int:
113     tasks = {
114         src: loop.run_in_executor(
115             executor, format_file_in_place, src, line_length, fast
116         )
117         for src in sources
118     }
119     await asyncio.wait(tasks.values())
120     cancelled = []
121     report = Report()
122     for src, task in tasks.items():
123         if not task.done():
124             report.failed(src, 'timed out, cancelling')
125             task.cancel()
126             cancelled.append(task)
127         elif task.exception():
128             report.failed(src, str(task.exception()))
129         else:
130             report.done(src, task.result())
131     if cancelled:
132         await asyncio.wait(cancelled, timeout=2)
133     out('All done! ✨ 🍰 ✨')
134     click.echo(str(report))
135     return report.return_code
136
137
138 def format_file_in_place(src: Path, line_length: int, fast: bool) -> bool:
139     """Format the file and rewrite if changed. Return True if changed."""
140     try:
141         contents, encoding = format_file(src, line_length=line_length, fast=fast)
142     except NothingChanged:
143         return False
144
145     with open(src, "w", encoding=encoding) as f:
146         f.write(contents)
147     return True
148
149
150 def format_file(
151     src: Path, line_length: int, fast: bool
152 ) -> Tuple[FileContent, Encoding]:
153     """Reformats a file and returns its contents and encoding."""
154     with tokenize.open(src) as src_buffer:
155         src_contents = src_buffer.read()
156     if src_contents.strip() == '':
157         raise NothingChanged(src)
158
159     dst_contents = format_str(src_contents, line_length=line_length)
160     if src_contents == dst_contents:
161         raise NothingChanged(src)
162
163     if not fast:
164         assert_equivalent(src_contents, dst_contents)
165         assert_stable(src_contents, dst_contents, line_length=line_length)
166     return dst_contents, src_buffer.encoding
167
168
169 def format_str(src_contents: str, line_length: int) -> FileContent:
170     """Reformats a string and returns new contents."""
171     src_node = lib2to3_parse(src_contents)
172     dst_contents = ""
173     comments: List[Line] = []
174     lines = LineGenerator()
175     elt = EmptyLineTracker()
176     empty_line = Line()
177     after = 0
178     for current_line in lines.visit(src_node):
179         for _ in range(after):
180             dst_contents += str(empty_line)
181         before, after = elt.maybe_empty_lines(current_line)
182         for _ in range(before):
183             dst_contents += str(empty_line)
184         if not current_line.is_comment:
185             for comment in comments:
186                 dst_contents += str(comment)
187             comments = []
188             for line in split_line(current_line, line_length=line_length):
189                 dst_contents += str(line)
190         else:
191             comments.append(current_line)
192     for comment in comments:
193         dst_contents += str(comment)
194     return dst_contents
195
196
197 def lib2to3_parse(src_txt: str) -> Node:
198     """Given a string with source, return the lib2to3 Node."""
199     grammar = pygram.python_grammar_no_print_statement
200     drv = driver.Driver(grammar, pytree.convert)
201     if src_txt[-1] != '\n':
202         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
203         src_txt += nl
204     try:
205         result = drv.parse_string(src_txt, True)
206     except ParseError as pe:
207         lineno, column = pe.context[1]
208         lines = src_txt.splitlines()
209         try:
210             faulty_line = lines[lineno - 1]
211         except IndexError:
212             faulty_line = "<line number missing in source>"
213         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
214
215     if isinstance(result, Leaf):
216         result = Node(syms.file_input, [result])
217     return result
218
219
220 def lib2to3_unparse(node: Node) -> str:
221     """Given a lib2to3 node, return its string representation."""
222     code = str(node)
223     return code
224
225
226 T = TypeVar('T')
227
228
229 class Visitor(Generic[T]):
230     """Basic lib2to3 visitor that yields things on visiting."""
231
232     def visit(self, node: LN) -> Iterator[T]:
233         if node.type < 256:
234             name = token.tok_name[node.type]
235         else:
236             name = type_repr(node.type)
237         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
238
239     def visit_default(self, node: LN) -> Iterator[T]:
240         if isinstance(node, Node):
241             for child in node.children:
242                 yield from self.visit(child)
243
244
245 @dataclass
246 class DebugVisitor(Visitor[T]):
247     tree_depth: int = attrib(default=0)
248
249     def visit_default(self, node: LN) -> Iterator[T]:
250         indent = ' ' * (2 * self.tree_depth)
251         if isinstance(node, Node):
252             _type = type_repr(node.type)
253             out(f'{indent}{_type}', fg='yellow')
254             self.tree_depth += 1
255             for child in node.children:
256                 yield from self.visit(child)
257
258             self.tree_depth -= 1
259             out(f'{indent}/{_type}', fg='yellow', bold=False)
260         else:
261             _type = token.tok_name.get(node.type, str(node.type))
262             out(f'{indent}{_type}', fg='blue', nl=False)
263             if node.prefix:
264                 # We don't have to handle prefixes for `Node` objects since
265                 # that delegates to the first child anyway.
266                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
267             out(f' {node.value!r}', fg='blue', bold=False)
268
269
270 KEYWORDS = set(keyword.kwlist)
271 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
272 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
273 STATEMENT = {
274     syms.if_stmt,
275     syms.while_stmt,
276     syms.for_stmt,
277     syms.try_stmt,
278     syms.except_clause,
279     syms.with_stmt,
280     syms.funcdef,
281     syms.classdef,
282 }
283 STANDALONE_COMMENT = 153
284 LOGIC_OPERATORS = {'and', 'or'}
285 COMPARATORS = {
286     token.LESS,
287     token.GREATER,
288     token.EQEQUAL,
289     token.NOTEQUAL,
290     token.LESSEQUAL,
291     token.GREATEREQUAL,
292 }
293 MATH_OPERATORS = {
294     token.PLUS,
295     token.MINUS,
296     token.STAR,
297     token.SLASH,
298     token.VBAR,
299     token.AMPER,
300     token.PERCENT,
301     token.CIRCUMFLEX,
302     token.LEFTSHIFT,
303     token.RIGHTSHIFT,
304     token.DOUBLESTAR,
305     token.DOUBLESLASH,
306 }
307 COMPREHENSION_PRIORITY = 20
308 COMMA_PRIORITY = 10
309 LOGIC_PRIORITY = 5
310 STRING_PRIORITY = 4
311 COMPARATOR_PRIORITY = 3
312 MATH_PRIORITY = 1
313
314
315 @dataclass
316 class BracketTracker:
317     depth: int = attrib(default=0)
318     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
319     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
320     previous: Optional[Leaf] = attrib(default=None)
321
322     def mark(self, leaf: Leaf) -> None:
323         if leaf.type == token.COMMENT:
324             return
325
326         if leaf.type in CLOSING_BRACKETS:
327             self.depth -= 1
328             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
329             leaf.opening_bracket = opening_bracket  # type: ignore
330         leaf.bracket_depth = self.depth  # type: ignore
331         if self.depth == 0:
332             delim = is_delimiter(leaf)
333             if delim:
334                 self.delimiters[id(leaf)] = delim
335             elif self.previous is not None:
336                 if leaf.type == token.STRING and self.previous.type == token.STRING:
337                     self.delimiters[id(self.previous)] = STRING_PRIORITY
338                 elif (
339                     leaf.type == token.NAME and
340                     leaf.value == 'for' and
341                     leaf.parent and
342                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
343                 ):
344                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
345                 elif (
346                     leaf.type == token.NAME and
347                     leaf.value == 'if' and
348                     leaf.parent and
349                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
350                 ):
351                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
352         if leaf.type in OPENING_BRACKETS:
353             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
354             self.depth += 1
355         self.previous = leaf
356
357     def any_open_brackets(self) -> bool:
358         """Returns True if there is an yet unmatched open bracket on the line."""
359         return bool(self.bracket_match)
360
361     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
362         """Returns the highest priority of a delimiter found on the line.
363
364         Values are consistent with what `is_delimiter()` returns.
365         """
366         return max(v for k, v in self.delimiters.items() if k not in exclude)
367
368
369 @dataclass
370 class Line:
371     depth: int = attrib(default=0)
372     leaves: List[Leaf] = attrib(default=Factory(list))
373     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
374     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
375     inside_brackets: bool = attrib(default=False)
376
377     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
378         has_value = leaf.value.strip()
379         if not has_value:
380             return
381
382         if self.leaves and not preformatted:
383             # Note: at this point leaf.prefix should be empty except for
384             # imports, for which we only preserve newlines.
385             leaf.prefix += whitespace(leaf)
386         if self.inside_brackets or not preformatted:
387             self.bracket_tracker.mark(leaf)
388             self.maybe_remove_trailing_comma(leaf)
389             if self.maybe_adapt_standalone_comment(leaf):
390                 return
391
392         if not self.append_comment(leaf):
393             self.leaves.append(leaf)
394
395     @property
396     def is_comment(self) -> bool:
397         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
398
399     @property
400     def is_decorator(self) -> bool:
401         return bool(self) and self.leaves[0].type == token.AT
402
403     @property
404     def is_import(self) -> bool:
405         return bool(self) and is_import(self.leaves[0])
406
407     @property
408     def is_class(self) -> bool:
409         return (
410             bool(self) and
411             self.leaves[0].type == token.NAME and
412             self.leaves[0].value == 'class'
413         )
414
415     @property
416     def is_def(self) -> bool:
417         """Also returns True for async defs."""
418         try:
419             first_leaf = self.leaves[0]
420         except IndexError:
421             return False
422
423         try:
424             second_leaf: Optional[Leaf] = self.leaves[1]
425         except IndexError:
426             second_leaf = None
427         return (
428             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
429             (
430                 first_leaf.type == token.NAME and
431                 first_leaf.value == 'async' and
432                 second_leaf is not None and
433                 second_leaf.type == token.NAME and
434                 second_leaf.value == 'def'
435             )
436         )
437
438     @property
439     def is_flow_control(self) -> bool:
440         return (
441             bool(self) and
442             self.leaves[0].type == token.NAME and
443             self.leaves[0].value in FLOW_CONTROL
444         )
445
446     @property
447     def is_yield(self) -> bool:
448         return (
449             bool(self) and
450             self.leaves[0].type == token.NAME and
451             self.leaves[0].value == 'yield'
452         )
453
454     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
455         if not (
456             self.leaves and
457             self.leaves[-1].type == token.COMMA and
458             closing.type in CLOSING_BRACKETS
459         ):
460             return False
461
462         if closing.type == token.RSQB or closing.type == token.RBRACE:
463             self.leaves.pop()
464             return True
465
466         # For parens let's check if it's safe to remove the comma.  If the
467         # trailing one is the only one, we might mistakenly change a tuple
468         # into a different type by removing the comma.
469         depth = closing.bracket_depth + 1  # type: ignore
470         commas = 0
471         opening = closing.opening_bracket  # type: ignore
472         for _opening_index, leaf in enumerate(self.leaves):
473             if leaf is opening:
474                 break
475
476         else:
477             return False
478
479         for leaf in self.leaves[_opening_index + 1:]:
480             if leaf is closing:
481                 break
482
483             bracket_depth = leaf.bracket_depth  # type: ignore
484             if bracket_depth == depth and leaf.type == token.COMMA:
485                 commas += 1
486         if commas > 1:
487             self.leaves.pop()
488             return True
489
490         return False
491
492     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
493         """Hack a standalone comment to act as a trailing comment for line splitting.
494
495         If this line has brackets and a standalone `comment`, we need to adapt
496         it to be able to still reformat the line.
497
498         This is not perfect, the line to which the standalone comment gets
499         appended will appear "too long" when splitting.
500         """
501         if not (
502             comment.type == STANDALONE_COMMENT and
503             self.bracket_tracker.any_open_brackets()
504         ):
505             return False
506
507         comment.type = token.COMMENT
508         comment.prefix = '\n' + '    ' * (self.depth + 1)
509         return self.append_comment(comment)
510
511     def append_comment(self, comment: Leaf) -> bool:
512         if comment.type != token.COMMENT:
513             return False
514
515         try:
516             after = id(self.last_non_delimiter())
517         except LookupError:
518             comment.type = STANDALONE_COMMENT
519             comment.prefix = ''
520             return False
521
522         else:
523             if after in self.comments:
524                 self.comments[after].value += str(comment)
525             else:
526                 self.comments[after] = comment
527             return True
528
529     def last_non_delimiter(self) -> Leaf:
530         for i in range(len(self.leaves)):
531             last = self.leaves[-i - 1]
532             if not is_delimiter(last):
533                 return last
534
535         raise LookupError("No non-delimiters found")
536
537     def __str__(self) -> str:
538         if not self:
539             return '\n'
540
541         indent = '    ' * self.depth
542         leaves = iter(self.leaves)
543         first = next(leaves)
544         res = f'{first.prefix}{indent}{first.value}'
545         for leaf in leaves:
546             res += str(leaf)
547         for comment in self.comments.values():
548             res += str(comment)
549         return res + '\n'
550
551     def __bool__(self) -> bool:
552         return bool(self.leaves or self.comments)
553
554
555 @dataclass
556 class EmptyLineTracker:
557     """Provides a stateful method that returns the number of potential extra
558     empty lines needed before and after the currently processed line.
559
560     Note: this tracker works on lines that haven't been split yet.
561     """
562     previous_line: Optional[Line] = attrib(default=None)
563     previous_after: int = attrib(default=0)
564     previous_defs: List[int] = attrib(default=Factory(list))
565
566     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
567         """Returns the number of extra empty lines before and after the `current_line`.
568
569         This is for separating `def`, `async def` and `class` with extra empty lines
570         (two on module-level), as well as providing an extra empty line after flow
571         control keywords to make them more prominent.
572         """
573         before, after = self._maybe_empty_lines(current_line)
574         self.previous_after = after
575         self.previous_line = current_line
576         return before, after
577
578     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
579         before = 0
580         depth = current_line.depth
581         while self.previous_defs and self.previous_defs[-1] >= depth:
582             self.previous_defs.pop()
583             before = (1 if depth else 2) - self.previous_after
584         is_decorator = current_line.is_decorator
585         if is_decorator or current_line.is_def or current_line.is_class:
586             if not is_decorator:
587                 self.previous_defs.append(depth)
588             if self.previous_line is None:
589                 # Don't insert empty lines before the first line in the file.
590                 return 0, 0
591
592             if self.previous_line and self.previous_line.is_decorator:
593                 # Don't insert empty lines between decorators.
594                 return 0, 0
595
596             newlines = 2
597             if current_line.depth:
598                 newlines -= 1
599             newlines -= self.previous_after
600             return newlines, 0
601
602         if current_line.is_flow_control:
603             return before, 1
604
605         if (
606             self.previous_line and
607             self.previous_line.is_import and
608             not current_line.is_import and
609             depth == self.previous_line.depth
610         ):
611             return (before or 1), 0
612
613         if (
614             self.previous_line and
615             self.previous_line.is_yield and
616             (not current_line.is_yield or depth != self.previous_line.depth)
617         ):
618             return (before or 1), 0
619
620         return before, 0
621
622
623 @dataclass
624 class LineGenerator(Visitor[Line]):
625     """Generates reformatted Line objects.  Empty lines are not emitted.
626
627     Note: destroys the tree it's visiting by mutating prefixes of its leaves
628     in ways that will no longer stringify to valid Python code on the tree.
629     """
630     current_line: Line = attrib(default=Factory(Line))
631     standalone_comments: List[Leaf] = attrib(default=Factory(list))
632
633     def line(self, indent: int = 0) -> Iterator[Line]:
634         """Generate a line.
635
636         If the line is empty, only emit if it makes sense.
637         If the line is too long, split it first and then generate.
638
639         If any lines were generated, set up a new current_line.
640         """
641         if not self.current_line:
642             self.current_line.depth += indent
643             return  # Line is empty, don't emit. Creating a new one unnecessary.
644
645         complete_line = self.current_line
646         self.current_line = Line(depth=complete_line.depth + indent)
647         yield complete_line
648
649     def visit_default(self, node: LN) -> Iterator[Line]:
650         if isinstance(node, Leaf):
651             for comment in generate_comments(node):
652                 if self.current_line.bracket_tracker.any_open_brackets():
653                     # any comment within brackets is subject to splitting
654                     self.current_line.append(comment)
655                 elif comment.type == token.COMMENT:
656                     # regular trailing comment
657                     self.current_line.append(comment)
658                     yield from self.line()
659
660                 else:
661                     # regular standalone comment, to be processed later (see
662                     # docstring in `generate_comments()`
663                     self.standalone_comments.append(comment)
664             normalize_prefix(node)
665             if node.type not in WHITESPACE:
666                 for comment in self.standalone_comments:
667                     yield from self.line()
668
669                     self.current_line.append(comment)
670                     yield from self.line()
671
672                 self.standalone_comments = []
673                 self.current_line.append(node)
674         yield from super().visit_default(node)
675
676     def visit_suite(self, node: Node) -> Iterator[Line]:
677         """Body of a statement after a colon."""
678         children = iter(node.children)
679         # Process newline before indenting.  It might contain an inline
680         # comment that should go right after the colon.
681         newline = next(children)
682         yield from self.visit(newline)
683         yield from self.line(+1)
684
685         for child in children:
686             yield from self.visit(child)
687
688         yield from self.line(-1)
689
690     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
691         """Visit a statement.
692
693         The relevant Python language keywords for this statement are NAME leaves
694         within it.
695         """
696         for child in node.children:
697             if child.type == token.NAME and child.value in keywords:  # type: ignore
698                 yield from self.line()
699
700             yield from self.visit(child)
701
702     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
703         """A statement without nested statements."""
704         is_suite_like = node.parent and node.parent.type in STATEMENT
705         if is_suite_like:
706             yield from self.line(+1)
707             yield from self.visit_default(node)
708             yield from self.line(-1)
709
710         else:
711             yield from self.line()
712             yield from self.visit_default(node)
713
714     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
715         yield from self.line()
716
717         children = iter(node.children)
718         for child in children:
719             yield from self.visit(child)
720
721             if child.type == token.NAME and child.value == 'async':  # type: ignore
722                 break
723
724         internal_stmt = next(children)
725         for child in internal_stmt.children:
726             yield from self.visit(child)
727
728     def visit_decorators(self, node: Node) -> Iterator[Line]:
729         for child in node.children:
730             yield from self.line()
731             yield from self.visit(child)
732
733     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
734         yield from self.line()
735
736     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
737         yield from self.visit_default(leaf)
738         yield from self.line()
739
740     def __attrs_post_init__(self) -> None:
741         """You are in a twisty little maze of passages."""
742         v = self.visit_stmt
743         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
744         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
745         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
746         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
747         self.visit_except_clause = partial(v, keywords={'except'})
748         self.visit_funcdef = partial(v, keywords={'def'})
749         self.visit_with_stmt = partial(v, keywords={'with'})
750         self.visit_classdef = partial(v, keywords={'class'})
751         self.visit_async_funcdef = self.visit_async_stmt
752         self.visit_decorated = self.visit_decorators
753
754
755 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
756 OPENING_BRACKETS = set(BRACKET.keys())
757 CLOSING_BRACKETS = set(BRACKET.values())
758 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
759
760
761 def whitespace(leaf: Leaf) -> str:
762     """Return whitespace prefix if needed for the given `leaf`."""
763     NO = ''
764     SPACE = ' '
765     DOUBLESPACE = '  '
766     t = leaf.type
767     p = leaf.parent
768     v = leaf.value
769     if t == token.COLON:
770         return NO
771
772     if t == token.COMMA:
773         return NO
774
775     if t == token.RPAR:
776         return NO
777
778     if t == token.COMMENT:
779         return DOUBLESPACE
780
781     if t == STANDALONE_COMMENT:
782         return NO
783
784     if t in CLOSING_BRACKETS:
785         return NO
786
787     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
788     prev = leaf.prev_sibling
789     if not prev:
790         prevp = preceding_leaf(p)
791         if not prevp or prevp.type in OPENING_BRACKETS:
792             return NO
793
794         if prevp.type == token.EQUAL:
795             if prevp.parent and prevp.parent.type in {
796                 syms.typedargslist,
797                 syms.varargslist,
798                 syms.parameters,
799                 syms.arglist,
800                 syms.argument,
801             }:
802                 return NO
803
804         elif prevp.type == token.DOUBLESTAR:
805             if prevp.parent and prevp.parent.type in {
806                 syms.typedargslist,
807                 syms.varargslist,
808                 syms.parameters,
809                 syms.arglist,
810                 syms.dictsetmaker,
811             }:
812                 return NO
813
814         elif prevp.type == token.COLON:
815             if prevp.parent and prevp.parent.type == syms.subscript:
816                 return NO
817
818         elif prevp.parent and prevp.parent.type == syms.factor:
819             return NO
820
821     elif prev.type in OPENING_BRACKETS:
822         return NO
823
824     if p.type in {syms.parameters, syms.arglist}:
825         # untyped function signatures or calls
826         if t == token.RPAR:
827             return NO
828
829         if not prev or prev.type != token.COMMA:
830             return NO
831
832     if p.type == syms.varargslist:
833         # lambdas
834         if t == token.RPAR:
835             return NO
836
837         if prev and prev.type != token.COMMA:
838             return NO
839
840     elif p.type == syms.typedargslist:
841         # typed function signatures
842         if not prev:
843             return NO
844
845         if t == token.EQUAL:
846             if prev.type != syms.tname:
847                 return NO
848
849         elif prev.type == token.EQUAL:
850             # A bit hacky: if the equal sign has whitespace, it means we
851             # previously found it's a typed argument.  So, we're using that, too.
852             return prev.prefix
853
854         elif prev.type != token.COMMA:
855             return NO
856
857     elif p.type == syms.tname:
858         # type names
859         if not prev:
860             prevp = preceding_leaf(p)
861             if not prevp or prevp.type != token.COMMA:
862                 return NO
863
864     elif p.type == syms.trailer:
865         # attributes and calls
866         if t == token.LPAR or t == token.RPAR:
867             return NO
868
869         if not prev:
870             if t == token.DOT:
871                 prevp = preceding_leaf(p)
872                 if not prevp or prevp.type != token.NUMBER:
873                     return NO
874
875             elif t == token.LSQB:
876                 return NO
877
878         elif prev.type != token.COMMA:
879             return NO
880
881     elif p.type == syms.argument:
882         # single argument
883         if t == token.EQUAL:
884             return NO
885
886         if not prev:
887             prevp = preceding_leaf(p)
888             if not prevp or prevp.type == token.LPAR:
889                 return NO
890
891         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
892             return NO
893
894     elif p.type == syms.decorator:
895         # decorators
896         return NO
897
898     elif p.type == syms.dotted_name:
899         if prev:
900             return NO
901
902         prevp = preceding_leaf(p)
903         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
904             return NO
905
906     elif p.type == syms.classdef:
907         if t == token.LPAR:
908             return NO
909
910         if prev and prev.type == token.LPAR:
911             return NO
912
913     elif p.type == syms.subscript:
914         # indexing
915         if not prev or prev.type == token.COLON:
916             return NO
917
918     elif p.type == syms.atom:
919         if prev and t == token.DOT:
920             # dots, but not the first one.
921             return NO
922
923     elif (
924         p.type == syms.listmaker or
925         p.type == syms.testlist_gexp or
926         p.type == syms.subscriptlist
927     ):
928         # list interior, including unpacking
929         if not prev:
930             return NO
931
932     elif p.type == syms.dictsetmaker:
933         # dict and set interior, including unpacking
934         if not prev:
935             return NO
936
937         if prev.type == token.DOUBLESTAR:
938             return NO
939
940     elif p.type == syms.factor or p.type == syms.star_expr:
941         # unary ops
942         if not prev:
943             prevp = preceding_leaf(p)
944             if not prevp or prevp.type in OPENING_BRACKETS:
945                 return NO
946
947             prevp_parent = prevp.parent
948             assert prevp_parent is not None
949             if prevp.type == token.COLON and prevp_parent.type in {
950                 syms.subscript, syms.sliceop
951             }:
952                 return NO
953
954             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
955                 return NO
956
957         elif t == token.NAME or t == token.NUMBER:
958             return NO
959
960     elif p.type == syms.import_from:
961         if t == token.DOT:
962             if prev and prev.type == token.DOT:
963                 return NO
964
965         elif t == token.NAME:
966             if v == 'import':
967                 return SPACE
968
969             if prev and prev.type == token.DOT:
970                 return NO
971
972     elif p.type == syms.sliceop:
973         return NO
974
975     return SPACE
976
977
978 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
979     """Returns the first leaf that precedes `node`, if any."""
980     while node:
981         res = node.prev_sibling
982         if res:
983             if isinstance(res, Leaf):
984                 return res
985
986             try:
987                 return list(res.leaves())[-1]
988
989             except IndexError:
990                 return None
991
992         node = node.parent
993     return None
994
995
996 def is_delimiter(leaf: Leaf) -> int:
997     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
998
999     Higher numbers are higher priority.
1000     """
1001     if leaf.type == token.COMMA:
1002         return COMMA_PRIORITY
1003
1004     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1005         return LOGIC_PRIORITY
1006
1007     if leaf.type in COMPARATORS:
1008         return COMPARATOR_PRIORITY
1009
1010     if (
1011         leaf.type in MATH_OPERATORS and
1012         leaf.parent and
1013         leaf.parent.type not in {syms.factor, syms.star_expr}
1014     ):
1015         return MATH_PRIORITY
1016
1017     return 0
1018
1019
1020 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1021     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1022
1023     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1024     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1025     move because it does away with modifying the grammar to include all the
1026     possible places in which comments can be placed.
1027
1028     The sad consequence for us though is that comments don't "belong" anywhere.
1029     This is why this function generates simple parentless Leaf objects for
1030     comments.  We simply don't know what the correct parent should be.
1031
1032     No matter though, we can live without this.  We really only need to
1033     differentiate between inline and standalone comments.  The latter don't
1034     share the line with any code.
1035
1036     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1037     are emitted with a fake STANDALONE_COMMENT token identifier.
1038     """
1039     if not leaf.prefix:
1040         return
1041
1042     if '#' not in leaf.prefix:
1043         return
1044
1045     before_comment, content = leaf.prefix.split('#', 1)
1046     content = content.rstrip()
1047     if content and (content[0] not in {' ', '!', '#'}):
1048         content = ' ' + content
1049     is_standalone_comment = (
1050         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1051     )
1052     if not is_standalone_comment:
1053         # simple trailing comment
1054         yield Leaf(token.COMMENT, value='#' + content)
1055         return
1056
1057     for line in ('#' + content).split('\n'):
1058         line = line.lstrip()
1059         if not line.startswith('#'):
1060             continue
1061
1062         yield Leaf(STANDALONE_COMMENT, line)
1063
1064
1065 def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]:
1066     """Splits a `line` into potentially many lines.
1067
1068     They should fit in the allotted `line_length` but might not be able to.
1069     `inner` signifies that there were a pair of brackets somewhere around the
1070     current `line`, possibly transitively. This means we can fallback to splitting
1071     by delimiters if the LHS/RHS don't yield any results.
1072     """
1073     line_str = str(line).strip('\n')
1074     if len(line_str) <= line_length and '\n' not in line_str:
1075         yield line
1076         return
1077
1078     if line.is_def:
1079         split_funcs = [left_hand_split]
1080     elif line.inside_brackets:
1081         split_funcs = [delimiter_split]
1082         if '\n' not in line_str:
1083             # Only attempt RHS if we don't have multiline strings or comments
1084             # on this line.
1085             split_funcs.append(right_hand_split)
1086     else:
1087         split_funcs = [right_hand_split]
1088     for split_func in split_funcs:
1089         # We are accumulating lines in `result` because we might want to abort
1090         # mission and return the original line in the end, or attempt a different
1091         # split altogether.
1092         result: List[Line] = []
1093         try:
1094             for l in split_func(line):
1095                 if str(l).strip('\n') == line_str:
1096                     raise CannotSplit("Split function returned an unchanged result")
1097
1098                 result.extend(split_line(l, line_length=line_length, inner=True))
1099         except CannotSplit as cs:
1100             continue
1101
1102         else:
1103             yield from result
1104             break
1105
1106     else:
1107         yield line
1108
1109
1110 def left_hand_split(line: Line) -> Iterator[Line]:
1111     """Split line into many lines, starting with the first matching bracket pair.
1112
1113     Note: this usually looks weird, only use this for function definitions.
1114     Prefer RHS otherwise.
1115     """
1116     head = Line(depth=line.depth)
1117     body = Line(depth=line.depth + 1, inside_brackets=True)
1118     tail = Line(depth=line.depth)
1119     tail_leaves: List[Leaf] = []
1120     body_leaves: List[Leaf] = []
1121     head_leaves: List[Leaf] = []
1122     current_leaves = head_leaves
1123     matching_bracket = None
1124     for leaf in line.leaves:
1125         if (
1126             current_leaves is body_leaves and
1127             leaf.type in CLOSING_BRACKETS and
1128             leaf.opening_bracket is matching_bracket  # type: ignore
1129         ):
1130             current_leaves = tail_leaves
1131         current_leaves.append(leaf)
1132         if current_leaves is head_leaves:
1133             if leaf.type in OPENING_BRACKETS:
1134                 matching_bracket = leaf
1135                 current_leaves = body_leaves
1136     # Since body is a new indent level, remove spurious leading whitespace.
1137     if body_leaves:
1138         normalize_prefix(body_leaves[0])
1139     # Build the new lines.
1140     for result, leaves in (
1141         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1142     ):
1143         for leaf in leaves:
1144             result.append(leaf, preformatted=True)
1145             comment_after = line.comments.get(id(leaf))
1146             if comment_after:
1147                 result.append(comment_after, preformatted=True)
1148     # Check if the split succeeded.
1149     tail_len = len(str(tail))
1150     if not body:
1151         if tail_len == 0:
1152             raise CannotSplit("Splitting brackets produced the same line")
1153
1154         elif tail_len < 3:
1155             raise CannotSplit(
1156                 f"Splitting brackets on an empty body to save "
1157                 f"{tail_len} characters is not worth it"
1158             )
1159
1160     for result in (head, body, tail):
1161         if result:
1162             yield result
1163
1164
1165 def right_hand_split(line: Line) -> Iterator[Line]:
1166     """Split line into many lines, starting with the last matching bracket pair."""
1167     head = Line(depth=line.depth)
1168     body = Line(depth=line.depth + 1, inside_brackets=True)
1169     tail = Line(depth=line.depth)
1170     tail_leaves: List[Leaf] = []
1171     body_leaves: List[Leaf] = []
1172     head_leaves: List[Leaf] = []
1173     current_leaves = tail_leaves
1174     opening_bracket = None
1175     for leaf in reversed(line.leaves):
1176         if current_leaves is body_leaves:
1177             if leaf is opening_bracket:
1178                 current_leaves = head_leaves
1179         current_leaves.append(leaf)
1180         if current_leaves is tail_leaves:
1181             if leaf.type in CLOSING_BRACKETS:
1182                 opening_bracket = leaf.opening_bracket  # type: ignore
1183                 current_leaves = body_leaves
1184     tail_leaves.reverse()
1185     body_leaves.reverse()
1186     head_leaves.reverse()
1187     # Since body is a new indent level, remove spurious leading whitespace.
1188     if body_leaves:
1189         normalize_prefix(body_leaves[0])
1190     # Build the new lines.
1191     for result, leaves in (
1192         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1193     ):
1194         for leaf in leaves:
1195             result.append(leaf, preformatted=True)
1196             comment_after = line.comments.get(id(leaf))
1197             if comment_after:
1198                 result.append(comment_after, preformatted=True)
1199     # Check if the split succeeded.
1200     tail_len = len(str(tail).strip('\n'))
1201     if not body:
1202         if tail_len == 0:
1203             raise CannotSplit("Splitting brackets produced the same line")
1204
1205         elif tail_len < 3:
1206             raise CannotSplit(
1207                 f"Splitting brackets on an empty body to save "
1208                 f"{tail_len} characters is not worth it"
1209             )
1210
1211     for result in (head, body, tail):
1212         if result:
1213             yield result
1214
1215
1216 def delimiter_split(line: Line) -> Iterator[Line]:
1217     """Split according to delimiters of the highest priority.
1218
1219     This kind of split doesn't increase indentation.
1220     """
1221     try:
1222         last_leaf = line.leaves[-1]
1223     except IndexError:
1224         raise CannotSplit("Line empty")
1225
1226     delimiters = line.bracket_tracker.delimiters
1227     try:
1228         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1229     except ValueError:
1230         raise CannotSplit("No delimiters found")
1231
1232     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1233     for leaf in line.leaves:
1234         current_line.append(leaf, preformatted=True)
1235         comment_after = line.comments.get(id(leaf))
1236         if comment_after:
1237             current_line.append(comment_after, preformatted=True)
1238         leaf_priority = delimiters.get(id(leaf))
1239         if leaf_priority == delimiter_priority:
1240             normalize_prefix(current_line.leaves[0])
1241             yield current_line
1242
1243             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1244     if current_line:
1245         if (
1246             delimiter_priority == COMMA_PRIORITY and
1247             current_line.leaves[-1].type != token.COMMA
1248         ):
1249             current_line.append(Leaf(token.COMMA, ','))
1250         normalize_prefix(current_line.leaves[0])
1251         yield current_line
1252
1253
1254 def is_import(leaf: Leaf) -> bool:
1255     """Returns True if the given leaf starts an import statement."""
1256     p = leaf.parent
1257     t = leaf.type
1258     v = leaf.value
1259     return bool(
1260         t == token.NAME and
1261         (
1262             (v == 'import' and p and p.type == syms.import_name) or
1263             (v == 'from' and p and p.type == syms.import_from)
1264         )
1265     )
1266
1267
1268 def normalize_prefix(leaf: Leaf) -> None:
1269     """Leave existing extra newlines for imports.  Remove everything else."""
1270     if is_import(leaf):
1271         spl = leaf.prefix.split('#', 1)
1272         nl_count = spl[0].count('\n')
1273         if len(spl) > 1:
1274             # Skip one newline since it was for a standalone comment.
1275             nl_count -= 1
1276         leaf.prefix = '\n' * nl_count
1277         return
1278
1279     leaf.prefix = ''
1280
1281
1282 PYTHON_EXTENSIONS = {'.py'}
1283 BLACKLISTED_DIRECTORIES = {
1284     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1285 }
1286
1287
1288 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1289     for child in path.iterdir():
1290         if child.is_dir():
1291             if child.name in BLACKLISTED_DIRECTORIES:
1292                 continue
1293
1294             yield from gen_python_files_in_dir(child)
1295
1296         elif child.suffix in PYTHON_EXTENSIONS:
1297             yield child
1298
1299
1300 @dataclass
1301 class Report:
1302     """Provides a reformatting counter."""
1303     change_count: int = attrib(default=0)
1304     same_count: int = attrib(default=0)
1305     failure_count: int = attrib(default=0)
1306
1307     def done(self, src: Path, changed: bool) -> None:
1308         """Increment the counter for successful reformatting. Write out a message."""
1309         if changed:
1310             out(f'reformatted {src}')
1311             self.change_count += 1
1312         else:
1313             out(f'{src} already well formatted, good job.', bold=False)
1314             self.same_count += 1
1315
1316     def failed(self, src: Path, message: str) -> None:
1317         """Increment the counter for failed reformatting. Write out a message."""
1318         err(f'error: cannot format {src}: {message}')
1319         self.failure_count += 1
1320
1321     @property
1322     def return_code(self) -> int:
1323         """Which return code should the app use considering the current state."""
1324         return 1 if self.failure_count else 0
1325
1326     def __str__(self) -> str:
1327         """A color report of the current state.
1328
1329         Use `click.unstyle` to remove colors.
1330         """
1331         report = []
1332         if self.change_count:
1333             s = 's' if self.change_count > 1 else ''
1334             report.append(
1335                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1336             )
1337         if self.same_count:
1338             s = 's' if self.same_count > 1 else ''
1339             report.append(f'{self.same_count} file{s} left unchanged')
1340         if self.failure_count:
1341             s = 's' if self.failure_count > 1 else ''
1342             report.append(
1343                 click.style(
1344                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1345                 )
1346             )
1347         return ', '.join(report) + '.'
1348
1349
1350 def assert_equivalent(src: str, dst: str) -> None:
1351     """Raises AssertionError if `src` and `dst` aren't equivalent.
1352
1353     This is a temporary sanity check until Black becomes stable.
1354     """
1355
1356     import ast
1357     import traceback
1358
1359     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1360         """Simple visitor generating strings to compare ASTs by content."""
1361         yield f"{'  ' * depth}{node.__class__.__name__}("
1362
1363         for field in sorted(node._fields):
1364             try:
1365                 value = getattr(node, field)
1366             except AttributeError:
1367                 continue
1368
1369             yield f"{'  ' * (depth+1)}{field}="
1370
1371             if isinstance(value, list):
1372                 for item in value:
1373                     if isinstance(item, ast.AST):
1374                         yield from _v(item, depth + 2)
1375
1376             elif isinstance(value, ast.AST):
1377                 yield from _v(value, depth + 2)
1378
1379             else:
1380                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1381
1382         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1383
1384     try:
1385         src_ast = ast.parse(src)
1386     except Exception as exc:
1387         raise AssertionError(f"cannot parse source: {exc}") from None
1388
1389     try:
1390         dst_ast = ast.parse(dst)
1391     except Exception as exc:
1392         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1393         raise AssertionError(
1394             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1395             f"Please report a bug on https://github.com/ambv/black/issues.  "
1396             f"This invalid output might be helpful: {log}",
1397         ) from None
1398
1399     src_ast_str = '\n'.join(_v(src_ast))
1400     dst_ast_str = '\n'.join(_v(dst_ast))
1401     if src_ast_str != dst_ast_str:
1402         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1403         raise AssertionError(
1404             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1405             f"the source.  "
1406             f"Please report a bug on https://github.com/ambv/black/issues.  "
1407             f"This diff might be helpful: {log}",
1408         ) from None
1409
1410
1411 def assert_stable(src: str, dst: str, line_length: int) -> None:
1412     """Raises AssertionError if `dst` reformats differently the second time.
1413
1414     This is a temporary sanity check until Black becomes stable.
1415     """
1416     newdst = format_str(dst, line_length=line_length)
1417     if dst != newdst:
1418         log = dump_to_file(
1419             diff(src, dst, 'source', 'first pass'),
1420             diff(dst, newdst, 'first pass', 'second pass'),
1421         )
1422         raise AssertionError(
1423             f"INTERNAL ERROR: Black produced different code on the second pass "
1424             f"of the formatter.  "
1425             f"Please report a bug on https://github.com/ambv/black/issues.  "
1426             f"This diff might be helpful: {log}",
1427         ) from None
1428
1429
1430 def dump_to_file(*output: str) -> str:
1431     """Dumps `output` to a temporary file. Returns path to the file."""
1432     import tempfile
1433
1434     with tempfile.NamedTemporaryFile(
1435         mode='w', prefix='blk_', suffix='.log', delete=False
1436     ) as f:
1437         for lines in output:
1438             f.write(lines)
1439             f.write('\n')
1440     return f.name
1441
1442
1443 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1444     """Returns a udiff string between strings `a` and `b`."""
1445     import difflib
1446
1447     a_lines = [line + '\n' for line in a.split('\n')]
1448     b_lines = [line + '\n' for line in b.split('\n')]
1449     return ''.join(
1450         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1451     )
1452
1453
1454 if __name__ == '__main__':
1455     main()