]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add support for pyi files (#210)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Tue, 15 May 2018 19:09:35 +0000 (15:09 -0400)
committerŁukasz Langa <lukasz@langa.pl>
Tue, 15 May 2018 19:09:35 +0000 (15:09 -0400)
Fixes #207

black.py
tests/stub.pyi [new file with mode: 0644]
tests/test_black.py

index 7823ae0afe2e809bdb2f5ca208ec9cc0565b4ced..81241f6ec0248df6aee8703c0c1dc7d6330e7303 100644 (file)
--- a/black.py
+++ b/black.py
@@ -329,12 +329,13 @@ def format_file_in_place(
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
+    is_pyi = src.suffix == ".pyi"
 
     with tokenize.open(src) as src_buffer:
         src_contents = src_buffer.read()
     try:
         dst_contents = format_file_contents(
-            src_contents, line_length=line_length, fast=fast
+            src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
         )
     except NothingChanged:
         return False
@@ -383,7 +384,7 @@ def format_stdin_to_stdout(
 
 
 def format_file_contents(
-    src_contents: str, line_length: int, fast: bool
+    src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
@@ -394,17 +395,21 @@ def format_file_contents(
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(src_contents, line_length=line_length)
+    dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(src_contents, dst_contents, line_length=line_length)
+        assert_stable(
+            src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
+        )
     return dst_contents
 
 
-def format_str(src_contents: str, line_length: int) -> FileContent:
+def format_str(
+    src_contents: str, line_length: int, *, is_pyi: bool = False
+) -> FileContent:
     """Reformat a string and return new contents.
 
     `line_length` determines how many characters per line are allowed.
@@ -412,9 +417,11 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
+    elt = EmptyLineTracker(is_pyi=is_pyi)
     py36 = is_python36(src_node)
-    lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
-    elt = EmptyLineTracker()
+    lines = LineGenerator(
+        remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
+    )
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -833,6 +840,14 @@ class Line:
             and self.leaves[0].value == "class"
         )
 
+    @property
+    def is_trivial_class(self) -> bool:
+        """Is this line a class definition with a body consisting only of "..."?"""
+        return (
+            self.is_class
+            and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)]
+        )
+
     @property
     def is_def(self) -> bool:
         """Is this a function definition? (Also returns True for async defs.)"""
@@ -1100,6 +1115,7 @@ class EmptyLineTracker:
     the prefix of the first leaf consists of optional newlines.  Those newlines
     are consumed by `maybe_empty_lines()` and included in the computation.
     """
+    is_pyi: bool = False
     previous_line: Optional[Line] = None
     previous_after: int = 0
     previous_defs: List[int] = Factory(list)
@@ -1123,7 +1139,7 @@ class EmptyLineTracker:
     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
         max_allowed = 1
         if current_line.depth == 0:
-            max_allowed = 2
+            max_allowed = 1 if self.is_pyi else 2
         if current_line.leaves:
             # Consume the first leaf's extra newlines.
             first_leaf = current_line.leaves[0]
@@ -1135,7 +1151,10 @@ class EmptyLineTracker:
         depth = current_line.depth
         while self.previous_defs and self.previous_defs[-1] >= depth:
             self.previous_defs.pop()
-            before = 1 if depth else 2
+            if self.is_pyi:
+                before = 0 if depth else 1
+            else:
+                before = 1 if depth else 2
         is_decorator = current_line.is_decorator
         if is_decorator or current_line.is_def or current_line.is_class:
             if not is_decorator:
@@ -1154,8 +1173,22 @@ class EmptyLineTracker:
             ):
                 return 0, 0
 
-            newlines = 2
-            if current_line.depth:
+            if self.is_pyi:
+                if self.previous_line.depth > current_line.depth:
+                    newlines = 1
+                elif current_line.is_class or self.previous_line.is_class:
+                    if (
+                        current_line.is_trivial_class
+                        and self.previous_line.is_trivial_class
+                    ):
+                        newlines = 0
+                    else:
+                        newlines = 1
+                else:
+                    newlines = 0
+            else:
+                newlines = 2
+            if current_line.depth and newlines:
                 newlines -= 1
             return newlines, 0
 
@@ -1177,6 +1210,7 @@ class LineGenerator(Visitor[Line]):
     Note: destroys the tree it's visiting by mutating prefixes of its leaves
     in ways that will no longer stringify to valid Python code on the tree.
     """
+    is_pyi: bool = False
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
 
@@ -1293,16 +1327,66 @@ class LineGenerator(Visitor[Line]):
 
             yield from self.visit(child)
 
+    def visit_suite(self, node: Node) -> Iterator[Line]:
+        """Visit a suite."""
+        if self.is_pyi and self.is_trivial_suite(node):
+            yield from self.visit(node.children[2])
+        else:
+            yield from self.visit_default(node)
+
+    def is_trivial_suite(self, node: Node) -> bool:
+        if len(node.children) != 4:
+            return False
+        if (
+            not isinstance(node.children[0], Leaf)
+            or node.children[0].type != token.NEWLINE
+        ):
+            return False
+        if (
+            not isinstance(node.children[1], Leaf)
+            or node.children[1].type != token.INDENT
+        ):
+            return False
+        if (
+            not isinstance(node.children[3], Leaf)
+            or node.children[3].type != token.DEDENT
+        ):
+            return False
+        stmt = node.children[2]
+        if not isinstance(stmt, Node):
+            return False
+        return self.is_trivial_body(stmt)
+
+    def is_trivial_body(self, stmt: Node) -> bool:
+        if not isinstance(stmt, Node) or stmt.type != syms.simple_stmt:
+            return False
+        if len(stmt.children) != 2:
+            return False
+        child = stmt.children[0]
+        return (
+            child.type == syms.atom
+            and len(child.children) == 3
+            and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
+        )
+
     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
         """Visit a statement without nested statements."""
         is_suite_like = node.parent and node.parent.type in STATEMENT
         if is_suite_like:
-            yield from self.line(+1)
-            yield from self.visit_default(node)
-            yield from self.line(-1)
+            if self.is_pyi and self.is_trivial_body(node):
+                yield from self.visit_default(node)
+            else:
+                yield from self.line(+1)
+                yield from self.visit_default(node)
+                yield from self.line(-1)
 
         else:
-            yield from self.line()
+            if (
+                not self.is_pyi
+                or not node.parent
+                or not self.is_trivial_suite(node.parent)
+            ):
+                yield from self.line()
             yield from self.visit_default(node)
 
     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
@@ -2554,7 +2638,7 @@ def get_future_imports(node: Node) -> Set[str]:
     return imports
 
 
-PYTHON_EXTENSIONS = {".py"}
+PYTHON_EXTENSIONS = {".py", ".pyi"}
 BLACKLISTED_DIRECTORIES = {
     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
 }
@@ -2717,9 +2801,9 @@ def assert_equivalent(src: str, dst: str) -> None:
         ) from None
 
 
-def assert_stable(src: str, dst: str, line_length: int) -> None:
+def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, line_length=line_length)
+    newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
diff --git a/tests/stub.pyi b/tests/stub.pyi
new file mode 100644 (file)
index 0000000..986cc84
--- /dev/null
@@ -0,0 +1,27 @@
+class C:
+    ...
+
+class B:
+    ...
+
+class A:
+    def f(self) -> int:
+        ...
+
+    def g(self) -> str: ...
+
+def g():
+    ...
+
+def h(): ...
+
+# output
+class C: ...
+class B: ...
+
+class A:
+    def f(self) -> int: ...
+    def g(self) -> str: ...
+
+def g(): ...
+def h(): ...
index bc133cab4aca84f69de1a6eab504134cd7ed1065..82e3f5a6d0d62a2f420b4843fcf494ea35c8862f 100644 (file)
@@ -31,7 +31,7 @@ def dump_to_stderr(*output: str) -> str:
 
 def read_data(name: str) -> Tuple[str, str]:
     """read_data('test_name') -> 'input', 'output'"""
-    if not name.endswith((".py", ".out", ".diff")):
+    if not name.endswith((".py", ".pyi", ".out", ".diff")):
         name += ".py"
     _input: List[str] = []
     _output: List[str] = []
@@ -340,6 +340,13 @@ class BlackTestCase(unittest.TestCase):
         self.assertFormatEqual(expected, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_stub(self) -> None:
+        source, expected = read_data("stub.pyi")
+        actual = fs(source, is_pyi=True)
+        self.assertFormatEqual(expected, actual)
+        black.assert_stable(source, actual, line_length=ll, is_pyi=True)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_fmtonoff(self) -> None:
         source, expected = read_data("fmtonoff")