X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/3eea3aad864b83af3b6c477c32f15eb821fe9341..96e68f034197d6060aad876216a0e40e0eac0987:/black.py?ds=inline diff --git a/black.py b/black.py index 7823ae0..f389176 100644 --- a/black.py +++ b/black.py @@ -239,11 +239,8 @@ def reformat_one( src = src.resolve() if src in cache and cache[src] == get_cache_info(src): changed = Changed.CACHED - if ( - changed is not Changed.CACHED - and format_file_in_place( - src, line_length=line_length, fast=fast, write_back=write_back - ) + if changed is not Changed.CACHED and format_file_in_place( + src, line_length=line_length, fast=fast, write_back=write_back ): changed = Changed.YES if write_back == WriteBack.YES and changed is not Changed.NO: @@ -285,32 +282,29 @@ async def schedule_formatting( manager = Manager() lock = manager.Lock() tasks = { - src: loop.run_in_executor( + loop.run_in_executor( executor, format_file_in_place, src, line_length, fast, write_back, lock - ) - for src in sources + ): src + for src in sorted(sources) } - _task_values = list(tasks.values()) + pending: Iterable[asyncio.Task] = tasks.keys() try: - loop.add_signal_handler(signal.SIGINT, cancel, _task_values) - loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) + loop.add_signal_handler(signal.SIGINT, cancel, pending) + loop.add_signal_handler(signal.SIGTERM, cancel, pending) except NotImplementedError: # There are no good alternatives for these on Windows pass - await asyncio.wait(_task_values) - for src, task in tasks.items(): - if not task.done(): - report.failed(src, "timed out, cancelling") - task.cancel() - cancelled.append(task) - elif task.cancelled(): - cancelled.append(task) - elif task.exception(): - report.failed(src, str(task.exception())) - else: - formatted.append(src) - report.done(src, Changed.YES if task.result() else Changed.NO) - + while pending: + done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + src = tasks.pop(task) + if task.cancelled(): + cancelled.append(task) + elif task.exception(): + report.failed(src, str(task.exception())) + else: + formatted.append(src) + report.done(src, Changed.YES if task.result() else Changed.NO) if cancelled: await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) if write_back == WriteBack.YES and formatted: @@ -329,12 +323,13 @@ def format_file_in_place( If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` options are passed to :func:`format_file_contents`. """ + is_pyi = src.suffix == ".pyi" with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: dst_contents = format_file_contents( - src_contents, line_length=line_length, fast=fast + src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi ) except NothingChanged: return False @@ -383,7 +378,7 @@ def format_stdin_to_stdout( def format_file_contents( - src_contents: str, line_length: int, fast: bool + src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False ) -> FileContent: """Reformat contents a file and return new contents. @@ -394,17 +389,21 @@ def format_file_contents( if src_contents.strip() == "": raise NothingChanged - dst_contents = format_str(src_contents, line_length=line_length) + dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi) if src_contents == dst_contents: raise NothingChanged if not fast: assert_equivalent(src_contents, dst_contents) - assert_stable(src_contents, dst_contents, line_length=line_length) + assert_stable( + src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi + ) return dst_contents -def format_str(src_contents: str, line_length: int) -> FileContent: +def format_str( + src_contents: str, line_length: int, *, is_pyi: bool = False +) -> FileContent: """Reformat a string and return new contents. `line_length` determines how many characters per line are allowed. @@ -412,9 +411,11 @@ def format_str(src_contents: str, line_length: int) -> FileContent: src_node = lib2to3_parse(src_contents) dst_contents = "" future_imports = get_future_imports(src_node) + elt = EmptyLineTracker(is_pyi=is_pyi) py36 = is_python36(src_node) - lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports) - elt = EmptyLineTracker() + lines = LineGenerator( + remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi + ) empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -833,6 +834,14 @@ class Line: and self.leaves[0].value == "class" ) + @property + def is_stub_class(self) -> bool: + """Is this line a class definition with a body consisting only of "..."?""" + return ( + self.is_class + and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)] + ) + @property def is_def(self) -> bool: """Is this a function definition? (Also returns True for async defs.)""" @@ -845,14 +854,11 @@ class Line: second_leaf: Optional[Leaf] = self.leaves[1] except IndexError: second_leaf = None - return ( - (first_leaf.type == token.NAME and first_leaf.value == "def") - or ( - first_leaf.type == token.ASYNC - and second_leaf is not None - and second_leaf.type == token.NAME - and second_leaf.value == "def" - ) + return (first_leaf.type == token.NAME and first_leaf.value == "def") or ( + first_leaf.type == token.ASYNC + and second_leaf is not None + and second_leaf.type == token.NAME + and second_leaf.value == "def" ) @property @@ -1017,9 +1023,8 @@ class Line: and subscript_start.type == syms.subscriptlist ): subscript_start = child_towards(subscript_start, leaf) - return ( - subscript_start is not None - and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()) + return subscript_start is not None and any( + n.type in TEST_DESCENDANTS for n in subscript_start.pre_order() ) def __str__(self) -> str: @@ -1100,6 +1105,7 @@ class EmptyLineTracker: the prefix of the first leaf consists of optional newlines. Those newlines are consumed by `maybe_empty_lines()` and included in the computation. """ + is_pyi: bool = False previous_line: Optional[Line] = None previous_after: int = 0 previous_defs: List[int] = Factory(list) @@ -1123,7 +1129,7 @@ class EmptyLineTracker: def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: max_allowed = 1 if current_line.depth == 0: - max_allowed = 2 + max_allowed = 1 if self.is_pyi else 2 if current_line.leaves: # Consume the first leaf's extra newlines. first_leaf = current_line.leaves[0] @@ -1135,7 +1141,10 @@ class EmptyLineTracker: depth = current_line.depth while self.previous_defs and self.previous_defs[-1] >= depth: self.previous_defs.pop() - before = 1 if depth else 2 + if self.is_pyi: + before = 0 if depth else 1 + else: + before = 1 if depth else 2 is_decorator = current_line.is_decorator if is_decorator or current_line.is_def or current_line.is_class: if not is_decorator: @@ -1154,8 +1163,19 @@ class EmptyLineTracker: ): return 0, 0 - newlines = 2 - if current_line.depth: + if self.is_pyi: + if self.previous_line.depth > current_line.depth: + newlines = 1 + elif current_line.is_class or self.previous_line.is_class: + if current_line.is_stub_class and self.previous_line.is_stub_class: + newlines = 0 + else: + newlines = 1 + else: + newlines = 0 + else: + newlines = 2 + if current_line.depth and newlines: newlines -= 1 return newlines, 0 @@ -1177,6 +1197,7 @@ class LineGenerator(Visitor[Line]): Note: destroys the tree it's visiting by mutating prefixes of its leaves in ways that will no longer stringify to valid Python code on the tree. """ + is_pyi: bool = False current_line: Line = Factory(Line) remove_u_prefix: bool = False @@ -1293,16 +1314,27 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) + def visit_suite(self, node: Node) -> Iterator[Line]: + """Visit a suite.""" + if self.is_pyi and is_stub_suite(node): + yield from self.visit(node.children[2]) + else: + yield from self.visit_default(node) + def visit_simple_stmt(self, node: Node) -> Iterator[Line]: """Visit a statement without nested statements.""" is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: - yield from self.line(+1) - yield from self.visit_default(node) - yield from self.line(-1) + if self.is_pyi and is_stub_body(node): + yield from self.visit_default(node) + else: + yield from self.line(+1) + yield from self.visit_default(node) + yield from self.line(-1) else: - yield from self.line() + if not self.is_pyi or not node.parent or not is_stub_suite(node.parent): + yield from self.line() yield from self.visit_default(node) def visit_async_stmt(self, node: Node) -> Iterator[Line]: @@ -1929,6 +1961,8 @@ def right_hand_split( If the split was by optional parentheses, attempt splitting without them, too. `omit` is a collection of closing bracket IDs that shouldn't be considered for this split. + + Note: running this function modifies `bracket_depth` on the leaves of `line`. """ head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1955,8 +1989,9 @@ def right_hand_split( # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: normalize_prefix(body_leaves[0], inside_brackets=True) - elif not head_leaves: - # No `head` and no `body` means the split failed. `tail` has all content. + if not head_leaves: + # No `head` means the split failed. Either `tail` has all content or + # the matching `opening_bracket` wasn't available on `line` anymore. raise CannotSplit("No brackets found") # Build the new lines. @@ -1974,19 +2009,27 @@ def right_hand_split( # the closing bracket is an optional paren and closing_bracket.type == token.RPAR and not closing_bracket.value - # there are no delimiters or standalone comments in the body - and not body.bracket_tracker.delimiters + # there are no standalone comments in the body and not line.contains_standalone_comments(0) # and it's not an import (optional parens are the only thing we can split # on in this case; attempting a split without them is a waste of time) and not line.is_import ): omit = {id(closing_bracket), *omit} - try: - yield from right_hand_split(line, py36=py36, omit=omit) - return - except CannotSplit: - pass + delimiter_count = len(body.bracket_tracker.delimiters) + if ( + delimiter_count == 0 + or delimiter_count == 1 + and ( + body.leaves[0].type in OPENING_BRACKETS + or body.leaves[-1].type in CLOSING_BRACKETS + ) + ): + try: + yield from right_hand_split(line, py36=py36, omit=omit) + return + except CannotSplit: + pass ensure_visible(opening_bracket) ensure_visible(closing_bracket) @@ -2388,6 +2431,35 @@ def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: return p.type in within +def is_stub_suite(node: Node) -> bool: + """Return True if `node` is a suite with a stub body.""" + if ( + len(node.children) != 4 + or node.children[0].type != token.NEWLINE + or node.children[1].type != token.INDENT + or node.children[3].type != token.DEDENT + ): + return False + + return is_stub_body(node.children[2]) + + +def is_stub_body(node: LN) -> bool: + """Return True if `node` is a simple statement containing an ellipsis.""" + if not isinstance(node, Node) or node.type != syms.simple_stmt: + return False + + if len(node.children) != 2: + return False + + child = node.children[0] + return ( + child.type == syms.atom + and len(child.children) == 3 + and all(leaf == Leaf(token.DOT, ".") for leaf in child.children) + ) + + def max_delimiter_priority_in_atom(node: LN) -> int: """Return maximum delimiter priority inside `node`. @@ -2554,7 +2626,7 @@ def get_future_imports(node: Node) -> Set[str]: return imports -PYTHON_EXTENSIONS = {".py"} +PYTHON_EXTENSIONS = {".py", ".pyi"} BLACKLISTED_DIRECTORIES = { "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" } @@ -2717,9 +2789,9 @@ def assert_equivalent(src: str, dst: str) -> None: ) from None -def assert_stable(src: str, dst: str, line_length: int) -> None: +def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" - newdst = format_str(dst, line_length=line_length) + newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi) if dst != newdst: log = dump_to_file( diff(src, dst, "source", "first pass"), @@ -2758,7 +2830,7 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: ) -def cancel(tasks: List[asyncio.Task]) -> None: +def cancel(tasks: Iterable[asyncio.Task]) -> None: """asyncio signal handler that cancels all `tasks` and reports to stderr.""" err("Aborted!") for task in tasks: