All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
8 from pathlib import Path
12 Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
15 from attr import dataclass, Factory
19 from blib2to3.pytree import Node, Leaf, type_repr
20 from blib2to3 import pygram, pytree
21 from blib2to3.pgen2 import driver, token
22 from blib2to3.pgen2.parse import ParseError
24 __version__ = "18.3a2"
25 DEFAULT_LINE_LENGTH = 88
27 syms = pygram.python_symbols
34 LN = Union[Leaf, Node]
35 out = partial(click.secho, bold=True, err=True)
36 err = partial(click.secho, fg='red', err=True)
39 class NothingChanged(UserWarning):
40 """Raised by `format_file` when the reformatted code is the same as source."""
43 class CannotSplit(Exception):
44 """A readable split that fits the allotted line length is impossible.
46 Raised by `left_hand_split()` and `right_hand_split()`.
55 default=DEFAULT_LINE_LENGTH,
56 help='How many character per line to allow.',
63 "Don't write back the files, just return the status. Return code 0 "
64 "means nothing changed. Return code 1 means some files were "
65 "reformatted. Return code 123 means there was an internal error."
71 help='If --fast given, skip temporary sanity checks. [default: --safe]',
73 @click.version_option(version=__version__)
78 exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
83 ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
85 """The uncompromising code formatter."""
86 sources: List[Path] = []
90 sources.extend(gen_python_files_in_dir(p))
92 # if a file was explicitly given, we don't care about its extension
95 sources.append(Path('-'))
97 err(f'invalid path: {s}')
100 elif len(sources) == 1:
104 if not p.is_file() and str(p) == '-':
105 changed = format_stdin_to_stdout(line_length=line_length, fast=fast)
107 changed = format_file_in_place(
108 p, line_length=line_length, fast=fast, write_back=not check
110 report.done(p, changed)
111 except Exception as exc:
112 report.failed(p, str(exc))
113 ctx.exit(report.return_code)
115 loop = asyncio.get_event_loop()
116 executor = ProcessPoolExecutor(max_workers=os.cpu_count())
119 return_code = loop.run_until_complete(
121 sources, line_length, not check, fast, loop, executor
126 ctx.exit(return_code)
129 async def schedule_formatting(
138 src: loop.run_in_executor(
139 executor, format_file_in_place, src, line_length, fast, write_back
143 await asyncio.wait(tasks.values())
146 for src, task in tasks.items():
148 report.failed(src, 'timed out, cancelling')
150 cancelled.append(task)
151 elif task.exception():
152 report.failed(src, str(task.exception()))
154 report.done(src, task.result())
156 await asyncio.wait(cancelled, timeout=2)
157 out('All done! ✨ 🍰 ✨')
158 click.echo(str(report))
159 return report.return_code
162 def format_file_in_place(
163 src: Path, line_length: int, fast: bool, write_back: bool = False
165 """Format the file and rewrite if changed. Return True if changed."""
166 with tokenize.open(src) as src_buffer:
167 src_contents = src_buffer.read()
169 contents = format_file_contents(
170 src_contents, line_length=line_length, fast=fast
172 except NothingChanged:
176 with open(src, "w", encoding=src_buffer.encoding) as f:
181 def format_stdin_to_stdout(line_length: int, fast: bool) -> bool:
182 """Format file on stdin and pipe output to stdout. Return True if changed."""
183 contents = sys.stdin.read()
185 contents = format_file_contents(contents, line_length=line_length, fast=fast)
188 except NothingChanged:
192 sys.stdout.write(contents)
195 def format_file_contents(
196 src_contents: str, line_length: int, fast: bool
198 """Reformats a file and returns its contents and encoding."""
199 if src_contents.strip() == '':
202 dst_contents = format_str(src_contents, line_length=line_length)
203 if src_contents == dst_contents:
207 assert_equivalent(src_contents, dst_contents)
208 assert_stable(src_contents, dst_contents, line_length=line_length)
212 def format_str(src_contents: str, line_length: int) -> FileContent:
213 """Reformats a string and returns new contents."""
214 src_node = lib2to3_parse(src_contents)
216 comments: List[Line] = []
217 lines = LineGenerator()
218 elt = EmptyLineTracker()
219 py36 = is_python36(src_node)
222 for current_line in lines.visit(src_node):
223 for _ in range(after):
224 dst_contents += str(empty_line)
225 before, after = elt.maybe_empty_lines(current_line)
226 for _ in range(before):
227 dst_contents += str(empty_line)
228 if not current_line.is_comment:
229 for comment in comments:
230 dst_contents += str(comment)
232 for line in split_line(current_line, line_length=line_length, py36=py36):
233 dst_contents += str(line)
235 comments.append(current_line)
237 if elt.previous_defs:
238 # Separate postscriptum comments from the last module-level def.
239 dst_contents += str(empty_line)
240 dst_contents += str(empty_line)
241 for comment in comments:
242 dst_contents += str(comment)
246 def lib2to3_parse(src_txt: str) -> Node:
247 """Given a string with source, return the lib2to3 Node."""
248 grammar = pygram.python_grammar_no_print_statement
249 drv = driver.Driver(grammar, pytree.convert)
250 if src_txt[-1] != '\n':
251 nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
254 result = drv.parse_string(src_txt, True)
255 except ParseError as pe:
256 lineno, column = pe.context[1]
257 lines = src_txt.splitlines()
259 faulty_line = lines[lineno - 1]
261 faulty_line = "<line number missing in source>"
262 raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
264 if isinstance(result, Leaf):
265 result = Node(syms.file_input, [result])
269 def lib2to3_unparse(node: Node) -> str:
270 """Given a lib2to3 node, return its string representation."""
278 class Visitor(Generic[T]):
279 """Basic lib2to3 visitor that yields things on visiting."""
281 def visit(self, node: LN) -> Iterator[T]:
283 name = token.tok_name[node.type]
285 name = type_repr(node.type)
286 yield from getattr(self, f'visit_{name}', self.visit_default)(node)
288 def visit_default(self, node: LN) -> Iterator[T]:
289 if isinstance(node, Node):
290 for child in node.children:
291 yield from self.visit(child)
295 class DebugVisitor(Visitor[T]):
298 def visit_default(self, node: LN) -> Iterator[T]:
299 indent = ' ' * (2 * self.tree_depth)
300 if isinstance(node, Node):
301 _type = type_repr(node.type)
302 out(f'{indent}{_type}', fg='yellow')
304 for child in node.children:
305 yield from self.visit(child)
308 out(f'{indent}/{_type}', fg='yellow', bold=False)
310 _type = token.tok_name.get(node.type, str(node.type))
311 out(f'{indent}{_type}', fg='blue', nl=False)
313 # We don't have to handle prefixes for `Node` objects since
314 # that delegates to the first child anyway.
315 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
316 out(f' {node.value!r}', fg='blue', bold=False)
319 KEYWORDS = set(keyword.kwlist)
320 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
321 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
332 STANDALONE_COMMENT = 153
333 LOGIC_OPERATORS = {'and', 'or'}
356 COMPREHENSION_PRIORITY = 20
360 COMPARATOR_PRIORITY = 3
365 class BracketTracker:
367 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
368 delimiters: Dict[LeafID, Priority] = Factory(dict)
369 previous: Optional[Leaf] = None
371 def mark(self, leaf: Leaf) -> None:
372 if leaf.type == token.COMMENT:
375 if leaf.type in CLOSING_BRACKETS:
377 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
378 leaf.opening_bracket = opening_bracket
379 leaf.bracket_depth = self.depth
381 delim = is_delimiter(leaf)
383 self.delimiters[id(leaf)] = delim
384 elif self.previous is not None:
385 if leaf.type == token.STRING and self.previous.type == token.STRING:
386 self.delimiters[id(self.previous)] = STRING_PRIORITY
388 leaf.type == token.NAME
389 and leaf.value == 'for'
391 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
393 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
395 leaf.type == token.NAME
396 and leaf.value == 'if'
398 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
400 self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
402 leaf.type == token.NAME
403 and leaf.value in LOGIC_OPERATORS
406 self.delimiters[id(self.previous)] = LOGIC_PRIORITY
407 if leaf.type in OPENING_BRACKETS:
408 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
412 def any_open_brackets(self) -> bool:
413 """Returns True if there is an yet unmatched open bracket on the line."""
414 return bool(self.bracket_match)
416 def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
417 """Returns the highest priority of a delimiter found on the line.
419 Values are consistent with what `is_delimiter()` returns.
421 return max(v for k, v in self.delimiters.items() if k not in exclude)
427 leaves: List[Leaf] = Factory(list)
428 comments: Dict[LeafID, Leaf] = Factory(dict)
429 bracket_tracker: BracketTracker = Factory(BracketTracker)
430 inside_brackets: bool = False
431 has_for: bool = False
432 _for_loop_variable: bool = False
434 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
435 has_value = leaf.value.strip()
439 if self.leaves and not preformatted:
440 # Note: at this point leaf.prefix should be empty except for
441 # imports, for which we only preserve newlines.
442 leaf.prefix += whitespace(leaf)
443 if self.inside_brackets or not preformatted:
444 self.maybe_decrement_after_for_loop_variable(leaf)
445 self.bracket_tracker.mark(leaf)
446 self.maybe_remove_trailing_comma(leaf)
447 self.maybe_increment_for_loop_variable(leaf)
448 if self.maybe_adapt_standalone_comment(leaf):
451 if not self.append_comment(leaf):
452 self.leaves.append(leaf)
455 def is_comment(self) -> bool:
456 return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
459 def is_decorator(self) -> bool:
460 return bool(self) and self.leaves[0].type == token.AT
463 def is_import(self) -> bool:
464 return bool(self) and is_import(self.leaves[0])
467 def is_class(self) -> bool:
470 and self.leaves[0].type == token.NAME
471 and self.leaves[0].value == 'class'
475 def is_def(self) -> bool:
476 """Also returns True for async defs."""
478 first_leaf = self.leaves[0]
483 second_leaf: Optional[Leaf] = self.leaves[1]
487 (first_leaf.type == token.NAME and first_leaf.value == 'def')
489 first_leaf.type == token.ASYNC
490 and second_leaf is not None
491 and second_leaf.type == token.NAME
492 and second_leaf.value == 'def'
497 def is_flow_control(self) -> bool:
500 and self.leaves[0].type == token.NAME
501 and self.leaves[0].value in FLOW_CONTROL
505 def is_yield(self) -> bool:
508 and self.leaves[0].type == token.NAME
509 and self.leaves[0].value == 'yield'
512 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
515 and self.leaves[-1].type == token.COMMA
516 and closing.type in CLOSING_BRACKETS
520 if closing.type == token.RSQB or closing.type == token.RBRACE:
524 # For parens let's check if it's safe to remove the comma. If the
525 # trailing one is the only one, we might mistakenly change a tuple
526 # into a different type by removing the comma.
527 depth = closing.bracket_depth + 1
529 opening = closing.opening_bracket
530 for _opening_index, leaf in enumerate(self.leaves):
537 for leaf in self.leaves[_opening_index + 1:]:
541 bracket_depth = leaf.bracket_depth
542 if bracket_depth == depth and leaf.type == token.COMMA:
544 if leaf.parent and leaf.parent.type == syms.arglist:
554 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
555 """In a for loop, or comprehension, the variables are often unpacks.
557 To avoid splitting on the comma in this situation, we will increase
558 the depth of tokens between `for` and `in`.
560 if leaf.type == token.NAME and leaf.value == 'for':
562 self.bracket_tracker.depth += 1
563 self._for_loop_variable = True
568 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
569 # See `maybe_increment_for_loop_variable` above for explanation.
570 if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
571 self.bracket_tracker.depth -= 1
572 self._for_loop_variable = False
577 def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
578 """Hack a standalone comment to act as a trailing comment for line splitting.
580 If this line has brackets and a standalone `comment`, we need to adapt
581 it to be able to still reformat the line.
583 This is not perfect, the line to which the standalone comment gets
584 appended will appear "too long" when splitting.
587 comment.type == STANDALONE_COMMENT
588 and self.bracket_tracker.any_open_brackets()
592 comment.type = token.COMMENT
593 comment.prefix = '\n' + ' ' * (self.depth + 1)
594 return self.append_comment(comment)
596 def append_comment(self, comment: Leaf) -> bool:
597 if comment.type != token.COMMENT:
601 after = id(self.last_non_delimiter())
603 comment.type = STANDALONE_COMMENT
608 if after in self.comments:
609 self.comments[after].value += str(comment)
611 self.comments[after] = comment
614 def last_non_delimiter(self) -> Leaf:
615 for i in range(len(self.leaves)):
616 last = self.leaves[-i - 1]
617 if not is_delimiter(last):
620 raise LookupError("No non-delimiters found")
622 def __str__(self) -> str:
626 indent = ' ' * self.depth
627 leaves = iter(self.leaves)
629 res = f'{first.prefix}{indent}{first.value}'
632 for comment in self.comments.values():
636 def __bool__(self) -> bool:
637 return bool(self.leaves or self.comments)
641 class EmptyLineTracker:
642 """Provides a stateful method that returns the number of potential extra
643 empty lines needed before and after the currently processed line.
645 Note: this tracker works on lines that haven't been split yet. It assumes
646 the prefix of the first leaf consists of optional newlines. Those newlines
647 are consumed by `maybe_empty_lines()` and included in the computation.
649 previous_line: Optional[Line] = None
650 previous_after: int = 0
651 previous_defs: List[int] = Factory(list)
653 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
654 """Returns the number of extra empty lines before and after the `current_line`.
656 This is for separating `def`, `async def` and `class` with extra empty lines
657 (two on module-level), as well as providing an extra empty line after flow
658 control keywords to make them more prominent.
660 if current_line.is_comment:
661 # Don't count standalone comments towards previous empty lines.
664 before, after = self._maybe_empty_lines(current_line)
665 before -= self.previous_after
666 self.previous_after = after
667 self.previous_line = current_line
670 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
671 if current_line.leaves:
672 # Consume the first leaf's extra newlines.
673 first_leaf = current_line.leaves[0]
674 before = int('\n' in first_leaf.prefix)
675 first_leaf.prefix = ''
678 depth = current_line.depth
679 while self.previous_defs and self.previous_defs[-1] >= depth:
680 self.previous_defs.pop()
681 before = 1 if depth else 2
682 is_decorator = current_line.is_decorator
683 if is_decorator or current_line.is_def or current_line.is_class:
685 self.previous_defs.append(depth)
686 if self.previous_line is None:
687 # Don't insert empty lines before the first line in the file.
690 if self.previous_line and self.previous_line.is_decorator:
691 # Don't insert empty lines between decorators.
695 if current_line.depth:
699 if current_line.is_flow_control:
704 and self.previous_line.is_import
705 and not current_line.is_import
706 and depth == self.previous_line.depth
708 return (before or 1), 0
712 and self.previous_line.is_yield
713 and (not current_line.is_yield or depth != self.previous_line.depth)
715 return (before or 1), 0
721 class LineGenerator(Visitor[Line]):
722 """Generates reformatted Line objects. Empty lines are not emitted.
724 Note: destroys the tree it's visiting by mutating prefixes of its leaves
725 in ways that will no longer stringify to valid Python code on the tree.
727 current_line: Line = Factory(Line)
728 standalone_comments: List[Leaf] = Factory(list)
730 def line(self, indent: int = 0) -> Iterator[Line]:
733 If the line is empty, only emit if it makes sense.
734 If the line is too long, split it first and then generate.
736 If any lines were generated, set up a new current_line.
738 if not self.current_line:
739 self.current_line.depth += indent
740 return # Line is empty, don't emit. Creating a new one unnecessary.
742 complete_line = self.current_line
743 self.current_line = Line(depth=complete_line.depth + indent)
746 def visit_default(self, node: LN) -> Iterator[Line]:
747 if isinstance(node, Leaf):
748 for comment in generate_comments(node):
749 if self.current_line.bracket_tracker.any_open_brackets():
750 # any comment within brackets is subject to splitting
751 self.current_line.append(comment)
752 elif comment.type == token.COMMENT:
753 # regular trailing comment
754 self.current_line.append(comment)
755 yield from self.line()
758 # regular standalone comment, to be processed later (see
759 # docstring in `generate_comments()`
760 self.standalone_comments.append(comment)
761 normalize_prefix(node)
762 if node.type not in WHITESPACE:
763 for comment in self.standalone_comments:
764 yield from self.line()
766 self.current_line.append(comment)
767 yield from self.line()
769 self.standalone_comments = []
770 self.current_line.append(node)
771 yield from super().visit_default(node)
773 def visit_suite(self, node: Node) -> Iterator[Line]:
774 """Body of a statement after a colon."""
775 children = iter(node.children)
776 # Process newline before indenting. It might contain an inline
777 # comment that should go right after the colon.
778 newline = next(children)
779 yield from self.visit(newline)
780 yield from self.line(+1)
782 for child in children:
783 yield from self.visit(child)
785 yield from self.line(-1)
787 def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
788 """Visit a statement.
790 The relevant Python language keywords for this statement are NAME leaves
793 for child in node.children:
794 if child.type == token.NAME and child.value in keywords: # type: ignore
795 yield from self.line()
797 yield from self.visit(child)
799 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
800 """A statement without nested statements."""
801 is_suite_like = node.parent and node.parent.type in STATEMENT
803 yield from self.line(+1)
804 yield from self.visit_default(node)
805 yield from self.line(-1)
808 yield from self.line()
809 yield from self.visit_default(node)
811 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
812 yield from self.line()
814 children = iter(node.children)
815 for child in children:
816 yield from self.visit(child)
818 if child.type == token.ASYNC:
821 internal_stmt = next(children)
822 for child in internal_stmt.children:
823 yield from self.visit(child)
825 def visit_decorators(self, node: Node) -> Iterator[Line]:
826 for child in node.children:
827 yield from self.line()
828 yield from self.visit(child)
830 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
831 yield from self.line()
833 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
834 yield from self.visit_default(leaf)
835 yield from self.line()
837 def __attrs_post_init__(self) -> None:
838 """You are in a twisty little maze of passages."""
840 self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
841 self.visit_while_stmt = partial(v, keywords={'while', 'else'})
842 self.visit_for_stmt = partial(v, keywords={'for', 'else'})
843 self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
844 self.visit_except_clause = partial(v, keywords={'except'})
845 self.visit_funcdef = partial(v, keywords={'def'})
846 self.visit_with_stmt = partial(v, keywords={'with'})
847 self.visit_classdef = partial(v, keywords={'class'})
848 self.visit_async_funcdef = self.visit_async_stmt
849 self.visit_decorated = self.visit_decorators
852 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
853 OPENING_BRACKETS = set(BRACKET.keys())
854 CLOSING_BRACKETS = set(BRACKET.values())
855 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
856 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
859 def whitespace(leaf: Leaf) -> str: # noqa C901
860 """Return whitespace prefix if needed for the given `leaf`."""
867 if t in ALWAYS_NO_SPACE:
870 if t == token.COMMENT:
873 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
874 if t == token.COLON and p.type != syms.subscript:
877 prev = leaf.prev_sibling
879 prevp = preceding_leaf(p)
880 if not prevp or prevp.type in OPENING_BRACKETS:
884 return SPACE if prevp.type == token.COMMA else NO
886 if prevp.type == token.EQUAL:
887 if prevp.parent and prevp.parent.type in {
896 elif prevp.type == token.DOUBLESTAR:
897 if prevp.parent and prevp.parent.type in {
906 elif prevp.type == token.COLON:
907 if prevp.parent and prevp.parent.type == syms.subscript:
910 elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
913 elif prev.type in OPENING_BRACKETS:
916 if p.type in {syms.parameters, syms.arglist}:
917 # untyped function signatures or calls
921 if not prev or prev.type != token.COMMA:
924 if p.type == syms.varargslist:
929 if prev and prev.type != token.COMMA:
932 elif p.type == syms.typedargslist:
933 # typed function signatures
938 if prev.type != syms.tname:
941 elif prev.type == token.EQUAL:
942 # A bit hacky: if the equal sign has whitespace, it means we
943 # previously found it's a typed argument. So, we're using that, too.
946 elif prev.type != token.COMMA:
949 elif p.type == syms.tname:
952 prevp = preceding_leaf(p)
953 if not prevp or prevp.type != token.COMMA:
956 elif p.type == syms.trailer:
957 # attributes and calls
958 if t == token.LPAR or t == token.RPAR:
963 prevp = preceding_leaf(p)
964 if not prevp or prevp.type != token.NUMBER:
967 elif t == token.LSQB:
970 elif prev.type != token.COMMA:
973 elif p.type == syms.argument:
979 prevp = preceding_leaf(p)
980 if not prevp or prevp.type == token.LPAR:
983 elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
986 elif p.type == syms.decorator:
990 elif p.type == syms.dotted_name:
994 prevp = preceding_leaf(p)
995 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
998 elif p.type == syms.classdef:
1002 if prev and prev.type == token.LPAR:
1005 elif p.type == syms.subscript:
1008 assert p.parent is not None, "subscripts are always parented"
1009 if p.parent.type == syms.subscriptlist:
1017 elif p.type == syms.atom:
1018 if prev and t == token.DOT:
1019 # dots, but not the first one.
1023 p.type == syms.listmaker
1024 or p.type == syms.testlist_gexp
1025 or p.type == syms.subscriptlist
1027 # list interior, including unpacking
1031 elif p.type == syms.dictsetmaker:
1032 # dict and set interior, including unpacking
1036 if prev.type == token.DOUBLESTAR:
1039 elif p.type in {syms.factor, syms.star_expr}:
1042 prevp = preceding_leaf(p)
1043 if not prevp or prevp.type in OPENING_BRACKETS:
1046 prevp_parent = prevp.parent
1047 assert prevp_parent is not None
1048 if prevp.type == token.COLON and prevp_parent.type in {
1049 syms.subscript, syms.sliceop
1053 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1056 elif t == token.NAME or t == token.NUMBER:
1059 elif p.type == syms.import_from:
1061 if prev and prev.type == token.DOT:
1064 elif t == token.NAME:
1068 if prev and prev.type == token.DOT:
1071 elif p.type == syms.sliceop:
1077 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1078 """Returns the first leaf that precedes `node`, if any."""
1080 res = node.prev_sibling
1082 if isinstance(res, Leaf):
1086 return list(res.leaves())[-1]
1095 def is_delimiter(leaf: Leaf) -> int:
1096 """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1098 Higher numbers are higher priority.
1100 if leaf.type == token.COMMA:
1101 return COMMA_PRIORITY
1103 if leaf.type in COMPARATORS:
1104 return COMPARATOR_PRIORITY
1107 leaf.type in MATH_OPERATORS
1109 and leaf.parent.type not in {syms.factor, syms.star_expr}
1111 return MATH_PRIORITY
1116 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1117 """Cleans the prefix of the `leaf` and generates comments from it, if any.
1119 Comments in lib2to3 are shoved into the whitespace prefix. This happens
1120 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
1121 move because it does away with modifying the grammar to include all the
1122 possible places in which comments can be placed.
1124 The sad consequence for us though is that comments don't "belong" anywhere.
1125 This is why this function generates simple parentless Leaf objects for
1126 comments. We simply don't know what the correct parent should be.
1128 No matter though, we can live without this. We really only need to
1129 differentiate between inline and standalone comments. The latter don't
1130 share the line with any code.
1132 Inline comments are emitted as regular token.COMMENT leaves. Standalone
1133 are emitted with a fake STANDALONE_COMMENT token identifier.
1138 if '#' not in leaf.prefix:
1141 before_comment, content = leaf.prefix.split('#', 1)
1142 content = content.rstrip()
1143 if content and (content[0] not in {' ', '!', '#'}):
1144 content = ' ' + content
1145 is_standalone_comment = (
1146 '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER
1148 if not is_standalone_comment:
1149 # simple trailing comment
1150 yield Leaf(token.COMMENT, value='#' + content)
1153 for line in ('#' + content).split('\n'):
1154 line = line.lstrip()
1155 if not line.startswith('#'):
1158 yield Leaf(STANDALONE_COMMENT, line)
1162 line: Line, line_length: int, inner: bool = False, py36: bool = False
1163 ) -> Iterator[Line]:
1164 """Splits a `line` into potentially many lines.
1166 They should fit in the allotted `line_length` but might not be able to.
1167 `inner` signifies that there were a pair of brackets somewhere around the
1168 current `line`, possibly transitively. This means we can fallback to splitting
1169 by delimiters if the LHS/RHS don't yield any results.
1171 If `py36` is True, splitting may generate syntax that is only compatible
1172 with Python 3.6 and later.
1174 line_str = str(line).strip('\n')
1175 if len(line_str) <= line_length and '\n' not in line_str:
1180 split_funcs = [left_hand_split]
1181 elif line.inside_brackets:
1182 split_funcs = [delimiter_split]
1183 if '\n' not in line_str:
1184 # Only attempt RHS if we don't have multiline strings or comments
1186 split_funcs.append(right_hand_split)
1188 split_funcs = [right_hand_split]
1189 for split_func in split_funcs:
1190 # We are accumulating lines in `result` because we might want to abort
1191 # mission and return the original line in the end, or attempt a different
1193 result: List[Line] = []
1195 for l in split_func(line, py36=py36):
1196 if str(l).strip('\n') == line_str:
1197 raise CannotSplit("Split function returned an unchanged result")
1200 split_line(l, line_length=line_length, inner=True, py36=py36)
1202 except CannotSplit as cs:
1213 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1214 """Split line into many lines, starting with the first matching bracket pair.
1216 Note: this usually looks weird, only use this for function definitions.
1217 Prefer RHS otherwise.
1219 head = Line(depth=line.depth)
1220 body = Line(depth=line.depth + 1, inside_brackets=True)
1221 tail = Line(depth=line.depth)
1222 tail_leaves: List[Leaf] = []
1223 body_leaves: List[Leaf] = []
1224 head_leaves: List[Leaf] = []
1225 current_leaves = head_leaves
1226 matching_bracket = None
1227 for leaf in line.leaves:
1229 current_leaves is body_leaves
1230 and leaf.type in CLOSING_BRACKETS
1231 and leaf.opening_bracket is matching_bracket
1233 current_leaves = tail_leaves if body_leaves else head_leaves
1234 current_leaves.append(leaf)
1235 if current_leaves is head_leaves:
1236 if leaf.type in OPENING_BRACKETS:
1237 matching_bracket = leaf
1238 current_leaves = body_leaves
1239 # Since body is a new indent level, remove spurious leading whitespace.
1241 normalize_prefix(body_leaves[0])
1242 # Build the new lines.
1243 for result, leaves in (
1244 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1247 result.append(leaf, preformatted=True)
1248 comment_after = line.comments.get(id(leaf))
1250 result.append(comment_after, preformatted=True)
1251 split_succeeded_or_raise(head, body, tail)
1252 for result in (head, body, tail):
1257 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1258 """Split line into many lines, starting with the last matching bracket pair."""
1259 head = Line(depth=line.depth)
1260 body = Line(depth=line.depth + 1, inside_brackets=True)
1261 tail = Line(depth=line.depth)
1262 tail_leaves: List[Leaf] = []
1263 body_leaves: List[Leaf] = []
1264 head_leaves: List[Leaf] = []
1265 current_leaves = tail_leaves
1266 opening_bracket = None
1267 for leaf in reversed(line.leaves):
1268 if current_leaves is body_leaves:
1269 if leaf is opening_bracket:
1270 current_leaves = head_leaves if body_leaves else tail_leaves
1271 current_leaves.append(leaf)
1272 if current_leaves is tail_leaves:
1273 if leaf.type in CLOSING_BRACKETS:
1274 opening_bracket = leaf.opening_bracket
1275 current_leaves = body_leaves
1276 tail_leaves.reverse()
1277 body_leaves.reverse()
1278 head_leaves.reverse()
1279 # Since body is a new indent level, remove spurious leading whitespace.
1281 normalize_prefix(body_leaves[0])
1282 # Build the new lines.
1283 for result, leaves in (
1284 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1287 result.append(leaf, preformatted=True)
1288 comment_after = line.comments.get(id(leaf))
1290 result.append(comment_after, preformatted=True)
1291 split_succeeded_or_raise(head, body, tail)
1292 for result in (head, body, tail):
1297 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1298 tail_len = len(str(tail).strip())
1301 raise CannotSplit("Splitting brackets produced the same line")
1305 f"Splitting brackets on an empty body to save "
1306 f"{tail_len} characters is not worth it"
1310 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1311 """Split according to delimiters of the highest priority.
1313 This kind of split doesn't increase indentation.
1314 If `py36` is True, the split will add trailing commas also in function
1315 signatures that contain * and **.
1318 last_leaf = line.leaves[-1]
1320 raise CannotSplit("Line empty")
1322 delimiters = line.bracket_tracker.delimiters
1324 delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1326 raise CannotSplit("No delimiters found")
1328 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1329 lowest_depth = sys.maxsize
1330 trailing_comma_safe = True
1331 for leaf in line.leaves:
1332 current_line.append(leaf, preformatted=True)
1333 comment_after = line.comments.get(id(leaf))
1335 current_line.append(comment_after, preformatted=True)
1336 lowest_depth = min(lowest_depth, leaf.bracket_depth)
1338 leaf.bracket_depth == lowest_depth
1339 and leaf.type == token.STAR
1340 or leaf.type == token.DOUBLESTAR
1342 trailing_comma_safe = trailing_comma_safe and py36
1343 leaf_priority = delimiters.get(id(leaf))
1344 if leaf_priority == delimiter_priority:
1345 normalize_prefix(current_line.leaves[0])
1348 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1351 delimiter_priority == COMMA_PRIORITY
1352 and current_line.leaves[-1].type != token.COMMA
1353 and trailing_comma_safe
1355 current_line.append(Leaf(token.COMMA, ','))
1356 normalize_prefix(current_line.leaves[0])
1360 def is_import(leaf: Leaf) -> bool:
1361 """Returns True if the given leaf starts an import statement."""
1368 (v == 'import' and p and p.type == syms.import_name)
1369 or (v == 'from' and p and p.type == syms.import_from)
1374 def normalize_prefix(leaf: Leaf) -> None:
1375 """Leave existing extra newlines for imports. Remove everything else."""
1377 spl = leaf.prefix.split('#', 1)
1378 nl_count = spl[0].count('\n')
1379 leaf.prefix = '\n' * nl_count
1385 def is_python36(node: Node) -> bool:
1386 """Returns True if the current file is using Python 3.6+ features.
1388 Currently looking for:
1390 - trailing commas after * or ** in function signatures.
1392 for n in node.pre_order():
1393 if n.type == token.STRING:
1394 value_head = n.value[:2] # type: ignore
1395 if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1399 n.type == syms.typedargslist
1401 and n.children[-1].type == token.COMMA
1403 for ch in n.children:
1404 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1410 PYTHON_EXTENSIONS = {'.py'}
1411 BLACKLISTED_DIRECTORIES = {
1412 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1416 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1417 for child in path.iterdir():
1419 if child.name in BLACKLISTED_DIRECTORIES:
1422 yield from gen_python_files_in_dir(child)
1424 elif child.suffix in PYTHON_EXTENSIONS:
1430 """Provides a reformatting counter."""
1431 change_count: int = 0
1433 failure_count: int = 0
1435 def done(self, src: Path, changed: bool) -> None:
1436 """Increment the counter for successful reformatting. Write out a message."""
1438 out(f'reformatted {src}')
1439 self.change_count += 1
1441 out(f'{src} already well formatted, good job.', bold=False)
1442 self.same_count += 1
1444 def failed(self, src: Path, message: str) -> None:
1445 """Increment the counter for failed reformatting. Write out a message."""
1446 err(f'error: cannot format {src}: {message}')
1447 self.failure_count += 1
1450 def return_code(self) -> int:
1451 """Which return code should the app use considering the current state."""
1452 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1453 # 126 we have special returncodes reserved by the shell.
1454 if self.failure_count:
1457 elif self.change_count:
1462 def __str__(self) -> str:
1463 """A color report of the current state.
1465 Use `click.unstyle` to remove colors.
1468 if self.change_count:
1469 s = 's' if self.change_count > 1 else ''
1471 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1474 s = 's' if self.same_count > 1 else ''
1475 report.append(f'{self.same_count} file{s} left unchanged')
1476 if self.failure_count:
1477 s = 's' if self.failure_count > 1 else ''
1480 f'{self.failure_count} file{s} failed to reformat', fg='red'
1483 return ', '.join(report) + '.'
1486 def assert_equivalent(src: str, dst: str) -> None:
1487 """Raises AssertionError if `src` and `dst` aren't equivalent.
1489 This is a temporary sanity check until Black becomes stable.
1495 def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1496 """Simple visitor generating strings to compare ASTs by content."""
1497 yield f"{' ' * depth}{node.__class__.__name__}("
1499 for field in sorted(node._fields):
1501 value = getattr(node, field)
1502 except AttributeError:
1505 yield f"{' ' * (depth+1)}{field}="
1507 if isinstance(value, list):
1509 if isinstance(item, ast.AST):
1510 yield from _v(item, depth + 2)
1512 elif isinstance(value, ast.AST):
1513 yield from _v(value, depth + 2)
1516 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
1518 yield f"{' ' * depth}) # /{node.__class__.__name__}"
1521 src_ast = ast.parse(src)
1522 except Exception as exc:
1523 raise AssertionError(f"cannot parse source: {exc}") from None
1526 dst_ast = ast.parse(dst)
1527 except Exception as exc:
1528 log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1529 raise AssertionError(
1530 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1531 f"Please report a bug on https://github.com/ambv/black/issues. "
1532 f"This invalid output might be helpful: {log}"
1535 src_ast_str = '\n'.join(_v(src_ast))
1536 dst_ast_str = '\n'.join(_v(dst_ast))
1537 if src_ast_str != dst_ast_str:
1538 log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1539 raise AssertionError(
1540 f"INTERNAL ERROR: Black produced code that is not equivalent to "
1542 f"Please report a bug on https://github.com/ambv/black/issues. "
1543 f"This diff might be helpful: {log}"
1547 def assert_stable(src: str, dst: str, line_length: int) -> None:
1548 """Raises AssertionError if `dst` reformats differently the second time.
1550 This is a temporary sanity check until Black becomes stable.
1552 newdst = format_str(dst, line_length=line_length)
1555 diff(src, dst, 'source', 'first pass'),
1556 diff(dst, newdst, 'first pass', 'second pass'),
1558 raise AssertionError(
1559 f"INTERNAL ERROR: Black produced different code on the second pass "
1560 f"of the formatter. "
1561 f"Please report a bug on https://github.com/ambv/black/issues. "
1562 f"This diff might be helpful: {log}"
1566 def dump_to_file(*output: str) -> str:
1567 """Dumps `output` to a temporary file. Returns path to the file."""
1570 with tempfile.NamedTemporaryFile(
1571 mode='w', prefix='blk_', suffix='.log', delete=False
1573 for lines in output:
1579 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1580 """Returns a udiff string between strings `a` and `b`."""
1583 a_lines = [line + '\n' for line in a.split('\n')]
1584 b_lines = [line + '\n' for line in b.split('\n')]
1586 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1590 if __name__ == '__main__':