]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix invalid spacing of dots in relative imports
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 from typing import (
11     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
12 )
13
14 from attr import attrib, dataclass, Factory
15 import click
16
17 # lib2to3 fork
18 from blib2to3.pytree import Node, Leaf, type_repr
19 from blib2to3 import pygram, pytree
20 from blib2to3.pgen2 import driver, token
21 from blib2to3.pgen2.parse import ParseError
22
23 __version__ = "18.3a1"
24 DEFAULT_LINE_LENGTH = 88
25 # types
26 syms = pygram.python_symbols
27 FileContent = str
28 Encoding = str
29 Depth = int
30 NodeType = int
31 LeafID = int
32 Priority = int
33 LN = Union[Leaf, Node]
34 out = partial(click.secho, bold=True, err=True)
35 err = partial(click.secho, fg='red', err=True)
36
37
38 class NothingChanged(UserWarning):
39     """Raised by `format_file` when the reformatted code is the same as source."""
40
41
42 class CannotSplit(Exception):
43     """A readable split that fits the allotted line length is impossible.
44
45     Raised by `left_hand_split()` and `right_hand_split()`.
46     """
47
48
49 @click.command()
50 @click.option(
51     '-l',
52     '--line-length',
53     type=int,
54     default=DEFAULT_LINE_LENGTH,
55     help='How many character per line to allow.',
56     show_default=True,
57 )
58 @click.option(
59     '--fast/--safe',
60     is_flag=True,
61     help='If --fast given, skip temporary sanity checks. [default: --safe]',
62 )
63 @click.version_option(version=__version__)
64 @click.argument(
65     'src',
66     nargs=-1,
67     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
68 )
69 @click.pass_context
70 def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> None:
71     """The uncompromising code formatter."""
72     sources: List[Path] = []
73     for s in src:
74         p = Path(s)
75         if p.is_dir():
76             sources.extend(gen_python_files_in_dir(p))
77         elif p.is_file():
78             # if a file was explicitly given, we don't care about its extension
79             sources.append(p)
80         else:
81             err(f'invalid path: {s}')
82     if len(sources) == 0:
83         ctx.exit(0)
84     elif len(sources) == 1:
85         p = sources[0]
86         report = Report()
87         try:
88             changed = format_file_in_place(p, line_length=line_length, fast=fast)
89             report.done(p, changed)
90         except Exception as exc:
91             report.failed(p, str(exc))
92         ctx.exit(report.return_code)
93     else:
94         loop = asyncio.get_event_loop()
95         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
96         return_code = 1
97         try:
98             return_code = loop.run_until_complete(
99                 schedule_formatting(sources, line_length, fast, loop, executor)
100             )
101         finally:
102             loop.close()
103             ctx.exit(return_code)
104
105
106 async def schedule_formatting(
107     sources: List[Path],
108     line_length: int,
109     fast: bool,
110     loop: BaseEventLoop,
111     executor: Executor,
112 ) -> int:
113     tasks = {
114         src: loop.run_in_executor(
115             executor, format_file_in_place, src, line_length, fast
116         )
117         for src in sources
118     }
119     await asyncio.wait(tasks.values())
120     cancelled = []
121     report = Report()
122     for src, task in tasks.items():
123         if not task.done():
124             report.failed(src, 'timed out, cancelling')
125             task.cancel()
126             cancelled.append(task)
127         elif task.exception():
128             report.failed(src, str(task.exception()))
129         else:
130             report.done(src, task.result())
131     if cancelled:
132         await asyncio.wait(cancelled, timeout=2)
133     out('All done! ✨ 🍰 ✨')
134     click.echo(str(report))
135     return report.return_code
136
137
138 def format_file_in_place(src: Path, line_length: int, fast: bool) -> bool:
139     """Format the file and rewrite if changed. Return True if changed."""
140     try:
141         contents, encoding = format_file(src, line_length=line_length, fast=fast)
142     except NothingChanged:
143         return False
144
145     with open(src, "w", encoding=encoding) as f:
146         f.write(contents)
147     return True
148
149
150 def format_file(
151     src: Path, line_length: int, fast: bool
152 ) -> Tuple[FileContent, Encoding]:
153     """Reformats a file and returns its contents and encoding."""
154     with tokenize.open(src) as src_buffer:
155         src_contents = src_buffer.read()
156     if src_contents.strip() == '':
157         raise NothingChanged(src)
158
159     dst_contents = format_str(src_contents, line_length=line_length)
160     if src_contents == dst_contents:
161         raise NothingChanged(src)
162
163     if not fast:
164         assert_equivalent(src_contents, dst_contents)
165         assert_stable(src_contents, dst_contents, line_length=line_length)
166     return dst_contents, src_buffer.encoding
167
168
169 def format_str(src_contents: str, line_length: int) -> FileContent:
170     """Reformats a string and returns new contents."""
171     src_node = lib2to3_parse(src_contents)
172     dst_contents = ""
173     comments: List[Line] = []
174     lines = LineGenerator()
175     elt = EmptyLineTracker()
176     empty_line = Line()
177     after = 0
178     for current_line in lines.visit(src_node):
179         for _ in range(after):
180             dst_contents += str(empty_line)
181         before, after = elt.maybe_empty_lines(current_line)
182         for _ in range(before):
183             dst_contents += str(empty_line)
184         if not current_line.is_comment:
185             for comment in comments:
186                 dst_contents += str(comment)
187             comments = []
188             for line in split_line(current_line, line_length=line_length):
189                 dst_contents += str(line)
190         else:
191             comments.append(current_line)
192     for comment in comments:
193         dst_contents += str(comment)
194     return dst_contents
195
196
197 def lib2to3_parse(src_txt: str) -> Node:
198     """Given a string with source, return the lib2to3 Node."""
199     grammar = pygram.python_grammar_no_print_statement
200     drv = driver.Driver(grammar, pytree.convert)
201     if src_txt[-1] != '\n':
202         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
203         src_txt += nl
204     try:
205         result = drv.parse_string(src_txt, True)
206     except ParseError as pe:
207         lineno, column = pe.context[1]
208         lines = src_txt.splitlines()
209         try:
210             faulty_line = lines[lineno - 1]
211         except IndexError:
212             faulty_line = "<line number missing in source>"
213         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
214
215     if isinstance(result, Leaf):
216         result = Node(syms.file_input, [result])
217     return result
218
219
220 def lib2to3_unparse(node: Node) -> str:
221     """Given a lib2to3 node, return its string representation."""
222     code = str(node)
223     return code
224
225
226 T = TypeVar('T')
227
228
229 class Visitor(Generic[T]):
230     """Basic lib2to3 visitor that yields things on visiting."""
231
232     def visit(self, node: LN) -> Iterator[T]:
233         if node.type < 256:
234             name = token.tok_name[node.type]
235         else:
236             name = type_repr(node.type)
237         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
238
239     def visit_default(self, node: LN) -> Iterator[T]:
240         if isinstance(node, Node):
241             for child in node.children:
242                 yield from self.visit(child)
243
244
245 @dataclass
246 class DebugVisitor(Visitor[T]):
247     tree_depth: int = attrib(default=0)
248
249     def visit_default(self, node: LN) -> Iterator[T]:
250         indent = ' ' * (2 * self.tree_depth)
251         if isinstance(node, Node):
252             _type = type_repr(node.type)
253             out(f'{indent}{_type}', fg='yellow')
254             self.tree_depth += 1
255             for child in node.children:
256                 yield from self.visit(child)
257
258             self.tree_depth -= 1
259             out(f'{indent}/{_type}', fg='yellow', bold=False)
260         else:
261             _type = token.tok_name.get(node.type, str(node.type))
262             out(f'{indent}{_type}', fg='blue', nl=False)
263             if node.prefix:
264                 # We don't have to handle prefixes for `Node` objects since
265                 # that delegates to the first child anyway.
266                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
267             out(f' {node.value!r}', fg='blue', bold=False)
268
269
270 KEYWORDS = set(keyword.kwlist)
271 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
272 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
273 STATEMENT = {
274     syms.if_stmt,
275     syms.while_stmt,
276     syms.for_stmt,
277     syms.try_stmt,
278     syms.except_clause,
279     syms.with_stmt,
280     syms.funcdef,
281     syms.classdef,
282 }
283 STANDALONE_COMMENT = 153
284 LOGIC_OPERATORS = {'and', 'or'}
285 COMPARATORS = {
286     token.LESS,
287     token.GREATER,
288     token.EQEQUAL,
289     token.NOTEQUAL,
290     token.LESSEQUAL,
291     token.GREATEREQUAL,
292 }
293 MATH_OPERATORS = {
294     token.PLUS,
295     token.MINUS,
296     token.STAR,
297     token.SLASH,
298     token.VBAR,
299     token.AMPER,
300     token.PERCENT,
301     token.CIRCUMFLEX,
302     token.LEFTSHIFT,
303     token.RIGHTSHIFT,
304     token.DOUBLESTAR,
305     token.DOUBLESLASH,
306 }
307 COMPREHENSION_PRIORITY = 20
308 COMMA_PRIORITY = 10
309 LOGIC_PRIORITY = 5
310 STRING_PRIORITY = 4
311 COMPARATOR_PRIORITY = 3
312 MATH_PRIORITY = 1
313
314
315 @dataclass
316 class BracketTracker:
317     depth: int = attrib(default=0)
318     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
319     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
320     previous: Optional[Leaf] = attrib(default=None)
321
322     def mark(self, leaf: Leaf) -> None:
323         if leaf.type == token.COMMENT:
324             return
325
326         if leaf.type in CLOSING_BRACKETS:
327             self.depth -= 1
328             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
329             leaf.opening_bracket = opening_bracket  # type: ignore
330         leaf.bracket_depth = self.depth  # type: ignore
331         if self.depth == 0:
332             delim = is_delimiter(leaf)
333             if delim:
334                 self.delimiters[id(leaf)] = delim
335             elif self.previous is not None:
336                 if leaf.type == token.STRING and self.previous.type == token.STRING:
337                     self.delimiters[id(self.previous)] = STRING_PRIORITY
338                 elif (
339                     leaf.type == token.NAME and
340                     leaf.value == 'for' and
341                     leaf.parent and
342                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
343                 ):
344                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
345                 elif (
346                     leaf.type == token.NAME and
347                     leaf.value == 'if' and
348                     leaf.parent and
349                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
350                 ):
351                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
352         if leaf.type in OPENING_BRACKETS:
353             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
354             self.depth += 1
355         self.previous = leaf
356
357     def any_open_brackets(self) -> bool:
358         """Returns True if there is an yet unmatched open bracket on the line."""
359         return bool(self.bracket_match)
360
361     def max_priority(self, exclude: Iterable[LeafID] = ()) -> int:
362         """Returns the highest priority of a delimiter found on the line.
363
364         Values are consistent with what `is_delimiter()` returns.
365         """
366         return max(v for k, v in self.delimiters.items() if k not in exclude)
367
368
369 @dataclass
370 class Line:
371     depth: int = attrib(default=0)
372     leaves: List[Leaf] = attrib(default=Factory(list))
373     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
374     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
375     inside_brackets: bool = attrib(default=False)
376
377     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
378         has_value = leaf.value.strip()
379         if not has_value:
380             return
381
382         if self.leaves and not preformatted:
383             # Note: at this point leaf.prefix should be empty except for
384             # imports, for which we only preserve newlines.
385             leaf.prefix += whitespace(leaf)
386         if self.inside_brackets or not preformatted:
387             self.bracket_tracker.mark(leaf)
388             self.maybe_remove_trailing_comma(leaf)
389             if self.maybe_adapt_standalone_comment(leaf):
390                 return
391
392         if not self.append_comment(leaf):
393             self.leaves.append(leaf)
394
395     @property
396     def is_comment(self) -> bool:
397         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
398
399     @property
400     def is_decorator(self) -> bool:
401         return bool(self) and self.leaves[0].type == token.AT
402
403     @property
404     def is_import(self) -> bool:
405         return bool(self) and is_import(self.leaves[0])
406
407     @property
408     def is_class(self) -> bool:
409         return (
410             bool(self) and
411             self.leaves[0].type == token.NAME and
412             self.leaves[0].value == 'class'
413         )
414
415     @property
416     def is_def(self) -> bool:
417         """Also returns True for async defs."""
418         try:
419             first_leaf = self.leaves[0]
420         except IndexError:
421             return False
422
423         try:
424             second_leaf: Optional[Leaf] = self.leaves[1]
425         except IndexError:
426             second_leaf = None
427         return (
428             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
429             (
430                 first_leaf.type == token.NAME and
431                 first_leaf.value == 'async' and
432                 second_leaf is not None and
433                 second_leaf.type == token.NAME and
434                 second_leaf.value == 'def'
435             )
436         )
437
438     @property
439     def is_flow_control(self) -> bool:
440         return (
441             bool(self) and
442             self.leaves[0].type == token.NAME and
443             self.leaves[0].value in FLOW_CONTROL
444         )
445
446     @property
447     def is_yield(self) -> bool:
448         return (
449             bool(self) and
450             self.leaves[0].type == token.NAME and
451             self.leaves[0].value == 'yield'
452         )
453
454     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
455         if not (
456             self.leaves and
457             self.leaves[-1].type == token.COMMA and
458             closing.type in CLOSING_BRACKETS
459         ):
460             return False
461
462         if closing.type == token.RSQB or closing.type == token.RBRACE:
463             self.leaves.pop()
464             return True
465
466         # For parens let's check if it's safe to remove the comma.  If the
467         # trailing one is the only one, we might mistakenly change a tuple
468         # into a different type by removing the comma.
469         depth = closing.bracket_depth + 1  # type: ignore
470         commas = 0
471         opening = closing.opening_bracket  # type: ignore
472         for _opening_index, leaf in enumerate(self.leaves):
473             if leaf is opening:
474                 break
475
476         else:
477             return False
478
479         for leaf in self.leaves[_opening_index + 1:]:
480             if leaf is closing:
481                 break
482
483             bracket_depth = leaf.bracket_depth  # type: ignore
484             if bracket_depth == depth and leaf.type == token.COMMA:
485                 commas += 1
486         if commas > 1:
487             self.leaves.pop()
488             return True
489
490         return False
491
492     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
493         """Hack a standalone comment to act as a trailing comment for line splitting.
494
495         If this line has brackets and a standalone `comment`, we need to adapt
496         it to be able to still reformat the line.
497
498         This is not perfect, the line to which the standalone comment gets
499         appended will appear "too long" when splitting.
500         """
501         if not (
502             comment.type == STANDALONE_COMMENT and
503             self.bracket_tracker.any_open_brackets()
504         ):
505             return False
506
507         comment.type = token.COMMENT
508         comment.prefix = '\n' + '    ' * (self.depth + 1)
509         return self.append_comment(comment)
510
511     def append_comment(self, comment: Leaf) -> bool:
512         if comment.type != token.COMMENT:
513             return False
514
515         try:
516             after = id(self.last_non_delimiter())
517         except LookupError:
518             comment.type = STANDALONE_COMMENT
519             comment.prefix = ''
520             return False
521
522         else:
523             if after in self.comments:
524                 self.comments[after].value += str(comment)
525             else:
526                 self.comments[after] = comment
527             return True
528
529     def last_non_delimiter(self) -> Leaf:
530         for i in range(len(self.leaves)):
531             last = self.leaves[-i - 1]
532             if not is_delimiter(last):
533                 return last
534
535         raise LookupError("No non-delimiters found")
536
537     def __str__(self) -> str:
538         if not self:
539             return '\n'
540
541         indent = '    ' * self.depth
542         leaves = iter(self.leaves)
543         first = next(leaves)
544         res = f'{first.prefix}{indent}{first.value}'
545         for leaf in leaves:
546             res += str(leaf)
547         for comment in self.comments.values():
548             res += str(comment)
549         return res + '\n'
550
551     def __bool__(self) -> bool:
552         return bool(self.leaves or self.comments)
553
554
555 @dataclass
556 class EmptyLineTracker:
557     """Provides a stateful method that returns the number of potential extra
558     empty lines needed before and after the currently processed line.
559
560     Note: this tracker works on lines that haven't been split yet.
561     """
562     previous_line: Optional[Line] = attrib(default=None)
563     previous_after: int = attrib(default=0)
564     previous_defs: List[int] = attrib(default=Factory(list))
565
566     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
567         """Returns the number of extra empty lines before and after the `current_line`.
568
569         This is for separating `def`, `async def` and `class` with extra empty lines
570         (two on module-level), as well as providing an extra empty line after flow
571         control keywords to make them more prominent.
572         """
573         before, after = self._maybe_empty_lines(current_line)
574         self.previous_after = after
575         self.previous_line = current_line
576         return before, after
577
578     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
579         before = 0
580         depth = current_line.depth
581         while self.previous_defs and self.previous_defs[-1] >= depth:
582             self.previous_defs.pop()
583             before = (1 if depth else 2) - self.previous_after
584         is_decorator = current_line.is_decorator
585         if is_decorator or current_line.is_def or current_line.is_class:
586             if not is_decorator:
587                 self.previous_defs.append(depth)
588             if self.previous_line is None:
589                 # Don't insert empty lines before the first line in the file.
590                 return 0, 0
591
592             if self.previous_line and self.previous_line.is_decorator:
593                 # Don't insert empty lines between decorators.
594                 return 0, 0
595
596             newlines = 2
597             if current_line.depth:
598                 newlines -= 1
599             newlines -= self.previous_after
600             return newlines, 0
601
602         if current_line.is_flow_control:
603             return before, 1
604
605         if (
606             self.previous_line and
607             self.previous_line.is_import and
608             not current_line.is_import and
609             depth == self.previous_line.depth
610         ):
611             return (before or 1), 0
612
613         if (
614             self.previous_line and
615             self.previous_line.is_yield and
616             (not current_line.is_yield or depth != self.previous_line.depth)
617         ):
618             return (before or 1), 0
619
620         return before, 0
621
622
623 @dataclass
624 class LineGenerator(Visitor[Line]):
625     """Generates reformatted Line objects.  Empty lines are not emitted.
626
627     Note: destroys the tree it's visiting by mutating prefixes of its leaves
628     in ways that will no longer stringify to valid Python code on the tree.
629     """
630     current_line: Line = attrib(default=Factory(Line))
631     standalone_comments: List[Leaf] = attrib(default=Factory(list))
632
633     def line(self, indent: int = 0) -> Iterator[Line]:
634         """Generate a line.
635
636         If the line is empty, only emit if it makes sense.
637         If the line is too long, split it first and then generate.
638
639         If any lines were generated, set up a new current_line.
640         """
641         if not self.current_line:
642             self.current_line.depth += indent
643             return  # Line is empty, don't emit. Creating a new one unnecessary.
644
645         complete_line = self.current_line
646         self.current_line = Line(depth=complete_line.depth + indent)
647         yield complete_line
648
649     def visit_default(self, node: LN) -> Iterator[Line]:
650         if isinstance(node, Leaf):
651             for comment in generate_comments(node):
652                 if self.current_line.bracket_tracker.any_open_brackets():
653                     # any comment within brackets is subject to splitting
654                     self.current_line.append(comment)
655                 elif comment.type == token.COMMENT:
656                     # regular trailing comment
657                     self.current_line.append(comment)
658                     yield from self.line()
659
660                 else:
661                     # regular standalone comment, to be processed later (see
662                     # docstring in `generate_comments()`
663                     self.standalone_comments.append(comment)
664             normalize_prefix(node)
665             if node.type not in WHITESPACE:
666                 for comment in self.standalone_comments:
667                     yield from self.line()
668
669                     self.current_line.append(comment)
670                     yield from self.line()
671
672                 self.standalone_comments = []
673                 self.current_line.append(node)
674         yield from super().visit_default(node)
675
676     def visit_suite(self, node: Node) -> Iterator[Line]:
677         """Body of a statement after a colon."""
678         children = iter(node.children)
679         # Process newline before indenting.  It might contain an inline
680         # comment that should go right after the colon.
681         newline = next(children)
682         yield from self.visit(newline)
683         yield from self.line(+1)
684
685         for child in children:
686             yield from self.visit(child)
687
688         yield from self.line(-1)
689
690     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
691         """Visit a statement.
692
693         The relevant Python language keywords for this statement are NAME leaves
694         within it.
695         """
696         for child in node.children:
697             if child.type == token.NAME and child.value in keywords:  # type: ignore
698                 yield from self.line()
699
700             yield from self.visit(child)
701
702     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
703         """A statement without nested statements."""
704         is_suite_like = node.parent and node.parent.type in STATEMENT
705         if is_suite_like:
706             yield from self.line(+1)
707             yield from self.visit_default(node)
708             yield from self.line(-1)
709
710         else:
711             yield from self.line()
712             yield from self.visit_default(node)
713
714     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
715         yield from self.line()
716
717         children = iter(node.children)
718         for child in children:
719             yield from self.visit(child)
720
721             if child.type == token.NAME and child.value == 'async':  # type: ignore
722                 break
723
724         internal_stmt = next(children)
725         for child in internal_stmt.children:
726             yield from self.visit(child)
727
728     def visit_decorators(self, node: Node) -> Iterator[Line]:
729         for child in node.children:
730             yield from self.line()
731             yield from self.visit(child)
732
733     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
734         yield from self.line()
735
736     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
737         yield from self.visit_default(leaf)
738         yield from self.line()
739
740     def __attrs_post_init__(self) -> None:
741         """You are in a twisty little maze of passages."""
742         v = self.visit_stmt
743         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
744         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
745         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
746         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
747         self.visit_except_clause = partial(v, keywords={'except'})
748         self.visit_funcdef = partial(v, keywords={'def'})
749         self.visit_with_stmt = partial(v, keywords={'with'})
750         self.visit_classdef = partial(v, keywords={'class'})
751         self.visit_async_funcdef = self.visit_async_stmt
752         self.visit_decorated = self.visit_decorators
753
754
755 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
756 OPENING_BRACKETS = set(BRACKET.keys())
757 CLOSING_BRACKETS = set(BRACKET.values())
758 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
759
760
761 def whitespace(leaf: Leaf) -> str:
762     """Return whitespace prefix if needed for the given `leaf`."""
763     NO = ''
764     SPACE = ' '
765     DOUBLESPACE = '  '
766     t = leaf.type
767     p = leaf.parent
768     v = leaf.value
769     if t == token.COLON:
770         return NO
771
772     if t == token.COMMA:
773         return NO
774
775     if t == token.RPAR:
776         return NO
777
778     if t == token.COMMENT:
779         return DOUBLESPACE
780
781     if t == STANDALONE_COMMENT:
782         return NO
783
784     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
785     if p.type in {syms.parameters, syms.arglist}:
786         # untyped function signatures or calls
787         if t == token.RPAR:
788             return NO
789
790         prev = leaf.prev_sibling
791         if not prev or prev.type != token.COMMA:
792             return NO
793
794     if p.type == syms.varargslist:
795         # lambdas
796         if t == token.RPAR:
797             return NO
798
799         prev = leaf.prev_sibling
800         if prev and prev.type != token.COMMA:
801             return NO
802
803     elif p.type == syms.typedargslist:
804         # typed function signatures
805         prev = leaf.prev_sibling
806         if not prev:
807             return NO
808
809         if t == token.EQUAL:
810             if prev.type != syms.tname:
811                 return NO
812
813         elif prev.type == token.EQUAL:
814             # A bit hacky: if the equal sign has whitespace, it means we
815             # previously found it's a typed argument.  So, we're using that, too.
816             return prev.prefix
817
818         elif prev.type != token.COMMA:
819             return NO
820
821     elif p.type == syms.tname:
822         # type names
823         prev = leaf.prev_sibling
824         if not prev:
825             prevp = preceding_leaf(p)
826             if not prevp or prevp.type != token.COMMA:
827                 return NO
828
829     elif p.type == syms.trailer:
830         # attributes and calls
831         if t == token.LPAR or t == token.RPAR:
832             return NO
833
834         prev = leaf.prev_sibling
835         if not prev:
836             if t == token.DOT:
837                 prevp = preceding_leaf(p)
838                 if not prevp or prevp.type != token.NUMBER:
839                     return NO
840
841             elif t == token.LSQB:
842                 return NO
843
844         elif prev.type != token.COMMA:
845             return NO
846
847     elif p.type == syms.argument:
848         # single argument
849         if t == token.EQUAL:
850             return NO
851
852         prev = leaf.prev_sibling
853         if not prev:
854             prevp = preceding_leaf(p)
855             if not prevp or prevp.type == token.LPAR:
856                 return NO
857
858         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
859             return NO
860
861     elif p.type == syms.decorator:
862         # decorators
863         return NO
864
865     elif p.type == syms.dotted_name:
866         prev = leaf.prev_sibling
867         if prev:
868             return NO
869
870         prevp = preceding_leaf(p)
871         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
872             return NO
873
874     elif p.type == syms.classdef:
875         if t == token.LPAR:
876             return NO
877
878         prev = leaf.prev_sibling
879         if prev and prev.type == token.LPAR:
880             return NO
881
882     elif p.type == syms.subscript:
883         # indexing
884         if t == token.COLON:
885             return NO
886
887         prev = leaf.prev_sibling
888         if not prev or prev.type == token.COLON:
889             return NO
890
891     elif p.type in {
892         syms.test,
893         syms.not_test,
894         syms.xor_expr,
895         syms.or_test,
896         syms.and_test,
897         syms.arith_expr,
898         syms.shift_expr,
899         syms.yield_expr,
900         syms.term,
901         syms.power,
902         syms.comparison,
903     }:
904         # various arithmetic and logic expressions
905         prev = leaf.prev_sibling
906         if not prev:
907             prevp = preceding_leaf(p)
908             if not prevp or prevp.type in OPENING_BRACKETS:
909                 return NO
910
911             if prevp.type == token.EQUAL:
912                 if prevp.parent and prevp.parent.type in {
913                     syms.varargslist, syms.parameters, syms.arglist, syms.argument
914                 }:
915                     return NO
916
917         return SPACE
918
919     elif p.type == syms.atom:
920         if t in CLOSING_BRACKETS:
921             return NO
922
923         prev = leaf.prev_sibling
924         if not prev:
925             prevp = preceding_leaf(p)
926             if not prevp:
927                 return NO
928
929             if prevp.type in OPENING_BRACKETS:
930                 return NO
931
932             if prevp.type == token.EQUAL:
933                 if prevp.parent and prevp.parent.type in {
934                     syms.varargslist, syms.parameters, syms.arglist, syms.argument
935                 }:
936                     return NO
937
938             if prevp.type == token.DOUBLESTAR:
939                 if prevp.parent and prevp.parent.type in {
940                     syms.varargslist, syms.parameters, syms.arglist, syms.dictsetmaker
941                 }:
942                     return NO
943
944         elif prev.type in OPENING_BRACKETS:
945             return NO
946
947         elif t == token.DOT:
948             # dots, but not the first one.
949             return NO
950
951     elif (
952         p.type == syms.listmaker or
953         p.type == syms.testlist_gexp or
954         p.type == syms.subscriptlist
955     ):
956         # list interior, including unpacking
957         prev = leaf.prev_sibling
958         if not prev:
959             return NO
960
961     elif p.type == syms.dictsetmaker:
962         # dict and set interior, including unpacking
963         prev = leaf.prev_sibling
964         if not prev:
965             return NO
966
967         if prev.type == token.DOUBLESTAR:
968             return NO
969
970     elif p.type == syms.factor or p.type == syms.star_expr:
971         # unary ops
972         prev = leaf.prev_sibling
973         if not prev:
974             prevp = preceding_leaf(p)
975             if not prevp or prevp.type in OPENING_BRACKETS:
976                 return NO
977
978             prevp_parent = prevp.parent
979             assert prevp_parent is not None
980             if prevp.type == token.COLON and prevp_parent.type in {
981                 syms.subscript, syms.sliceop
982             }:
983                 return NO
984
985             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
986                 return NO
987
988         elif t == token.NAME or t == token.NUMBER:
989             return NO
990
991     elif p.type == syms.import_from:
992         if t == token.DOT:
993             prev = leaf.prev_sibling
994             if prev and prev.type == token.DOT:
995                 return NO
996
997         elif t == token.NAME:
998             if v == 'import':
999                 return SPACE
1000
1001             prev = leaf.prev_sibling
1002             if prev and prev.type == token.DOT:
1003                 return NO
1004
1005     elif p.type == syms.sliceop:
1006         return NO
1007
1008     return SPACE
1009
1010
1011 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1012     """Returns the first leaf that precedes `node`, if any."""
1013     while node:
1014         res = node.prev_sibling
1015         if res:
1016             if isinstance(res, Leaf):
1017                 return res
1018
1019             try:
1020                 return list(res.leaves())[-1]
1021
1022             except IndexError:
1023                 return None
1024
1025         node = node.parent
1026     return None
1027
1028
1029 def is_delimiter(leaf: Leaf) -> int:
1030     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1031
1032     Higher numbers are higher priority.
1033     """
1034     if leaf.type == token.COMMA:
1035         return COMMA_PRIORITY
1036
1037     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1038         return LOGIC_PRIORITY
1039
1040     if leaf.type in COMPARATORS:
1041         return COMPARATOR_PRIORITY
1042
1043     if (
1044         leaf.type in MATH_OPERATORS and
1045         leaf.parent and
1046         leaf.parent.type not in {syms.factor, syms.star_expr}
1047     ):
1048         return MATH_PRIORITY
1049
1050     return 0
1051
1052
1053 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1054     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1055
1056     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1057     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1058     move because it does away with modifying the grammar to include all the
1059     possible places in which comments can be placed.
1060
1061     The sad consequence for us though is that comments don't "belong" anywhere.
1062     This is why this function generates simple parentless Leaf objects for
1063     comments.  We simply don't know what the correct parent should be.
1064
1065     No matter though, we can live without this.  We really only need to
1066     differentiate between inline and standalone comments.  The latter don't
1067     share the line with any code.
1068
1069     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1070     are emitted with a fake STANDALONE_COMMENT token identifier.
1071     """
1072     if not leaf.prefix:
1073         return
1074
1075     if '#' not in leaf.prefix:
1076         return
1077
1078     before_comment, content = leaf.prefix.split('#', 1)
1079     content = content.rstrip()
1080     if content and (content[0] not in {' ', '!', '#'}):
1081         content = ' ' + content
1082     is_standalone_comment = (
1083         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1084     )
1085     if not is_standalone_comment:
1086         # simple trailing comment
1087         yield Leaf(token.COMMENT, value='#' + content)
1088         return
1089
1090     for line in ('#' + content).split('\n'):
1091         line = line.lstrip()
1092         if not line.startswith('#'):
1093             continue
1094
1095         yield Leaf(STANDALONE_COMMENT, line)
1096
1097
1098 def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]:
1099     """Splits a `line` into potentially many lines.
1100
1101     They should fit in the allotted `line_length` but might not be able to.
1102     `inner` signifies that there were a pair of brackets somewhere around the
1103     current `line`, possibly transitively. This means we can fallback to splitting
1104     by delimiters if the LHS/RHS don't yield any results.
1105     """
1106     line_str = str(line).strip('\n')
1107     if len(line_str) <= line_length and '\n' not in line_str:
1108         yield line
1109         return
1110
1111     if line.is_def:
1112         split_funcs = [left_hand_split]
1113     elif line.inside_brackets:
1114         split_funcs = [delimiter_split]
1115         if '\n' not in line_str:
1116             # Only attempt RHS if we don't have multiline strings or comments
1117             # on this line.
1118             split_funcs.append(right_hand_split)
1119     else:
1120         split_funcs = [right_hand_split]
1121     for split_func in split_funcs:
1122         # We are accumulating lines in `result` because we might want to abort
1123         # mission and return the original line in the end, or attempt a different
1124         # split altogether.
1125         result: List[Line] = []
1126         try:
1127             for l in split_func(line):
1128                 if str(l).strip('\n') == line_str:
1129                     raise CannotSplit("Split function returned an unchanged result")
1130
1131                 result.extend(split_line(l, line_length=line_length, inner=True))
1132         except CannotSplit as cs:
1133             continue
1134
1135         else:
1136             yield from result
1137             break
1138
1139     else:
1140         yield line
1141
1142
1143 def left_hand_split(line: Line) -> Iterator[Line]:
1144     """Split line into many lines, starting with the first matching bracket pair.
1145
1146     Note: this usually looks weird, only use this for function definitions.
1147     Prefer RHS otherwise.
1148     """
1149     head = Line(depth=line.depth)
1150     body = Line(depth=line.depth + 1, inside_brackets=True)
1151     tail = Line(depth=line.depth)
1152     tail_leaves: List[Leaf] = []
1153     body_leaves: List[Leaf] = []
1154     head_leaves: List[Leaf] = []
1155     current_leaves = head_leaves
1156     matching_bracket = None
1157     for leaf in line.leaves:
1158         if (
1159             current_leaves is body_leaves and
1160             leaf.type in CLOSING_BRACKETS and
1161             leaf.opening_bracket is matching_bracket  # type: ignore
1162         ):
1163             current_leaves = tail_leaves
1164         current_leaves.append(leaf)
1165         if current_leaves is head_leaves:
1166             if leaf.type in OPENING_BRACKETS:
1167                 matching_bracket = leaf
1168                 current_leaves = body_leaves
1169     # Since body is a new indent level, remove spurious leading whitespace.
1170     if body_leaves:
1171         normalize_prefix(body_leaves[0])
1172     # Build the new lines.
1173     for result, leaves in (
1174         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1175     ):
1176         for leaf in leaves:
1177             result.append(leaf, preformatted=True)
1178             comment_after = line.comments.get(id(leaf))
1179             if comment_after:
1180                 result.append(comment_after, preformatted=True)
1181     # Check if the split succeeded.
1182     tail_len = len(str(tail))
1183     if not body:
1184         if tail_len == 0:
1185             raise CannotSplit("Splitting brackets produced the same line")
1186
1187         elif tail_len < 3:
1188             raise CannotSplit(
1189                 f"Splitting brackets on an empty body to save "
1190                 f"{tail_len} characters is not worth it"
1191             )
1192
1193     for result in (head, body, tail):
1194         if result:
1195             yield result
1196
1197
1198 def right_hand_split(line: Line) -> Iterator[Line]:
1199     """Split line into many lines, starting with the last matching bracket pair."""
1200     head = Line(depth=line.depth)
1201     body = Line(depth=line.depth + 1, inside_brackets=True)
1202     tail = Line(depth=line.depth)
1203     tail_leaves: List[Leaf] = []
1204     body_leaves: List[Leaf] = []
1205     head_leaves: List[Leaf] = []
1206     current_leaves = tail_leaves
1207     opening_bracket = None
1208     for leaf in reversed(line.leaves):
1209         if current_leaves is body_leaves:
1210             if leaf is opening_bracket:
1211                 current_leaves = head_leaves
1212         current_leaves.append(leaf)
1213         if current_leaves is tail_leaves:
1214             if leaf.type in CLOSING_BRACKETS:
1215                 opening_bracket = leaf.opening_bracket  # type: ignore
1216                 current_leaves = body_leaves
1217     tail_leaves.reverse()
1218     body_leaves.reverse()
1219     head_leaves.reverse()
1220     # Since body is a new indent level, remove spurious leading whitespace.
1221     if body_leaves:
1222         normalize_prefix(body_leaves[0])
1223     # Build the new lines.
1224     for result, leaves in (
1225         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1226     ):
1227         for leaf in leaves:
1228             result.append(leaf, preformatted=True)
1229             comment_after = line.comments.get(id(leaf))
1230             if comment_after:
1231                 result.append(comment_after, preformatted=True)
1232     # Check if the split succeeded.
1233     tail_len = len(str(tail).strip('\n'))
1234     if not body:
1235         if tail_len == 0:
1236             raise CannotSplit("Splitting brackets produced the same line")
1237
1238         elif tail_len < 3:
1239             raise CannotSplit(
1240                 f"Splitting brackets on an empty body to save "
1241                 f"{tail_len} characters is not worth it"
1242             )
1243
1244     for result in (head, body, tail):
1245         if result:
1246             yield result
1247
1248
1249 def delimiter_split(line: Line) -> Iterator[Line]:
1250     """Split according to delimiters of the highest priority.
1251
1252     This kind of split doesn't increase indentation.
1253     """
1254     try:
1255         last_leaf = line.leaves[-1]
1256     except IndexError:
1257         raise CannotSplit("Line empty")
1258
1259     delimiters = line.bracket_tracker.delimiters
1260     try:
1261         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1262     except ValueError:
1263         raise CannotSplit("No delimiters found")
1264
1265     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1266     for leaf in line.leaves:
1267         current_line.append(leaf, preformatted=True)
1268         comment_after = line.comments.get(id(leaf))
1269         if comment_after:
1270             current_line.append(comment_after, preformatted=True)
1271         leaf_priority = delimiters.get(id(leaf))
1272         if leaf_priority == delimiter_priority:
1273             normalize_prefix(current_line.leaves[0])
1274             yield current_line
1275
1276             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1277     if current_line:
1278         if (
1279             delimiter_priority == COMMA_PRIORITY and
1280             current_line.leaves[-1].type != token.COMMA
1281         ):
1282             current_line.append(Leaf(token.COMMA, ','))
1283         normalize_prefix(current_line.leaves[0])
1284         yield current_line
1285
1286
1287 def is_import(leaf: Leaf) -> bool:
1288     """Returns True if the given leaf starts an import statement."""
1289     p = leaf.parent
1290     t = leaf.type
1291     v = leaf.value
1292     return bool(
1293         t == token.NAME and
1294         (
1295             (v == 'import' and p and p.type == syms.import_name) or
1296             (v == 'from' and p and p.type == syms.import_from)
1297         )
1298     )
1299
1300
1301 def normalize_prefix(leaf: Leaf) -> None:
1302     """Leave existing extra newlines for imports.  Remove everything else."""
1303     if is_import(leaf):
1304         spl = leaf.prefix.split('#', 1)
1305         nl_count = spl[0].count('\n')
1306         if len(spl) > 1:
1307             # Skip one newline since it was for a standalone comment.
1308             nl_count -= 1
1309         leaf.prefix = '\n' * nl_count
1310         return
1311
1312     leaf.prefix = ''
1313
1314
1315 PYTHON_EXTENSIONS = {'.py'}
1316 BLACKLISTED_DIRECTORIES = {
1317     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1318 }
1319
1320
1321 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1322     for child in path.iterdir():
1323         if child.is_dir():
1324             if child.name in BLACKLISTED_DIRECTORIES:
1325                 continue
1326
1327             yield from gen_python_files_in_dir(child)
1328
1329         elif child.suffix in PYTHON_EXTENSIONS:
1330             yield child
1331
1332
1333 @dataclass
1334 class Report:
1335     """Provides a reformatting counter."""
1336     change_count: int = attrib(default=0)
1337     same_count: int = attrib(default=0)
1338     failure_count: int = attrib(default=0)
1339
1340     def done(self, src: Path, changed: bool) -> None:
1341         """Increment the counter for successful reformatting. Write out a message."""
1342         if changed:
1343             out(f'reformatted {src}')
1344             self.change_count += 1
1345         else:
1346             out(f'{src} already well formatted, good job.', bold=False)
1347             self.same_count += 1
1348
1349     def failed(self, src: Path, message: str) -> None:
1350         """Increment the counter for failed reformatting. Write out a message."""
1351         err(f'error: cannot format {src}: {message}')
1352         self.failure_count += 1
1353
1354     @property
1355     def return_code(self) -> int:
1356         """Which return code should the app use considering the current state."""
1357         return 1 if self.failure_count else 0
1358
1359     def __str__(self) -> str:
1360         """A color report of the current state.
1361
1362         Use `click.unstyle` to remove colors.
1363         """
1364         report = []
1365         if self.change_count:
1366             s = 's' if self.change_count > 1 else ''
1367             report.append(
1368                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1369             )
1370         if self.same_count:
1371             s = 's' if self.same_count > 1 else ''
1372             report.append(f'{self.same_count} file{s} left unchanged')
1373         if self.failure_count:
1374             s = 's' if self.failure_count > 1 else ''
1375             report.append(
1376                 click.style(
1377                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1378                 )
1379             )
1380         return ', '.join(report) + '.'
1381
1382
1383 def assert_equivalent(src: str, dst: str) -> None:
1384     """Raises AssertionError if `src` and `dst` aren't equivalent.
1385
1386     This is a temporary sanity check until Black becomes stable.
1387     """
1388
1389     import ast
1390     import traceback
1391
1392     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1393         """Simple visitor generating strings to compare ASTs by content."""
1394         yield f"{'  ' * depth}{node.__class__.__name__}("
1395
1396         for field in sorted(node._fields):
1397             try:
1398                 value = getattr(node, field)
1399             except AttributeError:
1400                 continue
1401
1402             yield f"{'  ' * (depth+1)}{field}="
1403
1404             if isinstance(value, list):
1405                 for item in value:
1406                     if isinstance(item, ast.AST):
1407                         yield from _v(item, depth + 2)
1408
1409             elif isinstance(value, ast.AST):
1410                 yield from _v(value, depth + 2)
1411
1412             else:
1413                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1414
1415         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1416
1417     try:
1418         src_ast = ast.parse(src)
1419     except Exception as exc:
1420         raise AssertionError(f"cannot parse source: {exc}") from None
1421
1422     try:
1423         dst_ast = ast.parse(dst)
1424     except Exception as exc:
1425         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1426         raise AssertionError(
1427             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1428             f"Please report a bug on https://github.com/ambv/black/issues.  "
1429             f"This invalid output might be helpful: {log}",
1430         ) from None
1431
1432     src_ast_str = '\n'.join(_v(src_ast))
1433     dst_ast_str = '\n'.join(_v(dst_ast))
1434     if src_ast_str != dst_ast_str:
1435         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1436         raise AssertionError(
1437             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1438             f"the source.  "
1439             f"Please report a bug on https://github.com/ambv/black/issues.  "
1440             f"This diff might be helpful: {log}",
1441         ) from None
1442
1443
1444 def assert_stable(src: str, dst: str, line_length: int) -> None:
1445     """Raises AssertionError if `dst` reformats differently the second time.
1446
1447     This is a temporary sanity check until Black becomes stable.
1448     """
1449     newdst = format_str(dst, line_length=line_length)
1450     if dst != newdst:
1451         log = dump_to_file(
1452             diff(src, dst, 'source', 'first pass'),
1453             diff(dst, newdst, 'first pass', 'second pass'),
1454         )
1455         raise AssertionError(
1456             f"INTERNAL ERROR: Black produced different code on the second pass "
1457             f"of the formatter.  "
1458             f"Please report a bug on https://github.com/ambv/black/issues.  "
1459             f"This diff might be helpful: {log}",
1460         ) from None
1461
1462
1463 def dump_to_file(*output: str) -> str:
1464     """Dumps `output` to a temporary file. Returns path to the file."""
1465     import tempfile
1466
1467     with tempfile.NamedTemporaryFile(
1468         mode='w', prefix='blk_', suffix='.log', delete=False
1469     ) as f:
1470         for lines in output:
1471             f.write(lines)
1472             f.write('\n')
1473     return f.name
1474
1475
1476 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1477     """Returns a udiff string between strings `a` and `b`."""
1478     import difflib
1479
1480     a_lines = [line + '\n' for line in a.split('\n')]
1481     b_lines = [line + '\n' for line in b.split('\n')]
1482     return ''.join(
1483         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1484     )
1485
1486
1487 if __name__ == '__main__':
1488     main()