All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
8 from pathlib import Path
12 Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
15 from attr import dataclass, Factory
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
27 syms = pygram.python_symbols
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
39 class NothingChanged(UserWarning):
40 """Raised by `format_file` when the reformatted code is the same as source."""
43 class CannotSplit(Exception):
44 """A readable split that fits the allotted line length is impossible.
46 Raised by `left_hand_split()` and `right_hand_split()`.
55 default=DEFAULT_LINE_LENGTH,
56 help='How many character per line to allow.',
63 "Don't write back the files, just return the status. Return code 0 "
64 "means nothing changed. Return code 1 means some files were "
65 "reformatted. Return code 123 means there was an internal error."
71 help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 @click.version_option(version=__version__)
77 type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
81 ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
83 """The uncompromising code formatter."""
84 sources: List[Path] = []
88 sources.extend(gen_python_files_in_dir(p))
90 # if a file was explicitly given, we don't care about its extension
93 err(f'invalid path: {s}')
96 elif len(sources) == 1:
100 changed = format_file_in_place(
101 p, line_length=line_length, fast=fast, write_back=not check
103 report.done(p, changed)
104 except Exception as exc:
105 report.failed(p, str(exc))
106 ctx.exit(report.return_code)
108 loop = asyncio.get_event_loop()
109 executor = ProcessPoolExecutor(max_workers=os.cpu_count())
112 return_code = loop.run_until_complete(
114 sources, line_length, not check, fast, loop, executor
119 ctx.exit(return_code)
122 async def schedule_formatting(
131 src: loop.run_in_executor(
132 executor, format_file_in_place, src, line_length, fast, write_back
136 await asyncio.wait(tasks.values())
139 for src, task in tasks.items():
141 report.failed(src, 'timed out, cancelling')
143 cancelled.append(task)
144 elif task.exception():
145 report.failed(src, str(task.exception()))
147 report.done(src, task.result())
149 await asyncio.wait(cancelled, timeout=2)
150 out('All done! ✨ 🍰 ✨')
151 click.echo(str(report))
152 return report.return_code
155 def format_file_in_place(
156 src: Path, line_length: int, fast: bool, write_back: bool = False
158 """Format the file and rewrite if changed. Return True if changed."""
160 contents, encoding = format_file(src, line_length=line_length, fast=fast)
161 except NothingChanged:
165 with open(src, "w", encoding=encoding) as f:
171 src: Path, line_length: int, fast: bool
172 ) -> Tuple[FileContent, Encoding]:
173 """Reformats a file and returns its contents and encoding."""
174 with tokenize.open(src) as src_buffer:
175 src_contents = src_buffer.read()
176 if src_contents.strip() == '':
177 raise NothingChanged(src)
179 dst_contents = format_str(src_contents, line_length=line_length)
180 if src_contents == dst_contents:
181 raise NothingChanged(src)
184 assert_equivalent(src_contents, dst_contents)
185 assert_stable(src_contents, dst_contents, line_length=line_length)
186 return dst_contents, src_buffer.encoding
189 def format_str(src_contents: str, line_length: int) -> FileContent:
190 """Reformats a string and returns new contents."""
191 src_node = lib2to3_parse(src_contents)
193 comments: List[Line] = []
194 lines = LineGenerator()
195 elt = EmptyLineTracker()
196 py36 = is_python36(src_node)
199 for current_line in lines.visit(src_node):
200 for _ in range(after):
201 dst_contents += str(empty_line)
202 before, after = elt.maybe_empty_lines(current_line)
203 for _ in range(before):
204 dst_contents += str(empty_line)
205 if not current_line.is_comment:
206 for comment in comments:
207 dst_contents += str(comment)
209 for line in split_line(current_line, line_length=line_length, py36=py36):
210 dst_contents += str(line)
212 comments.append(current_line)
213 for comment in comments:
214 dst_contents += str(comment)
218 def lib2to3_parse(src_txt: str) -> Node:
219 """Given a string with source, return the lib2to3 Node."""
220 grammar = pygram.python_grammar_no_print_statement
221 drv = driver.Driver(grammar, pytree.convert)
222 if src_txt[-1] != '\n':
223 nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
226 result = drv.parse_string(src_txt, True)
227 except ParseError as pe:
228 lineno, column = pe.context[1]
229 lines = src_txt.splitlines()
231 faulty_line = lines[lineno - 1]
233 faulty_line = "<line number missing in source>"
234 raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
236 if isinstance(result, Leaf):
237 result = Node(syms.file_input, [result])
241 def lib2to3_unparse(node: Node) -> str:
242 """Given a lib2to3 node, return its string representation."""
250 class Visitor(Generic[T]):
251 """Basic lib2to3 visitor that yields things on visiting."""
253 def visit(self, node: LN) -> Iterator[T]:
255 name = token.tok_name[node.type]
257 name = type_repr(node.type)
258 yield from getattr(self, f'visit_{name}', self.visit_default)(node)
260 def visit_default(self, node: LN) -> Iterator[T]:
261 if isinstance(node, Node):
262 for child in node.children:
263 yield from self.visit(child)
267 class DebugVisitor(Visitor[T]):
270 def visit_default(self, node: LN) -> Iterator[T]:
271 indent = ' ' * (2 * self.tree_depth)
272 if isinstance(node, Node):
273 _type = type_repr(node.type)
274 out(f'{indent}{_type}', fg='yellow')
276 for child in node.children:
277 yield from self.visit(child)
280 out(f'{indent}/{_type}', fg='yellow', bold=False)
282 _type = token.tok_name.get(node.type, str(node.type))
283 out(f'{indent}{_type}', fg='blue', nl=False)
285 # We don't have to handle prefixes for `Node` objects since
286 # that delegates to the first child anyway.
287 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
288 out(f' {node.value!r}', fg='blue', bold=False)
291 KEYWORDS = set(keyword.kwlist)
292 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
293 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
304 STANDALONE_COMMENT = 153
305 LOGIC_OPERATORS = {'and', 'or'}
328 COMPREHENSION_PRIORITY = 20
332 COMPARATOR_PRIORITY = 3
337 class BracketTracker:
339 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
340 delimiters: Dict[LeafID, Priority] = Factory(dict)
341 previous: Optional[Leaf] = None
343 def mark(self, leaf: Leaf) -> None:
344 if leaf.type == token.COMMENT:
347 if leaf.type in CLOSING_BRACKETS:
349 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
350 leaf.opening_bracket = opening_bracket
351 leaf.bracket_depth = self.depth
353 delim = is_delimiter(leaf)
355 self.delimiters[id(leaf)] = delim
356 elif self.previous is not None:
357 if leaf.type == token.STRING and self.previous.type == token.STRING:
358 self.delimiters[id(self.previous)] = STRING_PRIORITY
360 leaf.type == token.NAME
361 and leaf.value == 'for'
363 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
365 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
367 leaf.type == token.NAME
368 and leaf.value == 'if'
370 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
372 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
374 leaf.type == token.NAME
375 and leaf.value in LOGIC_OPERATORS
378 self.delimiters[id(self.previous)] = LOGIC_PRIORITY
379 if leaf.type in OPENING_BRACKETS:
380 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
384 def any_open_brackets(self) -> bool:
385 """Returns True if there is an yet unmatched open bracket on the line."""
386 return bool(self.bracket_match)
388 def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
389 """Returns the highest priority of a delimiter found on the line.
391 Values are consistent with what `is_delimiter()` returns.
393 return max(v for k, v in self.delimiters.items() if k not in exclude)
399 leaves: List[Leaf] = Factory(list)
400 comments: Dict[LeafID, Leaf] = Factory(dict)
401 bracket_tracker: BracketTracker = Factory(BracketTracker)
402 inside_brackets: bool = False
403 has_for: bool = False
404 _for_loop_variable: bool = False
406 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
407 has_value = leaf.value.strip()
411 if self.leaves and not preformatted:
412 # Note: at this point leaf.prefix should be empty except for
413 # imports, for which we only preserve newlines.
414 leaf.prefix += whitespace(leaf)
415 if self.inside_brackets or not preformatted:
416 self.maybe_decrement_after_for_loop_variable(leaf)
417 self.bracket_tracker.mark(leaf)
418 self.maybe_remove_trailing_comma(leaf)
419 self.maybe_increment_for_loop_variable(leaf)
420 if self.maybe_adapt_standalone_comment(leaf):
423 if not self.append_comment(leaf):
424 self.leaves.append(leaf)
427 def is_comment(self) -> bool:
428 return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
431 def is_decorator(self) -> bool:
432 return bool(self) and self.leaves[0].type == token.AT
435 def is_import(self) -> bool:
436 return bool(self) and is_import(self.leaves[0])
439 def is_class(self) -> bool:
442 and self.leaves[0].type == token.NAME
443 and self.leaves[0].value == 'class'
447 def is_def(self) -> bool:
448 """Also returns True for async defs."""
450 first_leaf = self.leaves[0]
455 second_leaf: Optional[Leaf] = self.leaves[1]
459 (first_leaf.type == token.NAME and first_leaf.value == 'def')
461 first_leaf.type == token.NAME
462 and first_leaf.value == 'async'
463 and second_leaf is not None
464 and second_leaf.type == token.NAME
465 and second_leaf.value == 'def'
470 def is_flow_control(self) -> bool:
473 and self.leaves[0].type == token.NAME
474 and self.leaves[0].value in FLOW_CONTROL
478 def is_yield(self) -> bool:
481 and self.leaves[0].type == token.NAME
482 and self.leaves[0].value == 'yield'
485 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
488 and self.leaves[-1].type == token.COMMA
489 and closing.type in CLOSING_BRACKETS
493 if closing.type == token.RSQB or closing.type == token.RBRACE:
497 # For parens let's check if it's safe to remove the comma. If the
498 # trailing one is the only one, we might mistakenly change a tuple
499 # into a different type by removing the comma.
500 depth = closing.bracket_depth + 1
502 opening = closing.opening_bracket
503 for _opening_index, leaf in enumerate(self.leaves):
510 for leaf in self.leaves[_opening_index + 1:]:
514 bracket_depth = leaf.bracket_depth
515 if bracket_depth == depth and leaf.type == token.COMMA:
517 if leaf.parent and leaf.parent.type == syms.arglist:
527 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
528 """In a for loop, or comprehension, the variables are often unpacks.
530 To avoid splitting on the comma in this situation, we will increase
531 the depth of tokens between `for` and `in`.
533 if leaf.type == token.NAME and leaf.value == 'for':
535 self.bracket_tracker.depth += 1
536 self._for_loop_variable = True
541 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
542 # See `maybe_increment_for_loop_variable` above for explanation.
543 if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
544 self.bracket_tracker.depth -= 1
545 self._for_loop_variable = False
550 def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
551 """Hack a standalone comment to act as a trailing comment for line splitting.
553 If this line has brackets and a standalone `comment`, we need to adapt
554 it to be able to still reformat the line.
556 This is not perfect, the line to which the standalone comment gets
557 appended will appear "too long" when splitting.
560 comment.type == STANDALONE_COMMENT
561 and self.bracket_tracker.any_open_brackets()
565 comment.type = token.COMMENT
566 comment.prefix = '\n' + ' ' * (self.depth + 1)
567 return self.append_comment(comment)
569 def append_comment(self, comment: Leaf) -> bool:
570 if comment.type != token.COMMENT:
574 after = id(self.last_non_delimiter())
576 comment.type = STANDALONE_COMMENT
581 if after in self.comments:
582 self.comments[after].value += str(comment)
584 self.comments[after] = comment
587 def last_non_delimiter(self) -> Leaf:
588 for i in range(len(self.leaves)):
589 last = self.leaves[-i - 1]
590 if not is_delimiter(last):
593 raise LookupError("No non-delimiters found")
595 def __str__(self) -> str:
599 indent = ' ' * self.depth
600 leaves = iter(self.leaves)
602 res = f'{first.prefix}{indent}{first.value}'
605 for comment in self.comments.values():
609 def __bool__(self) -> bool:
610 return bool(self.leaves or self.comments)
614 class EmptyLineTracker:
615 """Provides a stateful method that returns the number of potential extra
616 empty lines needed before and after the currently processed line.
618 Note: this tracker works on lines that haven't been split yet.
620 previous_line: Optional[Line] = None
621 previous_after: int = 0
622 previous_defs: List[int] = Factory(list)
624 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
625 """Returns the number of extra empty lines before and after the `current_line`.
627 This is for separating `def`, `async def` and `class` with extra empty lines
628 (two on module-level), as well as providing an extra empty line after flow
629 control keywords to make them more prominent.
631 if current_line.is_comment:
632 # Don't count standalone comments towards previous empty lines.
635 before, after = self._maybe_empty_lines(current_line)
636 self.previous_after = after
637 self.previous_line = current_line
640 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
642 depth = current_line.depth
643 while self.previous_defs and self.previous_defs[-1] >= depth:
644 self.previous_defs.pop()
645 before = (1 if depth else 2) - self.previous_after
646 is_decorator = current_line.is_decorator
647 if is_decorator or current_line.is_def or current_line.is_class:
649 self.previous_defs.append(depth)
650 if self.previous_line is None:
651 # Don't insert empty lines before the first line in the file.
654 if self.previous_line and self.previous_line.is_decorator:
655 # Don't insert empty lines between decorators.
659 if current_line.depth:
661 newlines -= self.previous_after
664 if current_line.is_flow_control:
669 and self.previous_line.is_import
670 and not current_line.is_import
671 and depth == self.previous_line.depth
673 return (before or 1), 0
677 and self.previous_line.is_yield
678 and (not current_line.is_yield or depth != self.previous_line.depth)
680 return (before or 1), 0
686 class LineGenerator(Visitor[Line]):
687 """Generates reformatted Line objects. Empty lines are not emitted.
689 Note: destroys the tree it's visiting by mutating prefixes of its leaves
690 in ways that will no longer stringify to valid Python code on the tree.
692 current_line: Line = Factory(Line)
693 standalone_comments: List[Leaf] = Factory(list)
695 def line(self, indent: int = 0) -> Iterator[Line]:
698 If the line is empty, only emit if it makes sense.
699 If the line is too long, split it first and then generate.
701 If any lines were generated, set up a new current_line.
703 if not self.current_line:
704 self.current_line.depth += indent
705 return # Line is empty, don't emit. Creating a new one unnecessary.
707 complete_line = self.current_line
708 self.current_line = Line(depth=complete_line.depth + indent)
711 def visit_default(self, node: LN) -> Iterator[Line]:
712 if isinstance(node, Leaf):
713 for comment in generate_comments(node):
714 if self.current_line.bracket_tracker.any_open_brackets():
715 # any comment within brackets is subject to splitting
716 self.current_line.append(comment)
717 elif comment.type == token.COMMENT:
718 # regular trailing comment
719 self.current_line.append(comment)
720 yield from self.line()
723 # regular standalone comment, to be processed later (see
724 # docstring in `generate_comments()`
725 self.standalone_comments.append(comment)
726 normalize_prefix(node)
727 if node.type not in WHITESPACE:
728 for comment in self.standalone_comments:
729 yield from self.line()
731 self.current_line.append(comment)
732 yield from self.line()
734 self.standalone_comments = []
735 self.current_line.append(node)
736 yield from super().visit_default(node)
738 def visit_suite(self, node: Node) -> Iterator[Line]:
739 """Body of a statement after a colon."""
740 children = iter(node.children)
741 # Process newline before indenting. It might contain an inline
742 # comment that should go right after the colon.
743 newline = next(children)
744 yield from self.visit(newline)
745 yield from self.line(+1)
747 for child in children:
748 yield from self.visit(child)
750 yield from self.line(-1)
752 def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
753 """Visit a statement.
755 The relevant Python language keywords for this statement are NAME leaves
758 for child in node.children:
759 if child.type == token.NAME and child.value in keywords: # type: ignore
760 yield from self.line()
762 yield from self.visit(child)
764 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
765 """A statement without nested statements."""
766 is_suite_like = node.parent and node.parent.type in STATEMENT
768 yield from self.line(+1)
769 yield from self.visit_default(node)
770 yield from self.line(-1)
773 yield from self.line()
774 yield from self.visit_default(node)
776 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
777 yield from self.line()
779 children = iter(node.children)
780 for child in children:
781 yield from self.visit(child)
783 if child.type == token.NAME and child.value == 'async': # type: ignore
786 internal_stmt = next(children)
787 for child in internal_stmt.children:
788 yield from self.visit(child)
790 def visit_decorators(self, node: Node) -> Iterator[Line]:
791 for child in node.children:
792 yield from self.line()
793 yield from self.visit(child)
795 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
796 yield from self.line()
798 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
799 yield from self.visit_default(leaf)
800 yield from self.line()
802 def __attrs_post_init__(self) -> None:
803 """You are in a twisty little maze of passages."""
805 self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
806 self.visit_while_stmt = partial(v, keywords={'while', 'else'})
807 self.visit_for_stmt = partial(v, keywords={'for', 'else'})
808 self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
809 self.visit_except_clause = partial(v, keywords={'except'})
810 self.visit_funcdef = partial(v, keywords={'def'})
811 self.visit_with_stmt = partial(v, keywords={'with'})
812 self.visit_classdef = partial(v, keywords={'class'})
813 self.visit_async_funcdef = self.visit_async_stmt
814 self.visit_decorated = self.visit_decorators
817 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
818 OPENING_BRACKETS = set(BRACKET.keys())
819 CLOSING_BRACKETS = set(BRACKET.values())
820 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
821 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
824 def whitespace(leaf: Leaf) -> str: # noqa C901
825 """Return whitespace prefix if needed for the given `leaf`."""
832 if t in ALWAYS_NO_SPACE:
835 if t == token.COMMENT:
838 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
839 prev = leaf.prev_sibling
841 prevp = preceding_leaf(p)
842 if not prevp or prevp.type in OPENING_BRACKETS:
845 if prevp.type == token.EQUAL:
846 if prevp.parent and prevp.parent.type in {
855 elif prevp.type == token.DOUBLESTAR:
856 if prevp.parent and prevp.parent.type in {
865 elif prevp.type == token.COLON:
866 if prevp.parent and prevp.parent.type == syms.subscript:
869 elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
872 elif prev.type in OPENING_BRACKETS:
875 if p.type in {syms.parameters, syms.arglist}:
876 # untyped function signatures or calls
880 if not prev or prev.type != token.COMMA:
883 if p.type == syms.varargslist:
888 if prev and prev.type != token.COMMA:
891 elif p.type == syms.typedargslist:
892 # typed function signatures
897 if prev.type != syms.tname:
900 elif prev.type == token.EQUAL:
901 # A bit hacky: if the equal sign has whitespace, it means we
902 # previously found it's a typed argument. So, we're using that, too.
905 elif prev.type != token.COMMA:
908 elif p.type == syms.tname:
911 prevp = preceding_leaf(p)
912 if not prevp or prevp.type != token.COMMA:
915 elif p.type == syms.trailer:
916 # attributes and calls
917 if t == token.LPAR or t == token.RPAR:
922 prevp = preceding_leaf(p)
923 if not prevp or prevp.type != token.NUMBER:
926 elif t == token.LSQB:
929 elif prev.type != token.COMMA:
932 elif p.type == syms.argument:
938 prevp = preceding_leaf(p)
939 if not prevp or prevp.type == token.LPAR:
942 elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
945 elif p.type == syms.decorator:
949 elif p.type == syms.dotted_name:
953 prevp = preceding_leaf(p)
954 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
957 elif p.type == syms.classdef:
961 if prev and prev.type == token.LPAR:
964 elif p.type == syms.subscript:
967 assert p.parent is not None, "subscripts are always parented"
968 if p.parent.type == syms.subscriptlist:
973 elif prev.type == token.COLON:
976 elif p.type == syms.atom:
977 if prev and t == token.DOT:
978 # dots, but not the first one.
982 p.type == syms.listmaker
983 or p.type == syms.testlist_gexp
984 or p.type == syms.subscriptlist
986 # list interior, including unpacking
990 elif p.type == syms.dictsetmaker:
991 # dict and set interior, including unpacking
995 if prev.type == token.DOUBLESTAR:
998 elif p.type in {syms.factor, syms.star_expr}:
1001 prevp = preceding_leaf(p)
1002 if not prevp or prevp.type in OPENING_BRACKETS:
1005 prevp_parent = prevp.parent
1006 assert prevp_parent is not None
1007 if prevp.type == token.COLON and prevp_parent.type in {
1008 syms.subscript, syms.sliceop
1012 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1015 elif t == token.NAME or t == token.NUMBER:
1018 elif p.type == syms.import_from:
1020 if prev and prev.type == token.DOT:
1023 elif t == token.NAME:
1027 if prev and prev.type == token.DOT:
1030 elif p.type == syms.sliceop:
1036 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1037 """Returns the first leaf that precedes `node`, if any."""
1039 res = node.prev_sibling
1041 if isinstance(res, Leaf):
1045 return list(res.leaves())[-1]
1054 def is_delimiter(leaf: Leaf) -> int:
1055 """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1057 Higher numbers are higher priority.
1059 if leaf.type == token.COMMA:
1060 return COMMA_PRIORITY
1062 if leaf.type in COMPARATORS:
1063 return COMPARATOR_PRIORITY
1066 leaf.type in MATH_OPERATORS
1068 and leaf.parent.type not in {syms.factor, syms.star_expr}
1070 return MATH_PRIORITY
1075 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1076 """Cleans the prefix of the `leaf` and generates comments from it, if any.
1078 Comments in lib2to3 are shoved into the whitespace prefix. This happens
1079 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
1080 move because it does away with modifying the grammar to include all the
1081 possible places in which comments can be placed.
1083 The sad consequence for us though is that comments don't "belong" anywhere.
1084 This is why this function generates simple parentless Leaf objects for
1085 comments. We simply don't know what the correct parent should be.
1087 No matter though, we can live without this. We really only need to
1088 differentiate between inline and standalone comments. The latter don't
1089 share the line with any code.
1091 Inline comments are emitted as regular token.COMMENT leaves. Standalone
1092 are emitted with a fake STANDALONE_COMMENT token identifier.
1097 if '#' not in leaf.prefix:
1100 before_comment, content = leaf.prefix.split('#', 1)
1101 content = content.rstrip()
1102 if content and (content[0] not in {' ', '!', '#'}):
1103 content = ' ' + content
1104 is_standalone_comment = (
1105 '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1107 if not is_standalone_comment:
1108 # simple trailing comment
1109 yield Leaf(token.COMMENT, value='#' + content)
1112 for line in ('#' + content).split('\n'):
1113 line = line.lstrip()
1114 if not line.startswith('#'):
1117 yield Leaf(STANDALONE_COMMENT, line)
1121 line: Line, line_length: int, inner: bool = False, py36: bool = False
1122 ) -> Iterator[Line]:
1123 """Splits a `line` into potentially many lines.
1125 They should fit in the allotted `line_length` but might not be able to.
1126 `inner` signifies that there were a pair of brackets somewhere around the
1127 current `line`, possibly transitively. This means we can fallback to splitting
1128 by delimiters if the LHS/RHS don't yield any results.
1130 If `py36` is True, splitting may generate syntax that is only compatible
1131 with Python 3.6 and later.
1133 line_str = str(line).strip('\n')
1134 if len(line_str) <= line_length and '\n' not in line_str:
1139 split_funcs = [left_hand_split]
1140 elif line.inside_brackets:
1141 split_funcs = [delimiter_split]
1142 if '\n' not in line_str:
1143 # Only attempt RHS if we don't have multiline strings or comments
1145 split_funcs.append(right_hand_split)
1147 split_funcs = [right_hand_split]
1148 for split_func in split_funcs:
1149 # We are accumulating lines in `result` because we might want to abort
1150 # mission and return the original line in the end, or attempt a different
1152 result: List[Line] = []
1154 for l in split_func(line, py36=py36):
1155 if str(l).strip('\n') == line_str:
1156 raise CannotSplit("Split function returned an unchanged result")
1159 split_line(l, line_length=line_length, inner=True, py36=py36)
1161 except CannotSplit as cs:
1172 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1173 """Split line into many lines, starting with the first matching bracket pair.
1175 Note: this usually looks weird, only use this for function definitions.
1176 Prefer RHS otherwise.
1178 head = Line(depth=line.depth)
1179 body = Line(depth=line.depth + 1, inside_brackets=True)
1180 tail = Line(depth=line.depth)
1181 tail_leaves: List[Leaf] = []
1182 body_leaves: List[Leaf] = []
1183 head_leaves: List[Leaf] = []
1184 current_leaves = head_leaves
1185 matching_bracket = None
1186 for leaf in line.leaves:
1188 current_leaves is body_leaves
1189 and leaf.type in CLOSING_BRACKETS
1190 and leaf.opening_bracket is matching_bracket
1192 current_leaves = tail_leaves if body_leaves else head_leaves
1193 current_leaves.append(leaf)
1194 if current_leaves is head_leaves:
1195 if leaf.type in OPENING_BRACKETS:
1196 matching_bracket = leaf
1197 current_leaves = body_leaves
1198 # Since body is a new indent level, remove spurious leading whitespace.
1200 normalize_prefix(body_leaves[0])
1201 # Build the new lines.
1202 for result, leaves in (
1203 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1206 result.append(leaf, preformatted=True)
1207 comment_after = line.comments.get(id(leaf))
1209 result.append(comment_after, preformatted=True)
1210 split_succeeded_or_raise(head, body, tail)
1211 for result in (head, body, tail):
1216 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1217 """Split line into many lines, starting with the last matching bracket pair."""
1218 head = Line(depth=line.depth)
1219 body = Line(depth=line.depth + 1, inside_brackets=True)
1220 tail = Line(depth=line.depth)
1221 tail_leaves: List[Leaf] = []
1222 body_leaves: List[Leaf] = []
1223 head_leaves: List[Leaf] = []
1224 current_leaves = tail_leaves
1225 opening_bracket = None
1226 for leaf in reversed(line.leaves):
1227 if current_leaves is body_leaves:
1228 if leaf is opening_bracket:
1229 current_leaves = head_leaves if body_leaves else tail_leaves
1230 current_leaves.append(leaf)
1231 if current_leaves is tail_leaves:
1232 if leaf.type in CLOSING_BRACKETS:
1233 opening_bracket = leaf.opening_bracket
1234 current_leaves = body_leaves
1235 tail_leaves.reverse()
1236 body_leaves.reverse()
1237 head_leaves.reverse()
1238 # Since body is a new indent level, remove spurious leading whitespace.
1240 normalize_prefix(body_leaves[0])
1241 # Build the new lines.
1242 for result, leaves in (
1243 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1246 result.append(leaf, preformatted=True)
1247 comment_after = line.comments.get(id(leaf))
1249 result.append(comment_after, preformatted=True)
1250 split_succeeded_or_raise(head, body, tail)
1251 for result in (head, body, tail):
1256 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1257 tail_len = len(str(tail).strip())
1260 raise CannotSplit("Splitting brackets produced the same line")
1264 f"Splitting brackets on an empty body to save "
1265 f"{tail_len} characters is not worth it"
1269 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1270 """Split according to delimiters of the highest priority.
1272 This kind of split doesn't increase indentation.
1273 If `py36` is True, the split will add trailing commas also in function
1274 signatures that contain * and **.
1277 last_leaf = line.leaves[-1]
1279 raise CannotSplit("Line empty")
1281 delimiters = line.bracket_tracker.delimiters
1283 delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1285 raise CannotSplit("No delimiters found")
1287 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1288 lowest_depth = sys.maxsize
1289 trailing_comma_safe = True
1290 for leaf in line.leaves:
1291 current_line.append(leaf, preformatted=True)
1292 comment_after = line.comments.get(id(leaf))
1294 current_line.append(comment_after, preformatted=True)
1295 lowest_depth = min(lowest_depth, leaf.bracket_depth)
1297 leaf.bracket_depth == lowest_depth
1298 and leaf.type == token.STAR
1299 or leaf.type == token.DOUBLESTAR
1301 trailing_comma_safe = trailing_comma_safe and py36
1302 leaf_priority = delimiters.get(id(leaf))
1303 if leaf_priority == delimiter_priority:
1304 normalize_prefix(current_line.leaves[0])
1307 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1310 delimiter_priority == COMMA_PRIORITY
1311 and current_line.leaves[-1].type != token.COMMA
1312 and trailing_comma_safe
1314 current_line.append(Leaf(token.COMMA, ','))
1315 normalize_prefix(current_line.leaves[0])
1319 def is_import(leaf: Leaf) -> bool:
1320 """Returns True if the given leaf starts an import statement."""
1327 (v == 'import' and p and p.type == syms.import_name)
1328 or (v == 'from' and p and p.type == syms.import_from)
1333 def normalize_prefix(leaf: Leaf) -> None:
1334 """Leave existing extra newlines for imports. Remove everything else."""
1336 spl = leaf.prefix.split('#', 1)
1337 nl_count = spl[0].count('\n')
1339 # Skip one newline since it was for a standalone comment.
1341 leaf.prefix = '\n' * nl_count
1347 def is_python36(node: Node) -> bool:
1348 """Returns True if the current file is using Python 3.6+ features.
1350 Currently looking for:
1352 - trailing commas after * or ** in function signatures.
1354 for n in node.pre_order():
1355 if n.type == token.STRING:
1356 value_head = n.value[:2] # type: ignore
1357 if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1361 n.type == syms.typedargslist
1363 and n.children[-1].type == token.COMMA
1365 for ch in n.children:
1366 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1372 PYTHON_EXTENSIONS = {'.py'}
1373 BLACKLISTED_DIRECTORIES = {
1374 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1378 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1379 for child in path.iterdir():
1381 if child.name in BLACKLISTED_DIRECTORIES:
1384 yield from gen_python_files_in_dir(child)
1386 elif child.suffix in PYTHON_EXTENSIONS:
1392 """Provides a reformatting counter."""
1393 change_count: int = 0
1395 failure_count: int = 0
1397 def done(self, src: Path, changed: bool) -> None:
1398 """Increment the counter for successful reformatting. Write out a message."""
1400 out(f'reformatted {src}')
1401 self.change_count += 1
1403 out(f'{src} already well formatted, good job.', bold=False)
1404 self.same_count += 1
1406 def failed(self, src: Path, message: str) -> None:
1407 """Increment the counter for failed reformatting. Write out a message."""
1408 err(f'error: cannot format {src}: {message}')
1409 self.failure_count += 1
1412 def return_code(self) -> int:
1413 """Which return code should the app use considering the current state."""
1414 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1415 # 126 we have special returncodes reserved by the shell.
1416 if self.failure_count:
1419 elif self.change_count:
1424 def __str__(self) -> str:
1425 """A color report of the current state.
1427 Use `click.unstyle` to remove colors.
1430 if self.change_count:
1431 s = 's' if self.change_count > 1 else ''
1433 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1436 s = 's' if self.same_count > 1 else ''
1437 report.append(f'{self.same_count} file{s} left unchanged')
1438 if self.failure_count:
1439 s = 's' if self.failure_count > 1 else ''
1442 f'{self.failure_count} file{s} failed to reformat', fg='red'
1445 return ', '.join(report) + '.'
1448 def assert_equivalent(src: str, dst: str) -> None:
1449 """Raises AssertionError if `src` and `dst` aren't equivalent.
1451 This is a temporary sanity check until Black becomes stable.
1457 def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1458 """Simple visitor generating strings to compare ASTs by content."""
1459 yield f"{' ' * depth}{node.__class__.__name__}("
1461 for field in sorted(node._fields):
1463 value = getattr(node, field)
1464 except AttributeError:
1467 yield f"{' ' * (depth+1)}{field}="
1469 if isinstance(value, list):
1471 if isinstance(item, ast.AST):
1472 yield from _v(item, depth + 2)
1474 elif isinstance(value, ast.AST):
1475 yield from _v(value, depth + 2)
1478 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
1480 yield f"{' ' * depth}) # /{node.__class__.__name__}"
1483 src_ast = ast.parse(src)
1484 except Exception as exc:
1485 raise AssertionError(f"cannot parse source: {exc}") from None
1488 dst_ast = ast.parse(dst)
1489 except Exception as exc:
1490 log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1491 raise AssertionError(
1492 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1493 f"Please report a bug on https://github.com/ambv/black/issues. "
1494 f"This invalid output might be helpful: {log}"
1497 src_ast_str = '\n'.join(_v(src_ast))
1498 dst_ast_str = '\n'.join(_v(dst_ast))
1499 if src_ast_str != dst_ast_str:
1500 log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1501 raise AssertionError(
1502 f"INTERNAL ERROR: Black produced code that is not equivalent to "
1504 f"Please report a bug on https://github.com/ambv/black/issues. "
1505 f"This diff might be helpful: {log}"
1509 def assert_stable(src: str, dst: str, line_length: int) -> None:
1510 """Raises AssertionError if `dst` reformats differently the second time.
1512 This is a temporary sanity check until Black becomes stable.
1514 newdst = format_str(dst, line_length=line_length)
1517 diff(src, dst, 'source', 'first pass'),
1518 diff(dst, newdst, 'first pass', 'second pass'),
1520 raise AssertionError(
1521 f"INTERNAL ERROR: Black produced different code on the second pass "
1522 f"of the formatter. "
1523 f"Please report a bug on https://github.com/ambv/black/issues. "
1524 f"This diff might be helpful: {log}"
1528 def dump_to_file(*output: str) -> str:
1529 """Dumps `output` to a temporary file. Returns path to the file."""
1532 with tempfile.NamedTemporaryFile(
1533 mode='w', prefix='blk_', suffix='.log', delete=False
1535 for lines in output:
1541 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1542 """Returns a udiff string between strings `a` and `b`."""
1545 a_lines = [line + '\n' for line in a.split('\n')]
1546 b_lines = [line + '\n' for line in b.split('\n')]
1548 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1552 if __name__ == '__main__':