If `write_back` is True, write reformatted code back to stdout.
`line_length` and `fast` options are passed to :func:`format_file_contents`.
"""
+ is_pyi = src.suffix == ".pyi"
with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read()
try:
dst_contents = format_file_contents(
- src_contents, line_length=line_length, fast=fast
+ src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
)
except NothingChanged:
return False
def format_file_contents(
- src_contents: str, line_length: int, fast: bool
+ src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
) -> FileContent:
"""Reformat contents a file and return new contents.
if src_contents.strip() == "":
raise NothingChanged
- dst_contents = format_str(src_contents, line_length=line_length)
+ dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
if src_contents == dst_contents:
raise NothingChanged
if not fast:
assert_equivalent(src_contents, dst_contents)
- assert_stable(src_contents, dst_contents, line_length=line_length)
+ assert_stable(
+ src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
+ )
return dst_contents
-def format_str(src_contents: str, line_length: int) -> FileContent:
+def format_str(
+ src_contents: str, line_length: int, *, is_pyi: bool = False
+) -> FileContent:
"""Reformat a string and return new contents.
`line_length` determines how many characters per line are allowed.
src_node = lib2to3_parse(src_contents)
dst_contents = ""
future_imports = get_future_imports(src_node)
+ elt = EmptyLineTracker(is_pyi=is_pyi)
py36 = is_python36(src_node)
- lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
- elt = EmptyLineTracker()
+ lines = LineGenerator(
+ remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
+ )
empty_line = Line()
after = 0
for current_line in lines.visit(src_node):
and self.leaves[0].value == "class"
)
+ @property
+ def is_trivial_class(self) -> bool:
+ """Is this line a class definition with a body consisting only of "..."?"""
+ return (
+ self.is_class
+ and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)]
+ )
+
@property
def is_def(self) -> bool:
"""Is this a function definition? (Also returns True for async defs.)"""
the prefix of the first leaf consists of optional newlines. Those newlines
are consumed by `maybe_empty_lines()` and included in the computation.
"""
+ is_pyi: bool = False
previous_line: Optional[Line] = None
previous_after: int = 0
previous_defs: List[int] = Factory(list)
def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
max_allowed = 1
if current_line.depth == 0:
- max_allowed = 2
+ max_allowed = 1 if self.is_pyi else 2
if current_line.leaves:
# Consume the first leaf's extra newlines.
first_leaf = current_line.leaves[0]
depth = current_line.depth
while self.previous_defs and self.previous_defs[-1] >= depth:
self.previous_defs.pop()
- before = 1 if depth else 2
+ if self.is_pyi:
+ before = 0 if depth else 1
+ else:
+ before = 1 if depth else 2
is_decorator = current_line.is_decorator
if is_decorator or current_line.is_def or current_line.is_class:
if not is_decorator:
):
return 0, 0
- newlines = 2
- if current_line.depth:
+ if self.is_pyi:
+ if self.previous_line.depth > current_line.depth:
+ newlines = 1
+ elif current_line.is_class or self.previous_line.is_class:
+ if (
+ current_line.is_trivial_class
+ and self.previous_line.is_trivial_class
+ ):
+ newlines = 0
+ else:
+ newlines = 1
+ else:
+ newlines = 0
+ else:
+ newlines = 2
+ if current_line.depth and newlines:
newlines -= 1
return newlines, 0
Note: destroys the tree it's visiting by mutating prefixes of its leaves
in ways that will no longer stringify to valid Python code on the tree.
"""
+ is_pyi: bool = False
current_line: Line = Factory(Line)
remove_u_prefix: bool = False
yield from self.visit(child)
+ def visit_suite(self, node: Node) -> Iterator[Line]:
+ """Visit a suite."""
+ if self.is_pyi and self.is_trivial_suite(node):
+ yield from self.visit(node.children[2])
+ else:
+ yield from self.visit_default(node)
+
+ def is_trivial_suite(self, node: Node) -> bool:
+ if len(node.children) != 4:
+ return False
+ if (
+ not isinstance(node.children[0], Leaf)
+ or node.children[0].type != token.NEWLINE
+ ):
+ return False
+ if (
+ not isinstance(node.children[1], Leaf)
+ or node.children[1].type != token.INDENT
+ ):
+ return False
+ if (
+ not isinstance(node.children[3], Leaf)
+ or node.children[3].type != token.DEDENT
+ ):
+ return False
+ stmt = node.children[2]
+ if not isinstance(stmt, Node):
+ return False
+ return self.is_trivial_body(stmt)
+
+ def is_trivial_body(self, stmt: Node) -> bool:
+ if not isinstance(stmt, Node) or stmt.type != syms.simple_stmt:
+ return False
+ if len(stmt.children) != 2:
+ return False
+ child = stmt.children[0]
+ return (
+ child.type == syms.atom
+ and len(child.children) == 3
+ and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
+ )
+
def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
"""Visit a statement without nested statements."""
is_suite_like = node.parent and node.parent.type in STATEMENT
if is_suite_like:
- yield from self.line(+1)
- yield from self.visit_default(node)
- yield from self.line(-1)
+ if self.is_pyi and self.is_trivial_body(node):
+ yield from self.visit_default(node)
+ else:
+ yield from self.line(+1)
+ yield from self.visit_default(node)
+ yield from self.line(-1)
else:
- yield from self.line()
+ if (
+ not self.is_pyi
+ or not node.parent
+ or not self.is_trivial_suite(node.parent)
+ ):
+ yield from self.line()
yield from self.visit_default(node)
def visit_async_stmt(self, node: Node) -> Iterator[Line]:
return imports
-PYTHON_EXTENSIONS = {".py"}
+PYTHON_EXTENSIONS = {".py", ".pyi"}
BLACKLISTED_DIRECTORIES = {
"build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
}
) from None
-def assert_stable(src: str, dst: str, line_length: int) -> None:
+def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
- newdst = format_str(dst, line_length=line_length)
+ newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
if dst != newdst:
log = dump_to_file(
diff(src, dst, "source", "first pass"),