All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
9 from pathlib import Path
13 Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
16 from attr import dataclass, Factory
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
25 __version__ = "18.3a4"
26 DEFAULT_LINE_LENGTH = 88
28 syms = pygram.python_symbols
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
40 class NothingChanged(UserWarning):
41 """Raised by `format_file` when the reformatted code is the same as source."""
44 class CannotSplit(Exception):
45 """A readable split that fits the allotted line length is impossible.
47 Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
51 class FormatError(Exception):
52 """Base fmt: on/off error.
54 It holds the number of bytes of the prefix consumed before the format
55 control comment appeared.
58 def __init__(self, consumed: int) -> None:
59 super().__init__(consumed)
60 self.consumed = consumed
62 def trim_prefix(self, leaf: Leaf) -> None:
63 leaf.prefix = leaf.prefix[self.consumed:]
65 def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
66 """Returns a new Leaf from the consumed part of the prefix."""
67 unformatted_prefix = leaf.prefix[:self.consumed]
68 return Leaf(token.NEWLINE, unformatted_prefix)
71 class FormatOn(FormatError):
72 """Found a comment like `# fmt: on` in the file."""
75 class FormatOff(FormatError):
76 """Found a comment like `# fmt: off` in the file."""
84 default=DEFAULT_LINE_LENGTH,
85 help='How many character per line to allow.',
92 "Don't write back the files, just return the status. Return code 0 "
93 "means nothing would change. Return code 1 means some files would be "
94 "reformatted. Return code 123 means there was an internal error."
100 help='If --fast given, skip temporary sanity checks. [default: --safe]',
102 @click.version_option(version=__version__)
107 exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
112 ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
114 """The uncompromising code formatter."""
115 sources: List[Path] = []
119 sources.extend(gen_python_files_in_dir(p))
121 # if a file was explicitly given, we don't care about its extension
124 sources.append(Path('-'))
126 err(f'invalid path: {s}')
127 if len(sources) == 0:
129 elif len(sources) == 1:
131 report = Report(check=check)
133 if not p.is_file() and str(p) == '-':
134 changed = format_stdin_to_stdout(
135 line_length=line_length, fast=fast, write_back=not check
138 changed = format_file_in_place(
139 p, line_length=line_length, fast=fast, write_back=not check
141 report.done(p, changed)
142 except Exception as exc:
143 report.failed(p, str(exc))
144 ctx.exit(report.return_code)
146 loop = asyncio.get_event_loop()
147 executor = ProcessPoolExecutor(max_workers=os.cpu_count())
150 return_code = loop.run_until_complete(
152 sources, line_length, not check, fast, loop, executor
157 ctx.exit(return_code)
160 async def schedule_formatting(
169 src: loop.run_in_executor(
170 executor, format_file_in_place, src, line_length, fast, write_back
174 await asyncio.wait(tasks.values())
177 for src, task in tasks.items():
179 report.failed(src, 'timed out, cancelling')
181 cancelled.append(task)
182 elif task.exception():
183 report.failed(src, str(task.exception()))
185 report.done(src, task.result())
187 await asyncio.wait(cancelled, timeout=2)
188 out('All done! ✨ 🍰 ✨')
189 click.echo(str(report))
190 return report.return_code
193 def format_file_in_place(
194 src: Path, line_length: int, fast: bool, write_back: bool = False
196 """Format the file and rewrite if changed. Return True if changed."""
197 with tokenize.open(src) as src_buffer:
198 src_contents = src_buffer.read()
200 contents = format_file_contents(
201 src_contents, line_length=line_length, fast=fast
203 except NothingChanged:
207 with open(src, "w", encoding=src_buffer.encoding) as f:
212 def format_stdin_to_stdout(
213 line_length: int, fast: bool, write_back: bool = False
215 """Format file on stdin and pipe output to stdout. Return True if changed."""
216 contents = sys.stdin.read()
218 contents = format_file_contents(contents, line_length=line_length, fast=fast)
221 except NothingChanged:
226 sys.stdout.write(contents)
229 def format_file_contents(
230 src_contents: str, line_length: int, fast: bool
232 """Reformats a file and returns its contents and encoding."""
233 if src_contents.strip() == '':
236 dst_contents = format_str(src_contents, line_length=line_length)
237 if src_contents == dst_contents:
241 assert_equivalent(src_contents, dst_contents)
242 assert_stable(src_contents, dst_contents, line_length=line_length)
246 def format_str(src_contents: str, line_length: int) -> FileContent:
247 """Reformats a string and returns new contents."""
248 src_node = lib2to3_parse(src_contents)
250 lines = LineGenerator()
251 elt = EmptyLineTracker()
252 py36 = is_python36(src_node)
255 for current_line in lines.visit(src_node):
256 for _ in range(after):
257 dst_contents += str(empty_line)
258 before, after = elt.maybe_empty_lines(current_line)
259 for _ in range(before):
260 dst_contents += str(empty_line)
261 for line in split_line(current_line, line_length=line_length, py36=py36):
262 dst_contents += str(line)
267 pygram.python_grammar_no_print_statement_no_exec_statement,
268 pygram.python_grammar_no_print_statement,
269 pygram.python_grammar_no_exec_statement,
270 pygram.python_grammar,
274 def lib2to3_parse(src_txt: str) -> Node:
275 """Given a string with source, return the lib2to3 Node."""
276 grammar = pygram.python_grammar_no_print_statement
277 if src_txt[-1] != '\n':
278 nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
280 for grammar in GRAMMARS:
281 drv = driver.Driver(grammar, pytree.convert)
283 result = drv.parse_string(src_txt, True)
286 except ParseError as pe:
287 lineno, column = pe.context[1]
288 lines = src_txt.splitlines()
290 faulty_line = lines[lineno - 1]
292 faulty_line = "<line number missing in source>"
293 exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
297 if isinstance(result, Leaf):
298 result = Node(syms.file_input, [result])
302 def lib2to3_unparse(node: Node) -> str:
303 """Given a lib2to3 node, return its string representation."""
311 class Visitor(Generic[T]):
312 """Basic lib2to3 visitor that yields things on visiting."""
314 def visit(self, node: LN) -> Iterator[T]:
316 name = token.tok_name[node.type]
318 name = type_repr(node.type)
319 yield from getattr(self, f'visit_{name}', self.visit_default)(node)
321 def visit_default(self, node: LN) -> Iterator[T]:
322 if isinstance(node, Node):
323 for child in node.children:
324 yield from self.visit(child)
328 class DebugVisitor(Visitor[T]):
331 def visit_default(self, node: LN) -> Iterator[T]:
332 indent = ' ' * (2 * self.tree_depth)
333 if isinstance(node, Node):
334 _type = type_repr(node.type)
335 out(f'{indent}{_type}', fg='yellow')
337 for child in node.children:
338 yield from self.visit(child)
341 out(f'{indent}/{_type}', fg='yellow', bold=False)
343 _type = token.tok_name.get(node.type, str(node.type))
344 out(f'{indent}{_type}', fg='blue', nl=False)
346 # We don't have to handle prefixes for `Node` objects since
347 # that delegates to the first child anyway.
348 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
349 out(f' {node.value!r}', fg='blue', bold=False)
352 def show(cls, code: str) -> None:
353 """Pretty-prints a given string of `code`.
355 Convenience method for debugging.
357 v: DebugVisitor[None] = DebugVisitor()
358 list(v.visit(lib2to3_parse(code)))
361 KEYWORDS = set(keyword.kwlist)
362 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
363 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
374 STANDALONE_COMMENT = 153
375 LOGIC_OPERATORS = {'and', 'or'}
399 COMPREHENSION_PRIORITY = 20
403 COMPARATOR_PRIORITY = 3
408 class BracketTracker:
410 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
411 delimiters: Dict[LeafID, Priority] = Factory(dict)
412 previous: Optional[Leaf] = None
414 def mark(self, leaf: Leaf) -> None:
415 if leaf.type == token.COMMENT:
418 if leaf.type in CLOSING_BRACKETS:
420 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
421 leaf.opening_bracket = opening_bracket
422 leaf.bracket_depth = self.depth
424 delim = is_delimiter(leaf)
426 self.delimiters[id(leaf)] = delim
427 elif self.previous is not None:
428 if leaf.type == token.STRING and self.previous.type == token.STRING:
429 self.delimiters[id(self.previous)] = STRING_PRIORITY
431 leaf.type == token.NAME
432 and leaf.value == 'for'
434 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
436 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
438 leaf.type == token.NAME
439 and leaf.value == 'if'
441 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
443 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
445 leaf.type == token.NAME
446 and leaf.value in LOGIC_OPERATORS
449 self.delimiters[id(self.previous)] = LOGIC_PRIORITY
450 if leaf.type in OPENING_BRACKETS:
451 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
455 def any_open_brackets(self) -> bool:
456 """Returns True if there is an yet unmatched open bracket on the line."""
457 return bool(self.bracket_match)
459 def max_priority(self, exclude: Iterable[LeafID] = ()) -> int:
460 """Returns the highest priority of a delimiter found on the line.
462 Values are consistent with what `is_delimiter()` returns.
464 return max(v for k, v in self.delimiters.items() if k not in exclude)
470 leaves: List[Leaf] = Factory(list)
471 comments: Dict[LeafID, Leaf] = Factory(dict)
472 bracket_tracker: BracketTracker = Factory(BracketTracker)
473 inside_brackets: bool = False
474 has_for: bool = False
475 _for_loop_variable: bool = False
477 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
478 has_value = leaf.value.strip()
482 if self.leaves and not preformatted:
483 # Note: at this point leaf.prefix should be empty except for
484 # imports, for which we only preserve newlines.
485 leaf.prefix += whitespace(leaf)
486 if self.inside_brackets or not preformatted:
487 self.maybe_decrement_after_for_loop_variable(leaf)
488 self.bracket_tracker.mark(leaf)
489 self.maybe_remove_trailing_comma(leaf)
490 self.maybe_increment_for_loop_variable(leaf)
491 if self.maybe_adapt_standalone_comment(leaf):
494 if not self.append_comment(leaf):
495 self.leaves.append(leaf)
498 def is_comment(self) -> bool:
499 return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
502 def is_decorator(self) -> bool:
503 return bool(self) and self.leaves[0].type == token.AT
506 def is_import(self) -> bool:
507 return bool(self) and is_import(self.leaves[0])
510 def is_class(self) -> bool:
513 and self.leaves[0].type == token.NAME
514 and self.leaves[0].value == 'class'
518 def is_def(self) -> bool:
519 """Also returns True for async defs."""
521 first_leaf = self.leaves[0]
526 second_leaf: Optional[Leaf] = self.leaves[1]
530 (first_leaf.type == token.NAME and first_leaf.value == 'def')
532 first_leaf.type == token.ASYNC
533 and second_leaf is not None
534 and second_leaf.type == token.NAME
535 and second_leaf.value == 'def'
540 def is_flow_control(self) -> bool:
543 and self.leaves[0].type == token.NAME
544 and self.leaves[0].value in FLOW_CONTROL
548 def is_yield(self) -> bool:
551 and self.leaves[0].type == token.NAME
552 and self.leaves[0].value == 'yield'
555 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
558 and self.leaves[-1].type == token.COMMA
559 and closing.type in CLOSING_BRACKETS
563 if closing.type == token.RBRACE:
567 if closing.type == token.RSQB:
568 comma = self.leaves[-1]
569 if comma.parent and comma.parent.type == syms.listmaker:
573 # For parens let's check if it's safe to remove the comma. If the
574 # trailing one is the only one, we might mistakenly change a tuple
575 # into a different type by removing the comma.
576 depth = closing.bracket_depth + 1
578 opening = closing.opening_bracket
579 for _opening_index, leaf in enumerate(self.leaves):
586 for leaf in self.leaves[_opening_index + 1:]:
590 bracket_depth = leaf.bracket_depth
591 if bracket_depth == depth and leaf.type == token.COMMA:
593 if leaf.parent and leaf.parent.type == syms.arglist:
603 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
604 """In a for loop, or comprehension, the variables are often unpacks.
606 To avoid splitting on the comma in this situation, we will increase
607 the depth of tokens between `for` and `in`.
609 if leaf.type == token.NAME and leaf.value == 'for':
611 self.bracket_tracker.depth += 1
612 self._for_loop_variable = True
617 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
618 # See `maybe_increment_for_loop_variable` above for explanation.
619 if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
620 self.bracket_tracker.depth -= 1
621 self._for_loop_variable = False
626 def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
627 """Hack a standalone comment to act as a trailing comment for line splitting.
629 If this line has brackets and a standalone `comment`, we need to adapt
630 it to be able to still reformat the line.
632 This is not perfect, the line to which the standalone comment gets
633 appended will appear "too long" when splitting.
636 comment.type == STANDALONE_COMMENT
637 and self.bracket_tracker.any_open_brackets()
641 comment.type = token.COMMENT
642 comment.prefix = '\n' + ' ' * (self.depth + 1)
643 return self.append_comment(comment)
645 def append_comment(self, comment: Leaf) -> bool:
646 if comment.type != token.COMMENT:
650 after = id(self.last_non_delimiter())
652 comment.type = STANDALONE_COMMENT
657 if after in self.comments:
658 self.comments[after].value += str(comment)
660 self.comments[after] = comment
663 def last_non_delimiter(self) -> Leaf:
664 for i in range(len(self.leaves)):
665 last = self.leaves[-i - 1]
666 if not is_delimiter(last):
669 raise LookupError("No non-delimiters found")
671 def __str__(self) -> str:
675 indent = ' ' * self.depth
676 leaves = iter(self.leaves)
678 res = f'{first.prefix}{indent}{first.value}'
681 for comment in self.comments.values():
685 def __bool__(self) -> bool:
686 return bool(self.leaves or self.comments)
689 class UnformattedLines(Line):
691 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
693 list(generate_comments(leaf))
694 except FormatOn as f_on:
695 self.leaves.append(f_on.leaf_from_consumed(leaf))
698 self.leaves.append(leaf)
699 if leaf.type == token.INDENT:
701 elif leaf.type == token.DEDENT:
704 def append_comment(self, comment: Leaf) -> bool:
705 raise NotImplementedError("Unformatted lines don't store comments separately.")
707 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
710 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
713 def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
716 def __str__(self) -> str:
721 for leaf in self.leaves:
727 class EmptyLineTracker:
728 """Provides a stateful method that returns the number of potential extra
729 empty lines needed before and after the currently processed line.
731 Note: this tracker works on lines that haven't been split yet. It assumes
732 the prefix of the first leaf consists of optional newlines. Those newlines
733 are consumed by `maybe_empty_lines()` and included in the computation.
735 previous_line: Optional[Line] = None
736 previous_after: int = 0
737 previous_defs: List[int] = Factory(list)
739 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
740 """Returns the number of extra empty lines before and after the `current_line`.
742 This is for separating `def`, `async def` and `class` with extra empty lines
743 (two on module-level), as well as providing an extra empty line after flow
744 control keywords to make them more prominent.
746 if isinstance(current_line, UnformattedLines):
749 before, after = self._maybe_empty_lines(current_line)
750 before -= self.previous_after
751 self.previous_after = after
752 self.previous_line = current_line
755 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
757 if current_line.depth == 0:
759 if current_line.leaves:
760 # Consume the first leaf's extra newlines.
761 first_leaf = current_line.leaves[0]
762 before = first_leaf.prefix.count('\n')
763 before = min(before, max_allowed)
764 first_leaf.prefix = ''
767 depth = current_line.depth
768 while self.previous_defs and self.previous_defs[-1] >= depth:
769 self.previous_defs.pop()
770 before = 1 if depth else 2
771 is_decorator = current_line.is_decorator
772 if is_decorator or current_line.is_def or current_line.is_class:
774 self.previous_defs.append(depth)
775 if self.previous_line is None:
776 # Don't insert empty lines before the first line in the file.
779 if self.previous_line and self.previous_line.is_decorator:
780 # Don't insert empty lines between decorators.
784 if current_line.depth:
788 if current_line.is_flow_control:
793 and self.previous_line.is_import
794 and not current_line.is_import
795 and depth == self.previous_line.depth
797 return (before or 1), 0
801 and self.previous_line.is_yield
802 and (not current_line.is_yield or depth != self.previous_line.depth)
804 return (before or 1), 0
810 class LineGenerator(Visitor[Line]):
811 """Generates reformatted Line objects. Empty lines are not emitted.
813 Note: destroys the tree it's visiting by mutating prefixes of its leaves
814 in ways that will no longer stringify to valid Python code on the tree.
816 current_line: Line = Factory(Line)
818 def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
821 If the line is empty, only emit if it makes sense.
822 If the line is too long, split it first and then generate.
824 If any lines were generated, set up a new current_line.
826 if not self.current_line:
827 if self.current_line.__class__ == type:
828 self.current_line.depth += indent
830 self.current_line = type(depth=self.current_line.depth + indent)
831 return # Line is empty, don't emit. Creating a new one unnecessary.
833 complete_line = self.current_line
834 self.current_line = type(depth=complete_line.depth + indent)
837 def visit(self, node: LN) -> Iterator[Line]:
838 """High-level entry point to the visitor."""
839 if isinstance(self.current_line, UnformattedLines):
840 # File contained `# fmt: off`
841 yield from self.visit_unformatted(node)
844 yield from super().visit(node)
846 def visit_default(self, node: LN) -> Iterator[Line]:
847 if isinstance(node, Leaf):
848 any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
850 for comment in generate_comments(node):
851 if any_open_brackets:
852 # any comment within brackets is subject to splitting
853 self.current_line.append(comment)
854 elif comment.type == token.COMMENT:
855 # regular trailing comment
856 self.current_line.append(comment)
857 yield from self.line()
860 # regular standalone comment
861 yield from self.line()
863 self.current_line.append(comment)
864 yield from self.line()
866 except FormatOff as f_off:
867 f_off.trim_prefix(node)
868 yield from self.line(type=UnformattedLines)
869 yield from self.visit(node)
871 except FormatOn as f_on:
872 # This only happens here if somebody says "fmt: on" multiple
874 f_on.trim_prefix(node)
875 yield from self.visit_default(node)
878 normalize_prefix(node, inside_brackets=any_open_brackets)
879 if node.type not in WHITESPACE:
880 self.current_line.append(node)
881 yield from super().visit_default(node)
883 def visit_INDENT(self, node: Node) -> Iterator[Line]:
884 yield from self.line(+1)
885 yield from self.visit_default(node)
887 def visit_DEDENT(self, node: Node) -> Iterator[Line]:
888 # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
889 yield from self.line(-1)
891 def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
892 """Visit a statement.
894 The relevant Python language keywords for this statement are NAME leaves
897 for child in node.children:
898 if child.type == token.NAME and child.value in keywords: # type: ignore
899 yield from self.line()
901 yield from self.visit(child)
903 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
904 """A statement without nested statements."""
905 is_suite_like = node.parent and node.parent.type in STATEMENT
907 yield from self.line(+1)
908 yield from self.visit_default(node)
909 yield from self.line(-1)
912 yield from self.line()
913 yield from self.visit_default(node)
915 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
916 yield from self.line()
918 children = iter(node.children)
919 for child in children:
920 yield from self.visit(child)
922 if child.type == token.ASYNC:
925 internal_stmt = next(children)
926 for child in internal_stmt.children:
927 yield from self.visit(child)
929 def visit_decorators(self, node: Node) -> Iterator[Line]:
930 for child in node.children:
931 yield from self.line()
932 yield from self.visit(child)
934 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
935 yield from self.line()
937 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
938 yield from self.visit_default(leaf)
939 yield from self.line()
941 def visit_unformatted(self, node: LN) -> Iterator[Line]:
942 if isinstance(node, Node):
943 for child in node.children:
944 yield from self.visit(child)
948 self.current_line.append(node)
949 except FormatOn as f_on:
950 f_on.trim_prefix(node)
951 yield from self.line()
952 yield from self.visit(node)
954 def __attrs_post_init__(self) -> None:
955 """You are in a twisty little maze of passages."""
957 self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
958 self.visit_while_stmt = partial(v, keywords={'while', 'else'})
959 self.visit_for_stmt = partial(v, keywords={'for', 'else'})
960 self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
961 self.visit_except_clause = partial(v, keywords={'except'})
962 self.visit_funcdef = partial(v, keywords={'def'})
963 self.visit_with_stmt = partial(v, keywords={'with'})
964 self.visit_classdef = partial(v, keywords={'class'})
965 self.visit_async_funcdef = self.visit_async_stmt
966 self.visit_decorated = self.visit_decorators
969 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
970 OPENING_BRACKETS = set(BRACKET.keys())
971 CLOSING_BRACKETS = set(BRACKET.values())
972 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
973 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
976 def whitespace(leaf: Leaf) -> str: # noqa C901
977 """Return whitespace prefix if needed for the given `leaf`."""
984 if t in ALWAYS_NO_SPACE:
987 if t == token.COMMENT:
990 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
991 if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
994 prev = leaf.prev_sibling
996 prevp = preceding_leaf(p)
997 if not prevp or prevp.type in OPENING_BRACKETS:
1000 if t == token.COLON:
1001 return SPACE if prevp.type == token.COMMA else NO
1003 if prevp.type == token.EQUAL:
1005 if prevp.parent.type in {
1006 syms.arglist, syms.argument, syms.parameters, syms.varargslist
1010 elif prevp.parent.type == syms.typedargslist:
1011 # A bit hacky: if the equal sign has whitespace, it means we
1012 # previously found it's a typed argument. So, we're using
1016 elif prevp.type == token.DOUBLESTAR:
1017 if prevp.parent and prevp.parent.type in {
1027 elif prevp.type == token.COLON:
1028 if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1033 and prevp.parent.type in {syms.factor, syms.star_expr}
1034 and prevp.type in MATH_OPERATORS
1039 prevp.type == token.RIGHTSHIFT
1041 and prevp.parent.type == syms.shift_expr
1042 and prevp.prev_sibling
1043 and prevp.prev_sibling.type == token.NAME
1044 and prevp.prev_sibling.value == 'print' # type: ignore
1046 # Python 2 print chevron
1049 elif prev.type in OPENING_BRACKETS:
1052 if p.type in {syms.parameters, syms.arglist}:
1053 # untyped function signatures or calls
1057 if not prev or prev.type != token.COMMA:
1060 elif p.type == syms.varargslist:
1065 if prev and prev.type != token.COMMA:
1068 elif p.type == syms.typedargslist:
1069 # typed function signatures
1073 if t == token.EQUAL:
1074 if prev.type != syms.tname:
1077 elif prev.type == token.EQUAL:
1078 # A bit hacky: if the equal sign has whitespace, it means we
1079 # previously found it's a typed argument. So, we're using that, too.
1082 elif prev.type != token.COMMA:
1085 elif p.type == syms.tname:
1088 prevp = preceding_leaf(p)
1089 if not prevp or prevp.type != token.COMMA:
1092 elif p.type == syms.trailer:
1093 # attributes and calls
1094 if t == token.LPAR or t == token.RPAR:
1099 prevp = preceding_leaf(p)
1100 if not prevp or prevp.type != token.NUMBER:
1103 elif t == token.LSQB:
1106 elif prev.type != token.COMMA:
1109 elif p.type == syms.argument:
1111 if t == token.EQUAL:
1115 prevp = preceding_leaf(p)
1116 if not prevp or prevp.type == token.LPAR:
1119 elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1122 elif p.type == syms.decorator:
1126 elif p.type == syms.dotted_name:
1130 prevp = preceding_leaf(p)
1131 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1134 elif p.type == syms.classdef:
1138 if prev and prev.type == token.LPAR:
1141 elif p.type == syms.subscript:
1144 assert p.parent is not None, "subscripts are always parented"
1145 if p.parent.type == syms.subscriptlist:
1153 elif p.type == syms.atom:
1154 if prev and t == token.DOT:
1155 # dots, but not the first one.
1159 p.type == syms.listmaker
1160 or p.type == syms.testlist_gexp
1161 or p.type == syms.subscriptlist
1163 # list interior, including unpacking
1167 elif p.type == syms.dictsetmaker:
1168 # dict and set interior, including unpacking
1172 if prev.type == token.DOUBLESTAR:
1175 elif p.type in {syms.factor, syms.star_expr}:
1178 prevp = preceding_leaf(p)
1179 if not prevp or prevp.type in OPENING_BRACKETS:
1182 prevp_parent = prevp.parent
1183 assert prevp_parent is not None
1184 if prevp.type == token.COLON and prevp_parent.type in {
1185 syms.subscript, syms.sliceop
1189 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1192 elif t == token.NAME or t == token.NUMBER:
1195 elif p.type == syms.import_from:
1197 if prev and prev.type == token.DOT:
1200 elif t == token.NAME:
1204 if prev and prev.type == token.DOT:
1207 elif p.type == syms.sliceop:
1213 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1214 """Returns the first leaf that precedes `node`, if any."""
1216 res = node.prev_sibling
1218 if isinstance(res, Leaf):
1222 return list(res.leaves())[-1]
1231 def is_delimiter(leaf: Leaf) -> int:
1232 """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1234 Higher numbers are higher priority.
1236 if leaf.type == token.COMMA:
1237 return COMMA_PRIORITY
1239 if leaf.type in COMPARATORS:
1240 return COMPARATOR_PRIORITY
1243 leaf.type in MATH_OPERATORS
1245 and leaf.parent.type not in {syms.factor, syms.star_expr}
1247 return MATH_PRIORITY
1252 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1253 """Cleans the prefix of the `leaf` and generates comments from it, if any.
1255 Comments in lib2to3 are shoved into the whitespace prefix. This happens
1256 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
1257 move because it does away with modifying the grammar to include all the
1258 possible places in which comments can be placed.
1260 The sad consequence for us though is that comments don't "belong" anywhere.
1261 This is why this function generates simple parentless Leaf objects for
1262 comments. We simply don't know what the correct parent should be.
1264 No matter though, we can live without this. We really only need to
1265 differentiate between inline and standalone comments. The latter don't
1266 share the line with any code.
1268 Inline comments are emitted as regular token.COMMENT leaves. Standalone
1269 are emitted with a fake STANDALONE_COMMENT token identifier.
1280 for index, line in enumerate(p.split('\n')):
1281 consumed += len(line) + 1 # adding the length of the split '\n'
1282 line = line.lstrip()
1285 if not line.startswith('#'):
1288 if index == 0 and leaf.type != token.ENDMARKER:
1289 comment_type = token.COMMENT # simple trailing comment
1291 comment_type = STANDALONE_COMMENT
1292 comment = make_comment(line)
1293 yield Leaf(comment_type, comment, prefix='\n' * nlines)
1295 if comment in {'# fmt: on', '# yapf: enable'}:
1296 raise FormatOn(consumed)
1298 if comment in {'# fmt: off', '# yapf: disable'}:
1299 raise FormatOff(consumed)
1304 def make_comment(content: str) -> str:
1305 content = content.rstrip()
1309 if content[0] == '#':
1310 content = content[1:]
1311 if content and content[0] not in ' !:#':
1312 content = ' ' + content
1313 return '#' + content
1317 line: Line, line_length: int, inner: bool = False, py36: bool = False
1318 ) -> Iterator[Line]:
1319 """Splits a `line` into potentially many lines.
1321 They should fit in the allotted `line_length` but might not be able to.
1322 `inner` signifies that there were a pair of brackets somewhere around the
1323 current `line`, possibly transitively. This means we can fallback to splitting
1324 by delimiters if the LHS/RHS don't yield any results.
1326 If `py36` is True, splitting may generate syntax that is only compatible
1327 with Python 3.6 and later.
1329 if isinstance(line, UnformattedLines):
1333 line_str = str(line).strip('\n')
1334 if len(line_str) <= line_length and '\n' not in line_str:
1339 split_funcs = [left_hand_split]
1340 elif line.inside_brackets:
1341 split_funcs = [delimiter_split]
1342 if '\n' not in line_str:
1343 # Only attempt RHS if we don't have multiline strings or comments
1345 split_funcs.append(right_hand_split)
1347 split_funcs = [right_hand_split]
1348 for split_func in split_funcs:
1349 # We are accumulating lines in `result` because we might want to abort
1350 # mission and return the original line in the end, or attempt a different
1352 result: List[Line] = []
1354 for l in split_func(line, py36=py36):
1355 if str(l).strip('\n') == line_str:
1356 raise CannotSplit("Split function returned an unchanged result")
1359 split_line(l, line_length=line_length, inner=True, py36=py36)
1361 except CannotSplit as cs:
1372 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1373 """Split line into many lines, starting with the first matching bracket pair.
1375 Note: this usually looks weird, only use this for function definitions.
1376 Prefer RHS otherwise.
1378 head = Line(depth=line.depth)
1379 body = Line(depth=line.depth + 1, inside_brackets=True)
1380 tail = Line(depth=line.depth)
1381 tail_leaves: List[Leaf] = []
1382 body_leaves: List[Leaf] = []
1383 head_leaves: List[Leaf] = []
1384 current_leaves = head_leaves
1385 matching_bracket = None
1386 for leaf in line.leaves:
1388 current_leaves is body_leaves
1389 and leaf.type in CLOSING_BRACKETS
1390 and leaf.opening_bracket is matching_bracket
1392 current_leaves = tail_leaves if body_leaves else head_leaves
1393 current_leaves.append(leaf)
1394 if current_leaves is head_leaves:
1395 if leaf.type in OPENING_BRACKETS:
1396 matching_bracket = leaf
1397 current_leaves = body_leaves
1398 # Since body is a new indent level, remove spurious leading whitespace.
1400 normalize_prefix(body_leaves[0], inside_brackets=True)
1401 # Build the new lines.
1402 for result, leaves in (
1403 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1406 result.append(leaf, preformatted=True)
1407 comment_after = line.comments.get(id(leaf))
1409 result.append(comment_after, preformatted=True)
1410 split_succeeded_or_raise(head, body, tail)
1411 for result in (head, body, tail):
1416 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1417 """Split line into many lines, starting with the last matching bracket pair."""
1418 head = Line(depth=line.depth)
1419 body = Line(depth=line.depth + 1, inside_brackets=True)
1420 tail = Line(depth=line.depth)
1421 tail_leaves: List[Leaf] = []
1422 body_leaves: List[Leaf] = []
1423 head_leaves: List[Leaf] = []
1424 current_leaves = tail_leaves
1425 opening_bracket = None
1426 for leaf in reversed(line.leaves):
1427 if current_leaves is body_leaves:
1428 if leaf is opening_bracket:
1429 current_leaves = head_leaves if body_leaves else tail_leaves
1430 current_leaves.append(leaf)
1431 if current_leaves is tail_leaves:
1432 if leaf.type in CLOSING_BRACKETS:
1433 opening_bracket = leaf.opening_bracket
1434 current_leaves = body_leaves
1435 tail_leaves.reverse()
1436 body_leaves.reverse()
1437 head_leaves.reverse()
1438 # Since body is a new indent level, remove spurious leading whitespace.
1440 normalize_prefix(body_leaves[0], inside_brackets=True)
1441 # Build the new lines.
1442 for result, leaves in (
1443 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1446 result.append(leaf, preformatted=True)
1447 comment_after = line.comments.get(id(leaf))
1449 result.append(comment_after, preformatted=True)
1450 split_succeeded_or_raise(head, body, tail)
1451 for result in (head, body, tail):
1456 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1457 tail_len = len(str(tail).strip())
1460 raise CannotSplit("Splitting brackets produced the same line")
1464 f"Splitting brackets on an empty body to save "
1465 f"{tail_len} characters is not worth it"
1469 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1470 """Split according to delimiters of the highest priority.
1472 This kind of split doesn't increase indentation.
1473 If `py36` is True, the split will add trailing commas also in function
1474 signatures that contain * and **.
1477 last_leaf = line.leaves[-1]
1479 raise CannotSplit("Line empty")
1481 delimiters = line.bracket_tracker.delimiters
1483 delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1485 raise CannotSplit("No delimiters found")
1487 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1488 lowest_depth = sys.maxsize
1489 trailing_comma_safe = True
1490 for leaf in line.leaves:
1491 current_line.append(leaf, preformatted=True)
1492 comment_after = line.comments.get(id(leaf))
1494 current_line.append(comment_after, preformatted=True)
1495 lowest_depth = min(lowest_depth, leaf.bracket_depth)
1497 leaf.bracket_depth == lowest_depth
1498 and leaf.type == token.STAR
1499 or leaf.type == token.DOUBLESTAR
1501 trailing_comma_safe = trailing_comma_safe and py36
1502 leaf_priority = delimiters.get(id(leaf))
1503 if leaf_priority == delimiter_priority:
1504 normalize_prefix(current_line.leaves[0], inside_brackets=True)
1507 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1510 delimiter_priority == COMMA_PRIORITY
1511 and current_line.leaves[-1].type != token.COMMA
1512 and trailing_comma_safe
1514 current_line.append(Leaf(token.COMMA, ','))
1515 normalize_prefix(current_line.leaves[0], inside_brackets=True)
1519 def is_import(leaf: Leaf) -> bool:
1520 """Returns True if the given leaf starts an import statement."""
1527 (v == 'import' and p and p.type == syms.import_name)
1528 or (v == 'from' and p and p.type == syms.import_from)
1533 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1534 """Leave existing extra newlines if not `inside_brackets`.
1536 Remove everything else. Note: don't use backslashes for formatting or
1537 you'll lose your voting rights.
1539 if not inside_brackets:
1540 spl = leaf.prefix.split('#')
1541 if '\\' not in spl[0]:
1542 nl_count = spl[-1].count('\n')
1545 leaf.prefix = '\n' * nl_count
1551 def is_python36(node: Node) -> bool:
1552 """Returns True if the current file is using Python 3.6+ features.
1554 Currently looking for:
1556 - trailing commas after * or ** in function signatures.
1558 for n in node.pre_order():
1559 if n.type == token.STRING:
1560 value_head = n.value[:2] # type: ignore
1561 if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1565 n.type == syms.typedargslist
1567 and n.children[-1].type == token.COMMA
1569 for ch in n.children:
1570 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1576 PYTHON_EXTENSIONS = {'.py'}
1577 BLACKLISTED_DIRECTORIES = {
1578 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1582 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1583 for child in path.iterdir():
1585 if child.name in BLACKLISTED_DIRECTORIES:
1588 yield from gen_python_files_in_dir(child)
1590 elif child.suffix in PYTHON_EXTENSIONS:
1596 """Provides a reformatting counter."""
1598 change_count: int = 0
1600 failure_count: int = 0
1602 def done(self, src: Path, changed: bool) -> None:
1603 """Increment the counter for successful reformatting. Write out a message."""
1605 reformatted = 'would reformat' if self.check else 'reformatted'
1606 out(f'{reformatted} {src}')
1607 self.change_count += 1
1609 out(f'{src} already well formatted, good job.', bold=False)
1610 self.same_count += 1
1612 def failed(self, src: Path, message: str) -> None:
1613 """Increment the counter for failed reformatting. Write out a message."""
1614 err(f'error: cannot format {src}: {message}')
1615 self.failure_count += 1
1618 def return_code(self) -> int:
1619 """Which return code should the app use considering the current state."""
1620 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1621 # 126 we have special returncodes reserved by the shell.
1622 if self.failure_count:
1625 elif self.change_count and self.check:
1630 def __str__(self) -> str:
1631 """A color report of the current state.
1633 Use `click.unstyle` to remove colors.
1636 reformatted = "would be reformatted"
1637 unchanged = "would be left unchanged"
1638 failed = "would fail to reformat"
1640 reformatted = "reformatted"
1641 unchanged = "left unchanged"
1642 failed = "failed to reformat"
1644 if self.change_count:
1645 s = 's' if self.change_count > 1 else ''
1647 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1650 s = 's' if self.same_count > 1 else ''
1651 report.append(f'{self.same_count} file{s} {unchanged}')
1652 if self.failure_count:
1653 s = 's' if self.failure_count > 1 else ''
1655 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1657 return ', '.join(report) + '.'
1660 def assert_equivalent(src: str, dst: str) -> None:
1661 """Raises AssertionError if `src` and `dst` aren't equivalent.
1663 This is a temporary sanity check until Black becomes stable.
1669 def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1670 """Simple visitor generating strings to compare ASTs by content."""
1671 yield f"{' ' * depth}{node.__class__.__name__}("
1673 for field in sorted(node._fields):
1675 value = getattr(node, field)
1676 except AttributeError:
1679 yield f"{' ' * (depth+1)}{field}="
1681 if isinstance(value, list):
1683 if isinstance(item, ast.AST):
1684 yield from _v(item, depth + 2)
1686 elif isinstance(value, ast.AST):
1687 yield from _v(value, depth + 2)
1690 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
1692 yield f"{' ' * depth}) # /{node.__class__.__name__}"
1695 src_ast = ast.parse(src)
1696 except Exception as exc:
1697 major, minor = sys.version_info[:2]
1698 raise AssertionError(
1699 f"cannot use --safe with this file; failed to parse source file "
1700 f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1701 f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1705 dst_ast = ast.parse(dst)
1706 except Exception as exc:
1707 log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1708 raise AssertionError(
1709 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1710 f"Please report a bug on https://github.com/ambv/black/issues. "
1711 f"This invalid output might be helpful: {log}"
1714 src_ast_str = '\n'.join(_v(src_ast))
1715 dst_ast_str = '\n'.join(_v(dst_ast))
1716 if src_ast_str != dst_ast_str:
1717 log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1718 raise AssertionError(
1719 f"INTERNAL ERROR: Black produced code that is not equivalent to "
1721 f"Please report a bug on https://github.com/ambv/black/issues. "
1722 f"This diff might be helpful: {log}"
1726 def assert_stable(src: str, dst: str, line_length: int) -> None:
1727 """Raises AssertionError if `dst` reformats differently the second time.
1729 This is a temporary sanity check until Black becomes stable.
1731 newdst = format_str(dst, line_length=line_length)
1734 diff(src, dst, 'source', 'first pass'),
1735 diff(dst, newdst, 'first pass', 'second pass'),
1737 raise AssertionError(
1738 f"INTERNAL ERROR: Black produced different code on the second pass "
1739 f"of the formatter. "
1740 f"Please report a bug on https://github.com/ambv/black/issues. "
1741 f"This diff might be helpful: {log}"
1745 def dump_to_file(*output: str) -> str:
1746 """Dumps `output` to a temporary file. Returns path to the file."""
1749 with tempfile.NamedTemporaryFile(
1750 mode='w', prefix='blk_', suffix='.log', delete=False
1752 for lines in output:
1758 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1759 """Returns a udiff string between strings `a` and `b`."""
1762 a_lines = [line + '\n' for line in a.split('\n')]
1763 b_lines = [line + '\n' for line in b.split('\n')]
1765 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1769 if __name__ == '__main__':