All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
9 from pathlib import Path
13 Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
16 from attr import dataclass, Factory
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
25 __version__ = "18.3a3"
26 DEFAULT_LINE_LENGTH = 88
28 syms = pygram.python_symbols
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
40 class NothingChanged(UserWarning):
41 """Raised by `format_file` when the reformatted code is the same as source."""
44 class CannotSplit(Exception):
45 """A readable split that fits the allotted line length is impossible.
47 Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
56 default=DEFAULT_LINE_LENGTH,
57 help='How many character per line to allow.',
64 "Don't write back the files, just return the status. Return code 0 "
65 "means nothing would change. Return code 1 means some files would be "
66 "reformatted. Return code 123 means there was an internal error."
72 help='If --fast given, skip temporary sanity checks. [default: --safe]',
74 @click.version_option(version=__version__)
79 exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
84 ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
86 """The uncompromising code formatter."""
87 sources: List[Path] = []
91 sources.extend(gen_python_files_in_dir(p))
93 # if a file was explicitly given, we don't care about its extension
96 sources.append(Path('-'))
98 err(f'invalid path: {s}')
101 elif len(sources) == 1:
103 report = Report(check=check)
105 if not p.is_file() and str(p) == '-':
106 changed = format_stdin_to_stdout(
107 line_length=line_length, fast=fast, write_back=not check
110 changed = format_file_in_place(
111 p, line_length=line_length, fast=fast, write_back=not check
113 report.done(p, changed)
114 except Exception as exc:
115 report.failed(p, str(exc))
116 ctx.exit(report.return_code)
118 loop = asyncio.get_event_loop()
119 executor = ProcessPoolExecutor(max_workers=os.cpu_count())
122 return_code = loop.run_until_complete(
124 sources, line_length, not check, fast, loop, executor
129 ctx.exit(return_code)
132 async def schedule_formatting(
141 src: loop.run_in_executor(
142 executor, format_file_in_place, src, line_length, fast, write_back
146 await asyncio.wait(tasks.values())
149 for src, task in tasks.items():
151 report.failed(src, 'timed out, cancelling')
153 cancelled.append(task)
154 elif task.exception():
155 report.failed(src, str(task.exception()))
157 report.done(src, task.result())
159 await asyncio.wait(cancelled, timeout=2)
160 out('All done! ✨ 🍰 ✨')
161 click.echo(str(report))
162 return report.return_code
165 def format_file_in_place(
166 src: Path, line_length: int, fast: bool, write_back: bool = False
168 """Format the file and rewrite if changed. Return True if changed."""
169 with tokenize.open(src) as src_buffer:
170 src_contents = src_buffer.read()
172 contents = format_file_contents(
173 src_contents, line_length=line_length, fast=fast
175 except NothingChanged:
179 with open(src, "w", encoding=src_buffer.encoding) as f:
184 def format_stdin_to_stdout(
185 line_length: int, fast: bool, write_back: bool = False
187 """Format file on stdin and pipe output to stdout. Return True if changed."""
188 contents = sys.stdin.read()
190 contents = format_file_contents(contents, line_length=line_length, fast=fast)
193 except NothingChanged:
198 sys.stdout.write(contents)
201 def format_file_contents(
202 src_contents: str, line_length: int, fast: bool
204 """Reformats a file and returns its contents and encoding."""
205 if src_contents.strip() == '':
208 dst_contents = format_str(src_contents, line_length=line_length)
209 if src_contents == dst_contents:
213 assert_equivalent(src_contents, dst_contents)
214 assert_stable(src_contents, dst_contents, line_length=line_length)
218 def format_str(src_contents: str, line_length: int) -> FileContent:
219 """Reformats a string and returns new contents."""
220 src_node = lib2to3_parse(src_contents)
222 lines = LineGenerator()
223 elt = EmptyLineTracker()
224 py36 = is_python36(src_node)
227 for current_line in lines.visit(src_node):
228 for _ in range(after):
229 dst_contents += str(empty_line)
230 before, after = elt.maybe_empty_lines(current_line)
231 for _ in range(before):
232 dst_contents += str(empty_line)
233 for line in split_line(current_line, line_length=line_length, py36=py36):
234 dst_contents += str(line)
239 pygram.python_grammar_no_print_statement_no_exec_statement,
240 pygram.python_grammar_no_print_statement,
241 pygram.python_grammar_no_exec_statement,
242 pygram.python_grammar,
246 def lib2to3_parse(src_txt: str) -> Node:
247 """Given a string with source, return the lib2to3 Node."""
248 grammar = pygram.python_grammar_no_print_statement
249 if src_txt[-1] != '\n':
250 nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
252 for grammar in GRAMMARS:
253 drv = driver.Driver(grammar, pytree.convert)
255 result = drv.parse_string(src_txt, True)
258 except ParseError as pe:
259 lineno, column = pe.context[1]
260 lines = src_txt.splitlines()
262 faulty_line = lines[lineno - 1]
264 faulty_line = "<line number missing in source>"
265 exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
269 if isinstance(result, Leaf):
270 result = Node(syms.file_input, [result])
274 def lib2to3_unparse(node: Node) -> str:
275 """Given a lib2to3 node, return its string representation."""
283 class Visitor(Generic[T]):
284 """Basic lib2to3 visitor that yields things on visiting."""
286 def visit(self, node: LN) -> Iterator[T]:
288 name = token.tok_name[node.type]
290 name = type_repr(node.type)
291 yield from getattr(self, f'visit_{name}', self.visit_default)(node)
293 def visit_default(self, node: LN) -> Iterator[T]:
294 if isinstance(node, Node):
295 for child in node.children:
296 yield from self.visit(child)
300 class DebugVisitor(Visitor[T]):
303 def visit_default(self, node: LN) -> Iterator[T]:
304 indent = ' ' * (2 * self.tree_depth)
305 if isinstance(node, Node):
306 _type = type_repr(node.type)
307 out(f'{indent}{_type}', fg='yellow')
309 for child in node.children:
310 yield from self.visit(child)
313 out(f'{indent}/{_type}', fg='yellow', bold=False)
315 _type = token.tok_name.get(node.type, str(node.type))
316 out(f'{indent}{_type}', fg='blue', nl=False)
318 # We don't have to handle prefixes for `Node` objects since
319 # that delegates to the first child anyway.
320 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
321 out(f' {node.value!r}', fg='blue', bold=False)
324 def show(cls, code: str) -> None:
325 """Pretty-prints a given string of `code`.
327 Convenience method for debugging.
329 v: DebugVisitor[None] = DebugVisitor()
330 list(v.visit(lib2to3_parse(code)))
333 KEYWORDS = set(keyword.kwlist)
334 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
335 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
346 STANDALONE_COMMENT = 153
347 LOGIC_OPERATORS = {'and', 'or'}
371 COMPREHENSION_PRIORITY = 20
375 COMPARATOR_PRIORITY = 3
380 class BracketTracker:
382 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
383 delimiters: Dict[LeafID, Priority] = Factory(dict)
384 previous: Optional[Leaf] = None
386 def mark(self, leaf: Leaf) -> None:
387 if leaf.type == token.COMMENT:
390 if leaf.type in CLOSING_BRACKETS:
392 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
393 leaf.opening_bracket = opening_bracket
394 leaf.bracket_depth = self.depth
396 delim = is_delimiter(leaf)
398 self.delimiters[id(leaf)] = delim
399 elif self.previous is not None:
400 if leaf.type == token.STRING and self.previous.type == token.STRING:
401 self.delimiters[id(self.previous)] = STRING_PRIORITY
403 leaf.type == token.NAME
404 and leaf.value == 'for'
406 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
408 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
410 leaf.type == token.NAME
411 and leaf.value == 'if'
413 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
415 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
417 leaf.type == token.NAME
418 and leaf.value in LOGIC_OPERATORS
421 self.delimiters[id(self.previous)] = LOGIC_PRIORITY
422 if leaf.type in OPENING_BRACKETS:
423 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
427 def any_open_brackets(self) -> bool:
428 """Returns True if there is an yet unmatched open bracket on the line."""
429 return bool(self.bracket_match)
431 def max_priority(self, exclude: Iterable[LeafID] = ()) -> int:
432 """Returns the highest priority of a delimiter found on the line.
434 Values are consistent with what `is_delimiter()` returns.
436 return max(v for k, v in self.delimiters.items() if k not in exclude)
442 leaves: List[Leaf] = Factory(list)
443 comments: Dict[LeafID, Leaf] = Factory(dict)
444 bracket_tracker: BracketTracker = Factory(BracketTracker)
445 inside_brackets: bool = False
446 has_for: bool = False
447 _for_loop_variable: bool = False
449 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
450 has_value = leaf.value.strip()
454 if self.leaves and not preformatted:
455 # Note: at this point leaf.prefix should be empty except for
456 # imports, for which we only preserve newlines.
457 leaf.prefix += whitespace(leaf)
458 if self.inside_brackets or not preformatted:
459 self.maybe_decrement_after_for_loop_variable(leaf)
460 self.bracket_tracker.mark(leaf)
461 self.maybe_remove_trailing_comma(leaf)
462 self.maybe_increment_for_loop_variable(leaf)
463 if self.maybe_adapt_standalone_comment(leaf):
466 if not self.append_comment(leaf):
467 self.leaves.append(leaf)
470 def is_comment(self) -> bool:
471 return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
474 def is_decorator(self) -> bool:
475 return bool(self) and self.leaves[0].type == token.AT
478 def is_import(self) -> bool:
479 return bool(self) and is_import(self.leaves[0])
482 def is_class(self) -> bool:
485 and self.leaves[0].type == token.NAME
486 and self.leaves[0].value == 'class'
490 def is_def(self) -> bool:
491 """Also returns True for async defs."""
493 first_leaf = self.leaves[0]
498 second_leaf: Optional[Leaf] = self.leaves[1]
502 (first_leaf.type == token.NAME and first_leaf.value == 'def')
504 first_leaf.type == token.ASYNC
505 and second_leaf is not None
506 and second_leaf.type == token.NAME
507 and second_leaf.value == 'def'
512 def is_flow_control(self) -> bool:
515 and self.leaves[0].type == token.NAME
516 and self.leaves[0].value in FLOW_CONTROL
520 def is_yield(self) -> bool:
523 and self.leaves[0].type == token.NAME
524 and self.leaves[0].value == 'yield'
527 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
530 and self.leaves[-1].type == token.COMMA
531 and closing.type in CLOSING_BRACKETS
535 if closing.type == token.RBRACE:
539 if closing.type == token.RSQB:
540 comma = self.leaves[-1]
541 if comma.parent and comma.parent.type == syms.listmaker:
545 # For parens let's check if it's safe to remove the comma. If the
546 # trailing one is the only one, we might mistakenly change a tuple
547 # into a different type by removing the comma.
548 depth = closing.bracket_depth + 1
550 opening = closing.opening_bracket
551 for _opening_index, leaf in enumerate(self.leaves):
558 for leaf in self.leaves[_opening_index + 1:]:
562 bracket_depth = leaf.bracket_depth
563 if bracket_depth == depth and leaf.type == token.COMMA:
565 if leaf.parent and leaf.parent.type == syms.arglist:
575 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
576 """In a for loop, or comprehension, the variables are often unpacks.
578 To avoid splitting on the comma in this situation, we will increase
579 the depth of tokens between `for` and `in`.
581 if leaf.type == token.NAME and leaf.value == 'for':
583 self.bracket_tracker.depth += 1
584 self._for_loop_variable = True
589 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
590 # See `maybe_increment_for_loop_variable` above for explanation.
591 if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
592 self.bracket_tracker.depth -= 1
593 self._for_loop_variable = False
598 def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
599 """Hack a standalone comment to act as a trailing comment for line splitting.
601 If this line has brackets and a standalone `comment`, we need to adapt
602 it to be able to still reformat the line.
604 This is not perfect, the line to which the standalone comment gets
605 appended will appear "too long" when splitting.
608 comment.type == STANDALONE_COMMENT
609 and self.bracket_tracker.any_open_brackets()
613 comment.type = token.COMMENT
614 comment.prefix = '\n' + ' ' * (self.depth + 1)
615 return self.append_comment(comment)
617 def append_comment(self, comment: Leaf) -> bool:
618 if comment.type != token.COMMENT:
622 after = id(self.last_non_delimiter())
624 comment.type = STANDALONE_COMMENT
629 if after in self.comments:
630 self.comments[after].value += str(comment)
632 self.comments[after] = comment
635 def last_non_delimiter(self) -> Leaf:
636 for i in range(len(self.leaves)):
637 last = self.leaves[-i - 1]
638 if not is_delimiter(last):
641 raise LookupError("No non-delimiters found")
643 def __str__(self) -> str:
647 indent = ' ' * self.depth
648 leaves = iter(self.leaves)
650 res = f'{first.prefix}{indent}{first.value}'
653 for comment in self.comments.values():
657 def __bool__(self) -> bool:
658 return bool(self.leaves or self.comments)
662 class EmptyLineTracker:
663 """Provides a stateful method that returns the number of potential extra
664 empty lines needed before and after the currently processed line.
666 Note: this tracker works on lines that haven't been split yet. It assumes
667 the prefix of the first leaf consists of optional newlines. Those newlines
668 are consumed by `maybe_empty_lines()` and included in the computation.
670 previous_line: Optional[Line] = None
671 previous_after: int = 0
672 previous_defs: List[int] = Factory(list)
674 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
675 """Returns the number of extra empty lines before and after the `current_line`.
677 This is for separating `def`, `async def` and `class` with extra empty lines
678 (two on module-level), as well as providing an extra empty line after flow
679 control keywords to make them more prominent.
681 before, after = self._maybe_empty_lines(current_line)
682 before -= self.previous_after
683 self.previous_after = after
684 self.previous_line = current_line
687 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
689 if current_line.is_comment and current_line.depth == 0:
691 if current_line.leaves:
692 # Consume the first leaf's extra newlines.
693 first_leaf = current_line.leaves[0]
694 before = first_leaf.prefix.count('\n')
695 before = min(before, max(before, max_allowed))
696 first_leaf.prefix = ''
699 depth = current_line.depth
700 while self.previous_defs and self.previous_defs[-1] >= depth:
701 self.previous_defs.pop()
702 before = 1 if depth else 2
703 is_decorator = current_line.is_decorator
704 if is_decorator or current_line.is_def or current_line.is_class:
706 self.previous_defs.append(depth)
707 if self.previous_line is None:
708 # Don't insert empty lines before the first line in the file.
711 if self.previous_line and self.previous_line.is_decorator:
712 # Don't insert empty lines between decorators.
716 if current_line.depth:
720 if current_line.is_flow_control:
725 and self.previous_line.is_import
726 and not current_line.is_import
727 and depth == self.previous_line.depth
729 return (before or 1), 0
733 and self.previous_line.is_yield
734 and (not current_line.is_yield or depth != self.previous_line.depth)
736 return (before or 1), 0
742 class LineGenerator(Visitor[Line]):
743 """Generates reformatted Line objects. Empty lines are not emitted.
745 Note: destroys the tree it's visiting by mutating prefixes of its leaves
746 in ways that will no longer stringify to valid Python code on the tree.
748 current_line: Line = Factory(Line)
750 def line(self, indent: int = 0) -> Iterator[Line]:
753 If the line is empty, only emit if it makes sense.
754 If the line is too long, split it first and then generate.
756 If any lines were generated, set up a new current_line.
758 if not self.current_line:
759 self.current_line.depth += indent
760 return # Line is empty, don't emit. Creating a new one unnecessary.
762 complete_line = self.current_line
763 self.current_line = Line(depth=complete_line.depth + indent)
766 def visit_default(self, node: LN) -> Iterator[Line]:
767 if isinstance(node, Leaf):
768 any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
769 for comment in generate_comments(node):
770 if any_open_brackets:
771 # any comment within brackets is subject to splitting
772 self.current_line.append(comment)
773 elif comment.type == token.COMMENT:
774 # regular trailing comment
775 self.current_line.append(comment)
776 yield from self.line()
779 # regular standalone comment
780 yield from self.line()
782 self.current_line.append(comment)
783 yield from self.line()
785 normalize_prefix(node, inside_brackets=any_open_brackets)
786 if node.type not in WHITESPACE:
787 self.current_line.append(node)
788 yield from super().visit_default(node)
790 def visit_INDENT(self, node: Node) -> Iterator[Line]:
791 yield from self.line(+1)
792 yield from self.visit_default(node)
794 def visit_DEDENT(self, node: Node) -> Iterator[Line]:
795 yield from self.line(-1)
797 def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
798 """Visit a statement.
800 The relevant Python language keywords for this statement are NAME leaves
803 for child in node.children:
804 if child.type == token.NAME and child.value in keywords: # type: ignore
805 yield from self.line()
807 yield from self.visit(child)
809 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
810 """A statement without nested statements."""
811 is_suite_like = node.parent and node.parent.type in STATEMENT
813 yield from self.line(+1)
814 yield from self.visit_default(node)
815 yield from self.line(-1)
818 yield from self.line()
819 yield from self.visit_default(node)
821 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
822 yield from self.line()
824 children = iter(node.children)
825 for child in children:
826 yield from self.visit(child)
828 if child.type == token.ASYNC:
831 internal_stmt = next(children)
832 for child in internal_stmt.children:
833 yield from self.visit(child)
835 def visit_decorators(self, node: Node) -> Iterator[Line]:
836 for child in node.children:
837 yield from self.line()
838 yield from self.visit(child)
840 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
841 yield from self.line()
843 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
844 yield from self.visit_default(leaf)
845 yield from self.line()
847 def __attrs_post_init__(self) -> None:
848 """You are in a twisty little maze of passages."""
850 self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
851 self.visit_while_stmt = partial(v, keywords={'while', 'else'})
852 self.visit_for_stmt = partial(v, keywords={'for', 'else'})
853 self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
854 self.visit_except_clause = partial(v, keywords={'except'})
855 self.visit_funcdef = partial(v, keywords={'def'})
856 self.visit_with_stmt = partial(v, keywords={'with'})
857 self.visit_classdef = partial(v, keywords={'class'})
858 self.visit_async_funcdef = self.visit_async_stmt
859 self.visit_decorated = self.visit_decorators
862 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
863 OPENING_BRACKETS = set(BRACKET.keys())
864 CLOSING_BRACKETS = set(BRACKET.values())
865 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
866 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
869 def whitespace(leaf: Leaf) -> str: # noqa C901
870 """Return whitespace prefix if needed for the given `leaf`."""
877 if t in ALWAYS_NO_SPACE:
880 if t == token.COMMENT:
883 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
884 if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
887 prev = leaf.prev_sibling
889 prevp = preceding_leaf(p)
890 if not prevp or prevp.type in OPENING_BRACKETS:
894 return SPACE if prevp.type == token.COMMA else NO
896 if prevp.type == token.EQUAL:
898 if prevp.parent.type in {
899 syms.arglist, syms.argument, syms.parameters, syms.varargslist
903 elif prevp.parent.type == syms.typedargslist:
904 # A bit hacky: if the equal sign has whitespace, it means we
905 # previously found it's a typed argument. So, we're using
909 elif prevp.type == token.DOUBLESTAR:
910 if prevp.parent and prevp.parent.type in {
920 elif prevp.type == token.COLON:
921 if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
926 and prevp.parent.type in {syms.factor, syms.star_expr}
927 and prevp.type in MATH_OPERATORS
932 prevp.type == token.RIGHTSHIFT
934 and prevp.parent.type == syms.shift_expr
935 and prevp.prev_sibling
936 and prevp.prev_sibling.type == token.NAME
937 and prevp.prev_sibling.value == 'print' # type: ignore
939 # Python 2 print chevron
942 elif prev.type in OPENING_BRACKETS:
945 if p.type in {syms.parameters, syms.arglist}:
946 # untyped function signatures or calls
950 if not prev or prev.type != token.COMMA:
953 elif p.type == syms.varargslist:
958 if prev and prev.type != token.COMMA:
961 elif p.type == syms.typedargslist:
962 # typed function signatures
967 if prev.type != syms.tname:
970 elif prev.type == token.EQUAL:
971 # A bit hacky: if the equal sign has whitespace, it means we
972 # previously found it's a typed argument. So, we're using that, too.
975 elif prev.type != token.COMMA:
978 elif p.type == syms.tname:
981 prevp = preceding_leaf(p)
982 if not prevp or prevp.type != token.COMMA:
985 elif p.type == syms.trailer:
986 # attributes and calls
987 if t == token.LPAR or t == token.RPAR:
992 prevp = preceding_leaf(p)
993 if not prevp or prevp.type != token.NUMBER:
996 elif t == token.LSQB:
999 elif prev.type != token.COMMA:
1002 elif p.type == syms.argument:
1004 if t == token.EQUAL:
1008 prevp = preceding_leaf(p)
1009 if not prevp or prevp.type == token.LPAR:
1012 elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1015 elif p.type == syms.decorator:
1019 elif p.type == syms.dotted_name:
1023 prevp = preceding_leaf(p)
1024 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1027 elif p.type == syms.classdef:
1031 if prev and prev.type == token.LPAR:
1034 elif p.type == syms.subscript:
1037 assert p.parent is not None, "subscripts are always parented"
1038 if p.parent.type == syms.subscriptlist:
1046 elif p.type == syms.atom:
1047 if prev and t == token.DOT:
1048 # dots, but not the first one.
1052 p.type == syms.listmaker
1053 or p.type == syms.testlist_gexp
1054 or p.type == syms.subscriptlist
1056 # list interior, including unpacking
1060 elif p.type == syms.dictsetmaker:
1061 # dict and set interior, including unpacking
1065 if prev.type == token.DOUBLESTAR:
1068 elif p.type in {syms.factor, syms.star_expr}:
1071 prevp = preceding_leaf(p)
1072 if not prevp or prevp.type in OPENING_BRACKETS:
1075 prevp_parent = prevp.parent
1076 assert prevp_parent is not None
1077 if prevp.type == token.COLON and prevp_parent.type in {
1078 syms.subscript, syms.sliceop
1082 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1085 elif t == token.NAME or t == token.NUMBER:
1088 elif p.type == syms.import_from:
1090 if prev and prev.type == token.DOT:
1093 elif t == token.NAME:
1097 if prev and prev.type == token.DOT:
1100 elif p.type == syms.sliceop:
1106 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1107 """Returns the first leaf that precedes `node`, if any."""
1109 res = node.prev_sibling
1111 if isinstance(res, Leaf):
1115 return list(res.leaves())[-1]
1124 def is_delimiter(leaf: Leaf) -> int:
1125 """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1127 Higher numbers are higher priority.
1129 if leaf.type == token.COMMA:
1130 return COMMA_PRIORITY
1132 if leaf.type in COMPARATORS:
1133 return COMPARATOR_PRIORITY
1136 leaf.type in MATH_OPERATORS
1138 and leaf.parent.type not in {syms.factor, syms.star_expr}
1140 return MATH_PRIORITY
1145 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1146 """Cleans the prefix of the `leaf` and generates comments from it, if any.
1148 Comments in lib2to3 are shoved into the whitespace prefix. This happens
1149 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
1150 move because it does away with modifying the grammar to include all the
1151 possible places in which comments can be placed.
1153 The sad consequence for us though is that comments don't "belong" anywhere.
1154 This is why this function generates simple parentless Leaf objects for
1155 comments. We simply don't know what the correct parent should be.
1157 No matter though, we can live without this. We really only need to
1158 differentiate between inline and standalone comments. The latter don't
1159 share the line with any code.
1161 Inline comments are emitted as regular token.COMMENT leaves. Standalone
1162 are emitted with a fake STANDALONE_COMMENT token identifier.
1172 for index, line in enumerate(p.split('\n')):
1173 line = line.lstrip()
1176 if not line.startswith('#'):
1179 if index == 0 and leaf.type != token.ENDMARKER:
1180 comment_type = token.COMMENT # simple trailing comment
1182 comment_type = STANDALONE_COMMENT
1183 yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
1188 def make_comment(content: str) -> str:
1189 content = content.rstrip()
1193 if content[0] == '#':
1194 content = content[1:]
1195 if content and content[0] not in ' !:#':
1196 content = ' ' + content
1197 return '#' + content
1201 line: Line, line_length: int, inner: bool = False, py36: bool = False
1202 ) -> Iterator[Line]:
1203 """Splits a `line` into potentially many lines.
1205 They should fit in the allotted `line_length` but might not be able to.
1206 `inner` signifies that there were a pair of brackets somewhere around the
1207 current `line`, possibly transitively. This means we can fallback to splitting
1208 by delimiters if the LHS/RHS don't yield any results.
1210 If `py36` is True, splitting may generate syntax that is only compatible
1211 with Python 3.6 and later.
1213 line_str = str(line).strip('\n')
1214 if len(line_str) <= line_length and '\n' not in line_str:
1219 split_funcs = [left_hand_split]
1220 elif line.inside_brackets:
1221 split_funcs = [delimiter_split]
1222 if '\n' not in line_str:
1223 # Only attempt RHS if we don't have multiline strings or comments
1225 split_funcs.append(right_hand_split)
1227 split_funcs = [right_hand_split]
1228 for split_func in split_funcs:
1229 # We are accumulating lines in `result` because we might want to abort
1230 # mission and return the original line in the end, or attempt a different
1232 result: List[Line] = []
1234 for l in split_func(line, py36=py36):
1235 if str(l).strip('\n') == line_str:
1236 raise CannotSplit("Split function returned an unchanged result")
1239 split_line(l, line_length=line_length, inner=True, py36=py36)
1241 except CannotSplit as cs:
1252 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1253 """Split line into many lines, starting with the first matching bracket pair.
1255 Note: this usually looks weird, only use this for function definitions.
1256 Prefer RHS otherwise.
1258 head = Line(depth=line.depth)
1259 body = Line(depth=line.depth + 1, inside_brackets=True)
1260 tail = Line(depth=line.depth)
1261 tail_leaves: List[Leaf] = []
1262 body_leaves: List[Leaf] = []
1263 head_leaves: List[Leaf] = []
1264 current_leaves = head_leaves
1265 matching_bracket = None
1266 for leaf in line.leaves:
1268 current_leaves is body_leaves
1269 and leaf.type in CLOSING_BRACKETS
1270 and leaf.opening_bracket is matching_bracket
1272 current_leaves = tail_leaves if body_leaves else head_leaves
1273 current_leaves.append(leaf)
1274 if current_leaves is head_leaves:
1275 if leaf.type in OPENING_BRACKETS:
1276 matching_bracket = leaf
1277 current_leaves = body_leaves
1278 # Since body is a new indent level, remove spurious leading whitespace.
1280 normalize_prefix(body_leaves[0], inside_brackets=True)
1281 # Build the new lines.
1282 for result, leaves in (
1283 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1286 result.append(leaf, preformatted=True)
1287 comment_after = line.comments.get(id(leaf))
1289 result.append(comment_after, preformatted=True)
1290 split_succeeded_or_raise(head, body, tail)
1291 for result in (head, body, tail):
1296 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1297 """Split line into many lines, starting with the last matching bracket pair."""
1298 head = Line(depth=line.depth)
1299 body = Line(depth=line.depth + 1, inside_brackets=True)
1300 tail = Line(depth=line.depth)
1301 tail_leaves: List[Leaf] = []
1302 body_leaves: List[Leaf] = []
1303 head_leaves: List[Leaf] = []
1304 current_leaves = tail_leaves
1305 opening_bracket = None
1306 for leaf in reversed(line.leaves):
1307 if current_leaves is body_leaves:
1308 if leaf is opening_bracket:
1309 current_leaves = head_leaves if body_leaves else tail_leaves
1310 current_leaves.append(leaf)
1311 if current_leaves is tail_leaves:
1312 if leaf.type in CLOSING_BRACKETS:
1313 opening_bracket = leaf.opening_bracket
1314 current_leaves = body_leaves
1315 tail_leaves.reverse()
1316 body_leaves.reverse()
1317 head_leaves.reverse()
1318 # Since body is a new indent level, remove spurious leading whitespace.
1320 normalize_prefix(body_leaves[0], inside_brackets=True)
1321 # Build the new lines.
1322 for result, leaves in (
1323 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1326 result.append(leaf, preformatted=True)
1327 comment_after = line.comments.get(id(leaf))
1329 result.append(comment_after, preformatted=True)
1330 split_succeeded_or_raise(head, body, tail)
1331 for result in (head, body, tail):
1336 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1337 tail_len = len(str(tail).strip())
1340 raise CannotSplit("Splitting brackets produced the same line")
1344 f"Splitting brackets on an empty body to save "
1345 f"{tail_len} characters is not worth it"
1349 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1350 """Split according to delimiters of the highest priority.
1352 This kind of split doesn't increase indentation.
1353 If `py36` is True, the split will add trailing commas also in function
1354 signatures that contain * and **.
1357 last_leaf = line.leaves[-1]
1359 raise CannotSplit("Line empty")
1361 delimiters = line.bracket_tracker.delimiters
1363 delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1365 raise CannotSplit("No delimiters found")
1367 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1368 lowest_depth = sys.maxsize
1369 trailing_comma_safe = True
1370 for leaf in line.leaves:
1371 current_line.append(leaf, preformatted=True)
1372 comment_after = line.comments.get(id(leaf))
1374 current_line.append(comment_after, preformatted=True)
1375 lowest_depth = min(lowest_depth, leaf.bracket_depth)
1377 leaf.bracket_depth == lowest_depth
1378 and leaf.type == token.STAR
1379 or leaf.type == token.DOUBLESTAR
1381 trailing_comma_safe = trailing_comma_safe and py36
1382 leaf_priority = delimiters.get(id(leaf))
1383 if leaf_priority == delimiter_priority:
1384 normalize_prefix(current_line.leaves[0], inside_brackets=True)
1387 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1390 delimiter_priority == COMMA_PRIORITY
1391 and current_line.leaves[-1].type != token.COMMA
1392 and trailing_comma_safe
1394 current_line.append(Leaf(token.COMMA, ','))
1395 normalize_prefix(current_line.leaves[0], inside_brackets=True)
1399 def is_import(leaf: Leaf) -> bool:
1400 """Returns True if the given leaf starts an import statement."""
1407 (v == 'import' and p and p.type == syms.import_name)
1408 or (v == 'from' and p and p.type == syms.import_from)
1413 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1414 """Leave existing extra newlines if not `inside_brackets`.
1416 Remove everything else. Note: don't use backslashes for formatting or
1417 you'll lose your voting rights.
1419 if not inside_brackets:
1420 spl = leaf.prefix.split('#')
1421 if '\\' not in spl[0]:
1422 nl_count = spl[-1].count('\n')
1425 leaf.prefix = '\n' * nl_count
1431 def is_python36(node: Node) -> bool:
1432 """Returns True if the current file is using Python 3.6+ features.
1434 Currently looking for:
1436 - trailing commas after * or ** in function signatures.
1438 for n in node.pre_order():
1439 if n.type == token.STRING:
1440 value_head = n.value[:2] # type: ignore
1441 if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1445 n.type == syms.typedargslist
1447 and n.children[-1].type == token.COMMA
1449 for ch in n.children:
1450 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1456 PYTHON_EXTENSIONS = {'.py'}
1457 BLACKLISTED_DIRECTORIES = {
1458 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1462 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1463 for child in path.iterdir():
1465 if child.name in BLACKLISTED_DIRECTORIES:
1468 yield from gen_python_files_in_dir(child)
1470 elif child.suffix in PYTHON_EXTENSIONS:
1476 """Provides a reformatting counter."""
1478 change_count: int = 0
1480 failure_count: int = 0
1482 def done(self, src: Path, changed: bool) -> None:
1483 """Increment the counter for successful reformatting. Write out a message."""
1485 reformatted = 'would reformat' if self.check else 'reformatted'
1486 out(f'{reformatted} {src}')
1487 self.change_count += 1
1489 out(f'{src} already well formatted, good job.', bold=False)
1490 self.same_count += 1
1492 def failed(self, src: Path, message: str) -> None:
1493 """Increment the counter for failed reformatting. Write out a message."""
1494 err(f'error: cannot format {src}: {message}')
1495 self.failure_count += 1
1498 def return_code(self) -> int:
1499 """Which return code should the app use considering the current state."""
1500 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1501 # 126 we have special returncodes reserved by the shell.
1502 if self.failure_count:
1505 elif self.change_count and self.check:
1510 def __str__(self) -> str:
1511 """A color report of the current state.
1513 Use `click.unstyle` to remove colors.
1516 reformatted = "would be reformatted"
1517 unchanged = "would be left unchanged"
1518 failed = "would fail to reformat"
1520 reformatted = "reformatted"
1521 unchanged = "left unchanged"
1522 failed = "failed to reformat"
1524 if self.change_count:
1525 s = 's' if self.change_count > 1 else ''
1527 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1530 s = 's' if self.same_count > 1 else ''
1531 report.append(f'{self.same_count} file{s} {unchanged}')
1532 if self.failure_count:
1533 s = 's' if self.failure_count > 1 else ''
1535 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1537 return ', '.join(report) + '.'
1540 def assert_equivalent(src: str, dst: str) -> None:
1541 """Raises AssertionError if `src` and `dst` aren't equivalent.
1543 This is a temporary sanity check until Black becomes stable.
1549 def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1550 """Simple visitor generating strings to compare ASTs by content."""
1551 yield f"{' ' * depth}{node.__class__.__name__}("
1553 for field in sorted(node._fields):
1555 value = getattr(node, field)
1556 except AttributeError:
1559 yield f"{' ' * (depth+1)}{field}="
1561 if isinstance(value, list):
1563 if isinstance(item, ast.AST):
1564 yield from _v(item, depth + 2)
1566 elif isinstance(value, ast.AST):
1567 yield from _v(value, depth + 2)
1570 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
1572 yield f"{' ' * depth}) # /{node.__class__.__name__}"
1575 src_ast = ast.parse(src)
1576 except Exception as exc:
1577 major, minor = sys.version_info[:2]
1578 raise AssertionError(
1579 f"cannot use --safe with this file; failed to parse source file "
1580 f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1581 f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1585 dst_ast = ast.parse(dst)
1586 except Exception as exc:
1587 log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1588 raise AssertionError(
1589 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1590 f"Please report a bug on https://github.com/ambv/black/issues. "
1591 f"This invalid output might be helpful: {log}"
1594 src_ast_str = '\n'.join(_v(src_ast))
1595 dst_ast_str = '\n'.join(_v(dst_ast))
1596 if src_ast_str != dst_ast_str:
1597 log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1598 raise AssertionError(
1599 f"INTERNAL ERROR: Black produced code that is not equivalent to "
1601 f"Please report a bug on https://github.com/ambv/black/issues. "
1602 f"This diff might be helpful: {log}"
1606 def assert_stable(src: str, dst: str, line_length: int) -> None:
1607 """Raises AssertionError if `dst` reformats differently the second time.
1609 This is a temporary sanity check until Black becomes stable.
1611 newdst = format_str(dst, line_length=line_length)
1614 diff(src, dst, 'source', 'first pass'),
1615 diff(dst, newdst, 'first pass', 'second pass'),
1617 raise AssertionError(
1618 f"INTERNAL ERROR: Black produced different code on the second pass "
1619 f"of the formatter. "
1620 f"Please report a bug on https://github.com/ambv/black/issues. "
1621 f"This diff might be helpful: {log}"
1625 def dump_to_file(*output: str) -> str:
1626 """Dumps `output` to a temporary file. Returns path to the file."""
1629 with tempfile.NamedTemporaryFile(
1630 mode='w', prefix='blk_', suffix='.log', delete=False
1632 for lines in output:
1638 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1639 """Returns a udiff string between strings `a` and `b`."""
1642 a_lines = [line + '\n' for line in a.split('\n')]
1643 b_lines = [line + '\n' for line in b.split('\n')]
1645 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1649 if __name__ == '__main__':