]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

75fbdcadd41159f0b38446288f1b05c5f1408995
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 from typing import (
11     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
12 )
13
14 from attr import attrib, dataclass, Factory
15 import click
16
17 # lib2to3 fork
18 from blib2to3.pytree import Node, Leaf, type_repr
19 from blib2to3 import pygram, pytree
20 from blib2to3.pgen2 import driver, token
21 from blib2to3.pgen2.parse import ParseError
22
23 __version__ = "18.3a1"
24 DEFAULT_LINE_LENGTH = 88
25 # types
26 syms = pygram.python_symbols
27 FileContent = str
28 Encoding = str
29 Depth = int
30 NodeType = int
31 LeafID = int
32 Priority = int
33 LN = Union[Leaf, Node]
34 out = partial(click.secho, bold=True, err=True)
35 err = partial(click.secho, fg='red', err=True)
36
37
38 class NothingChanged(UserWarning):
39     """Raised by `format_file` when the reformatted code is the same as source."""
40
41
42 class CannotSplit(Exception):
43     """A readable split that fits the allotted line length is impossible.
44
45     Raised by `left_hand_split()` and `right_hand_split()`.
46     """
47
48
49 @click.command()
50 @click.option(
51     '-l',
52     '--line-length',
53     type=int,
54     default=DEFAULT_LINE_LENGTH,
55     help='How many character per line to allow.',
56     show_default=True,
57 )
58 @click.option(
59     '--fast/--safe',
60     is_flag=True,
61     help='If --fast given, skip temporary sanity checks. [default: --safe]',
62 )
63 @click.version_option(version=__version__)
64 @click.argument(
65     'src',
66     nargs=-1,
67     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
68 )
69 @click.pass_context
70 def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> None:
71     """The uncompromising code formatter."""
72     sources: List[Path] = []
73     for s in src:
74         p = Path(s)
75         if p.is_dir():
76             sources.extend(gen_python_files_in_dir(p))
77         elif p.is_file():
78             # if a file was explicitly given, we don't care about its extension
79             sources.append(p)
80         else:
81             err(f'invalid path: {s}')
82     if len(sources) == 0:
83         ctx.exit(0)
84     elif len(sources) == 1:
85         p = sources[0]
86         report = Report()
87         try:
88             changed = format_file_in_place(p, line_length=line_length, fast=fast)
89             report.done(p, changed)
90         except Exception as exc:
91             report.failed(p, str(exc))
92         ctx.exit(report.return_code)
93     else:
94         loop = asyncio.get_event_loop()
95         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
96         return_code = 1
97         try:
98             return_code = loop.run_until_complete(
99                 schedule_formatting(sources, line_length, fast, loop, executor)
100             )
101         finally:
102             loop.close()
103             ctx.exit(return_code)
104
105
106 async def schedule_formatting(
107     sources: List[Path],
108     line_length: int,
109     fast: bool,
110     loop: BaseEventLoop,
111     executor: Executor,
112 ) -> int:
113     tasks = {
114         src: loop.run_in_executor(
115             executor, format_file_in_place, src, line_length, fast
116         )
117         for src in sources
118     }
119     await asyncio.wait(tasks.values())
120     cancelled = []
121     report = Report()
122     for src, task in tasks.items():
123         if not task.done():
124             report.failed(src, 'timed out, cancelling')
125             task.cancel()
126             cancelled.append(task)
127         elif task.exception():
128             report.failed(src, str(task.exception()))
129         else:
130             report.done(src, task.result())
131     if cancelled:
132         await asyncio.wait(cancelled, timeout=2)
133     out('All done! ✨ 🍰 ✨')
134     click.echo(str(report))
135     return report.return_code
136
137
138 def format_file_in_place(src: Path, line_length: int, fast: bool) -> bool:
139     """Format the file and rewrite if changed. Return True if changed."""
140     try:
141         contents, encoding = format_file(src, line_length=line_length, fast=fast)
142     except NothingChanged:
143         return False
144
145     with open(src, "w", encoding=encoding) as f:
146         f.write(contents)
147     return True
148
149
150 def format_file(
151     src: Path, line_length: int, fast: bool
152 ) -> Tuple[FileContent, Encoding]:
153     """Reformats a file and returns its contents and encoding."""
154     with tokenize.open(src) as src_buffer:
155         src_contents = src_buffer.read()
156     if src_contents.strip() == '':
157         raise NothingChanged(src)
158
159     dst_contents = format_str(src_contents, line_length=line_length)
160     if src_contents == dst_contents:
161         raise NothingChanged(src)
162
163     if not fast:
164         assert_equivalent(src_contents, dst_contents)
165         assert_stable(src_contents, dst_contents, line_length=line_length)
166     return dst_contents, src_buffer.encoding
167
168
169 def format_str(src_contents: str, line_length: int) -> FileContent:
170     """Reformats a string and returns new contents."""
171     src_node = lib2to3_parse(src_contents)
172     dst_contents = ""
173     comments: List[Line] = []
174     lines = LineGenerator()
175     elt = EmptyLineTracker()
176     empty_line = Line()
177     after = 0
178     for current_line in lines.visit(src_node):
179         for _ in range(after):
180             dst_contents += str(empty_line)
181         before, after = elt.maybe_empty_lines(current_line)
182         for _ in range(before):
183             dst_contents += str(empty_line)
184         if not current_line.is_comment:
185             for comment in comments:
186                 dst_contents += str(comment)
187             comments = []
188             for line in split_line(current_line, line_length=line_length):
189                 dst_contents += str(line)
190         else:
191             comments.append(current_line)
192     for comment in comments:
193         dst_contents += str(comment)
194     return dst_contents
195
196
197 def lib2to3_parse(src_txt: str) -> Node:
198     """Given a string with source, return the lib2to3 Node."""
199     grammar = pygram.python_grammar_no_print_statement
200     drv = driver.Driver(grammar, pytree.convert)
201     if src_txt[-1] != '\n':
202         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
203         src_txt += nl
204     try:
205         result = drv.parse_string(src_txt, True)
206     except ParseError as pe:
207         lineno, column = pe.context[1]
208         lines = src_txt.splitlines()
209         try:
210             faulty_line = lines[lineno - 1]
211         except IndexError:
212             faulty_line = "<line number missing in source>"
213         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
214
215     if isinstance(result, Leaf):
216         result = Node(syms.file_input, [result])
217     return result
218
219
220 def lib2to3_unparse(node: Node) -> str:
221     """Given a lib2to3 node, return its string representation."""
222     code = str(node)
223     return code
224
225
226 T = TypeVar('T')
227
228
229 class Visitor(Generic[T]):
230     """Basic lib2to3 visitor that yields things on visiting."""
231
232     def visit(self, node: LN) -> Iterator[T]:
233         if node.type < 256:
234             name = token.tok_name[node.type]
235         else:
236             name = type_repr(node.type)
237         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
238
239     def visit_default(self, node: LN) -> Iterator[T]:
240         if isinstance(node, Node):
241             for child in node.children:
242                 yield from self.visit(child)
243
244
245 @dataclass
246 class DebugVisitor(Visitor[T]):
247     tree_depth: int = attrib(default=0)
248
249     def visit_default(self, node: LN) -> Iterator[T]:
250         indent = ' ' * (2 * self.tree_depth)
251         if isinstance(node, Node):
252             _type = type_repr(node.type)
253             out(f'{indent}{_type}', fg='yellow')
254             self.tree_depth += 1
255             for child in node.children:
256                 yield from self.visit(child)
257
258             self.tree_depth -= 1
259             out(f'{indent}/{_type}', fg='yellow', bold=False)
260         else:
261             _type = token.tok_name.get(node.type, str(node.type))
262             out(f'{indent}{_type}', fg='blue', nl=False)
263             if node.prefix:
264                 # We don't have to handle prefixes for `Node` objects since
265                 # that delegates to the first child anyway.
266                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
267             out(f' {node.value!r}', fg='blue', bold=False)
268
269
270 KEYWORDS = set(keyword.kwlist)
271 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
272 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
273 STATEMENT = {
274     syms.if_stmt,
275     syms.while_stmt,
276     syms.for_stmt,
277     syms.try_stmt,
278     syms.except_clause,
279     syms.with_stmt,
280     syms.funcdef,
281     syms.classdef,
282 }
283 STANDALONE_COMMENT = 153
284 LOGIC_OPERATORS = {'and', 'or'}
285 COMPARATORS = {
286     token.LESS,
287     token.GREATER,
288     token.EQEQUAL,
289     token.NOTEQUAL,
290     token.LESSEQUAL,
291     token.GREATEREQUAL,
292 }
293 MATH_OPERATORS = {
294     token.PLUS,
295     token.MINUS,
296     token.STAR,
297     token.SLASH,
298     token.VBAR,
299     token.AMPER,
300     token.PERCENT,
301     token.CIRCUMFLEX,
302     token.LEFTSHIFT,
303     token.RIGHTSHIFT,
304     token.DOUBLESTAR,
305     token.DOUBLESLASH,
306 }
307 COMPREHENSION_PRIORITY = 20
308 COMMA_PRIORITY = 10
309 LOGIC_PRIORITY = 5
310 STRING_PRIORITY = 4
311 COMPARATOR_PRIORITY = 3
312 MATH_PRIORITY = 1
313
314
315 @dataclass
316 class BracketTracker:
317     depth: int = attrib(default=0)
318     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
319     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
320     previous: Optional[Leaf] = attrib(default=None)
321
322     def mark(self, leaf: Leaf) -> None:
323         if leaf.type == token.COMMENT:
324             return
325
326         if leaf.type in CLOSING_BRACKETS:
327             self.depth -= 1
328             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
329             leaf.opening_bracket = opening_bracket  # type: ignore
330         leaf.bracket_depth = self.depth  # type: ignore
331         if self.depth == 0:
332             delim = is_delimiter(leaf)
333             if delim:
334                 self.delimiters[id(leaf)] = delim
335             elif self.previous is not None:
336                 if leaf.type == token.STRING and self.previous.type == token.STRING:
337                     self.delimiters[id(self.previous)] = STRING_PRIORITY
338                 elif (
339                     leaf.type == token.NAME and
340                     leaf.value == 'for' and
341                     leaf.parent and
342                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
343                 ):
344                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
345                 elif (
346                     leaf.type == token.NAME and
347                     leaf.value == 'if' and
348                     leaf.parent and
349                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
350                 ):
351                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
352         if leaf.type in OPENING_BRACKETS:
353             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
354             self.depth += 1
355         self.previous = leaf
356
357     def any_open_brackets(self) -> bool:
358         """Returns True if there is an yet unmatched open bracket on the line."""
359         return bool(self.bracket_match)
360
361     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
362         """Returns the highest priority of a delimiter found on the line.
363
364         Values are consistent with what `is_delimiter()` returns.
365         """
366         return max(v for k, v in self.delimiters.items() if k not in exclude)
367
368
369 @dataclass
370 class Line:
371     depth: int = attrib(default=0)
372     leaves: List[Leaf] = attrib(default=Factory(list))
373     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
374     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
375     inside_brackets: bool = attrib(default=False)
376
377     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
378         has_value = leaf.value.strip()
379         if not has_value:
380             return
381
382         if self.leaves and not preformatted:
383             # Note: at this point leaf.prefix should be empty except for
384             # imports, for which we only preserve newlines.
385             leaf.prefix += whitespace(leaf)
386         if self.inside_brackets or not preformatted:
387             self.bracket_tracker.mark(leaf)
388             self.maybe_remove_trailing_comma(leaf)
389             if self.maybe_adapt_standalone_comment(leaf):
390                 return
391
392         if not self.append_comment(leaf):
393             self.leaves.append(leaf)
394
395     @property
396     def is_comment(self) -> bool:
397         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
398
399     @property
400     def is_decorator(self) -> bool:
401         return bool(self) and self.leaves[0].type == token.AT
402
403     @property
404     def is_import(self) -> bool:
405         return bool(self) and is_import(self.leaves[0])
406
407     @property
408     def is_class(self) -> bool:
409         return (
410             bool(self) and
411             self.leaves[0].type == token.NAME and
412             self.leaves[0].value == 'class'
413         )
414
415     @property
416     def is_def(self) -> bool:
417         """Also returns True for async defs."""
418         try:
419             first_leaf = self.leaves[0]
420         except IndexError:
421             return False
422
423         try:
424             second_leaf: Optional[Leaf] = self.leaves[1]
425         except IndexError:
426             second_leaf = None
427         return (
428             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
429             (
430                 first_leaf.type == token.NAME and
431                 first_leaf.value == 'async' and
432                 second_leaf is not None and
433                 second_leaf.type == token.NAME and
434                 second_leaf.value == 'def'
435             )
436         )
437
438     @property
439     def is_flow_control(self) -> bool:
440         return (
441             bool(self) and
442             self.leaves[0].type == token.NAME and
443             self.leaves[0].value in FLOW_CONTROL
444         )
445
446     @property
447     def is_yield(self) -> bool:
448         return (
449             bool(self) and
450             self.leaves[0].type == token.NAME and
451             self.leaves[0].value == 'yield'
452         )
453
454     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
455         if not (
456             self.leaves and
457             self.leaves[-1].type == token.COMMA and
458             closing.type in CLOSING_BRACKETS
459         ):
460             return False
461
462         if closing.type == token.RSQB or closing.type == token.RBRACE:
463             self.leaves.pop()
464             return True
465
466         # For parens let's check if it's safe to remove the comma.  If the
467         # trailing one is the only one, we might mistakenly change a tuple
468         # into a different type by removing the comma.
469         depth = closing.bracket_depth + 1  # type: ignore
470         commas = 0
471         opening = closing.opening_bracket  # type: ignore
472         for _opening_index, leaf in enumerate(self.leaves):
473             if leaf is opening:
474                 break
475
476         else:
477             return False
478
479         for leaf in self.leaves[_opening_index + 1:]:
480             if leaf is closing:
481                 break
482
483             bracket_depth = leaf.bracket_depth  # type: ignore
484             if bracket_depth == depth and leaf.type == token.COMMA:
485                 commas += 1
486         if commas > 1:
487             self.leaves.pop()
488             return True
489
490         return False
491
492     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
493         """Hack a standalone comment to act as a trailing comment for line splitting.
494
495         If this line has brackets and a standalone `comment`, we need to adapt
496         it to be able to still reformat the line.
497
498         This is not perfect, the line to which the standalone comment gets
499         appended will appear "too long" when splitting.
500         """
501         if not (
502             comment.type == STANDALONE_COMMENT and
503             self.bracket_tracker.any_open_brackets()
504         ):
505             return False
506
507         comment.type = token.COMMENT
508         comment.prefix = '\n' + '    ' * (self.depth + 1)
509         return self.append_comment(comment)
510
511     def append_comment(self, comment: Leaf) -> bool:
512         if comment.type != token.COMMENT:
513             return False
514
515         try:
516             after = id(self.last_non_delimiter())
517         except LookupError:
518             comment.type = STANDALONE_COMMENT
519             comment.prefix = ''
520             return False
521
522         else:
523             if after in self.comments:
524                 self.comments[after].value += str(comment)
525             else:
526                 self.comments[after] = comment
527             return True
528
529     def last_non_delimiter(self) -> Leaf:
530         for i in range(len(self.leaves)):
531             last = self.leaves[-i - 1]
532             if not is_delimiter(last):
533                 return last
534
535         raise LookupError("No non-delimiters found")
536
537     def __str__(self) -> str:
538         if not self:
539             return '\n'
540
541         indent = '    ' * self.depth
542         leaves = iter(self.leaves)
543         first = next(leaves)
544         res = f'{first.prefix}{indent}{first.value}'
545         for leaf in leaves:
546             res += str(leaf)
547         for comment in self.comments.values():
548             res += str(comment)
549         return res + '\n'
550
551     def __bool__(self) -> bool:
552         return bool(self.leaves or self.comments)
553
554
555 @dataclass
556 class EmptyLineTracker:
557     """Provides a stateful method that returns the number of potential extra
558     empty lines needed before and after the currently processed line.
559
560     Note: this tracker works on lines that haven't been split yet.
561     """
562     previous_line: Optional[Line] = attrib(default=None)
563     previous_after: int = attrib(default=0)
564     previous_defs: List[int] = attrib(default=Factory(list))
565
566     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
567         """Returns the number of extra empty lines before and after the `current_line`.
568
569         This is for separating `def`, `async def` and `class` with extra empty lines
570         (two on module-level), as well as providing an extra empty line after flow
571         control keywords to make them more prominent.
572         """
573         before, after = self._maybe_empty_lines(current_line)
574         self.previous_after = after
575         self.previous_line = current_line
576         return before, after
577
578     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
579         before = 0
580         depth = current_line.depth
581         while self.previous_defs and self.previous_defs[-1] >= depth:
582             self.previous_defs.pop()
583             before = (1 if depth else 2) - self.previous_after
584         is_decorator = current_line.is_decorator
585         if is_decorator or current_line.is_def or current_line.is_class:
586             if not is_decorator:
587                 self.previous_defs.append(depth)
588             if self.previous_line is None:
589                 # Don't insert empty lines before the first line in the file.
590                 return 0, 0
591
592             if self.previous_line and self.previous_line.is_decorator:
593                 # Don't insert empty lines between decorators.
594                 return 0, 0
595
596             newlines = 2
597             if current_line.depth:
598                 newlines -= 1
599             newlines -= self.previous_after
600             return newlines, 0
601
602         if current_line.is_flow_control:
603             return before, 1
604
605         if (
606             self.previous_line and
607             self.previous_line.is_import and
608             not current_line.is_import and
609             depth == self.previous_line.depth
610         ):
611             return (before or 1), 0
612
613         if (
614             self.previous_line and
615             self.previous_line.is_yield and
616             (not current_line.is_yield or depth != self.previous_line.depth)
617         ):
618             return (before or 1), 0
619
620         return before, 0
621
622
623 @dataclass
624 class LineGenerator(Visitor[Line]):
625     """Generates reformatted Line objects.  Empty lines are not emitted.
626
627     Note: destroys the tree it's visiting by mutating prefixes of its leaves
628     in ways that will no longer stringify to valid Python code on the tree.
629     """
630     current_line: Line = attrib(default=Factory(Line))
631     standalone_comments: List[Leaf] = attrib(default=Factory(list))
632
633     def line(self, indent: int = 0) -> Iterator[Line]:
634         """Generate a line.
635
636         If the line is empty, only emit if it makes sense.
637         If the line is too long, split it first and then generate.
638
639         If any lines were generated, set up a new current_line.
640         """
641         if not self.current_line:
642             self.current_line.depth += indent
643             return  # Line is empty, don't emit. Creating a new one unnecessary.
644
645         complete_line = self.current_line
646         self.current_line = Line(depth=complete_line.depth + indent)
647         yield complete_line
648
649     def visit_default(self, node: LN) -> Iterator[Line]:
650         if isinstance(node, Leaf):
651             for comment in generate_comments(node):
652                 if self.current_line.bracket_tracker.any_open_brackets():
653                     # any comment within brackets is subject to splitting
654                     self.current_line.append(comment)
655                 elif comment.type == token.COMMENT:
656                     # regular trailing comment
657                     self.current_line.append(comment)
658                     yield from self.line()
659
660                 else:
661                     # regular standalone comment, to be processed later (see
662                     # docstring in `generate_comments()`
663                     self.standalone_comments.append(comment)
664             normalize_prefix(node)
665             if node.type not in WHITESPACE:
666                 for comment in self.standalone_comments:
667                     yield from self.line()
668
669                     self.current_line.append(comment)
670                     yield from self.line()
671
672                 self.standalone_comments = []
673                 self.current_line.append(node)
674         yield from super().visit_default(node)
675
676     def visit_suite(self, node: Node) -> Iterator[Line]:
677         """Body of a statement after a colon."""
678         children = iter(node.children)
679         # Process newline before indenting.  It might contain an inline
680         # comment that should go right after the colon.
681         newline = next(children)
682         yield from self.visit(newline)
683         yield from self.line(+1)
684
685         for child in children:
686             yield from self.visit(child)
687
688         yield from self.line(-1)
689
690     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
691         """Visit a statement.
692
693         The relevant Python language keywords for this statement are NAME leaves
694         within it.
695         """
696         for child in node.children:
697             if child.type == token.NAME and child.value in keywords:  # type: ignore
698                 yield from self.line()
699
700             yield from self.visit(child)
701
702     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
703         """A statement without nested statements."""
704         is_suite_like = node.parent and node.parent.type in STATEMENT
705         if is_suite_like:
706             yield from self.line(+1)
707             yield from self.visit_default(node)
708             yield from self.line(-1)
709
710         else:
711             yield from self.line()
712             yield from self.visit_default(node)
713
714     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
715         yield from self.line()
716
717         children = iter(node.children)
718         for child in children:
719             yield from self.visit(child)
720
721             if child.type == token.NAME and child.value == 'async':  # type: ignore
722                 break
723
724         internal_stmt = next(children)
725         for child in internal_stmt.children:
726             yield from self.visit(child)
727
728     def visit_decorators(self, node: Node) -> Iterator[Line]:
729         for child in node.children:
730             yield from self.line()
731             yield from self.visit(child)
732
733     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
734         yield from self.line()
735
736     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
737         yield from self.visit_default(leaf)
738         yield from self.line()
739
740     def __attrs_post_init__(self) -> None:
741         """You are in a twisty little maze of passages."""
742         v = self.visit_stmt
743         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
744         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
745         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
746         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
747         self.visit_except_clause = partial(v, keywords={'except'})
748         self.visit_funcdef = partial(v, keywords={'def'})
749         self.visit_with_stmt = partial(v, keywords={'with'})
750         self.visit_classdef = partial(v, keywords={'class'})
751         self.visit_async_funcdef = self.visit_async_stmt
752         self.visit_decorated = self.visit_decorators
753
754
755 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
756 OPENING_BRACKETS = set(BRACKET.keys())
757 CLOSING_BRACKETS = set(BRACKET.values())
758 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
759
760
761 def whitespace(leaf: Leaf) -> str:
762     """Return whitespace prefix if needed for the given `leaf`."""
763     NO = ''
764     SPACE = ' '
765     DOUBLESPACE = '  '
766     t = leaf.type
767     p = leaf.parent
768     v = leaf.value
769     if t == token.COLON:
770         return NO
771
772     if t == token.COMMA:
773         return NO
774
775     if t == token.RPAR:
776         return NO
777
778     if t == token.COMMENT:
779         return DOUBLESPACE
780
781     if t == STANDALONE_COMMENT:
782         return NO
783
784     if t in CLOSING_BRACKETS:
785         return NO
786
787     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
788     prev = leaf.prev_sibling
789     if not prev:
790         prevp = preceding_leaf(p)
791         if not prevp or prevp.type in OPENING_BRACKETS:
792             return NO
793
794         if prevp.type == token.EQUAL:
795             if prevp.parent and prevp.parent.type in {
796                 syms.typedargslist,
797                 syms.varargslist,
798                 syms.parameters,
799                 syms.arglist,
800                 syms.argument,
801             }:
802                 return NO
803
804         elif prevp.type == token.DOUBLESTAR:
805             if prevp.parent and prevp.parent.type in {
806                 syms.typedargslist,
807                 syms.varargslist,
808                 syms.parameters,
809                 syms.arglist,
810                 syms.dictsetmaker,
811             }:
812                 return NO
813
814         elif prevp.type == token.COLON:
815             if prevp.parent and prevp.parent.type == syms.subscript:
816                 return NO
817
818     elif prev.type in OPENING_BRACKETS:
819         return NO
820
821     if p.type in {syms.parameters, syms.arglist}:
822         # untyped function signatures or calls
823         if t == token.RPAR:
824             return NO
825
826         if not prev or prev.type != token.COMMA:
827             return NO
828
829     if p.type == syms.varargslist:
830         # lambdas
831         if t == token.RPAR:
832             return NO
833
834         if prev and prev.type != token.COMMA:
835             return NO
836
837     elif p.type == syms.typedargslist:
838         # typed function signatures
839         if not prev:
840             return NO
841
842         if t == token.EQUAL:
843             if prev.type != syms.tname:
844                 return NO
845
846         elif prev.type == token.EQUAL:
847             # A bit hacky: if the equal sign has whitespace, it means we
848             # previously found it's a typed argument.  So, we're using that, too.
849             return prev.prefix
850
851         elif prev.type != token.COMMA:
852             return NO
853
854     elif p.type == syms.tname:
855         # type names
856         if not prev:
857             prevp = preceding_leaf(p)
858             if not prevp or prevp.type != token.COMMA:
859                 return NO
860
861     elif p.type == syms.trailer:
862         # attributes and calls
863         if t == token.LPAR or t == token.RPAR:
864             return NO
865
866         if not prev:
867             if t == token.DOT:
868                 prevp = preceding_leaf(p)
869                 if not prevp or prevp.type != token.NUMBER:
870                     return NO
871
872             elif t == token.LSQB:
873                 return NO
874
875         elif prev.type != token.COMMA:
876             return NO
877
878     elif p.type == syms.argument:
879         # single argument
880         if t == token.EQUAL:
881             return NO
882
883         if not prev:
884             prevp = preceding_leaf(p)
885             if not prevp or prevp.type == token.LPAR:
886                 return NO
887
888         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
889             return NO
890
891     elif p.type == syms.decorator:
892         # decorators
893         return NO
894
895     elif p.type == syms.dotted_name:
896         if prev:
897             return NO
898
899         prevp = preceding_leaf(p)
900         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
901             return NO
902
903     elif p.type == syms.classdef:
904         if t == token.LPAR:
905             return NO
906
907         if prev and prev.type == token.LPAR:
908             return NO
909
910     elif p.type == syms.subscript:
911         # indexing
912         if not prev or prev.type == token.COLON:
913             return NO
914
915     elif p.type == syms.atom:
916         if prev and t == token.DOT:
917             # dots, but not the first one.
918             return NO
919
920     elif (
921         p.type == syms.listmaker or
922         p.type == syms.testlist_gexp or
923         p.type == syms.subscriptlist
924     ):
925         # list interior, including unpacking
926         if not prev:
927             return NO
928
929     elif p.type == syms.dictsetmaker:
930         # dict and set interior, including unpacking
931         if not prev:
932             return NO
933
934         if prev.type == token.DOUBLESTAR:
935             return NO
936
937     elif p.type == syms.factor or p.type == syms.star_expr:
938         # unary ops
939         if not prev:
940             prevp = preceding_leaf(p)
941             if not prevp or prevp.type in OPENING_BRACKETS:
942                 return NO
943
944             prevp_parent = prevp.parent
945             assert prevp_parent is not None
946             if prevp.type == token.COLON and prevp_parent.type in {
947                 syms.subscript, syms.sliceop
948             }:
949                 return NO
950
951             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
952                 return NO
953
954         elif t == token.NAME or t == token.NUMBER:
955             return NO
956
957     elif p.type == syms.import_from:
958         if t == token.DOT:
959             if prev and prev.type == token.DOT:
960                 return NO
961
962         elif t == token.NAME:
963             if v == 'import':
964                 return SPACE
965
966             if prev and prev.type == token.DOT:
967                 return NO
968
969     elif p.type == syms.sliceop:
970         return NO
971
972     return SPACE
973
974
975 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
976     """Returns the first leaf that precedes `node`, if any."""
977     while node:
978         res = node.prev_sibling
979         if res:
980             if isinstance(res, Leaf):
981                 return res
982
983             try:
984                 return list(res.leaves())[-1]
985
986             except IndexError:
987                 return None
988
989         node = node.parent
990     return None
991
992
993 def is_delimiter(leaf: Leaf) -> int:
994     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
995
996     Higher numbers are higher priority.
997     """
998     if leaf.type == token.COMMA:
999         return COMMA_PRIORITY
1000
1001     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1002         return LOGIC_PRIORITY
1003
1004     if leaf.type in COMPARATORS:
1005         return COMPARATOR_PRIORITY
1006
1007     if (
1008         leaf.type in MATH_OPERATORS and
1009         leaf.parent and
1010         leaf.parent.type not in {syms.factor, syms.star_expr}
1011     ):
1012         return MATH_PRIORITY
1013
1014     return 0
1015
1016
1017 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1018     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1019
1020     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1021     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1022     move because it does away with modifying the grammar to include all the
1023     possible places in which comments can be placed.
1024
1025     The sad consequence for us though is that comments don't "belong" anywhere.
1026     This is why this function generates simple parentless Leaf objects for
1027     comments.  We simply don't know what the correct parent should be.
1028
1029     No matter though, we can live without this.  We really only need to
1030     differentiate between inline and standalone comments.  The latter don't
1031     share the line with any code.
1032
1033     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1034     are emitted with a fake STANDALONE_COMMENT token identifier.
1035     """
1036     if not leaf.prefix:
1037         return
1038
1039     if '#' not in leaf.prefix:
1040         return
1041
1042     before_comment, content = leaf.prefix.split('#', 1)
1043     content = content.rstrip()
1044     if content and (content[0] not in {' ', '!', '#'}):
1045         content = ' ' + content
1046     is_standalone_comment = (
1047         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1048     )
1049     if not is_standalone_comment:
1050         # simple trailing comment
1051         yield Leaf(token.COMMENT, value='#' + content)
1052         return
1053
1054     for line in ('#' + content).split('\n'):
1055         line = line.lstrip()
1056         if not line.startswith('#'):
1057             continue
1058
1059         yield Leaf(STANDALONE_COMMENT, line)
1060
1061
1062 def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]:
1063     """Splits a `line` into potentially many lines.
1064
1065     They should fit in the allotted `line_length` but might not be able to.
1066     `inner` signifies that there were a pair of brackets somewhere around the
1067     current `line`, possibly transitively. This means we can fallback to splitting
1068     by delimiters if the LHS/RHS don't yield any results.
1069     """
1070     line_str = str(line).strip('\n')
1071     if len(line_str) <= line_length and '\n' not in line_str:
1072         yield line
1073         return
1074
1075     if line.is_def:
1076         split_funcs = [left_hand_split]
1077     elif line.inside_brackets:
1078         split_funcs = [delimiter_split]
1079         if '\n' not in line_str:
1080             # Only attempt RHS if we don't have multiline strings or comments
1081             # on this line.
1082             split_funcs.append(right_hand_split)
1083     else:
1084         split_funcs = [right_hand_split]
1085     for split_func in split_funcs:
1086         # We are accumulating lines in `result` because we might want to abort
1087         # mission and return the original line in the end, or attempt a different
1088         # split altogether.
1089         result: List[Line] = []
1090         try:
1091             for l in split_func(line):
1092                 if str(l).strip('\n') == line_str:
1093                     raise CannotSplit("Split function returned an unchanged result")
1094
1095                 result.extend(split_line(l, line_length=line_length, inner=True))
1096         except CannotSplit as cs:
1097             continue
1098
1099         else:
1100             yield from result
1101             break
1102
1103     else:
1104         yield line
1105
1106
1107 def left_hand_split(line: Line) -> Iterator[Line]:
1108     """Split line into many lines, starting with the first matching bracket pair.
1109
1110     Note: this usually looks weird, only use this for function definitions.
1111     Prefer RHS otherwise.
1112     """
1113     head = Line(depth=line.depth)
1114     body = Line(depth=line.depth + 1, inside_brackets=True)
1115     tail = Line(depth=line.depth)
1116     tail_leaves: List[Leaf] = []
1117     body_leaves: List[Leaf] = []
1118     head_leaves: List[Leaf] = []
1119     current_leaves = head_leaves
1120     matching_bracket = None
1121     for leaf in line.leaves:
1122         if (
1123             current_leaves is body_leaves and
1124             leaf.type in CLOSING_BRACKETS and
1125             leaf.opening_bracket is matching_bracket  # type: ignore
1126         ):
1127             current_leaves = tail_leaves
1128         current_leaves.append(leaf)
1129         if current_leaves is head_leaves:
1130             if leaf.type in OPENING_BRACKETS:
1131                 matching_bracket = leaf
1132                 current_leaves = body_leaves
1133     # Since body is a new indent level, remove spurious leading whitespace.
1134     if body_leaves:
1135         normalize_prefix(body_leaves[0])
1136     # Build the new lines.
1137     for result, leaves in (
1138         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1139     ):
1140         for leaf in leaves:
1141             result.append(leaf, preformatted=True)
1142             comment_after = line.comments.get(id(leaf))
1143             if comment_after:
1144                 result.append(comment_after, preformatted=True)
1145     # Check if the split succeeded.
1146     tail_len = len(str(tail))
1147     if not body:
1148         if tail_len == 0:
1149             raise CannotSplit("Splitting brackets produced the same line")
1150
1151         elif tail_len < 3:
1152             raise CannotSplit(
1153                 f"Splitting brackets on an empty body to save "
1154                 f"{tail_len} characters is not worth it"
1155             )
1156
1157     for result in (head, body, tail):
1158         if result:
1159             yield result
1160
1161
1162 def right_hand_split(line: Line) -> Iterator[Line]:
1163     """Split line into many lines, starting with the last matching bracket pair."""
1164     head = Line(depth=line.depth)
1165     body = Line(depth=line.depth + 1, inside_brackets=True)
1166     tail = Line(depth=line.depth)
1167     tail_leaves: List[Leaf] = []
1168     body_leaves: List[Leaf] = []
1169     head_leaves: List[Leaf] = []
1170     current_leaves = tail_leaves
1171     opening_bracket = None
1172     for leaf in reversed(line.leaves):
1173         if current_leaves is body_leaves:
1174             if leaf is opening_bracket:
1175                 current_leaves = head_leaves
1176         current_leaves.append(leaf)
1177         if current_leaves is tail_leaves:
1178             if leaf.type in CLOSING_BRACKETS:
1179                 opening_bracket = leaf.opening_bracket  # type: ignore
1180                 current_leaves = body_leaves
1181     tail_leaves.reverse()
1182     body_leaves.reverse()
1183     head_leaves.reverse()
1184     # Since body is a new indent level, remove spurious leading whitespace.
1185     if body_leaves:
1186         normalize_prefix(body_leaves[0])
1187     # Build the new lines.
1188     for result, leaves in (
1189         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1190     ):
1191         for leaf in leaves:
1192             result.append(leaf, preformatted=True)
1193             comment_after = line.comments.get(id(leaf))
1194             if comment_after:
1195                 result.append(comment_after, preformatted=True)
1196     # Check if the split succeeded.
1197     tail_len = len(str(tail).strip('\n'))
1198     if not body:
1199         if tail_len == 0:
1200             raise CannotSplit("Splitting brackets produced the same line")
1201
1202         elif tail_len < 3:
1203             raise CannotSplit(
1204                 f"Splitting brackets on an empty body to save "
1205                 f"{tail_len} characters is not worth it"
1206             )
1207
1208     for result in (head, body, tail):
1209         if result:
1210             yield result
1211
1212
1213 def delimiter_split(line: Line) -> Iterator[Line]:
1214     """Split according to delimiters of the highest priority.
1215
1216     This kind of split doesn't increase indentation.
1217     """
1218     try:
1219         last_leaf = line.leaves[-1]
1220     except IndexError:
1221         raise CannotSplit("Line empty")
1222
1223     delimiters = line.bracket_tracker.delimiters
1224     try:
1225         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1226     except ValueError:
1227         raise CannotSplit("No delimiters found")
1228
1229     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1230     for leaf in line.leaves:
1231         current_line.append(leaf, preformatted=True)
1232         comment_after = line.comments.get(id(leaf))
1233         if comment_after:
1234             current_line.append(comment_after, preformatted=True)
1235         leaf_priority = delimiters.get(id(leaf))
1236         if leaf_priority == delimiter_priority:
1237             normalize_prefix(current_line.leaves[0])
1238             yield current_line
1239
1240             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1241     if current_line:
1242         if (
1243             delimiter_priority == COMMA_PRIORITY and
1244             current_line.leaves[-1].type != token.COMMA
1245         ):
1246             current_line.append(Leaf(token.COMMA, ','))
1247         normalize_prefix(current_line.leaves[0])
1248         yield current_line
1249
1250
1251 def is_import(leaf: Leaf) -> bool:
1252     """Returns True if the given leaf starts an import statement."""
1253     p = leaf.parent
1254     t = leaf.type
1255     v = leaf.value
1256     return bool(
1257         t == token.NAME and
1258         (
1259             (v == 'import' and p and p.type == syms.import_name) or
1260             (v == 'from' and p and p.type == syms.import_from)
1261         )
1262     )
1263
1264
1265 def normalize_prefix(leaf: Leaf) -> None:
1266     """Leave existing extra newlines for imports.  Remove everything else."""
1267     if is_import(leaf):
1268         spl = leaf.prefix.split('#', 1)
1269         nl_count = spl[0].count('\n')
1270         if len(spl) > 1:
1271             # Skip one newline since it was for a standalone comment.
1272             nl_count -= 1
1273         leaf.prefix = '\n' * nl_count
1274         return
1275
1276     leaf.prefix = ''
1277
1278
1279 PYTHON_EXTENSIONS = {'.py'}
1280 BLACKLISTED_DIRECTORIES = {
1281     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1282 }
1283
1284
1285 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1286     for child in path.iterdir():
1287         if child.is_dir():
1288             if child.name in BLACKLISTED_DIRECTORIES:
1289                 continue
1290
1291             yield from gen_python_files_in_dir(child)
1292
1293         elif child.suffix in PYTHON_EXTENSIONS:
1294             yield child
1295
1296
1297 @dataclass
1298 class Report:
1299     """Provides a reformatting counter."""
1300     change_count: int = attrib(default=0)
1301     same_count: int = attrib(default=0)
1302     failure_count: int = attrib(default=0)
1303
1304     def done(self, src: Path, changed: bool) -> None:
1305         """Increment the counter for successful reformatting. Write out a message."""
1306         if changed:
1307             out(f'reformatted {src}')
1308             self.change_count += 1
1309         else:
1310             out(f'{src} already well formatted, good job.', bold=False)
1311             self.same_count += 1
1312
1313     def failed(self, src: Path, message: str) -> None:
1314         """Increment the counter for failed reformatting. Write out a message."""
1315         err(f'error: cannot format {src}: {message}')
1316         self.failure_count += 1
1317
1318     @property
1319     def return_code(self) -> int:
1320         """Which return code should the app use considering the current state."""
1321         return 1 if self.failure_count else 0
1322
1323     def __str__(self) -> str:
1324         """A color report of the current state.
1325
1326         Use `click.unstyle` to remove colors.
1327         """
1328         report = []
1329         if self.change_count:
1330             s = 's' if self.change_count > 1 else ''
1331             report.append(
1332                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1333             )
1334         if self.same_count:
1335             s = 's' if self.same_count > 1 else ''
1336             report.append(f'{self.same_count} file{s} left unchanged')
1337         if self.failure_count:
1338             s = 's' if self.failure_count > 1 else ''
1339             report.append(
1340                 click.style(
1341                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1342                 )
1343             )
1344         return ', '.join(report) + '.'
1345
1346
1347 def assert_equivalent(src: str, dst: str) -> None:
1348     """Raises AssertionError if `src` and `dst` aren't equivalent.
1349
1350     This is a temporary sanity check until Black becomes stable.
1351     """
1352
1353     import ast
1354     import traceback
1355
1356     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1357         """Simple visitor generating strings to compare ASTs by content."""
1358         yield f"{'  ' * depth}{node.__class__.__name__}("
1359
1360         for field in sorted(node._fields):
1361             try:
1362                 value = getattr(node, field)
1363             except AttributeError:
1364                 continue
1365
1366             yield f"{'  ' * (depth+1)}{field}="
1367
1368             if isinstance(value, list):
1369                 for item in value:
1370                     if isinstance(item, ast.AST):
1371                         yield from _v(item, depth + 2)
1372
1373             elif isinstance(value, ast.AST):
1374                 yield from _v(value, depth + 2)
1375
1376             else:
1377                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1378
1379         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1380
1381     try:
1382         src_ast = ast.parse(src)
1383     except Exception as exc:
1384         raise AssertionError(f"cannot parse source: {exc}") from None
1385
1386     try:
1387         dst_ast = ast.parse(dst)
1388     except Exception as exc:
1389         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1390         raise AssertionError(
1391             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1392             f"Please report a bug on https://github.com/ambv/black/issues.  "
1393             f"This invalid output might be helpful: {log}",
1394         ) from None
1395
1396     src_ast_str = '\n'.join(_v(src_ast))
1397     dst_ast_str = '\n'.join(_v(dst_ast))
1398     if src_ast_str != dst_ast_str:
1399         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1400         raise AssertionError(
1401             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1402             f"the source.  "
1403             f"Please report a bug on https://github.com/ambv/black/issues.  "
1404             f"This diff might be helpful: {log}",
1405         ) from None
1406
1407
1408 def assert_stable(src: str, dst: str, line_length: int) -> None:
1409     """Raises AssertionError if `dst` reformats differently the second time.
1410
1411     This is a temporary sanity check until Black becomes stable.
1412     """
1413     newdst = format_str(dst, line_length=line_length)
1414     if dst != newdst:
1415         log = dump_to_file(
1416             diff(src, dst, 'source', 'first pass'),
1417             diff(dst, newdst, 'first pass', 'second pass'),
1418         )
1419         raise AssertionError(
1420             f"INTERNAL ERROR: Black produced different code on the second pass "
1421             f"of the formatter.  "
1422             f"Please report a bug on https://github.com/ambv/black/issues.  "
1423             f"This diff might be helpful: {log}",
1424         ) from None
1425
1426
1427 def dump_to_file(*output: str) -> str:
1428     """Dumps `output` to a temporary file. Returns path to the file."""
1429     import tempfile
1430
1431     with tempfile.NamedTemporaryFile(
1432         mode='w', prefix='blk_', suffix='.log', delete=False
1433     ) as f:
1434         for lines in output:
1435             f.write(lines)
1436             f.write('\n')
1437     return f.name
1438
1439
1440 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1441     """Returns a udiff string between strings `a` and `b`."""
1442     import difflib
1443
1444     a_lines = [line + '\n' for line in a.split('\n')]
1445     b_lines = [line + '\n' for line in b.split('\n')]
1446     return ''.join(
1447         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1448     )
1449
1450
1451 if __name__ == '__main__':
1452     main()