src = src.resolve()
if src in cache and cache[src] == get_cache_info(src):
changed = Changed.CACHED
- if (
- changed is not Changed.CACHED
- and format_file_in_place(
- src, line_length=line_length, fast=fast, write_back=write_back
- )
+ if changed is not Changed.CACHED and format_file_in_place(
+ src, line_length=line_length, fast=fast, write_back=write_back
):
changed = Changed.YES
if write_back == WriteBack.YES and changed is not Changed.NO:
manager = Manager()
lock = manager.Lock()
tasks = {
- src: loop.run_in_executor(
+ loop.run_in_executor(
executor, format_file_in_place, src, line_length, fast, write_back, lock
- )
- for src in sources
+ ): src
+ for src in sorted(sources)
}
- _task_values = list(tasks.values())
+ pending: Iterable[asyncio.Task] = tasks.keys()
try:
- loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
- loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
+ loop.add_signal_handler(signal.SIGINT, cancel, pending)
+ loop.add_signal_handler(signal.SIGTERM, cancel, pending)
except NotImplementedError:
# There are no good alternatives for these on Windows
pass
- await asyncio.wait(_task_values)
- for src, task in tasks.items():
- if not task.done():
- report.failed(src, "timed out, cancelling")
- task.cancel()
- cancelled.append(task)
- elif task.cancelled():
- cancelled.append(task)
- elif task.exception():
- report.failed(src, str(task.exception()))
- else:
- formatted.append(src)
- report.done(src, Changed.YES if task.result() else Changed.NO)
-
+ while pending:
+ done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
+ for task in done:
+ src = tasks.pop(task)
+ if task.cancelled():
+ cancelled.append(task)
+ elif task.exception():
+ report.failed(src, str(task.exception()))
+ else:
+ formatted.append(src)
+ report.done(src, Changed.YES if task.result() else Changed.NO)
if cancelled:
await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
if write_back == WriteBack.YES and formatted:
If `write_back` is True, write reformatted code back to stdout.
`line_length` and `fast` options are passed to :func:`format_file_contents`.
"""
+ is_pyi = src.suffix == ".pyi"
with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read()
try:
dst_contents = format_file_contents(
- src_contents, line_length=line_length, fast=fast
+ src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
)
except NothingChanged:
return False
def format_file_contents(
- src_contents: str, line_length: int, fast: bool
+ src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
) -> FileContent:
"""Reformat contents a file and return new contents.
if src_contents.strip() == "":
raise NothingChanged
- dst_contents = format_str(src_contents, line_length=line_length)
+ dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
if src_contents == dst_contents:
raise NothingChanged
if not fast:
assert_equivalent(src_contents, dst_contents)
- assert_stable(src_contents, dst_contents, line_length=line_length)
+ assert_stable(
+ src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
+ )
return dst_contents
-def format_str(src_contents: str, line_length: int) -> FileContent:
+def format_str(
+ src_contents: str, line_length: int, *, is_pyi: bool = False
+) -> FileContent:
"""Reformat a string and return new contents.
`line_length` determines how many characters per line are allowed.
src_node = lib2to3_parse(src_contents)
dst_contents = ""
future_imports = get_future_imports(src_node)
+ elt = EmptyLineTracker(is_pyi=is_pyi)
py36 = is_python36(src_node)
- lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
- elt = EmptyLineTracker()
+ lines = LineGenerator(
+ remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
+ )
empty_line = Line()
after = 0
for current_line in lines.visit(src_node):
"""
return max(v for k, v in self.delimiters.items() if k not in exclude)
+ def delimiter_count_with_priority(self, priority: int = 0) -> int:
+ """Return the number of delimiters with the given `priority`.
+
+ If no `priority` is passed, defaults to max priority on the line.
+ """
+ if not self.delimiters:
+ return 0
+
+ priority = priority or self.max_delimiter_priority()
+ return sum(1 for p in self.delimiters.values() if p == priority)
+
def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
"""In a for loop, or comprehension, the variables are often unpacks.
and self.leaves[0].value == "class"
)
+ @property
+ def is_stub_class(self) -> bool:
+ """Is this line a class definition with a body consisting only of "..."?"""
+ return self.is_class and self.leaves[-3:] == [
+ Leaf(token.DOT, ".") for _ in range(3)
+ ]
+
@property
def is_def(self) -> bool:
"""Is this a function definition? (Also returns True for async defs.)"""
second_leaf: Optional[Leaf] = self.leaves[1]
except IndexError:
second_leaf = None
- return (
- (first_leaf.type == token.NAME and first_leaf.value == "def")
- or (
- first_leaf.type == token.ASYNC
- and second_leaf is not None
- and second_leaf.type == token.NAME
- and second_leaf.value == "def"
- )
+ return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
+ first_leaf.type == token.ASYNC
+ and second_leaf is not None
+ and second_leaf.type == token.NAME
+ and second_leaf.value == "def"
)
@property
and subscript_start.type == syms.subscriptlist
):
subscript_start = child_towards(subscript_start, leaf)
- return (
- subscript_start is not None
- and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
+ return subscript_start is not None and any(
+ n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
)
def __str__(self) -> str:
the prefix of the first leaf consists of optional newlines. Those newlines
are consumed by `maybe_empty_lines()` and included in the computation.
"""
+ is_pyi: bool = False
previous_line: Optional[Line] = None
previous_after: int = 0
previous_defs: List[int] = Factory(list)
def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
max_allowed = 1
if current_line.depth == 0:
- max_allowed = 2
+ max_allowed = 1 if self.is_pyi else 2
if current_line.leaves:
# Consume the first leaf's extra newlines.
first_leaf = current_line.leaves[0]
depth = current_line.depth
while self.previous_defs and self.previous_defs[-1] >= depth:
self.previous_defs.pop()
- before = 1 if depth else 2
+ if self.is_pyi:
+ before = 0 if depth else 1
+ else:
+ before = 1 if depth else 2
is_decorator = current_line.is_decorator
if is_decorator or current_line.is_def or current_line.is_class:
if not is_decorator:
):
return 0, 0
- newlines = 2
- if current_line.depth:
+ if self.is_pyi:
+ if self.previous_line.depth > current_line.depth:
+ newlines = 1
+ elif current_line.is_class or self.previous_line.is_class:
+ if current_line.is_stub_class and self.previous_line.is_stub_class:
+ newlines = 0
+ else:
+ newlines = 1
+ else:
+ newlines = 0
+ else:
+ newlines = 2
+ if current_line.depth and newlines:
newlines -= 1
return newlines, 0
Note: destroys the tree it's visiting by mutating prefixes of its leaves
in ways that will no longer stringify to valid Python code on the tree.
"""
+ is_pyi: bool = False
current_line: Line = Factory(Line)
remove_u_prefix: bool = False
yield from self.visit(child)
+ def visit_suite(self, node: Node) -> Iterator[Line]:
+ """Visit a suite."""
+ if self.is_pyi and is_stub_suite(node):
+ yield from self.visit(node.children[2])
+ else:
+ yield from self.visit_default(node)
+
def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
"""Visit a statement without nested statements."""
is_suite_like = node.parent and node.parent.type in STATEMENT
if is_suite_like:
- yield from self.line(+1)
- yield from self.visit_default(node)
- yield from self.line(-1)
+ if self.is_pyi and is_stub_body(node):
+ yield from self.visit_default(node)
+ else:
+ yield from self.line(+1)
+ yield from self.visit_default(node)
+ yield from self.line(-1)
else:
- yield from self.line()
+ if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
+ yield from self.line()
yield from self.visit_default(node)
def visit_async_stmt(self, node: Node) -> Iterator[Line]:
return DOUBLESPACE
assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
- if (
- t == token.COLON
- and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
- ):
+ if t == token.COLON and p.type not in {
+ syms.subscript, syms.subscriptlist, syms.sliceop
+ }:
return NO
prev = leaf.prev_sibling
prevp_parent = prevp.parent
assert prevp_parent is not None
- if (
- prevp.type == token.COLON
- and prevp_parent.type in {syms.subscript, syms.sliceop}
- ):
+ if prevp.type == token.COLON and prevp_parent.type in {
+ syms.subscript, syms.sliceop
+ }:
return NO
elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
):
return STRING_PRIORITY
+ if leaf.type != token.NAME:
+ return 0
+
if (
- leaf.type == token.NAME
- and leaf.value == "for"
+ leaf.value == "for"
and leaf.parent
and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
):
return COMPREHENSION_PRIORITY
if (
- leaf.type == token.NAME
- and leaf.value == "if"
+ leaf.value == "if"
and leaf.parent
and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
):
return COMPREHENSION_PRIORITY
+ if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
+ return TERNARY_PRIORITY
+
+ if leaf.value == "is":
+ return COMPARATOR_PRIORITY
+
if (
- leaf.type == token.NAME
- and leaf.value in {"if", "else"}
+ leaf.value == "in"
and leaf.parent
- and leaf.parent.type == syms.test
+ and leaf.parent.type in {syms.comp_op, syms.comparison}
+ and not (
+ previous is not None
+ and previous.type == token.NAME
+ and previous.value == "not"
+ )
):
- return TERNARY_PRIORITY
+ return COMPARATOR_PRIORITY
- if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
+ if (
+ leaf.value == "not"
+ and leaf.parent
+ and leaf.parent.type == syms.comp_op
+ and not (
+ previous is not None
+ and previous.type == token.NAME
+ and previous.value == "is"
+ )
+ ):
+ return COMPARATOR_PRIORITY
+
+ if leaf.value in LOGIC_OPERATORS and leaf.parent:
return LOGIC_PRIORITY
return 0
If the split was by optional parentheses, attempt splitting without them, too.
`omit` is a collection of closing bracket IDs that shouldn't be considered for
this split.
+
+ Note: running this function modifies `bracket_depth` on the leaves of `line`.
"""
head = Line(depth=line.depth)
body = Line(depth=line.depth + 1, inside_brackets=True)
# Since body is a new indent level, remove spurious leading whitespace.
if body_leaves:
normalize_prefix(body_leaves[0], inside_brackets=True)
- elif not head_leaves:
- # No `head` and no `body` means the split failed. `tail` has all content.
+ if not head_leaves:
+ # No `head` means the split failed. Either `tail` has all content or
+ # the matching `opening_bracket` wasn't available on `line` anymore.
raise CannotSplit("No brackets found")
# Build the new lines.
# the closing bracket is an optional paren
and closing_bracket.type == token.RPAR
and not closing_bracket.value
- # there are no delimiters or standalone comments in the body
- and not body.bracket_tracker.delimiters
+ # there are no standalone comments in the body
and not line.contains_standalone_comments(0)
# and it's not an import (optional parens are the only thing we can split
# on in this case; attempting a split without them is a waste of time)
and not line.is_import
):
omit = {id(closing_bracket), *omit}
- try:
- yield from right_hand_split(line, py36=py36, omit=omit)
- return
- except CannotSplit:
- pass
+ delimiter_count = body.bracket_tracker.delimiter_count_with_priority()
+ if (
+ delimiter_count == 0
+ or delimiter_count == 1
+ and (
+ body.leaves[0].type in OPENING_BRACKETS
+ or body.leaves[-1].type in CLOSING_BRACKETS
+ )
+ ):
+ try:
+ yield from right_hand_split(line, py36=py36, omit=omit)
+ return
+ except CannotSplit:
+ pass
ensure_visible(opening_bracket)
ensure_visible(closing_bracket)
except IndexError:
raise CannotSplit("Line empty")
- delimiters = line.bracket_tracker.delimiters
+ bt = line.bracket_tracker
try:
- delimiter_priority = line.bracket_tracker.max_delimiter_priority(
- exclude={id(last_leaf)}
- )
+ delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
except ValueError:
raise CannotSplit("No delimiters found")
yield from append_to_line(comment_after)
lowest_depth = min(lowest_depth, leaf.bracket_depth)
- if (
- leaf.bracket_depth == lowest_depth
- and is_vararg(leaf, within=VARARGS_PARENTS)
+ if leaf.bracket_depth == lowest_depth and is_vararg(
+ leaf, within=VARARGS_PARENTS
):
trailing_comma_safe = trailing_comma_safe and py36
- leaf_priority = delimiters.get(id(leaf))
+ leaf_priority = bt.delimiters.get(id(leaf))
if leaf_priority == delimiter_priority:
yield current_line
return p.type in within
+def is_stub_suite(node: Node) -> bool:
+ """Return True if `node` is a suite with a stub body."""
+ if (
+ len(node.children) != 4
+ or node.children[0].type != token.NEWLINE
+ or node.children[1].type != token.INDENT
+ or node.children[3].type != token.DEDENT
+ ):
+ return False
+
+ return is_stub_body(node.children[2])
+
+
+def is_stub_body(node: LN) -> bool:
+ """Return True if `node` is a simple statement containing an ellipsis."""
+ if not isinstance(node, Node) or node.type != syms.simple_stmt:
+ return False
+
+ if len(node.children) != 2:
+ return False
+
+ child = node.children[0]
+ return (
+ child.type == syms.atom
+ and len(child.children) == 3
+ and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
+ )
+
+
def max_delimiter_priority_in_atom(node: LN) -> int:
"""Return maximum delimiter priority inside `node`.
return imports
-PYTHON_EXTENSIONS = {".py"}
+PYTHON_EXTENSIONS = {".py", ".pyi"}
BLACKLISTED_DIRECTORIES = {
"build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
}
) from None
-def assert_stable(src: str, dst: str, line_length: int) -> None:
+def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
- newdst = format_str(dst, line_length=line_length)
+ newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
if dst != newdst:
log = dump_to_file(
diff(src, dst, "source", "first pass"),
)
-def cancel(tasks: List[asyncio.Task]) -> None:
+def cancel(tasks: Iterable[asyncio.Task]) -> None:
"""asyncio signal handler that cancels all `tasks` and reports to stderr."""
err("Aborted!")
for task in tasks: