All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
8 from pathlib import Path
12 Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
15 from attr import dataclass, Factory
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
27 syms = pygram.python_symbols
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
39 class NothingChanged(UserWarning):
40 """Raised by `format_file` when the reformatted code is the same as source."""
43 class CannotSplit(Exception):
44 """A readable split that fits the allotted line length is impossible.
46 Raised by `left_hand_split()` and `right_hand_split()`.
55 default=DEFAULT_LINE_LENGTH,
56 help='How many character per line to allow.',
63 "Don't write back the files, just return the status. Return code 0 "
64 "means nothing changed. Return code 1 means some files were "
65 "reformatted. Return code 123 means there was an internal error."
71 help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 @click.version_option(version=__version__)
78 exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
83 ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 """The uncompromising code formatter."""
86 sources: List[Path] = []
90 sources.extend(gen_python_files_in_dir(p))
92 # if a file was explicitly given, we don't care about its extension
95 sources.append(Path('-'))
97 err(f'invalid path: {s}')
100 elif len(sources) == 1:
104 if not p.is_file() and str(p) == '-':
105 changed = format_stdin_to_stdout(line_length=line_length, fast=fast)
107 changed = format_file_in_place(
108 p, line_length=line_length, fast=fast, write_back=not check
110 report.done(p, changed)
111 except Exception as exc:
112 report.failed(p, str(exc))
113 ctx.exit(report.return_code)
115 loop = asyncio.get_event_loop()
116 executor = ProcessPoolExecutor(max_workers=os.cpu_count())
119 return_code = loop.run_until_complete(
121 sources, line_length, not check, fast, loop, executor
126 ctx.exit(return_code)
129 async def schedule_formatting(
138 src: loop.run_in_executor(
139 executor, format_file_in_place, src, line_length, fast, write_back
143 await asyncio.wait(tasks.values())
146 for src, task in tasks.items():
148 report.failed(src, 'timed out, cancelling')
150 cancelled.append(task)
151 elif task.exception():
152 report.failed(src, str(task.exception()))
154 report.done(src, task.result())
156 await asyncio.wait(cancelled, timeout=2)
157 out('All done! ✨ 🍰 ✨')
158 click.echo(str(report))
159 return report.return_code
162 def format_file_in_place(
163 src: Path, line_length: int, fast: bool, write_back: bool = False
165 """Format the file and rewrite if changed. Return True if changed."""
166 with tokenize.open(src) as src_buffer:
167 src_contents = src_buffer.read()
169 contents = format_file_contents(
170 src_contents, line_length=line_length, fast=fast
172 except NothingChanged:
176 with open(src, "w", encoding=src_buffer.encoding) as f:
181 def format_stdin_to_stdout(line_length: int, fast: bool) -> bool:
182 """Format file on stdin and pipe output to stdout. Return True if changed."""
183 contents = sys.stdin.read()
185 contents = format_file_contents(contents, line_length=line_length, fast=fast)
188 except NothingChanged:
192 sys.stdout.write(contents)
195 def format_file_contents(
196 src_contents: str, line_length: int, fast: bool
198 """Reformats a file and returns its contents and encoding."""
199 if src_contents.strip() == '':
202 dst_contents = format_str(src_contents, line_length=line_length)
203 if src_contents == dst_contents:
207 assert_equivalent(src_contents, dst_contents)
208 assert_stable(src_contents, dst_contents, line_length=line_length)
212 def format_str(src_contents: str, line_length: int) -> FileContent:
213 """Reformats a string and returns new contents."""
214 src_node = lib2to3_parse(src_contents)
216 comments: List[Line] = []
217 lines = LineGenerator()
218 elt = EmptyLineTracker()
219 py36 = is_python36(src_node)
222 for current_line in lines.visit(src_node):
223 for _ in range(after):
224 dst_contents += str(empty_line)
225 before, after = elt.maybe_empty_lines(current_line)
226 for _ in range(before):
227 dst_contents += str(empty_line)
228 if not current_line.is_comment:
229 for comment in comments:
230 dst_contents += str(comment)
232 for line in split_line(current_line, line_length=line_length, py36=py36):
233 dst_contents += str(line)
235 comments.append(current_line)
237 if elt.previous_defs:
238 # Separate postscriptum comments from the last module-level def.
239 dst_contents += str(empty_line)
240 dst_contents += str(empty_line)
241 for comment in comments:
242 dst_contents += str(comment)
246 def lib2to3_parse(src_txt: str) -> Node:
247 """Given a string with source, return the lib2to3 Node."""
248 grammar = pygram.python_grammar_no_print_statement
249 drv = driver.Driver(grammar, pytree.convert)
250 if src_txt[-1] != '\n':
251 nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
254 result = drv.parse_string(src_txt, True)
255 except ParseError as pe:
256 lineno, column = pe.context[1]
257 lines = src_txt.splitlines()
259 faulty_line = lines[lineno - 1]
261 faulty_line = "<line number missing in source>"
262 raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
264 if isinstance(result, Leaf):
265 result = Node(syms.file_input, [result])
269 def lib2to3_unparse(node: Node) -> str:
270 """Given a lib2to3 node, return its string representation."""
278 class Visitor(Generic[T]):
279 """Basic lib2to3 visitor that yields things on visiting."""
281 def visit(self, node: LN) -> Iterator[T]:
283 name = token.tok_name[node.type]
285 name = type_repr(node.type)
286 yield from getattr(self, f'visit_{name}', self.visit_default)(node)
288 def visit_default(self, node: LN) -> Iterator[T]:
289 if isinstance(node, Node):
290 for child in node.children:
291 yield from self.visit(child)
295 class DebugVisitor(Visitor[T]):
298 def visit_default(self, node: LN) -> Iterator[T]:
299 indent = ' ' * (2 * self.tree_depth)
300 if isinstance(node, Node):
301 _type = type_repr(node.type)
302 out(f'{indent}{_type}', fg='yellow')
304 for child in node.children:
305 yield from self.visit(child)
308 out(f'{indent}/{_type}', fg='yellow', bold=False)
310 _type = token.tok_name.get(node.type, str(node.type))
311 out(f'{indent}{_type}', fg='blue', nl=False)
313 # We don't have to handle prefixes for `Node` objects since
314 # that delegates to the first child anyway.
315 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
316 out(f' {node.value!r}', fg='blue', bold=False)
319 KEYWORDS = set(keyword.kwlist)
320 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
321 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
332 STANDALONE_COMMENT = 153
333 LOGIC_OPERATORS = {'and', 'or'}
356 COMPREHENSION_PRIORITY = 20
360 COMPARATOR_PRIORITY = 3
365 class BracketTracker:
367 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
368 delimiters: Dict[LeafID, Priority] = Factory(dict)
369 previous: Optional[Leaf] = None
371 def mark(self, leaf: Leaf) -> None:
372 if leaf.type == token.COMMENT:
375 if leaf.type in CLOSING_BRACKETS:
377 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
378 leaf.opening_bracket = opening_bracket
379 leaf.bracket_depth = self.depth
381 delim = is_delimiter(leaf)
383 self.delimiters[id(leaf)] = delim
384 elif self.previous is not None:
385 if leaf.type == token.STRING and self.previous.type == token.STRING:
386 self.delimiters[id(self.previous)] = STRING_PRIORITY
388 leaf.type == token.NAME
389 and leaf.value == 'for'
391 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
393 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
395 leaf.type == token.NAME
396 and leaf.value == 'if'
398 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
400 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
402 leaf.type == token.NAME
403 and leaf.value in LOGIC_OPERATORS
406 self.delimiters[id(self.previous)] = LOGIC_PRIORITY
407 if leaf.type in OPENING_BRACKETS:
408 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
412 def any_open_brackets(self) -> bool:
413 """Returns True if there is an yet unmatched open bracket on the line."""
414 return bool(self.bracket_match)
416 def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
417 """Returns the highest priority of a delimiter found on the line.
419 Values are consistent with what `is_delimiter()` returns.
421 return max(v for k, v in self.delimiters.items() if k not in exclude)
427 leaves: List[Leaf] = Factory(list)
428 comments: Dict[LeafID, Leaf] = Factory(dict)
429 bracket_tracker: BracketTracker = Factory(BracketTracker)
430 inside_brackets: bool = False
431 has_for: bool = False
432 _for_loop_variable: bool = False
434 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
435 has_value = leaf.value.strip()
439 if self.leaves and not preformatted:
440 # Note: at this point leaf.prefix should be empty except for
441 # imports, for which we only preserve newlines.
442 leaf.prefix += whitespace(leaf)
443 if self.inside_brackets or not preformatted:
444 self.maybe_decrement_after_for_loop_variable(leaf)
445 self.bracket_tracker.mark(leaf)
446 self.maybe_remove_trailing_comma(leaf)
447 self.maybe_increment_for_loop_variable(leaf)
448 if self.maybe_adapt_standalone_comment(leaf):
451 if not self.append_comment(leaf):
452 self.leaves.append(leaf)
455 def is_comment(self) -> bool:
456 return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
459 def is_decorator(self) -> bool:
460 return bool(self) and self.leaves[0].type == token.AT
463 def is_import(self) -> bool:
464 return bool(self) and is_import(self.leaves[0])
467 def is_class(self) -> bool:
470 and self.leaves[0].type == token.NAME
471 and self.leaves[0].value == 'class'
475 def is_def(self) -> bool:
476 """Also returns True for async defs."""
478 first_leaf = self.leaves[0]
483 second_leaf: Optional[Leaf] = self.leaves[1]
487 (first_leaf.type == token.NAME and first_leaf.value == 'def')
489 first_leaf.type == token.NAME
490 and first_leaf.value == 'async'
491 and second_leaf is not None
492 and second_leaf.type == token.NAME
493 and second_leaf.value == 'def'
498 def is_flow_control(self) -> bool:
501 and self.leaves[0].type == token.NAME
502 and self.leaves[0].value in FLOW_CONTROL
506 def is_yield(self) -> bool:
509 and self.leaves[0].type == token.NAME
510 and self.leaves[0].value == 'yield'
513 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
516 and self.leaves[-1].type == token.COMMA
517 and closing.type in CLOSING_BRACKETS
521 if closing.type == token.RSQB or closing.type == token.RBRACE:
525 # For parens let's check if it's safe to remove the comma. If the
526 # trailing one is the only one, we might mistakenly change a tuple
527 # into a different type by removing the comma.
528 depth = closing.bracket_depth + 1
530 opening = closing.opening_bracket
531 for _opening_index, leaf in enumerate(self.leaves):
538 for leaf in self.leaves[_opening_index + 1:]:
542 bracket_depth = leaf.bracket_depth
543 if bracket_depth == depth and leaf.type == token.COMMA:
545 if leaf.parent and leaf.parent.type == syms.arglist:
555 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
556 """In a for loop, or comprehension, the variables are often unpacks.
558 To avoid splitting on the comma in this situation, we will increase
559 the depth of tokens between `for` and `in`.
561 if leaf.type == token.NAME and leaf.value == 'for':
563 self.bracket_tracker.depth += 1
564 self._for_loop_variable = True
569 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
570 # See `maybe_increment_for_loop_variable` above for explanation.
571 if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
572 self.bracket_tracker.depth -= 1
573 self._for_loop_variable = False
578 def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
579 """Hack a standalone comment to act as a trailing comment for line splitting.
581 If this line has brackets and a standalone `comment`, we need to adapt
582 it to be able to still reformat the line.
584 This is not perfect, the line to which the standalone comment gets
585 appended will appear "too long" when splitting.
588 comment.type == STANDALONE_COMMENT
589 and self.bracket_tracker.any_open_brackets()
593 comment.type = token.COMMENT
594 comment.prefix = '\n' + ' ' * (self.depth + 1)
595 return self.append_comment(comment)
597 def append_comment(self, comment: Leaf) -> bool:
598 if comment.type != token.COMMENT:
602 after = id(self.last_non_delimiter())
604 comment.type = STANDALONE_COMMENT
609 if after in self.comments:
610 self.comments[after].value += str(comment)
612 self.comments[after] = comment
615 def last_non_delimiter(self) -> Leaf:
616 for i in range(len(self.leaves)):
617 last = self.leaves[-i - 1]
618 if not is_delimiter(last):
621 raise LookupError("No non-delimiters found")
623 def __str__(self) -> str:
627 indent = ' ' * self.depth
628 leaves = iter(self.leaves)
630 res = f'{first.prefix}{indent}{first.value}'
633 for comment in self.comments.values():
637 def __bool__(self) -> bool:
638 return bool(self.leaves or self.comments)
642 class EmptyLineTracker:
643 """Provides a stateful method that returns the number of potential extra
644 empty lines needed before and after the currently processed line.
646 Note: this tracker works on lines that haven't been split yet. It assumes
647 the prefix of the first leaf consists of optional newlines. Those newlines
648 are consumed by `maybe_empty_lines()` and included in the computation.
650 previous_line: Optional[Line] = None
651 previous_after: int = 0
652 previous_defs: List[int] = Factory(list)
654 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
655 """Returns the number of extra empty lines before and after the `current_line`.
657 This is for separating `def`, `async def` and `class` with extra empty lines
658 (two on module-level), as well as providing an extra empty line after flow
659 control keywords to make them more prominent.
661 if current_line.is_comment:
662 # Don't count standalone comments towards previous empty lines.
665 before, after = self._maybe_empty_lines(current_line)
666 before -= self.previous_after
667 self.previous_after = after
668 self.previous_line = current_line
671 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
672 if current_line.leaves:
673 # Consume the first leaf's extra newlines.
674 first_leaf = current_line.leaves[0]
675 before = int('\n' in first_leaf.prefix)
676 first_leaf.prefix = ''
679 depth = current_line.depth
680 while self.previous_defs and self.previous_defs[-1] >= depth:
681 self.previous_defs.pop()
682 before = 1 if depth else 2
683 is_decorator = current_line.is_decorator
684 if is_decorator or current_line.is_def or current_line.is_class:
686 self.previous_defs.append(depth)
687 if self.previous_line is None:
688 # Don't insert empty lines before the first line in the file.
691 if self.previous_line and self.previous_line.is_decorator:
692 # Don't insert empty lines between decorators.
696 if current_line.depth:
700 if current_line.is_flow_control:
705 and self.previous_line.is_import
706 and not current_line.is_import
707 and depth == self.previous_line.depth
709 return (before or 1), 0
713 and self.previous_line.is_yield
714 and (not current_line.is_yield or depth != self.previous_line.depth)
716 return (before or 1), 0
722 class LineGenerator(Visitor[Line]):
723 """Generates reformatted Line objects. Empty lines are not emitted.
725 Note: destroys the tree it's visiting by mutating prefixes of its leaves
726 in ways that will no longer stringify to valid Python code on the tree.
728 current_line: Line = Factory(Line)
729 standalone_comments: List[Leaf] = Factory(list)
731 def line(self, indent: int = 0) -> Iterator[Line]:
734 If the line is empty, only emit if it makes sense.
735 If the line is too long, split it first and then generate.
737 If any lines were generated, set up a new current_line.
739 if not self.current_line:
740 self.current_line.depth += indent
741 return # Line is empty, don't emit. Creating a new one unnecessary.
743 complete_line = self.current_line
744 self.current_line = Line(depth=complete_line.depth + indent)
747 def visit_default(self, node: LN) -> Iterator[Line]:
748 if isinstance(node, Leaf):
749 for comment in generate_comments(node):
750 if self.current_line.bracket_tracker.any_open_brackets():
751 # any comment within brackets is subject to splitting
752 self.current_line.append(comment)
753 elif comment.type == token.COMMENT:
754 # regular trailing comment
755 self.current_line.append(comment)
756 yield from self.line()
759 # regular standalone comment, to be processed later (see
760 # docstring in `generate_comments()`
761 self.standalone_comments.append(comment)
762 normalize_prefix(node)
763 if node.type not in WHITESPACE:
764 for comment in self.standalone_comments:
765 yield from self.line()
767 self.current_line.append(comment)
768 yield from self.line()
770 self.standalone_comments = []
771 self.current_line.append(node)
772 yield from super().visit_default(node)
774 def visit_suite(self, node: Node) -> Iterator[Line]:
775 """Body of a statement after a colon."""
776 children = iter(node.children)
777 # Process newline before indenting. It might contain an inline
778 # comment that should go right after the colon.
779 newline = next(children)
780 yield from self.visit(newline)
781 yield from self.line(+1)
783 for child in children:
784 yield from self.visit(child)
786 yield from self.line(-1)
788 def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
789 """Visit a statement.
791 The relevant Python language keywords for this statement are NAME leaves
794 for child in node.children:
795 if child.type == token.NAME and child.value in keywords: # type: ignore
796 yield from self.line()
798 yield from self.visit(child)
800 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
801 """A statement without nested statements."""
802 is_suite_like = node.parent and node.parent.type in STATEMENT
804 yield from self.line(+1)
805 yield from self.visit_default(node)
806 yield from self.line(-1)
809 yield from self.line()
810 yield from self.visit_default(node)
812 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
813 yield from self.line()
815 children = iter(node.children)
816 for child in children:
817 yield from self.visit(child)
819 if child.type == token.NAME and child.value == 'async': # type: ignore
822 internal_stmt = next(children)
823 for child in internal_stmt.children:
824 yield from self.visit(child)
826 def visit_decorators(self, node: Node) -> Iterator[Line]:
827 for child in node.children:
828 yield from self.line()
829 yield from self.visit(child)
831 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
832 yield from self.line()
834 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
835 yield from self.visit_default(leaf)
836 yield from self.line()
838 def __attrs_post_init__(self) -> None:
839 """You are in a twisty little maze of passages."""
841 self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
842 self.visit_while_stmt = partial(v, keywords={'while', 'else'})
843 self.visit_for_stmt = partial(v, keywords={'for', 'else'})
844 self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
845 self.visit_except_clause = partial(v, keywords={'except'})
846 self.visit_funcdef = partial(v, keywords={'def'})
847 self.visit_with_stmt = partial(v, keywords={'with'})
848 self.visit_classdef = partial(v, keywords={'class'})
849 self.visit_async_funcdef = self.visit_async_stmt
850 self.visit_decorated = self.visit_decorators
853 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
854 OPENING_BRACKETS = set(BRACKET.keys())
855 CLOSING_BRACKETS = set(BRACKET.values())
856 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
857 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
860 def whitespace(leaf: Leaf) -> str: # noqa C901
861 """Return whitespace prefix if needed for the given `leaf`."""
868 if t in ALWAYS_NO_SPACE:
871 if t == token.COMMENT:
874 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
875 if t == token.COLON and p.type != syms.subscript:
878 prev = leaf.prev_sibling
880 prevp = preceding_leaf(p)
881 if not prevp or prevp.type in OPENING_BRACKETS:
885 return SPACE if prevp.type == token.COMMA else NO
887 if prevp.type == token.EQUAL:
888 if prevp.parent and prevp.parent.type in {
897 elif prevp.type == token.DOUBLESTAR:
898 if prevp.parent and prevp.parent.type in {
907 elif prevp.type == token.COLON:
908 if prevp.parent and prevp.parent.type == syms.subscript:
911 elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
914 elif prev.type in OPENING_BRACKETS:
917 if p.type in {syms.parameters, syms.arglist}:
918 # untyped function signatures or calls
922 if not prev or prev.type != token.COMMA:
925 if p.type == syms.varargslist:
930 if prev and prev.type != token.COMMA:
933 elif p.type == syms.typedargslist:
934 # typed function signatures
939 if prev.type != syms.tname:
942 elif prev.type == token.EQUAL:
943 # A bit hacky: if the equal sign has whitespace, it means we
944 # previously found it's a typed argument. So, we're using that, too.
947 elif prev.type != token.COMMA:
950 elif p.type == syms.tname:
953 prevp = preceding_leaf(p)
954 if not prevp or prevp.type != token.COMMA:
957 elif p.type == syms.trailer:
958 # attributes and calls
959 if t == token.LPAR or t == token.RPAR:
964 prevp = preceding_leaf(p)
965 if not prevp or prevp.type != token.NUMBER:
968 elif t == token.LSQB:
971 elif prev.type != token.COMMA:
974 elif p.type == syms.argument:
980 prevp = preceding_leaf(p)
981 if not prevp or prevp.type == token.LPAR:
984 elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
987 elif p.type == syms.decorator:
991 elif p.type == syms.dotted_name:
995 prevp = preceding_leaf(p)
996 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
999 elif p.type == syms.classdef:
1003 if prev and prev.type == token.LPAR:
1006 elif p.type == syms.subscript:
1009 assert p.parent is not None, "subscripts are always parented"
1010 if p.parent.type == syms.subscriptlist:
1018 elif p.type == syms.atom:
1019 if prev and t == token.DOT:
1020 # dots, but not the first one.
1024 p.type == syms.listmaker
1025 or p.type == syms.testlist_gexp
1026 or p.type == syms.subscriptlist
1028 # list interior, including unpacking
1032 elif p.type == syms.dictsetmaker:
1033 # dict and set interior, including unpacking
1037 if prev.type == token.DOUBLESTAR:
1040 elif p.type in {syms.factor, syms.star_expr}:
1043 prevp = preceding_leaf(p)
1044 if not prevp or prevp.type in OPENING_BRACKETS:
1047 prevp_parent = prevp.parent
1048 assert prevp_parent is not None
1049 if prevp.type == token.COLON and prevp_parent.type in {
1050 syms.subscript, syms.sliceop
1054 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1057 elif t == token.NAME or t == token.NUMBER:
1060 elif p.type == syms.import_from:
1062 if prev and prev.type == token.DOT:
1065 elif t == token.NAME:
1069 if prev and prev.type == token.DOT:
1072 elif p.type == syms.sliceop:
1078 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1079 """Returns the first leaf that precedes `node`, if any."""
1081 res = node.prev_sibling
1083 if isinstance(res, Leaf):
1087 return list(res.leaves())[-1]
1096 def is_delimiter(leaf: Leaf) -> int:
1097 """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1099 Higher numbers are higher priority.
1101 if leaf.type == token.COMMA:
1102 return COMMA_PRIORITY
1104 if leaf.type in COMPARATORS:
1105 return COMPARATOR_PRIORITY
1108 leaf.type in MATH_OPERATORS
1110 and leaf.parent.type not in {syms.factor, syms.star_expr}
1112 return MATH_PRIORITY
1117 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1118 """Cleans the prefix of the `leaf` and generates comments from it, if any.
1120 Comments in lib2to3 are shoved into the whitespace prefix. This happens
1121 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
1122 move because it does away with modifying the grammar to include all the
1123 possible places in which comments can be placed.
1125 The sad consequence for us though is that comments don't "belong" anywhere.
1126 This is why this function generates simple parentless Leaf objects for
1127 comments. We simply don't know what the correct parent should be.
1129 No matter though, we can live without this. We really only need to
1130 differentiate between inline and standalone comments. The latter don't
1131 share the line with any code.
1133 Inline comments are emitted as regular token.COMMENT leaves. Standalone
1134 are emitted with a fake STANDALONE_COMMENT token identifier.
1139 if '#' not in leaf.prefix:
1142 before_comment, content = leaf.prefix.split('#', 1)
1143 content = content.rstrip()
1144 if content and (content[0] not in {' ', '!', '#'}):
1145 content = ' ' + content
1146 is_standalone_comment = (
1147 '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER
1149 if not is_standalone_comment:
1150 # simple trailing comment
1151 yield Leaf(token.COMMENT, value='#' + content)
1154 for line in ('#' + content).split('\n'):
1155 line = line.lstrip()
1156 if not line.startswith('#'):
1159 yield Leaf(STANDALONE_COMMENT, line)
1163 line: Line, line_length: int, inner: bool = False, py36: bool = False
1164 ) -> Iterator[Line]:
1165 """Splits a `line` into potentially many lines.
1167 They should fit in the allotted `line_length` but might not be able to.
1168 `inner` signifies that there were a pair of brackets somewhere around the
1169 current `line`, possibly transitively. This means we can fallback to splitting
1170 by delimiters if the LHS/RHS don't yield any results.
1172 If `py36` is True, splitting may generate syntax that is only compatible
1173 with Python 3.6 and later.
1175 line_str = str(line).strip('\n')
1176 if len(line_str) <= line_length and '\n' not in line_str:
1181 split_funcs = [left_hand_split]
1182 elif line.inside_brackets:
1183 split_funcs = [delimiter_split]
1184 if '\n' not in line_str:
1185 # Only attempt RHS if we don't have multiline strings or comments
1187 split_funcs.append(right_hand_split)
1189 split_funcs = [right_hand_split]
1190 for split_func in split_funcs:
1191 # We are accumulating lines in `result` because we might want to abort
1192 # mission and return the original line in the end, or attempt a different
1194 result: List[Line] = []
1196 for l in split_func(line, py36=py36):
1197 if str(l).strip('\n') == line_str:
1198 raise CannotSplit("Split function returned an unchanged result")
1201 split_line(l, line_length=line_length, inner=True, py36=py36)
1203 except CannotSplit as cs:
1214 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1215 """Split line into many lines, starting with the first matching bracket pair.
1217 Note: this usually looks weird, only use this for function definitions.
1218 Prefer RHS otherwise.
1220 head = Line(depth=line.depth)
1221 body = Line(depth=line.depth + 1, inside_brackets=True)
1222 tail = Line(depth=line.depth)
1223 tail_leaves: List[Leaf] = []
1224 body_leaves: List[Leaf] = []
1225 head_leaves: List[Leaf] = []
1226 current_leaves = head_leaves
1227 matching_bracket = None
1228 for leaf in line.leaves:
1230 current_leaves is body_leaves
1231 and leaf.type in CLOSING_BRACKETS
1232 and leaf.opening_bracket is matching_bracket
1234 current_leaves = tail_leaves if body_leaves else head_leaves
1235 current_leaves.append(leaf)
1236 if current_leaves is head_leaves:
1237 if leaf.type in OPENING_BRACKETS:
1238 matching_bracket = leaf
1239 current_leaves = body_leaves
1240 # Since body is a new indent level, remove spurious leading whitespace.
1242 normalize_prefix(body_leaves[0])
1243 # Build the new lines.
1244 for result, leaves in (
1245 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1248 result.append(leaf, preformatted=True)
1249 comment_after = line.comments.get(id(leaf))
1251 result.append(comment_after, preformatted=True)
1252 split_succeeded_or_raise(head, body, tail)
1253 for result in (head, body, tail):
1258 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1259 """Split line into many lines, starting with the last matching bracket pair."""
1260 head = Line(depth=line.depth)
1261 body = Line(depth=line.depth + 1, inside_brackets=True)
1262 tail = Line(depth=line.depth)
1263 tail_leaves: List[Leaf] = []
1264 body_leaves: List[Leaf] = []
1265 head_leaves: List[Leaf] = []
1266 current_leaves = tail_leaves
1267 opening_bracket = None
1268 for leaf in reversed(line.leaves):
1269 if current_leaves is body_leaves:
1270 if leaf is opening_bracket:
1271 current_leaves = head_leaves if body_leaves else tail_leaves
1272 current_leaves.append(leaf)
1273 if current_leaves is tail_leaves:
1274 if leaf.type in CLOSING_BRACKETS:
1275 opening_bracket = leaf.opening_bracket
1276 current_leaves = body_leaves
1277 tail_leaves.reverse()
1278 body_leaves.reverse()
1279 head_leaves.reverse()
1280 # Since body is a new indent level, remove spurious leading whitespace.
1282 normalize_prefix(body_leaves[0])
1283 # Build the new lines.
1284 for result, leaves in (
1285 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1288 result.append(leaf, preformatted=True)
1289 comment_after = line.comments.get(id(leaf))
1291 result.append(comment_after, preformatted=True)
1292 split_succeeded_or_raise(head, body, tail)
1293 for result in (head, body, tail):
1298 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1299 tail_len = len(str(tail).strip())
1302 raise CannotSplit("Splitting brackets produced the same line")
1306 f"Splitting brackets on an empty body to save "
1307 f"{tail_len} characters is not worth it"
1311 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1312 """Split according to delimiters of the highest priority.
1314 This kind of split doesn't increase indentation.
1315 If `py36` is True, the split will add trailing commas also in function
1316 signatures that contain * and **.
1319 last_leaf = line.leaves[-1]
1321 raise CannotSplit("Line empty")
1323 delimiters = line.bracket_tracker.delimiters
1325 delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1327 raise CannotSplit("No delimiters found")
1329 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1330 lowest_depth = sys.maxsize
1331 trailing_comma_safe = True
1332 for leaf in line.leaves:
1333 current_line.append(leaf, preformatted=True)
1334 comment_after = line.comments.get(id(leaf))
1336 current_line.append(comment_after, preformatted=True)
1337 lowest_depth = min(lowest_depth, leaf.bracket_depth)
1339 leaf.bracket_depth == lowest_depth
1340 and leaf.type == token.STAR
1341 or leaf.type == token.DOUBLESTAR
1343 trailing_comma_safe = trailing_comma_safe and py36
1344 leaf_priority = delimiters.get(id(leaf))
1345 if leaf_priority == delimiter_priority:
1346 normalize_prefix(current_line.leaves[0])
1349 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1352 delimiter_priority == COMMA_PRIORITY
1353 and current_line.leaves[-1].type != token.COMMA
1354 and trailing_comma_safe
1356 current_line.append(Leaf(token.COMMA, ','))
1357 normalize_prefix(current_line.leaves[0])
1361 def is_import(leaf: Leaf) -> bool:
1362 """Returns True if the given leaf starts an import statement."""
1369 (v == 'import' and p and p.type == syms.import_name)
1370 or (v == 'from' and p and p.type == syms.import_from)
1375 def normalize_prefix(leaf: Leaf) -> None:
1376 """Leave existing extra newlines for imports. Remove everything else."""
1378 spl = leaf.prefix.split('#', 1)
1379 nl_count = spl[0].count('\n')
1380 leaf.prefix = '\n' * nl_count
1386 def is_python36(node: Node) -> bool:
1387 """Returns True if the current file is using Python 3.6+ features.
1389 Currently looking for:
1391 - trailing commas after * or ** in function signatures.
1393 for n in node.pre_order():
1394 if n.type == token.STRING:
1395 value_head = n.value[:2] # type: ignore
1396 if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1400 n.type == syms.typedargslist
1402 and n.children[-1].type == token.COMMA
1404 for ch in n.children:
1405 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1411 PYTHON_EXTENSIONS = {'.py'}
1412 BLACKLISTED_DIRECTORIES = {
1413 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1417 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1418 for child in path.iterdir():
1420 if child.name in BLACKLISTED_DIRECTORIES:
1423 yield from gen_python_files_in_dir(child)
1425 elif child.suffix in PYTHON_EXTENSIONS:
1431 """Provides a reformatting counter."""
1432 change_count: int = 0
1434 failure_count: int = 0
1436 def done(self, src: Path, changed: bool) -> None:
1437 """Increment the counter for successful reformatting. Write out a message."""
1439 out(f'reformatted {src}')
1440 self.change_count += 1
1442 out(f'{src} already well formatted, good job.', bold=False)
1443 self.same_count += 1
1445 def failed(self, src: Path, message: str) -> None:
1446 """Increment the counter for failed reformatting. Write out a message."""
1447 err(f'error: cannot format {src}: {message}')
1448 self.failure_count += 1
1451 def return_code(self) -> int:
1452 """Which return code should the app use considering the current state."""
1453 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1454 # 126 we have special returncodes reserved by the shell.
1455 if self.failure_count:
1458 elif self.change_count:
1463 def __str__(self) -> str:
1464 """A color report of the current state.
1466 Use `click.unstyle` to remove colors.
1469 if self.change_count:
1470 s = 's' if self.change_count > 1 else ''
1472 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1475 s = 's' if self.same_count > 1 else ''
1476 report.append(f'{self.same_count} file{s} left unchanged')
1477 if self.failure_count:
1478 s = 's' if self.failure_count > 1 else ''
1481 f'{self.failure_count} file{s} failed to reformat', fg='red'
1484 return ', '.join(report) + '.'
1487 def assert_equivalent(src: str, dst: str) -> None:
1488 """Raises AssertionError if `src` and `dst` aren't equivalent.
1490 This is a temporary sanity check until Black becomes stable.
1496 def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1497 """Simple visitor generating strings to compare ASTs by content."""
1498 yield f"{' ' * depth}{node.__class__.__name__}("
1500 for field in sorted(node._fields):
1502 value = getattr(node, field)
1503 except AttributeError:
1506 yield f"{' ' * (depth+1)}{field}="
1508 if isinstance(value, list):
1510 if isinstance(item, ast.AST):
1511 yield from _v(item, depth + 2)
1513 elif isinstance(value, ast.AST):
1514 yield from _v(value, depth + 2)
1517 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
1519 yield f"{' ' * depth}) # /{node.__class__.__name__}"
1522 src_ast = ast.parse(src)
1523 except Exception as exc:
1524 raise AssertionError(f"cannot parse source: {exc}") from None
1527 dst_ast = ast.parse(dst)
1528 except Exception as exc:
1529 log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1530 raise AssertionError(
1531 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1532 f"Please report a bug on https://github.com/ambv/black/issues. "
1533 f"This invalid output might be helpful: {log}"
1536 src_ast_str = '\n'.join(_v(src_ast))
1537 dst_ast_str = '\n'.join(_v(dst_ast))
1538 if src_ast_str != dst_ast_str:
1539 log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1540 raise AssertionError(
1541 f"INTERNAL ERROR: Black produced code that is not equivalent to "
1543 f"Please report a bug on https://github.com/ambv/black/issues. "
1544 f"This diff might be helpful: {log}"
1548 def assert_stable(src: str, dst: str, line_length: int) -> None:
1549 """Raises AssertionError if `dst` reformats differently the second time.
1551 This is a temporary sanity check until Black becomes stable.
1553 newdst = format_str(dst, line_length=line_length)
1556 diff(src, dst, 'source', 'first pass'),
1557 diff(dst, newdst, 'first pass', 'second pass'),
1559 raise AssertionError(
1560 f"INTERNAL ERROR: Black produced different code on the second pass "
1561 f"of the formatter. "
1562 f"Please report a bug on https://github.com/ambv/black/issues. "
1563 f"This diff might be helpful: {log}"
1567 def dump_to_file(*output: str) -> str:
1568 """Dumps `output` to a temporary file. Returns path to the file."""
1571 with tempfile.NamedTemporaryFile(
1572 mode='w', prefix='blk_', suffix='.log', delete=False
1574 for lines in output:
1580 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1581 """Returns a udiff string between strings `a` and `b`."""
1584 a_lines = [line + '\n' for line in a.split('\n')]
1585 b_lines = [line + '\n' for line in b.split('\n')]
1587 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1591 if __name__ == '__main__':