From: Jelle Zijlstra Date: Tue, 15 May 2018 19:09:35 +0000 (-0400) Subject: Add support for pyi files (#210) X-Git-Url: https://git.madduck.net/etc/vim.git/commitdiff_plain/14ba1bf8b6248e6860ba6a0cb9468c4c1c25a102 Add support for pyi files (#210) Fixes #207 --- diff --git a/black.py b/black.py index 7823ae0..81241f6 100644 --- a/black.py +++ b/black.py @@ -329,12 +329,13 @@ def format_file_in_place( If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` options are passed to :func:`format_file_contents`. """ + is_pyi = src.suffix == ".pyi" with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: dst_contents = format_file_contents( - src_contents, line_length=line_length, fast=fast + src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi ) except NothingChanged: return False @@ -383,7 +384,7 @@ def format_stdin_to_stdout( def format_file_contents( - src_contents: str, line_length: int, fast: bool + src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False ) -> FileContent: """Reformat contents a file and return new contents. @@ -394,17 +395,21 @@ def format_file_contents( if src_contents.strip() == "": raise NothingChanged - dst_contents = format_str(src_contents, line_length=line_length) + dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi) if src_contents == dst_contents: raise NothingChanged if not fast: assert_equivalent(src_contents, dst_contents) - assert_stable(src_contents, dst_contents, line_length=line_length) + assert_stable( + src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi + ) return dst_contents -def format_str(src_contents: str, line_length: int) -> FileContent: +def format_str( + src_contents: str, line_length: int, *, is_pyi: bool = False +) -> FileContent: """Reformat a string and return new contents. `line_length` determines how many characters per line are allowed. @@ -412,9 +417,11 @@ def format_str(src_contents: str, line_length: int) -> FileContent: src_node = lib2to3_parse(src_contents) dst_contents = "" future_imports = get_future_imports(src_node) + elt = EmptyLineTracker(is_pyi=is_pyi) py36 = is_python36(src_node) - lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports) - elt = EmptyLineTracker() + lines = LineGenerator( + remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi + ) empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -833,6 +840,14 @@ class Line: and self.leaves[0].value == "class" ) + @property + def is_trivial_class(self) -> bool: + """Is this line a class definition with a body consisting only of "..."?""" + return ( + self.is_class + and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)] + ) + @property def is_def(self) -> bool: """Is this a function definition? (Also returns True for async defs.)""" @@ -1100,6 +1115,7 @@ class EmptyLineTracker: the prefix of the first leaf consists of optional newlines. Those newlines are consumed by `maybe_empty_lines()` and included in the computation. """ + is_pyi: bool = False previous_line: Optional[Line] = None previous_after: int = 0 previous_defs: List[int] = Factory(list) @@ -1123,7 +1139,7 @@ class EmptyLineTracker: def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: max_allowed = 1 if current_line.depth == 0: - max_allowed = 2 + max_allowed = 1 if self.is_pyi else 2 if current_line.leaves: # Consume the first leaf's extra newlines. first_leaf = current_line.leaves[0] @@ -1135,7 +1151,10 @@ class EmptyLineTracker: depth = current_line.depth while self.previous_defs and self.previous_defs[-1] >= depth: self.previous_defs.pop() - before = 1 if depth else 2 + if self.is_pyi: + before = 0 if depth else 1 + else: + before = 1 if depth else 2 is_decorator = current_line.is_decorator if is_decorator or current_line.is_def or current_line.is_class: if not is_decorator: @@ -1154,8 +1173,22 @@ class EmptyLineTracker: ): return 0, 0 - newlines = 2 - if current_line.depth: + if self.is_pyi: + if self.previous_line.depth > current_line.depth: + newlines = 1 + elif current_line.is_class or self.previous_line.is_class: + if ( + current_line.is_trivial_class + and self.previous_line.is_trivial_class + ): + newlines = 0 + else: + newlines = 1 + else: + newlines = 0 + else: + newlines = 2 + if current_line.depth and newlines: newlines -= 1 return newlines, 0 @@ -1177,6 +1210,7 @@ class LineGenerator(Visitor[Line]): Note: destroys the tree it's visiting by mutating prefixes of its leaves in ways that will no longer stringify to valid Python code on the tree. """ + is_pyi: bool = False current_line: Line = Factory(Line) remove_u_prefix: bool = False @@ -1293,16 +1327,66 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) + def visit_suite(self, node: Node) -> Iterator[Line]: + """Visit a suite.""" + if self.is_pyi and self.is_trivial_suite(node): + yield from self.visit(node.children[2]) + else: + yield from self.visit_default(node) + + def is_trivial_suite(self, node: Node) -> bool: + if len(node.children) != 4: + return False + if ( + not isinstance(node.children[0], Leaf) + or node.children[0].type != token.NEWLINE + ): + return False + if ( + not isinstance(node.children[1], Leaf) + or node.children[1].type != token.INDENT + ): + return False + if ( + not isinstance(node.children[3], Leaf) + or node.children[3].type != token.DEDENT + ): + return False + stmt = node.children[2] + if not isinstance(stmt, Node): + return False + return self.is_trivial_body(stmt) + + def is_trivial_body(self, stmt: Node) -> bool: + if not isinstance(stmt, Node) or stmt.type != syms.simple_stmt: + return False + if len(stmt.children) != 2: + return False + child = stmt.children[0] + return ( + child.type == syms.atom + and len(child.children) == 3 + and all(leaf == Leaf(token.DOT, ".") for leaf in child.children) + ) + def visit_simple_stmt(self, node: Node) -> Iterator[Line]: """Visit a statement without nested statements.""" is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: - yield from self.line(+1) - yield from self.visit_default(node) - yield from self.line(-1) + if self.is_pyi and self.is_trivial_body(node): + yield from self.visit_default(node) + else: + yield from self.line(+1) + yield from self.visit_default(node) + yield from self.line(-1) else: - yield from self.line() + if ( + not self.is_pyi + or not node.parent + or not self.is_trivial_suite(node.parent) + ): + yield from self.line() yield from self.visit_default(node) def visit_async_stmt(self, node: Node) -> Iterator[Line]: @@ -2554,7 +2638,7 @@ def get_future_imports(node: Node) -> Set[str]: return imports -PYTHON_EXTENSIONS = {".py"} +PYTHON_EXTENSIONS = {".py", ".pyi"} BLACKLISTED_DIRECTORIES = { "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" } @@ -2717,9 +2801,9 @@ def assert_equivalent(src: str, dst: str) -> None: ) from None -def assert_stable(src: str, dst: str, line_length: int) -> None: +def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" - newdst = format_str(dst, line_length=line_length) + newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi) if dst != newdst: log = dump_to_file( diff(src, dst, "source", "first pass"), diff --git a/tests/stub.pyi b/tests/stub.pyi new file mode 100644 index 0000000..986cc84 --- /dev/null +++ b/tests/stub.pyi @@ -0,0 +1,27 @@ +class C: + ... + +class B: + ... + +class A: + def f(self) -> int: + ... + + def g(self) -> str: ... + +def g(): + ... + +def h(): ... + +# output +class C: ... +class B: ... + +class A: + def f(self) -> int: ... + def g(self) -> str: ... + +def g(): ... +def h(): ... diff --git a/tests/test_black.py b/tests/test_black.py index bc133ca..82e3f5a 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -31,7 +31,7 @@ def dump_to_stderr(*output: str) -> str: def read_data(name: str) -> Tuple[str, str]: """read_data('test_name') -> 'input', 'output'""" - if not name.endswith((".py", ".out", ".diff")): + if not name.endswith((".py", ".pyi", ".out", ".diff")): name += ".py" _input: List[str] = [] _output: List[str] = [] @@ -340,6 +340,13 @@ class BlackTestCase(unittest.TestCase): self.assertFormatEqual(expected, actual) black.assert_stable(source, actual, line_length=ll) + @patch("black.dump_to_file", dump_to_stderr) + def test_stub(self) -> None: + source, expected = read_data("stub.pyi") + actual = fs(source, is_pyi=True) + self.assertFormatEqual(expected, actual) + black.assert_stable(source, actual, line_length=ll, is_pyi=True) + @patch("black.dump_to_file", dump_to_stderr) def test_fmtonoff(self) -> None: source, expected = read_data("fmtonoff")