]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

e2c427cb536c752b6e5d0e00c07ca2b2a2a6a23e
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 from typing import (
11     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
12 )
13
14 from attr import attrib, dataclass, Factory
15 import click
16
17 # lib2to3 fork
18 from blib2to3.pytree import Node, Leaf, type_repr
19 from blib2to3 import pygram, pytree
20 from blib2to3.pgen2 import driver, token
21 from blib2to3.pgen2.parse import ParseError
22
23 __version__ = "18.3a1"
24 DEFAULT_LINE_LENGTH = 88
25 # types
26 syms = pygram.python_symbols
27 FileContent = str
28 Encoding = str
29 Depth = int
30 NodeType = int
31 LeafID = int
32 Priority = int
33 LN = Union[Leaf, Node]
34 out = partial(click.secho, bold=True, err=True)
35 err = partial(click.secho, fg='red', err=True)
36
37
38 class NothingChanged(UserWarning):
39     """Raised by `format_file` when the reformatted code is the same as source."""
40
41
42 class CannotSplit(Exception):
43     """A readable split that fits the allotted line length is impossible.
44
45     Raised by `left_hand_split()` and `right_hand_split()`.
46     """
47
48
49 @click.command()
50 @click.option(
51     '-l',
52     '--line-length',
53     type=int,
54     default=DEFAULT_LINE_LENGTH,
55     help='How many character per line to allow.',
56     show_default=True,
57 )
58 @click.option(
59     '--check',
60     is_flag=True,
61     help=(
62         "Don't write back the files, just return the status.  Return code 0 "
63         "means nothing changed.  Return code 1 means some files were "
64         "reformatted.  Return code 123 means there was an internal error."
65     ),
66 )
67 @click.option(
68     '--fast/--safe',
69     is_flag=True,
70     help='If --fast given, skip temporary sanity checks. [default: --safe]',
71 )
72 @click.version_option(version=__version__)
73 @click.argument(
74     'src',
75     nargs=-1,
76     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
77 )
78 @click.pass_context
79 def main(
80     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
81 ) -> None:
82     """The uncompromising code formatter."""
83     sources: List[Path] = []
84     for s in src:
85         p = Path(s)
86         if p.is_dir():
87             sources.extend(gen_python_files_in_dir(p))
88         elif p.is_file():
89             # if a file was explicitly given, we don't care about its extension
90             sources.append(p)
91         else:
92             err(f'invalid path: {s}')
93     if len(sources) == 0:
94         ctx.exit(0)
95     elif len(sources) == 1:
96         p = sources[0]
97         report = Report()
98         try:
99             changed = format_file_in_place(
100                 p, line_length=line_length, fast=fast, write_back=not check
101             )
102             report.done(p, changed)
103         except Exception as exc:
104             report.failed(p, str(exc))
105         ctx.exit(report.return_code)
106     else:
107         loop = asyncio.get_event_loop()
108         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
109         return_code = 1
110         try:
111             return_code = loop.run_until_complete(
112                 schedule_formatting(
113                     sources, line_length, not check, fast, loop, executor
114                 )
115             )
116         finally:
117             loop.close()
118             ctx.exit(return_code)
119
120
121 async def schedule_formatting(
122     sources: List[Path],
123     line_length: int,
124     write_back: bool,
125     fast: bool,
126     loop: BaseEventLoop,
127     executor: Executor,
128 ) -> int:
129     tasks = {
130         src: loop.run_in_executor(
131             executor, format_file_in_place, src, line_length, fast, write_back
132         )
133         for src in sources
134     }
135     await asyncio.wait(tasks.values())
136     cancelled = []
137     report = Report()
138     for src, task in tasks.items():
139         if not task.done():
140             report.failed(src, 'timed out, cancelling')
141             task.cancel()
142             cancelled.append(task)
143         elif task.exception():
144             report.failed(src, str(task.exception()))
145         else:
146             report.done(src, task.result())
147     if cancelled:
148         await asyncio.wait(cancelled, timeout=2)
149     out('All done! ✨ 🍰 ✨')
150     click.echo(str(report))
151     return report.return_code
152
153
154 def format_file_in_place(
155     src: Path, line_length: int, fast: bool, write_back: bool = False
156 ) -> bool:
157     """Format the file and rewrite if changed. Return True if changed."""
158     try:
159         contents, encoding = format_file(src, line_length=line_length, fast=fast)
160     except NothingChanged:
161         return False
162
163     if write_back:
164         with open(src, "w", encoding=encoding) as f:
165             f.write(contents)
166     return True
167
168
169 def format_file(
170     src: Path, line_length: int, fast: bool
171 ) -> Tuple[FileContent, Encoding]:
172     """Reformats a file and returns its contents and encoding."""
173     with tokenize.open(src) as src_buffer:
174         src_contents = src_buffer.read()
175     if src_contents.strip() == '':
176         raise NothingChanged(src)
177
178     dst_contents = format_str(src_contents, line_length=line_length)
179     if src_contents == dst_contents:
180         raise NothingChanged(src)
181
182     if not fast:
183         assert_equivalent(src_contents, dst_contents)
184         assert_stable(src_contents, dst_contents, line_length=line_length)
185     return dst_contents, src_buffer.encoding
186
187
188 def format_str(src_contents: str, line_length: int) -> FileContent:
189     """Reformats a string and returns new contents."""
190     src_node = lib2to3_parse(src_contents)
191     dst_contents = ""
192     comments: List[Line] = []
193     lines = LineGenerator()
194     elt = EmptyLineTracker()
195     empty_line = Line()
196     after = 0
197     for current_line in lines.visit(src_node):
198         for _ in range(after):
199             dst_contents += str(empty_line)
200         before, after = elt.maybe_empty_lines(current_line)
201         for _ in range(before):
202             dst_contents += str(empty_line)
203         if not current_line.is_comment:
204             for comment in comments:
205                 dst_contents += str(comment)
206             comments = []
207             for line in split_line(current_line, line_length=line_length):
208                 dst_contents += str(line)
209         else:
210             comments.append(current_line)
211     for comment in comments:
212         dst_contents += str(comment)
213     return dst_contents
214
215
216 def lib2to3_parse(src_txt: str) -> Node:
217     """Given a string with source, return the lib2to3 Node."""
218     grammar = pygram.python_grammar_no_print_statement
219     drv = driver.Driver(grammar, pytree.convert)
220     if src_txt[-1] != '\n':
221         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
222         src_txt += nl
223     try:
224         result = drv.parse_string(src_txt, True)
225     except ParseError as pe:
226         lineno, column = pe.context[1]
227         lines = src_txt.splitlines()
228         try:
229             faulty_line = lines[lineno - 1]
230         except IndexError:
231             faulty_line = "<line number missing in source>"
232         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
233
234     if isinstance(result, Leaf):
235         result = Node(syms.file_input, [result])
236     return result
237
238
239 def lib2to3_unparse(node: Node) -> str:
240     """Given a lib2to3 node, return its string representation."""
241     code = str(node)
242     return code
243
244
245 T = TypeVar('T')
246
247
248 class Visitor(Generic[T]):
249     """Basic lib2to3 visitor that yields things on visiting."""
250
251     def visit(self, node: LN) -> Iterator[T]:
252         if node.type < 256:
253             name = token.tok_name[node.type]
254         else:
255             name = type_repr(node.type)
256         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
257
258     def visit_default(self, node: LN) -> Iterator[T]:
259         if isinstance(node, Node):
260             for child in node.children:
261                 yield from self.visit(child)
262
263
264 @dataclass
265 class DebugVisitor(Visitor[T]):
266     tree_depth: int = attrib(default=0)
267
268     def visit_default(self, node: LN) -> Iterator[T]:
269         indent = ' ' * (2 * self.tree_depth)
270         if isinstance(node, Node):
271             _type = type_repr(node.type)
272             out(f'{indent}{_type}', fg='yellow')
273             self.tree_depth += 1
274             for child in node.children:
275                 yield from self.visit(child)
276
277             self.tree_depth -= 1
278             out(f'{indent}/{_type}', fg='yellow', bold=False)
279         else:
280             _type = token.tok_name.get(node.type, str(node.type))
281             out(f'{indent}{_type}', fg='blue', nl=False)
282             if node.prefix:
283                 # We don't have to handle prefixes for `Node` objects since
284                 # that delegates to the first child anyway.
285                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
286             out(f' {node.value!r}', fg='blue', bold=False)
287
288
289 KEYWORDS = set(keyword.kwlist)
290 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
291 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
292 STATEMENT = {
293     syms.if_stmt,
294     syms.while_stmt,
295     syms.for_stmt,
296     syms.try_stmt,
297     syms.except_clause,
298     syms.with_stmt,
299     syms.funcdef,
300     syms.classdef,
301 }
302 STANDALONE_COMMENT = 153
303 LOGIC_OPERATORS = {'and', 'or'}
304 COMPARATORS = {
305     token.LESS,
306     token.GREATER,
307     token.EQEQUAL,
308     token.NOTEQUAL,
309     token.LESSEQUAL,
310     token.GREATEREQUAL,
311 }
312 MATH_OPERATORS = {
313     token.PLUS,
314     token.MINUS,
315     token.STAR,
316     token.SLASH,
317     token.VBAR,
318     token.AMPER,
319     token.PERCENT,
320     token.CIRCUMFLEX,
321     token.LEFTSHIFT,
322     token.RIGHTSHIFT,
323     token.DOUBLESTAR,
324     token.DOUBLESLASH,
325 }
326 COMPREHENSION_PRIORITY = 20
327 COMMA_PRIORITY = 10
328 LOGIC_PRIORITY = 5
329 STRING_PRIORITY = 4
330 COMPARATOR_PRIORITY = 3
331 MATH_PRIORITY = 1
332
333
334 @dataclass
335 class BracketTracker:
336     depth: int = attrib(default=0)
337     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
338     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
339     previous: Optional[Leaf] = attrib(default=None)
340
341     def mark(self, leaf: Leaf) -> None:
342         if leaf.type == token.COMMENT:
343             return
344
345         if leaf.type in CLOSING_BRACKETS:
346             self.depth -= 1
347             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
348             leaf.opening_bracket = opening_bracket  # type: ignore
349         leaf.bracket_depth = self.depth  # type: ignore
350         if self.depth == 0:
351             delim = is_delimiter(leaf)
352             if delim:
353                 self.delimiters[id(leaf)] = delim
354             elif self.previous is not None:
355                 if leaf.type == token.STRING and self.previous.type == token.STRING:
356                     self.delimiters[id(self.previous)] = STRING_PRIORITY
357                 elif (
358                     leaf.type == token.NAME and
359                     leaf.value == 'for' and
360                     leaf.parent and
361                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
362                 ):
363                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
364                 elif (
365                     leaf.type == token.NAME and
366                     leaf.value == 'if' and
367                     leaf.parent and
368                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
369                 ):
370                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
371         if leaf.type in OPENING_BRACKETS:
372             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
373             self.depth += 1
374         self.previous = leaf
375
376     def any_open_brackets(self) -> bool:
377         """Returns True if there is an yet unmatched open bracket on the line."""
378         return bool(self.bracket_match)
379
380     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
381         """Returns the highest priority of a delimiter found on the line.
382
383         Values are consistent with what `is_delimiter()` returns.
384         """
385         return max(v for k, v in self.delimiters.items() if k not in exclude)
386
387
388 @dataclass
389 class Line:
390     depth: int = attrib(default=0)
391     leaves: List[Leaf] = attrib(default=Factory(list))
392     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
393     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
394     inside_brackets: bool = attrib(default=False)
395
396     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
397         has_value = leaf.value.strip()
398         if not has_value:
399             return
400
401         if self.leaves and not preformatted:
402             # Note: at this point leaf.prefix should be empty except for
403             # imports, for which we only preserve newlines.
404             leaf.prefix += whitespace(leaf)
405         if self.inside_brackets or not preformatted:
406             self.bracket_tracker.mark(leaf)
407             self.maybe_remove_trailing_comma(leaf)
408             if self.maybe_adapt_standalone_comment(leaf):
409                 return
410
411         if not self.append_comment(leaf):
412             self.leaves.append(leaf)
413
414     @property
415     def is_comment(self) -> bool:
416         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
417
418     @property
419     def is_decorator(self) -> bool:
420         return bool(self) and self.leaves[0].type == token.AT
421
422     @property
423     def is_import(self) -> bool:
424         return bool(self) and is_import(self.leaves[0])
425
426     @property
427     def is_class(self) -> bool:
428         return (
429             bool(self) and
430             self.leaves[0].type == token.NAME and
431             self.leaves[0].value == 'class'
432         )
433
434     @property
435     def is_def(self) -> bool:
436         """Also returns True for async defs."""
437         try:
438             first_leaf = self.leaves[0]
439         except IndexError:
440             return False
441
442         try:
443             second_leaf: Optional[Leaf] = self.leaves[1]
444         except IndexError:
445             second_leaf = None
446         return (
447             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
448             (
449                 first_leaf.type == token.NAME and
450                 first_leaf.value == 'async' and
451                 second_leaf is not None and
452                 second_leaf.type == token.NAME and
453                 second_leaf.value == 'def'
454             )
455         )
456
457     @property
458     def is_flow_control(self) -> bool:
459         return (
460             bool(self) and
461             self.leaves[0].type == token.NAME and
462             self.leaves[0].value in FLOW_CONTROL
463         )
464
465     @property
466     def is_yield(self) -> bool:
467         return (
468             bool(self) and
469             self.leaves[0].type == token.NAME and
470             self.leaves[0].value == 'yield'
471         )
472
473     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
474         if not (
475             self.leaves and
476             self.leaves[-1].type == token.COMMA and
477             closing.type in CLOSING_BRACKETS
478         ):
479             return False
480
481         if closing.type == token.RSQB or closing.type == token.RBRACE:
482             self.leaves.pop()
483             return True
484
485         # For parens let's check if it's safe to remove the comma.  If the
486         # trailing one is the only one, we might mistakenly change a tuple
487         # into a different type by removing the comma.
488         depth = closing.bracket_depth + 1  # type: ignore
489         commas = 0
490         opening = closing.opening_bracket  # type: ignore
491         for _opening_index, leaf in enumerate(self.leaves):
492             if leaf is opening:
493                 break
494
495         else:
496             return False
497
498         for leaf in self.leaves[_opening_index + 1:]:
499             if leaf is closing:
500                 break
501
502             bracket_depth = leaf.bracket_depth  # type: ignore
503             if bracket_depth == depth and leaf.type == token.COMMA:
504                 commas += 1
505         if commas > 1:
506             self.leaves.pop()
507             return True
508
509         return False
510
511     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
512         """Hack a standalone comment to act as a trailing comment for line splitting.
513
514         If this line has brackets and a standalone `comment`, we need to adapt
515         it to be able to still reformat the line.
516
517         This is not perfect, the line to which the standalone comment gets
518         appended will appear "too long" when splitting.
519         """
520         if not (
521             comment.type == STANDALONE_COMMENT and
522             self.bracket_tracker.any_open_brackets()
523         ):
524             return False
525
526         comment.type = token.COMMENT
527         comment.prefix = '\n' + '    ' * (self.depth + 1)
528         return self.append_comment(comment)
529
530     def append_comment(self, comment: Leaf) -> bool:
531         if comment.type != token.COMMENT:
532             return False
533
534         try:
535             after = id(self.last_non_delimiter())
536         except LookupError:
537             comment.type = STANDALONE_COMMENT
538             comment.prefix = ''
539             return False
540
541         else:
542             if after in self.comments:
543                 self.comments[after].value += str(comment)
544             else:
545                 self.comments[after] = comment
546             return True
547
548     def last_non_delimiter(self) -> Leaf:
549         for i in range(len(self.leaves)):
550             last = self.leaves[-i - 1]
551             if not is_delimiter(last):
552                 return last
553
554         raise LookupError("No non-delimiters found")
555
556     def __str__(self) -> str:
557         if not self:
558             return '\n'
559
560         indent = '    ' * self.depth
561         leaves = iter(self.leaves)
562         first = next(leaves)
563         res = f'{first.prefix}{indent}{first.value}'
564         for leaf in leaves:
565             res += str(leaf)
566         for comment in self.comments.values():
567             res += str(comment)
568         return res + '\n'
569
570     def __bool__(self) -> bool:
571         return bool(self.leaves or self.comments)
572
573
574 @dataclass
575 class EmptyLineTracker:
576     """Provides a stateful method that returns the number of potential extra
577     empty lines needed before and after the currently processed line.
578
579     Note: this tracker works on lines that haven't been split yet.
580     """
581     previous_line: Optional[Line] = attrib(default=None)
582     previous_after: int = attrib(default=0)
583     previous_defs: List[int] = attrib(default=Factory(list))
584
585     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
586         """Returns the number of extra empty lines before and after the `current_line`.
587
588         This is for separating `def`, `async def` and `class` with extra empty lines
589         (two on module-level), as well as providing an extra empty line after flow
590         control keywords to make them more prominent.
591         """
592         before, after = self._maybe_empty_lines(current_line)
593         self.previous_after = after
594         self.previous_line = current_line
595         return before, after
596
597     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
598         before = 0
599         depth = current_line.depth
600         while self.previous_defs and self.previous_defs[-1] >= depth:
601             self.previous_defs.pop()
602             before = (1 if depth else 2) - self.previous_after
603         is_decorator = current_line.is_decorator
604         if is_decorator or current_line.is_def or current_line.is_class:
605             if not is_decorator:
606                 self.previous_defs.append(depth)
607             if self.previous_line is None:
608                 # Don't insert empty lines before the first line in the file.
609                 return 0, 0
610
611             if self.previous_line and self.previous_line.is_decorator:
612                 # Don't insert empty lines between decorators.
613                 return 0, 0
614
615             newlines = 2
616             if current_line.depth:
617                 newlines -= 1
618             newlines -= self.previous_after
619             return newlines, 0
620
621         if current_line.is_flow_control:
622             return before, 1
623
624         if (
625             self.previous_line and
626             self.previous_line.is_import and
627             not current_line.is_import and
628             depth == self.previous_line.depth
629         ):
630             return (before or 1), 0
631
632         if (
633             self.previous_line and
634             self.previous_line.is_yield and
635             (not current_line.is_yield or depth != self.previous_line.depth)
636         ):
637             return (before or 1), 0
638
639         return before, 0
640
641
642 @dataclass
643 class LineGenerator(Visitor[Line]):
644     """Generates reformatted Line objects.  Empty lines are not emitted.
645
646     Note: destroys the tree it's visiting by mutating prefixes of its leaves
647     in ways that will no longer stringify to valid Python code on the tree.
648     """
649     current_line: Line = attrib(default=Factory(Line))
650     standalone_comments: List[Leaf] = attrib(default=Factory(list))
651
652     def line(self, indent: int = 0) -> Iterator[Line]:
653         """Generate a line.
654
655         If the line is empty, only emit if it makes sense.
656         If the line is too long, split it first and then generate.
657
658         If any lines were generated, set up a new current_line.
659         """
660         if not self.current_line:
661             self.current_line.depth += indent
662             return  # Line is empty, don't emit. Creating a new one unnecessary.
663
664         complete_line = self.current_line
665         self.current_line = Line(depth=complete_line.depth + indent)
666         yield complete_line
667
668     def visit_default(self, node: LN) -> Iterator[Line]:
669         if isinstance(node, Leaf):
670             for comment in generate_comments(node):
671                 if self.current_line.bracket_tracker.any_open_brackets():
672                     # any comment within brackets is subject to splitting
673                     self.current_line.append(comment)
674                 elif comment.type == token.COMMENT:
675                     # regular trailing comment
676                     self.current_line.append(comment)
677                     yield from self.line()
678
679                 else:
680                     # regular standalone comment, to be processed later (see
681                     # docstring in `generate_comments()`
682                     self.standalone_comments.append(comment)
683             normalize_prefix(node)
684             if node.type not in WHITESPACE:
685                 for comment in self.standalone_comments:
686                     yield from self.line()
687
688                     self.current_line.append(comment)
689                     yield from self.line()
690
691                 self.standalone_comments = []
692                 self.current_line.append(node)
693         yield from super().visit_default(node)
694
695     def visit_suite(self, node: Node) -> Iterator[Line]:
696         """Body of a statement after a colon."""
697         children = iter(node.children)
698         # Process newline before indenting.  It might contain an inline
699         # comment that should go right after the colon.
700         newline = next(children)
701         yield from self.visit(newline)
702         yield from self.line(+1)
703
704         for child in children:
705             yield from self.visit(child)
706
707         yield from self.line(-1)
708
709     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
710         """Visit a statement.
711
712         The relevant Python language keywords for this statement are NAME leaves
713         within it.
714         """
715         for child in node.children:
716             if child.type == token.NAME and child.value in keywords:  # type: ignore
717                 yield from self.line()
718
719             yield from self.visit(child)
720
721     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
722         """A statement without nested statements."""
723         is_suite_like = node.parent and node.parent.type in STATEMENT
724         if is_suite_like:
725             yield from self.line(+1)
726             yield from self.visit_default(node)
727             yield from self.line(-1)
728
729         else:
730             yield from self.line()
731             yield from self.visit_default(node)
732
733     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
734         yield from self.line()
735
736         children = iter(node.children)
737         for child in children:
738             yield from self.visit(child)
739
740             if child.type == token.NAME and child.value == 'async':  # type: ignore
741                 break
742
743         internal_stmt = next(children)
744         for child in internal_stmt.children:
745             yield from self.visit(child)
746
747     def visit_decorators(self, node: Node) -> Iterator[Line]:
748         for child in node.children:
749             yield from self.line()
750             yield from self.visit(child)
751
752     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
753         yield from self.line()
754
755     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
756         yield from self.visit_default(leaf)
757         yield from self.line()
758
759     def __attrs_post_init__(self) -> None:
760         """You are in a twisty little maze of passages."""
761         v = self.visit_stmt
762         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
763         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
764         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
765         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
766         self.visit_except_clause = partial(v, keywords={'except'})
767         self.visit_funcdef = partial(v, keywords={'def'})
768         self.visit_with_stmt = partial(v, keywords={'with'})
769         self.visit_classdef = partial(v, keywords={'class'})
770         self.visit_async_funcdef = self.visit_async_stmt
771         self.visit_decorated = self.visit_decorators
772
773
774 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
775 OPENING_BRACKETS = set(BRACKET.keys())
776 CLOSING_BRACKETS = set(BRACKET.values())
777 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
778
779
780 def whitespace(leaf: Leaf) -> str:
781     """Return whitespace prefix if needed for the given `leaf`."""
782     NO = ''
783     SPACE = ' '
784     DOUBLESPACE = '  '
785     t = leaf.type
786     p = leaf.parent
787     v = leaf.value
788     if t == token.COLON:
789         return NO
790
791     if t == token.COMMA:
792         return NO
793
794     if t == token.RPAR:
795         return NO
796
797     if t == token.COMMENT:
798         return DOUBLESPACE
799
800     if t == STANDALONE_COMMENT:
801         return NO
802
803     if t in CLOSING_BRACKETS:
804         return NO
805
806     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
807     prev = leaf.prev_sibling
808     if not prev:
809         prevp = preceding_leaf(p)
810         if not prevp or prevp.type in OPENING_BRACKETS:
811             return NO
812
813         if prevp.type == token.EQUAL:
814             if prevp.parent and prevp.parent.type in {
815                 syms.typedargslist,
816                 syms.varargslist,
817                 syms.parameters,
818                 syms.arglist,
819                 syms.argument,
820             }:
821                 return NO
822
823         elif prevp.type == token.DOUBLESTAR:
824             if prevp.parent and prevp.parent.type in {
825                 syms.typedargslist,
826                 syms.varargslist,
827                 syms.parameters,
828                 syms.arglist,
829                 syms.dictsetmaker,
830             }:
831                 return NO
832
833         elif prevp.type == token.COLON:
834             if prevp.parent and prevp.parent.type == syms.subscript:
835                 return NO
836
837         elif prevp.parent and prevp.parent.type == syms.factor:
838             return NO
839
840     elif prev.type in OPENING_BRACKETS:
841         return NO
842
843     if p.type in {syms.parameters, syms.arglist}:
844         # untyped function signatures or calls
845         if t == token.RPAR:
846             return NO
847
848         if not prev or prev.type != token.COMMA:
849             return NO
850
851     if p.type == syms.varargslist:
852         # lambdas
853         if t == token.RPAR:
854             return NO
855
856         if prev and prev.type != token.COMMA:
857             return NO
858
859     elif p.type == syms.typedargslist:
860         # typed function signatures
861         if not prev:
862             return NO
863
864         if t == token.EQUAL:
865             if prev.type != syms.tname:
866                 return NO
867
868         elif prev.type == token.EQUAL:
869             # A bit hacky: if the equal sign has whitespace, it means we
870             # previously found it's a typed argument.  So, we're using that, too.
871             return prev.prefix
872
873         elif prev.type != token.COMMA:
874             return NO
875
876     elif p.type == syms.tname:
877         # type names
878         if not prev:
879             prevp = preceding_leaf(p)
880             if not prevp or prevp.type != token.COMMA:
881                 return NO
882
883     elif p.type == syms.trailer:
884         # attributes and calls
885         if t == token.LPAR or t == token.RPAR:
886             return NO
887
888         if not prev:
889             if t == token.DOT:
890                 prevp = preceding_leaf(p)
891                 if not prevp or prevp.type != token.NUMBER:
892                     return NO
893
894             elif t == token.LSQB:
895                 return NO
896
897         elif prev.type != token.COMMA:
898             return NO
899
900     elif p.type == syms.argument:
901         # single argument
902         if t == token.EQUAL:
903             return NO
904
905         if not prev:
906             prevp = preceding_leaf(p)
907             if not prevp or prevp.type == token.LPAR:
908                 return NO
909
910         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
911             return NO
912
913     elif p.type == syms.decorator:
914         # decorators
915         return NO
916
917     elif p.type == syms.dotted_name:
918         if prev:
919             return NO
920
921         prevp = preceding_leaf(p)
922         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
923             return NO
924
925     elif p.type == syms.classdef:
926         if t == token.LPAR:
927             return NO
928
929         if prev and prev.type == token.LPAR:
930             return NO
931
932     elif p.type == syms.subscript:
933         # indexing
934         if not prev or prev.type == token.COLON:
935             return NO
936
937     elif p.type == syms.atom:
938         if prev and t == token.DOT:
939             # dots, but not the first one.
940             return NO
941
942     elif (
943         p.type == syms.listmaker or
944         p.type == syms.testlist_gexp or
945         p.type == syms.subscriptlist
946     ):
947         # list interior, including unpacking
948         if not prev:
949             return NO
950
951     elif p.type == syms.dictsetmaker:
952         # dict and set interior, including unpacking
953         if not prev:
954             return NO
955
956         if prev.type == token.DOUBLESTAR:
957             return NO
958
959     elif p.type == syms.factor or p.type == syms.star_expr:
960         # unary ops
961         if not prev:
962             prevp = preceding_leaf(p)
963             if not prevp or prevp.type in OPENING_BRACKETS:
964                 return NO
965
966             prevp_parent = prevp.parent
967             assert prevp_parent is not None
968             if prevp.type == token.COLON and prevp_parent.type in {
969                 syms.subscript, syms.sliceop
970             }:
971                 return NO
972
973             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
974                 return NO
975
976         elif t == token.NAME or t == token.NUMBER:
977             return NO
978
979     elif p.type == syms.import_from:
980         if t == token.DOT:
981             if prev and prev.type == token.DOT:
982                 return NO
983
984         elif t == token.NAME:
985             if v == 'import':
986                 return SPACE
987
988             if prev and prev.type == token.DOT:
989                 return NO
990
991     elif p.type == syms.sliceop:
992         return NO
993
994     return SPACE
995
996
997 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
998     """Returns the first leaf that precedes `node`, if any."""
999     while node:
1000         res = node.prev_sibling
1001         if res:
1002             if isinstance(res, Leaf):
1003                 return res
1004
1005             try:
1006                 return list(res.leaves())[-1]
1007
1008             except IndexError:
1009                 return None
1010
1011         node = node.parent
1012     return None
1013
1014
1015 def is_delimiter(leaf: Leaf) -> int:
1016     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1017
1018     Higher numbers are higher priority.
1019     """
1020     if leaf.type == token.COMMA:
1021         return COMMA_PRIORITY
1022
1023     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1024         return LOGIC_PRIORITY
1025
1026     if leaf.type in COMPARATORS:
1027         return COMPARATOR_PRIORITY
1028
1029     if (
1030         leaf.type in MATH_OPERATORS and
1031         leaf.parent and
1032         leaf.parent.type not in {syms.factor, syms.star_expr}
1033     ):
1034         return MATH_PRIORITY
1035
1036     return 0
1037
1038
1039 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1040     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1041
1042     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1043     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1044     move because it does away with modifying the grammar to include all the
1045     possible places in which comments can be placed.
1046
1047     The sad consequence for us though is that comments don't "belong" anywhere.
1048     This is why this function generates simple parentless Leaf objects for
1049     comments.  We simply don't know what the correct parent should be.
1050
1051     No matter though, we can live without this.  We really only need to
1052     differentiate between inline and standalone comments.  The latter don't
1053     share the line with any code.
1054
1055     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1056     are emitted with a fake STANDALONE_COMMENT token identifier.
1057     """
1058     if not leaf.prefix:
1059         return
1060
1061     if '#' not in leaf.prefix:
1062         return
1063
1064     before_comment, content = leaf.prefix.split('#', 1)
1065     content = content.rstrip()
1066     if content and (content[0] not in {' ', '!', '#'}):
1067         content = ' ' + content
1068     is_standalone_comment = (
1069         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1070     )
1071     if not is_standalone_comment:
1072         # simple trailing comment
1073         yield Leaf(token.COMMENT, value='#' + content)
1074         return
1075
1076     for line in ('#' + content).split('\n'):
1077         line = line.lstrip()
1078         if not line.startswith('#'):
1079             continue
1080
1081         yield Leaf(STANDALONE_COMMENT, line)
1082
1083
1084 def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]:
1085     """Splits a `line` into potentially many lines.
1086
1087     They should fit in the allotted `line_length` but might not be able to.
1088     `inner` signifies that there were a pair of brackets somewhere around the
1089     current `line`, possibly transitively. This means we can fallback to splitting
1090     by delimiters if the LHS/RHS don't yield any results.
1091     """
1092     line_str = str(line).strip('\n')
1093     if len(line_str) <= line_length and '\n' not in line_str:
1094         yield line
1095         return
1096
1097     if line.is_def:
1098         split_funcs = [left_hand_split]
1099     elif line.inside_brackets:
1100         split_funcs = [delimiter_split]
1101         if '\n' not in line_str:
1102             # Only attempt RHS if we don't have multiline strings or comments
1103             # on this line.
1104             split_funcs.append(right_hand_split)
1105     else:
1106         split_funcs = [right_hand_split]
1107     for split_func in split_funcs:
1108         # We are accumulating lines in `result` because we might want to abort
1109         # mission and return the original line in the end, or attempt a different
1110         # split altogether.
1111         result: List[Line] = []
1112         try:
1113             for l in split_func(line):
1114                 if str(l).strip('\n') == line_str:
1115                     raise CannotSplit("Split function returned an unchanged result")
1116
1117                 result.extend(split_line(l, line_length=line_length, inner=True))
1118         except CannotSplit as cs:
1119             continue
1120
1121         else:
1122             yield from result
1123             break
1124
1125     else:
1126         yield line
1127
1128
1129 def left_hand_split(line: Line) -> Iterator[Line]:
1130     """Split line into many lines, starting with the first matching bracket pair.
1131
1132     Note: this usually looks weird, only use this for function definitions.
1133     Prefer RHS otherwise.
1134     """
1135     head = Line(depth=line.depth)
1136     body = Line(depth=line.depth + 1, inside_brackets=True)
1137     tail = Line(depth=line.depth)
1138     tail_leaves: List[Leaf] = []
1139     body_leaves: List[Leaf] = []
1140     head_leaves: List[Leaf] = []
1141     current_leaves = head_leaves
1142     matching_bracket = None
1143     for leaf in line.leaves:
1144         if (
1145             current_leaves is body_leaves and
1146             leaf.type in CLOSING_BRACKETS and
1147             leaf.opening_bracket is matching_bracket  # type: ignore
1148         ):
1149             current_leaves = tail_leaves
1150         current_leaves.append(leaf)
1151         if current_leaves is head_leaves:
1152             if leaf.type in OPENING_BRACKETS:
1153                 matching_bracket = leaf
1154                 current_leaves = body_leaves
1155     # Since body is a new indent level, remove spurious leading whitespace.
1156     if body_leaves:
1157         normalize_prefix(body_leaves[0])
1158     # Build the new lines.
1159     for result, leaves in (
1160         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1161     ):
1162         for leaf in leaves:
1163             result.append(leaf, preformatted=True)
1164             comment_after = line.comments.get(id(leaf))
1165             if comment_after:
1166                 result.append(comment_after, preformatted=True)
1167     # Check if the split succeeded.
1168     tail_len = len(str(tail))
1169     if not body:
1170         if tail_len == 0:
1171             raise CannotSplit("Splitting brackets produced the same line")
1172
1173         elif tail_len < 3:
1174             raise CannotSplit(
1175                 f"Splitting brackets on an empty body to save "
1176                 f"{tail_len} characters is not worth it"
1177             )
1178
1179     for result in (head, body, tail):
1180         if result:
1181             yield result
1182
1183
1184 def right_hand_split(line: Line) -> Iterator[Line]:
1185     """Split line into many lines, starting with the last matching bracket pair."""
1186     head = Line(depth=line.depth)
1187     body = Line(depth=line.depth + 1, inside_brackets=True)
1188     tail = Line(depth=line.depth)
1189     tail_leaves: List[Leaf] = []
1190     body_leaves: List[Leaf] = []
1191     head_leaves: List[Leaf] = []
1192     current_leaves = tail_leaves
1193     opening_bracket = None
1194     for leaf in reversed(line.leaves):
1195         if current_leaves is body_leaves:
1196             if leaf is opening_bracket:
1197                 current_leaves = head_leaves
1198         current_leaves.append(leaf)
1199         if current_leaves is tail_leaves:
1200             if leaf.type in CLOSING_BRACKETS:
1201                 opening_bracket = leaf.opening_bracket  # type: ignore
1202                 current_leaves = body_leaves
1203     tail_leaves.reverse()
1204     body_leaves.reverse()
1205     head_leaves.reverse()
1206     # Since body is a new indent level, remove spurious leading whitespace.
1207     if body_leaves:
1208         normalize_prefix(body_leaves[0])
1209     # Build the new lines.
1210     for result, leaves in (
1211         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1212     ):
1213         for leaf in leaves:
1214             result.append(leaf, preformatted=True)
1215             comment_after = line.comments.get(id(leaf))
1216             if comment_after:
1217                 result.append(comment_after, preformatted=True)
1218     # Check if the split succeeded.
1219     tail_len = len(str(tail).strip('\n'))
1220     if not body:
1221         if tail_len == 0:
1222             raise CannotSplit("Splitting brackets produced the same line")
1223
1224         elif tail_len < 3:
1225             raise CannotSplit(
1226                 f"Splitting brackets on an empty body to save "
1227                 f"{tail_len} characters is not worth it"
1228             )
1229
1230     for result in (head, body, tail):
1231         if result:
1232             yield result
1233
1234
1235 def delimiter_split(line: Line) -> Iterator[Line]:
1236     """Split according to delimiters of the highest priority.
1237
1238     This kind of split doesn't increase indentation.
1239     """
1240     try:
1241         last_leaf = line.leaves[-1]
1242     except IndexError:
1243         raise CannotSplit("Line empty")
1244
1245     delimiters = line.bracket_tracker.delimiters
1246     try:
1247         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1248     except ValueError:
1249         raise CannotSplit("No delimiters found")
1250
1251     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1252     for leaf in line.leaves:
1253         current_line.append(leaf, preformatted=True)
1254         comment_after = line.comments.get(id(leaf))
1255         if comment_after:
1256             current_line.append(comment_after, preformatted=True)
1257         leaf_priority = delimiters.get(id(leaf))
1258         if leaf_priority == delimiter_priority:
1259             normalize_prefix(current_line.leaves[0])
1260             yield current_line
1261
1262             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1263     if current_line:
1264         if (
1265             delimiter_priority == COMMA_PRIORITY and
1266             current_line.leaves[-1].type != token.COMMA
1267         ):
1268             current_line.append(Leaf(token.COMMA, ','))
1269         normalize_prefix(current_line.leaves[0])
1270         yield current_line
1271
1272
1273 def is_import(leaf: Leaf) -> bool:
1274     """Returns True if the given leaf starts an import statement."""
1275     p = leaf.parent
1276     t = leaf.type
1277     v = leaf.value
1278     return bool(
1279         t == token.NAME and
1280         (
1281             (v == 'import' and p and p.type == syms.import_name) or
1282             (v == 'from' and p and p.type == syms.import_from)
1283         )
1284     )
1285
1286
1287 def normalize_prefix(leaf: Leaf) -> None:
1288     """Leave existing extra newlines for imports.  Remove everything else."""
1289     if is_import(leaf):
1290         spl = leaf.prefix.split('#', 1)
1291         nl_count = spl[0].count('\n')
1292         if len(spl) > 1:
1293             # Skip one newline since it was for a standalone comment.
1294             nl_count -= 1
1295         leaf.prefix = '\n' * nl_count
1296         return
1297
1298     leaf.prefix = ''
1299
1300
1301 PYTHON_EXTENSIONS = {'.py'}
1302 BLACKLISTED_DIRECTORIES = {
1303     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1304 }
1305
1306
1307 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1308     for child in path.iterdir():
1309         if child.is_dir():
1310             if child.name in BLACKLISTED_DIRECTORIES:
1311                 continue
1312
1313             yield from gen_python_files_in_dir(child)
1314
1315         elif child.suffix in PYTHON_EXTENSIONS:
1316             yield child
1317
1318
1319 @dataclass
1320 class Report:
1321     """Provides a reformatting counter."""
1322     change_count: int = attrib(default=0)
1323     same_count: int = attrib(default=0)
1324     failure_count: int = attrib(default=0)
1325
1326     def done(self, src: Path, changed: bool) -> None:
1327         """Increment the counter for successful reformatting. Write out a message."""
1328         if changed:
1329             out(f'reformatted {src}')
1330             self.change_count += 1
1331         else:
1332             out(f'{src} already well formatted, good job.', bold=False)
1333             self.same_count += 1
1334
1335     def failed(self, src: Path, message: str) -> None:
1336         """Increment the counter for failed reformatting. Write out a message."""
1337         err(f'error: cannot format {src}: {message}')
1338         self.failure_count += 1
1339
1340     @property
1341     def return_code(self) -> int:
1342         """Which return code should the app use considering the current state."""
1343         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1344         # 126 we have special returncodes reserved by the shell.
1345         if self.failure_count:
1346             return 123
1347
1348         elif self.change_count:
1349             return 1
1350
1351         return 0
1352
1353     def __str__(self) -> str:
1354         """A color report of the current state.
1355
1356         Use `click.unstyle` to remove colors.
1357         """
1358         report = []
1359         if self.change_count:
1360             s = 's' if self.change_count > 1 else ''
1361             report.append(
1362                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1363             )
1364         if self.same_count:
1365             s = 's' if self.same_count > 1 else ''
1366             report.append(f'{self.same_count} file{s} left unchanged')
1367         if self.failure_count:
1368             s = 's' if self.failure_count > 1 else ''
1369             report.append(
1370                 click.style(
1371                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1372                 )
1373             )
1374         return ', '.join(report) + '.'
1375
1376
1377 def assert_equivalent(src: str, dst: str) -> None:
1378     """Raises AssertionError if `src` and `dst` aren't equivalent.
1379
1380     This is a temporary sanity check until Black becomes stable.
1381     """
1382
1383     import ast
1384     import traceback
1385
1386     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1387         """Simple visitor generating strings to compare ASTs by content."""
1388         yield f"{'  ' * depth}{node.__class__.__name__}("
1389
1390         for field in sorted(node._fields):
1391             try:
1392                 value = getattr(node, field)
1393             except AttributeError:
1394                 continue
1395
1396             yield f"{'  ' * (depth+1)}{field}="
1397
1398             if isinstance(value, list):
1399                 for item in value:
1400                     if isinstance(item, ast.AST):
1401                         yield from _v(item, depth + 2)
1402
1403             elif isinstance(value, ast.AST):
1404                 yield from _v(value, depth + 2)
1405
1406             else:
1407                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1408
1409         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1410
1411     try:
1412         src_ast = ast.parse(src)
1413     except Exception as exc:
1414         raise AssertionError(f"cannot parse source: {exc}") from None
1415
1416     try:
1417         dst_ast = ast.parse(dst)
1418     except Exception as exc:
1419         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1420         raise AssertionError(
1421             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1422             f"Please report a bug on https://github.com/ambv/black/issues.  "
1423             f"This invalid output might be helpful: {log}",
1424         ) from None
1425
1426     src_ast_str = '\n'.join(_v(src_ast))
1427     dst_ast_str = '\n'.join(_v(dst_ast))
1428     if src_ast_str != dst_ast_str:
1429         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1430         raise AssertionError(
1431             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1432             f"the source.  "
1433             f"Please report a bug on https://github.com/ambv/black/issues.  "
1434             f"This diff might be helpful: {log}",
1435         ) from None
1436
1437
1438 def assert_stable(src: str, dst: str, line_length: int) -> None:
1439     """Raises AssertionError if `dst` reformats differently the second time.
1440
1441     This is a temporary sanity check until Black becomes stable.
1442     """
1443     newdst = format_str(dst, line_length=line_length)
1444     if dst != newdst:
1445         log = dump_to_file(
1446             diff(src, dst, 'source', 'first pass'),
1447             diff(dst, newdst, 'first pass', 'second pass'),
1448         )
1449         raise AssertionError(
1450             f"INTERNAL ERROR: Black produced different code on the second pass "
1451             f"of the formatter.  "
1452             f"Please report a bug on https://github.com/ambv/black/issues.  "
1453             f"This diff might be helpful: {log}",
1454         ) from None
1455
1456
1457 def dump_to_file(*output: str) -> str:
1458     """Dumps `output` to a temporary file. Returns path to the file."""
1459     import tempfile
1460
1461     with tempfile.NamedTemporaryFile(
1462         mode='w', prefix='blk_', suffix='.log', delete=False
1463     ) as f:
1464         for lines in output:
1465             f.write(lines)
1466             f.write('\n')
1467     return f.name
1468
1469
1470 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1471     """Returns a udiff string between strings `a` and `b`."""
1472     import difflib
1473
1474     a_lines = [line + '\n' for line in a.split('\n')]
1475     b_lines = [line + '\n' for line in b.split('\n')]
1476     return ''.join(
1477         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1478     )
1479
1480
1481 if __name__ == '__main__':
1482     main()