All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
8 from pathlib import Path
12 Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
15 from attr import dataclass, Factory
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
27 syms = pygram.python_symbols
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
39 class NothingChanged(UserWarning):
40 """Raised by `format_file` when the reformatted code is the same as source."""
43 class CannotSplit(Exception):
44 """A readable split that fits the allotted line length is impossible.
46 Raised by `left_hand_split()` and `right_hand_split()`.
55 default=DEFAULT_LINE_LENGTH,
56 help='How many character per line to allow.',
63 "Don't write back the files, just return the status. Return code 0 "
64 "means nothing changed. Return code 1 means some files were "
65 "reformatted. Return code 123 means there was an internal error."
71 help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 @click.version_option(version=__version__)
78 exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
83 ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 """The uncompromising code formatter."""
86 sources: List[Path] = []
90 sources.extend(gen_python_files_in_dir(p))
92 # if a file was explicitly given, we don't care about its extension
95 sources.append(Path('-'))
97 err(f'invalid path: {s}')
100 elif len(sources) == 1:
104 if not p.is_file() and str(p) == '-':
105 changed = format_stdin_to_stdout(line_length=line_length, fast=fast)
107 changed = format_file_in_place(
108 p, line_length=line_length, fast=fast, write_back=not check
110 report.done(p, changed)
111 except Exception as exc:
112 report.failed(p, str(exc))
113 ctx.exit(report.return_code)
115 loop = asyncio.get_event_loop()
116 executor = ProcessPoolExecutor(max_workers=os.cpu_count())
119 return_code = loop.run_until_complete(
121 sources, line_length, not check, fast, loop, executor
126 ctx.exit(return_code)
129 async def schedule_formatting(
138 src: loop.run_in_executor(
139 executor, format_file_in_place, src, line_length, fast, write_back
143 await asyncio.wait(tasks.values())
146 for src, task in tasks.items():
148 report.failed(src, 'timed out, cancelling')
150 cancelled.append(task)
151 elif task.exception():
152 report.failed(src, str(task.exception()))
154 report.done(src, task.result())
156 await asyncio.wait(cancelled, timeout=2)
157 out('All done! ✨ 🍰 ✨')
158 click.echo(str(report))
159 return report.return_code
162 def format_file_in_place(
163 src: Path, line_length: int, fast: bool, write_back: bool = False
165 """Format the file and rewrite if changed. Return True if changed."""
166 with tokenize.open(src) as src_buffer:
167 src_contents = src_buffer.read()
169 contents = format_file_contents(
170 src_contents, line_length=line_length, fast=fast
172 except NothingChanged:
176 with open(src, "w", encoding=src_buffer.encoding) as f:
181 def format_stdin_to_stdout(line_length: int, fast: bool) -> bool:
182 """Format file on stdin and pipe output to stdout. Return True if changed."""
183 contents = sys.stdin.read()
185 contents = format_file_contents(contents, line_length=line_length, fast=fast)
188 except NothingChanged:
192 sys.stdout.write(contents)
195 def format_file_contents(
196 src_contents: str, line_length: int, fast: bool
198 """Reformats a file and returns its contents and encoding."""
199 if src_contents.strip() == '':
202 dst_contents = format_str(src_contents, line_length=line_length)
203 if src_contents == dst_contents:
207 assert_equivalent(src_contents, dst_contents)
208 assert_stable(src_contents, dst_contents, line_length=line_length)
212 def format_str(src_contents: str, line_length: int) -> FileContent:
213 """Reformats a string and returns new contents."""
214 src_node = lib2to3_parse(src_contents)
216 comments: List[Line] = []
217 lines = LineGenerator()
218 elt = EmptyLineTracker()
219 py36 = is_python36(src_node)
222 for current_line in lines.visit(src_node):
223 for _ in range(after):
224 dst_contents += str(empty_line)
225 before, after = elt.maybe_empty_lines(current_line)
226 for _ in range(before):
227 dst_contents += str(empty_line)
228 if not current_line.is_comment:
229 for comment in comments:
230 dst_contents += str(comment)
232 for line in split_line(current_line, line_length=line_length, py36=py36):
233 dst_contents += str(line)
235 comments.append(current_line)
237 if elt.previous_defs:
238 # Separate postscriptum comments from the last module-level def.
239 dst_contents += str(empty_line)
240 dst_contents += str(empty_line)
241 for comment in comments:
242 dst_contents += str(comment)
246 def lib2to3_parse(src_txt: str) -> Node:
247 """Given a string with source, return the lib2to3 Node."""
248 grammar = pygram.python_grammar_no_print_statement
249 drv = driver.Driver(grammar, pytree.convert)
250 if src_txt[-1] != '\n':
251 nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
254 result = drv.parse_string(src_txt, True)
255 except ParseError as pe:
256 lineno, column = pe.context[1]
257 lines = src_txt.splitlines()
259 faulty_line = lines[lineno - 1]
261 faulty_line = "<line number missing in source>"
262 raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
264 if isinstance(result, Leaf):
265 result = Node(syms.file_input, [result])
269 def lib2to3_unparse(node: Node) -> str:
270 """Given a lib2to3 node, return its string representation."""
278 class Visitor(Generic[T]):
279 """Basic lib2to3 visitor that yields things on visiting."""
281 def visit(self, node: LN) -> Iterator[T]:
283 name = token.tok_name[node.type]
285 name = type_repr(node.type)
286 yield from getattr(self, f'visit_{name}', self.visit_default)(node)
288 def visit_default(self, node: LN) -> Iterator[T]:
289 if isinstance(node, Node):
290 for child in node.children:
291 yield from self.visit(child)
295 class DebugVisitor(Visitor[T]):
298 def visit_default(self, node: LN) -> Iterator[T]:
299 indent = ' ' * (2 * self.tree_depth)
300 if isinstance(node, Node):
301 _type = type_repr(node.type)
302 out(f'{indent}{_type}', fg='yellow')
304 for child in node.children:
305 yield from self.visit(child)
308 out(f'{indent}/{_type}', fg='yellow', bold=False)
310 _type = token.tok_name.get(node.type, str(node.type))
311 out(f'{indent}{_type}', fg='blue', nl=False)
313 # We don't have to handle prefixes for `Node` objects since
314 # that delegates to the first child anyway.
315 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
316 out(f' {node.value!r}', fg='blue', bold=False)
319 KEYWORDS = set(keyword.kwlist)
320 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
321 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
332 STANDALONE_COMMENT = 153
333 LOGIC_OPERATORS = {'and', 'or'}
356 COMPREHENSION_PRIORITY = 20
360 COMPARATOR_PRIORITY = 3
365 class BracketTracker:
367 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
368 delimiters: Dict[LeafID, Priority] = Factory(dict)
369 previous: Optional[Leaf] = None
371 def mark(self, leaf: Leaf) -> None:
372 if leaf.type == token.COMMENT:
375 if leaf.type in CLOSING_BRACKETS:
377 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
378 leaf.opening_bracket = opening_bracket
379 leaf.bracket_depth = self.depth
381 delim = is_delimiter(leaf)
383 self.delimiters[id(leaf)] = delim
384 elif self.previous is not None:
385 if leaf.type == token.STRING and self.previous.type == token.STRING:
386 self.delimiters[id(self.previous)] = STRING_PRIORITY
388 leaf.type == token.NAME
389 and leaf.value == 'for'
391 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
393 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
395 leaf.type == token.NAME
396 and leaf.value == 'if'
398 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
400 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
402 leaf.type == token.NAME
403 and leaf.value in LOGIC_OPERATORS
406 self.delimiters[id(self.previous)] = LOGIC_PRIORITY
407 if leaf.type in OPENING_BRACKETS:
408 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
412 def any_open_brackets(self) -> bool:
413 """Returns True if there is an yet unmatched open bracket on the line."""
414 return bool(self.bracket_match)
416 def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
417 """Returns the highest priority of a delimiter found on the line.
419 Values are consistent with what `is_delimiter()` returns.
421 return max(v for k, v in self.delimiters.items() if k not in exclude)
427 leaves: List[Leaf] = Factory(list)
428 comments: Dict[LeafID, Leaf] = Factory(dict)
429 bracket_tracker: BracketTracker = Factory(BracketTracker)
430 inside_brackets: bool = False
431 has_for: bool = False
432 _for_loop_variable: bool = False
434 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
435 has_value = leaf.value.strip()
439 if self.leaves and not preformatted:
440 # Note: at this point leaf.prefix should be empty except for
441 # imports, for which we only preserve newlines.
442 leaf.prefix += whitespace(leaf)
443 if self.inside_brackets or not preformatted:
444 self.maybe_decrement_after_for_loop_variable(leaf)
445 self.bracket_tracker.mark(leaf)
446 self.maybe_remove_trailing_comma(leaf)
447 self.maybe_increment_for_loop_variable(leaf)
448 if self.maybe_adapt_standalone_comment(leaf):
451 if not self.append_comment(leaf):
452 self.leaves.append(leaf)
455 def is_comment(self) -> bool:
456 return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
459 def is_decorator(self) -> bool:
460 return bool(self) and self.leaves[0].type == token.AT
463 def is_import(self) -> bool:
464 return bool(self) and is_import(self.leaves[0])
467 def is_class(self) -> bool:
470 and self.leaves[0].type == token.NAME
471 and self.leaves[0].value == 'class'
475 def is_def(self) -> bool:
476 """Also returns True for async defs."""
478 first_leaf = self.leaves[0]
483 second_leaf: Optional[Leaf] = self.leaves[1]
487 (first_leaf.type == token.NAME and first_leaf.value == 'def')
489 first_leaf.type == token.ASYNC
490 and second_leaf is not None
491 and second_leaf.type == token.NAME
492 and second_leaf.value == 'def'
497 def is_flow_control(self) -> bool:
500 and self.leaves[0].type == token.NAME
501 and self.leaves[0].value in FLOW_CONTROL
505 def is_yield(self) -> bool:
508 and self.leaves[0].type == token.NAME
509 and self.leaves[0].value == 'yield'
512 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
515 and self.leaves[-1].type == token.COMMA
516 and closing.type in CLOSING_BRACKETS
520 if closing.type == token.RSQB or closing.type == token.RBRACE:
524 # For parens let's check if it's safe to remove the comma. If the
525 # trailing one is the only one, we might mistakenly change a tuple
526 # into a different type by removing the comma.
527 depth = closing.bracket_depth + 1
529 opening = closing.opening_bracket
530 for _opening_index, leaf in enumerate(self.leaves):
537 for leaf in self.leaves[_opening_index + 1:]:
541 bracket_depth = leaf.bracket_depth
542 if bracket_depth == depth and leaf.type == token.COMMA:
544 if leaf.parent and leaf.parent.type == syms.arglist:
554 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
555 """In a for loop, or comprehension, the variables are often unpacks.
557 To avoid splitting on the comma in this situation, we will increase
558 the depth of tokens between `for` and `in`.
560 if leaf.type == token.NAME and leaf.value == 'for':
562 self.bracket_tracker.depth += 1
563 self._for_loop_variable = True
568 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
569 # See `maybe_increment_for_loop_variable` above for explanation.
570 if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
571 self.bracket_tracker.depth -= 1
572 self._for_loop_variable = False
577 def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
578 """Hack a standalone comment to act as a trailing comment for line splitting.
580 If this line has brackets and a standalone `comment`, we need to adapt
581 it to be able to still reformat the line.
583 This is not perfect, the line to which the standalone comment gets
584 appended will appear "too long" when splitting.
587 comment.type == STANDALONE_COMMENT
588 and self.bracket_tracker.any_open_brackets()
592 comment.type = token.COMMENT
593 comment.prefix = '\n' + ' ' * (self.depth + 1)
594 return self.append_comment(comment)
596 def append_comment(self, comment: Leaf) -> bool:
597 if comment.type != token.COMMENT:
601 after = id(self.last_non_delimiter())
603 comment.type = STANDALONE_COMMENT
608 if after in self.comments:
609 self.comments[after].value += str(comment)
611 self.comments[after] = comment
614 def last_non_delimiter(self) -> Leaf:
615 for i in range(len(self.leaves)):
616 last = self.leaves[-i - 1]
617 if not is_delimiter(last):
620 raise LookupError("No non-delimiters found")
622 def __str__(self) -> str:
626 indent = ' ' * self.depth
627 leaves = iter(self.leaves)
629 res = f'{first.prefix}{indent}{first.value}'
632 for comment in self.comments.values():
636 def __bool__(self) -> bool:
637 return bool(self.leaves or self.comments)
641 class EmptyLineTracker:
642 """Provides a stateful method that returns the number of potential extra
643 empty lines needed before and after the currently processed line.
645 Note: this tracker works on lines that haven't been split yet. It assumes
646 the prefix of the first leaf consists of optional newlines. Those newlines
647 are consumed by `maybe_empty_lines()` and included in the computation.
649 previous_line: Optional[Line] = None
650 previous_after: int = 0
651 previous_defs: List[int] = Factory(list)
653 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
654 """Returns the number of extra empty lines before and after the `current_line`.
656 This is for separating `def`, `async def` and `class` with extra empty lines
657 (two on module-level), as well as providing an extra empty line after flow
658 control keywords to make them more prominent.
660 if current_line.is_comment:
661 # Don't count standalone comments towards previous empty lines.
664 before, after = self._maybe_empty_lines(current_line)
665 before -= self.previous_after
666 self.previous_after = after
667 self.previous_line = current_line
670 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
671 if current_line.leaves:
672 # Consume the first leaf's extra newlines.
673 first_leaf = current_line.leaves[0]
674 before = int('\n' in first_leaf.prefix)
675 first_leaf.prefix = ''
678 depth = current_line.depth
679 while self.previous_defs and self.previous_defs[-1] >= depth:
680 self.previous_defs.pop()
681 before = 1 if depth else 2
682 is_decorator = current_line.is_decorator
683 if is_decorator or current_line.is_def or current_line.is_class:
685 self.previous_defs.append(depth)
686 if self.previous_line is None:
687 # Don't insert empty lines before the first line in the file.
690 if self.previous_line and self.previous_line.is_decorator:
691 # Don't insert empty lines between decorators.
695 if current_line.depth:
699 if current_line.is_flow_control:
704 and self.previous_line.is_import
705 and not current_line.is_import
706 and depth == self.previous_line.depth
708 return (before or 1), 0
712 and self.previous_line.is_yield
713 and (not current_line.is_yield or depth != self.previous_line.depth)
715 return (before or 1), 0
721 class LineGenerator(Visitor[Line]):
722 """Generates reformatted Line objects. Empty lines are not emitted.
724 Note: destroys the tree it's visiting by mutating prefixes of its leaves
725 in ways that will no longer stringify to valid Python code on the tree.
727 current_line: Line = Factory(Line)
728 standalone_comments: List[Leaf] = Factory(list)
730 def line(self, indent: int = 0) -> Iterator[Line]:
733 If the line is empty, only emit if it makes sense.
734 If the line is too long, split it first and then generate.
736 If any lines were generated, set up a new current_line.
738 if not self.current_line:
739 self.current_line.depth += indent
740 return # Line is empty, don't emit. Creating a new one unnecessary.
742 complete_line = self.current_line
743 self.current_line = Line(depth=complete_line.depth + indent)
746 def visit_default(self, node: LN) -> Iterator[Line]:
747 if isinstance(node, Leaf):
748 any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
749 for comment in generate_comments(node):
750 if any_open_brackets:
751 # any comment within brackets is subject to splitting
752 self.current_line.append(comment)
753 elif comment.type == token.COMMENT:
754 # regular trailing comment
755 self.current_line.append(comment)
756 yield from self.line()
759 # regular standalone comment, to be processed later (see
760 # docstring in `generate_comments()`
761 self.standalone_comments.append(comment)
762 normalize_prefix(node, inside_brackets=any_open_brackets)
763 if node.type not in WHITESPACE:
764 for comment in self.standalone_comments:
765 yield from self.line()
767 self.current_line.append(comment)
768 yield from self.line()
770 self.standalone_comments = []
771 self.current_line.append(node)
772 yield from super().visit_default(node)
774 def visit_suite(self, node: Node) -> Iterator[Line]:
775 """Body of a statement after a colon."""
776 children = iter(node.children)
777 # Process newline before indenting. It might contain an inline
778 # comment that should go right after the colon.
779 newline = next(children)
780 yield from self.visit(newline)
781 yield from self.line(+1)
783 for child in children:
784 yield from self.visit(child)
786 yield from self.line(-1)
788 def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
789 """Visit a statement.
791 The relevant Python language keywords for this statement are NAME leaves
794 for child in node.children:
795 if child.type == token.NAME and child.value in keywords: # type: ignore
796 yield from self.line()
798 yield from self.visit(child)
800 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
801 """A statement without nested statements."""
802 is_suite_like = node.parent and node.parent.type in STATEMENT
804 yield from self.line(+1)
805 yield from self.visit_default(node)
806 yield from self.line(-1)
809 yield from self.line()
810 yield from self.visit_default(node)
812 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
813 yield from self.line()
815 children = iter(node.children)
816 for child in children:
817 yield from self.visit(child)
819 if child.type == token.ASYNC:
822 internal_stmt = next(children)
823 for child in internal_stmt.children:
824 yield from self.visit(child)
826 def visit_decorators(self, node: Node) -> Iterator[Line]:
827 for child in node.children:
828 yield from self.line()
829 yield from self.visit(child)
831 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
832 yield from self.line()
834 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
835 yield from self.visit_default(leaf)
836 yield from self.line()
838 def __attrs_post_init__(self) -> None:
839 """You are in a twisty little maze of passages."""
841 self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
842 self.visit_while_stmt = partial(v, keywords={'while', 'else'})
843 self.visit_for_stmt = partial(v, keywords={'for', 'else'})
844 self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
845 self.visit_except_clause = partial(v, keywords={'except'})
846 self.visit_funcdef = partial(v, keywords={'def'})
847 self.visit_with_stmt = partial(v, keywords={'with'})
848 self.visit_classdef = partial(v, keywords={'class'})
849 self.visit_async_funcdef = self.visit_async_stmt
850 self.visit_decorated = self.visit_decorators
853 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
854 OPENING_BRACKETS = set(BRACKET.keys())
855 CLOSING_BRACKETS = set(BRACKET.values())
856 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
857 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
860 def whitespace(leaf: Leaf) -> str: # noqa C901
861 """Return whitespace prefix if needed for the given `leaf`."""
868 if t in ALWAYS_NO_SPACE:
871 if t == token.COMMENT:
874 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
875 if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
878 prev = leaf.prev_sibling
880 prevp = preceding_leaf(p)
881 if not prevp or prevp.type in OPENING_BRACKETS:
885 return SPACE if prevp.type == token.COMMA else NO
887 if prevp.type == token.EQUAL:
888 if prevp.parent and prevp.parent.type in {
897 elif prevp.type == token.DOUBLESTAR:
898 if prevp.parent and prevp.parent.type in {
907 elif prevp.type == token.COLON:
908 if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
911 elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
914 elif prev.type in OPENING_BRACKETS:
917 if p.type in {syms.parameters, syms.arglist}:
918 # untyped function signatures or calls
922 if not prev or prev.type != token.COMMA:
925 if p.type == syms.varargslist:
930 if prev and prev.type != token.COMMA:
933 elif p.type == syms.typedargslist:
934 # typed function signatures
939 if prev.type != syms.tname:
942 elif prev.type == token.EQUAL:
943 # A bit hacky: if the equal sign has whitespace, it means we
944 # previously found it's a typed argument. So, we're using that, too.
947 elif prev.type != token.COMMA:
950 elif p.type == syms.tname:
953 prevp = preceding_leaf(p)
954 if not prevp or prevp.type != token.COMMA:
957 elif p.type == syms.trailer:
958 # attributes and calls
959 if t == token.LPAR or t == token.RPAR:
964 prevp = preceding_leaf(p)
965 if not prevp or prevp.type != token.NUMBER:
968 elif t == token.LSQB:
971 elif prev.type != token.COMMA:
974 elif p.type == syms.argument:
980 prevp = preceding_leaf(p)
981 if not prevp or prevp.type == token.LPAR:
984 elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
987 elif p.type == syms.decorator:
991 elif p.type == syms.dotted_name:
995 prevp = preceding_leaf(p)
996 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
999 elif p.type == syms.classdef:
1003 if prev and prev.type == token.LPAR:
1006 elif p.type == syms.subscript:
1009 assert p.parent is not None, "subscripts are always parented"
1010 if p.parent.type == syms.subscriptlist:
1018 elif p.type == syms.atom:
1019 if prev and t == token.DOT:
1020 # dots, but not the first one.
1024 p.type == syms.listmaker
1025 or p.type == syms.testlist_gexp
1026 or p.type == syms.subscriptlist
1028 # list interior, including unpacking
1032 elif p.type == syms.dictsetmaker:
1033 # dict and set interior, including unpacking
1037 if prev.type == token.DOUBLESTAR:
1040 elif p.type in {syms.factor, syms.star_expr}:
1043 prevp = preceding_leaf(p)
1044 if not prevp or prevp.type in OPENING_BRACKETS:
1047 prevp_parent = prevp.parent
1048 assert prevp_parent is not None
1049 if prevp.type == token.COLON and prevp_parent.type in {
1050 syms.subscript, syms.sliceop
1054 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1057 elif t == token.NAME or t == token.NUMBER:
1060 elif p.type == syms.import_from:
1062 if prev and prev.type == token.DOT:
1065 elif t == token.NAME:
1069 if prev and prev.type == token.DOT:
1072 elif p.type == syms.sliceop:
1078 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1079 """Returns the first leaf that precedes `node`, if any."""
1081 res = node.prev_sibling
1083 if isinstance(res, Leaf):
1087 return list(res.leaves())[-1]
1096 def is_delimiter(leaf: Leaf) -> int:
1097 """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1099 Higher numbers are higher priority.
1101 if leaf.type == token.COMMA:
1102 return COMMA_PRIORITY
1104 if leaf.type in COMPARATORS:
1105 return COMPARATOR_PRIORITY
1108 leaf.type in MATH_OPERATORS
1110 and leaf.parent.type not in {syms.factor, syms.star_expr}
1112 return MATH_PRIORITY
1117 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1118 """Cleans the prefix of the `leaf` and generates comments from it, if any.
1120 Comments in lib2to3 are shoved into the whitespace prefix. This happens
1121 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
1122 move because it does away with modifying the grammar to include all the
1123 possible places in which comments can be placed.
1125 The sad consequence for us though is that comments don't "belong" anywhere.
1126 This is why this function generates simple parentless Leaf objects for
1127 comments. We simply don't know what the correct parent should be.
1129 No matter though, we can live without this. We really only need to
1130 differentiate between inline and standalone comments. The latter don't
1131 share the line with any code.
1133 Inline comments are emitted as regular token.COMMENT leaves. Standalone
1134 are emitted with a fake STANDALONE_COMMENT token identifier.
1139 if '#' not in leaf.prefix:
1142 before_comment, content = leaf.prefix.split('#', 1)
1143 content = content.rstrip()
1144 if content and (content[0] not in {' ', '!', '#'}):
1145 content = ' ' + content
1146 is_standalone_comment = (
1147 '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER
1149 if not is_standalone_comment:
1150 # simple trailing comment
1151 yield Leaf(token.COMMENT, value='#' + content)
1154 for line in ('#' + content).split('\n'):
1155 line = line.lstrip()
1156 if not line.startswith('#'):
1159 yield Leaf(STANDALONE_COMMENT, line)
1163 line: Line, line_length: int, inner: bool = False, py36: bool = False
1164 ) -> Iterator[Line]:
1165 """Splits a `line` into potentially many lines.
1167 They should fit in the allotted `line_length` but might not be able to.
1168 `inner` signifies that there were a pair of brackets somewhere around the
1169 current `line`, possibly transitively. This means we can fallback to splitting
1170 by delimiters if the LHS/RHS don't yield any results.
1172 If `py36` is True, splitting may generate syntax that is only compatible
1173 with Python 3.6 and later.
1175 line_str = str(line).strip('\n')
1176 if len(line_str) <= line_length and '\n' not in line_str:
1181 split_funcs = [left_hand_split]
1182 elif line.inside_brackets:
1183 split_funcs = [delimiter_split]
1184 if '\n' not in line_str:
1185 # Only attempt RHS if we don't have multiline strings or comments
1187 split_funcs.append(right_hand_split)
1189 split_funcs = [right_hand_split]
1190 for split_func in split_funcs:
1191 # We are accumulating lines in `result` because we might want to abort
1192 # mission and return the original line in the end, or attempt a different
1194 result: List[Line] = []
1196 for l in split_func(line, py36=py36):
1197 if str(l).strip('\n') == line_str:
1198 raise CannotSplit("Split function returned an unchanged result")
1201 split_line(l, line_length=line_length, inner=True, py36=py36)
1203 except CannotSplit as cs:
1214 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1215 """Split line into many lines, starting with the first matching bracket pair.
1217 Note: this usually looks weird, only use this for function definitions.
1218 Prefer RHS otherwise.
1220 head = Line(depth=line.depth)
1221 body = Line(depth=line.depth + 1, inside_brackets=True)
1222 tail = Line(depth=line.depth)
1223 tail_leaves: List[Leaf] = []
1224 body_leaves: List[Leaf] = []
1225 head_leaves: List[Leaf] = []
1226 current_leaves = head_leaves
1227 matching_bracket = None
1228 for leaf in line.leaves:
1230 current_leaves is body_leaves
1231 and leaf.type in CLOSING_BRACKETS
1232 and leaf.opening_bracket is matching_bracket
1234 current_leaves = tail_leaves if body_leaves else head_leaves
1235 current_leaves.append(leaf)
1236 if current_leaves is head_leaves:
1237 if leaf.type in OPENING_BRACKETS:
1238 matching_bracket = leaf
1239 current_leaves = body_leaves
1240 # Since body is a new indent level, remove spurious leading whitespace.
1242 normalize_prefix(body_leaves[0], inside_brackets=True)
1243 # Build the new lines.
1244 for result, leaves in (
1245 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1248 result.append(leaf, preformatted=True)
1249 comment_after = line.comments.get(id(leaf))
1251 result.append(comment_after, preformatted=True)
1252 split_succeeded_or_raise(head, body, tail)
1253 for result in (head, body, tail):
1258 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1259 """Split line into many lines, starting with the last matching bracket pair."""
1260 head = Line(depth=line.depth)
1261 body = Line(depth=line.depth + 1, inside_brackets=True)
1262 tail = Line(depth=line.depth)
1263 tail_leaves: List[Leaf] = []
1264 body_leaves: List[Leaf] = []
1265 head_leaves: List[Leaf] = []
1266 current_leaves = tail_leaves
1267 opening_bracket = None
1268 for leaf in reversed(line.leaves):
1269 if current_leaves is body_leaves:
1270 if leaf is opening_bracket:
1271 current_leaves = head_leaves if body_leaves else tail_leaves
1272 current_leaves.append(leaf)
1273 if current_leaves is tail_leaves:
1274 if leaf.type in CLOSING_BRACKETS:
1275 opening_bracket = leaf.opening_bracket
1276 current_leaves = body_leaves
1277 tail_leaves.reverse()
1278 body_leaves.reverse()
1279 head_leaves.reverse()
1280 # Since body is a new indent level, remove spurious leading whitespace.
1282 normalize_prefix(body_leaves[0], inside_brackets=True)
1283 # Build the new lines.
1284 for result, leaves in (
1285 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1288 result.append(leaf, preformatted=True)
1289 comment_after = line.comments.get(id(leaf))
1291 result.append(comment_after, preformatted=True)
1292 split_succeeded_or_raise(head, body, tail)
1293 for result in (head, body, tail):
1298 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1299 tail_len = len(str(tail).strip())
1302 raise CannotSplit("Splitting brackets produced the same line")
1306 f"Splitting brackets on an empty body to save "
1307 f"{tail_len} characters is not worth it"
1311 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1312 """Split according to delimiters of the highest priority.
1314 This kind of split doesn't increase indentation.
1315 If `py36` is True, the split will add trailing commas also in function
1316 signatures that contain * and **.
1319 last_leaf = line.leaves[-1]
1321 raise CannotSplit("Line empty")
1323 delimiters = line.bracket_tracker.delimiters
1325 delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1327 raise CannotSplit("No delimiters found")
1329 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1330 lowest_depth = sys.maxsize
1331 trailing_comma_safe = True
1332 for leaf in line.leaves:
1333 current_line.append(leaf, preformatted=True)
1334 comment_after = line.comments.get(id(leaf))
1336 current_line.append(comment_after, preformatted=True)
1337 lowest_depth = min(lowest_depth, leaf.bracket_depth)
1339 leaf.bracket_depth == lowest_depth
1340 and leaf.type == token.STAR
1341 or leaf.type == token.DOUBLESTAR
1343 trailing_comma_safe = trailing_comma_safe and py36
1344 leaf_priority = delimiters.get(id(leaf))
1345 if leaf_priority == delimiter_priority:
1346 normalize_prefix(current_line.leaves[0], inside_brackets=True)
1349 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1352 delimiter_priority == COMMA_PRIORITY
1353 and current_line.leaves[-1].type != token.COMMA
1354 and trailing_comma_safe
1356 current_line.append(Leaf(token.COMMA, ','))
1357 normalize_prefix(current_line.leaves[0], inside_brackets=True)
1361 def is_import(leaf: Leaf) -> bool:
1362 """Returns True if the given leaf starts an import statement."""
1369 (v == 'import' and p and p.type == syms.import_name)
1370 or (v == 'from' and p and p.type == syms.import_from)
1375 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1376 """Leave existing extra newlines if not `inside_brackets`.
1378 Remove everything else. Note: don't use backslashes for formatting or
1379 you'll lose your voting rights.
1381 if not inside_brackets:
1382 spl = leaf.prefix.split('#', 1)
1383 if '\\' not in spl[0]:
1384 nl_count = spl[0].count('\n')
1385 leaf.prefix = '\n' * nl_count
1391 def is_python36(node: Node) -> bool:
1392 """Returns True if the current file is using Python 3.6+ features.
1394 Currently looking for:
1396 - trailing commas after * or ** in function signatures.
1398 for n in node.pre_order():
1399 if n.type == token.STRING:
1400 value_head = n.value[:2] # type: ignore
1401 if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1405 n.type == syms.typedargslist
1407 and n.children[-1].type == token.COMMA
1409 for ch in n.children:
1410 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1416 PYTHON_EXTENSIONS = {'.py'}
1417 BLACKLISTED_DIRECTORIES = {
1418 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1422 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1423 for child in path.iterdir():
1425 if child.name in BLACKLISTED_DIRECTORIES:
1428 yield from gen_python_files_in_dir(child)
1430 elif child.suffix in PYTHON_EXTENSIONS:
1436 """Provides a reformatting counter."""
1437 change_count: int = 0
1439 failure_count: int = 0
1441 def done(self, src: Path, changed: bool) -> None:
1442 """Increment the counter for successful reformatting. Write out a message."""
1444 out(f'reformatted {src}')
1445 self.change_count += 1
1447 out(f'{src} already well formatted, good job.', bold=False)
1448 self.same_count += 1
1450 def failed(self, src: Path, message: str) -> None:
1451 """Increment the counter for failed reformatting. Write out a message."""
1452 err(f'error: cannot format {src}: {message}')
1453 self.failure_count += 1
1456 def return_code(self) -> int:
1457 """Which return code should the app use considering the current state."""
1458 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1459 # 126 we have special returncodes reserved by the shell.
1460 if self.failure_count:
1463 elif self.change_count:
1468 def __str__(self) -> str:
1469 """A color report of the current state.
1471 Use `click.unstyle` to remove colors.
1474 if self.change_count:
1475 s = 's' if self.change_count > 1 else ''
1477 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1480 s = 's' if self.same_count > 1 else ''
1481 report.append(f'{self.same_count} file{s} left unchanged')
1482 if self.failure_count:
1483 s = 's' if self.failure_count > 1 else ''
1486 f'{self.failure_count} file{s} failed to reformat', fg='red'
1489 return ', '.join(report) + '.'
1492 def assert_equivalent(src: str, dst: str) -> None:
1493 """Raises AssertionError if `src` and `dst` aren't equivalent.
1495 This is a temporary sanity check until Black becomes stable.
1501 def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1502 """Simple visitor generating strings to compare ASTs by content."""
1503 yield f"{' ' * depth}{node.__class__.__name__}("
1505 for field in sorted(node._fields):
1507 value = getattr(node, field)
1508 except AttributeError:
1511 yield f"{' ' * (depth+1)}{field}="
1513 if isinstance(value, list):
1515 if isinstance(item, ast.AST):
1516 yield from _v(item, depth + 2)
1518 elif isinstance(value, ast.AST):
1519 yield from _v(value, depth + 2)
1522 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
1524 yield f"{' ' * depth}) # /{node.__class__.__name__}"
1527 src_ast = ast.parse(src)
1528 except Exception as exc:
1529 raise AssertionError(f"cannot parse source: {exc}") from None
1532 dst_ast = ast.parse(dst)
1533 except Exception as exc:
1534 log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1535 raise AssertionError(
1536 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1537 f"Please report a bug on https://github.com/ambv/black/issues. "
1538 f"This invalid output might be helpful: {log}"
1541 src_ast_str = '\n'.join(_v(src_ast))
1542 dst_ast_str = '\n'.join(_v(dst_ast))
1543 if src_ast_str != dst_ast_str:
1544 log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1545 raise AssertionError(
1546 f"INTERNAL ERROR: Black produced code that is not equivalent to "
1548 f"Please report a bug on https://github.com/ambv/black/issues. "
1549 f"This diff might be helpful: {log}"
1553 def assert_stable(src: str, dst: str, line_length: int) -> None:
1554 """Raises AssertionError if `dst` reformats differently the second time.
1556 This is a temporary sanity check until Black becomes stable.
1558 newdst = format_str(dst, line_length=line_length)
1561 diff(src, dst, 'source', 'first pass'),
1562 diff(dst, newdst, 'first pass', 'second pass'),
1564 raise AssertionError(
1565 f"INTERNAL ERROR: Black produced different code on the second pass "
1566 f"of the formatter. "
1567 f"Please report a bug on https://github.com/ambv/black/issues. "
1568 f"This diff might be helpful: {log}"
1572 def dump_to_file(*output: str) -> str:
1573 """Dumps `output` to a temporary file. Returns path to the file."""
1576 with tempfile.NamedTemporaryFile(
1577 mode='w', prefix='blk_', suffix='.log', delete=False
1579 for lines in output:
1585 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1586 """Returns a udiff string between strings `a` and `b`."""
1589 a_lines = [line + '\n' for line in a.split('\n')]
1590 b_lines = [line + '\n' for line in b.split('\n')]
1592 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1596 if __name__ == '__main__':