]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

For omitting optional parentheses, ignore delimiters of lower priorities
[etc/vim.git] / black.py
index 81241f6ec0248df6aee8703c0c1dc7d6330e7303..f335ebd828402c6167719bf23d673aa99cb15ddf 100644 (file)
--- a/black.py
+++ b/black.py
@@ -239,11 +239,8 @@ def reformat_one(
                 src = src.resolve()
                 if src in cache and cache[src] == get_cache_info(src):
                     changed = Changed.CACHED
-            if (
-                changed is not Changed.CACHED
-                and format_file_in_place(
-                    src, line_length=line_length, fast=fast, write_back=write_back
-                )
+            if changed is not Changed.CACHED and format_file_in_place(
+                src, line_length=line_length, fast=fast, write_back=write_back
             ):
                 changed = Changed.YES
             if write_back == WriteBack.YES and changed is not Changed.NO:
@@ -285,32 +282,29 @@ async def schedule_formatting(
             manager = Manager()
             lock = manager.Lock()
         tasks = {
-            src: loop.run_in_executor(
+            loop.run_in_executor(
                 executor, format_file_in_place, src, line_length, fast, write_back, lock
-            )
-            for src in sources
+            ): src
+            for src in sorted(sources)
         }
-        _task_values = list(tasks.values())
+        pending: Iterable[asyncio.Task] = tasks.keys()
         try:
-            loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
-            loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
+            loop.add_signal_handler(signal.SIGINT, cancel, pending)
+            loop.add_signal_handler(signal.SIGTERM, cancel, pending)
         except NotImplementedError:
             # There are no good alternatives for these on Windows
             pass
-        await asyncio.wait(_task_values)
-        for src, task in tasks.items():
-            if not task.done():
-                report.failed(src, "timed out, cancelling")
-                task.cancel()
-                cancelled.append(task)
-            elif task.cancelled():
-                cancelled.append(task)
-            elif task.exception():
-                report.failed(src, str(task.exception()))
-            else:
-                formatted.append(src)
-                report.done(src, Changed.YES if task.result() else Changed.NO)
-
+        while pending:
+            done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
+            for task in done:
+                src = tasks.pop(task)
+                if task.cancelled():
+                    cancelled.append(task)
+                elif task.exception():
+                    report.failed(src, str(task.exception()))
+                else:
+                    formatted.append(src)
+                    report.done(src, Changed.YES if task.result() else Changed.NO)
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if write_back == WriteBack.YES and formatted:
@@ -712,6 +706,17 @@ class BracketTracker:
         """
         return max(v for k, v in self.delimiters.items() if k not in exclude)
 
+    def delimiter_count_with_priority(self, priority: int = 0) -> int:
+        """Return the number of delimiters with the given `priority`.
+
+        If no `priority` is passed, defaults to max priority on the line.
+        """
+        if not self.delimiters:
+            return 0
+
+        priority = priority or self.max_delimiter_priority()
+        return sum(1 for p in self.delimiters.values() if p == priority)
+
     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
         """In a for loop, or comprehension, the variables are often unpacks.
 
@@ -841,12 +846,11 @@ class Line:
         )
 
     @property
-    def is_trivial_class(self) -> bool:
+    def is_stub_class(self) -> bool:
         """Is this line a class definition with a body consisting only of "..."?"""
-        return (
-            self.is_class
-            and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)]
-        )
+        return self.is_class and self.leaves[-3:] == [
+            Leaf(token.DOT, ".") for _ in range(3)
+        ]
 
     @property
     def is_def(self) -> bool:
@@ -860,14 +864,11 @@ class Line:
             second_leaf: Optional[Leaf] = self.leaves[1]
         except IndexError:
             second_leaf = None
-        return (
-            (first_leaf.type == token.NAME and first_leaf.value == "def")
-            or (
-                first_leaf.type == token.ASYNC
-                and second_leaf is not None
-                and second_leaf.type == token.NAME
-                and second_leaf.value == "def"
-            )
+        return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
+            first_leaf.type == token.ASYNC
+            and second_leaf is not None
+            and second_leaf.type == token.NAME
+            and second_leaf.value == "def"
         )
 
     @property
@@ -1032,9 +1033,8 @@ class Line:
             and subscript_start.type == syms.subscriptlist
         ):
             subscript_start = child_towards(subscript_start, leaf)
-        return (
-            subscript_start is not None
-            and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
+        return subscript_start is not None and any(
+            n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
         )
 
     def __str__(self) -> str:
@@ -1177,10 +1177,7 @@ class EmptyLineTracker:
                 if self.previous_line.depth > current_line.depth:
                     newlines = 1
                 elif current_line.is_class or self.previous_line.is_class:
-                    if (
-                        current_line.is_trivial_class
-                        and self.previous_line.is_trivial_class
-                    ):
+                    if current_line.is_stub_class and self.previous_line.is_stub_class:
                         newlines = 0
                     else:
                         newlines = 1
@@ -1329,51 +1326,16 @@ class LineGenerator(Visitor[Line]):
 
     def visit_suite(self, node: Node) -> Iterator[Line]:
         """Visit a suite."""
-        if self.is_pyi and self.is_trivial_suite(node):
+        if self.is_pyi and is_stub_suite(node):
             yield from self.visit(node.children[2])
         else:
             yield from self.visit_default(node)
 
-    def is_trivial_suite(self, node: Node) -> bool:
-        if len(node.children) != 4:
-            return False
-        if (
-            not isinstance(node.children[0], Leaf)
-            or node.children[0].type != token.NEWLINE
-        ):
-            return False
-        if (
-            not isinstance(node.children[1], Leaf)
-            or node.children[1].type != token.INDENT
-        ):
-            return False
-        if (
-            not isinstance(node.children[3], Leaf)
-            or node.children[3].type != token.DEDENT
-        ):
-            return False
-        stmt = node.children[2]
-        if not isinstance(stmt, Node):
-            return False
-        return self.is_trivial_body(stmt)
-
-    def is_trivial_body(self, stmt: Node) -> bool:
-        if not isinstance(stmt, Node) or stmt.type != syms.simple_stmt:
-            return False
-        if len(stmt.children) != 2:
-            return False
-        child = stmt.children[0]
-        return (
-            child.type == syms.atom
-            and len(child.children) == 3
-            and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
-        )
-
     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
         """Visit a statement without nested statements."""
         is_suite_like = node.parent and node.parent.type in STATEMENT
         if is_suite_like:
-            if self.is_pyi and self.is_trivial_body(node):
+            if self.is_pyi and is_stub_body(node):
                 yield from self.visit_default(node)
             else:
                 yield from self.line(+1)
@@ -1381,11 +1343,7 @@ class LineGenerator(Visitor[Line]):
                 yield from self.line(-1)
 
         else:
-            if (
-                not self.is_pyi
-                or not node.parent
-                or not self.is_trivial_suite(node.parent)
-            ):
+            if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
                 yield from self.line()
             yield from self.visit_default(node)
 
@@ -1513,10 +1471,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
         return DOUBLESPACE
 
     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
-    if (
-        t == token.COLON
-        and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
-    ):
+    if t == token.COLON and p.type not in {
+        syms.subscript, syms.subscriptlist, syms.sliceop
+    }:
         return NO
 
     prev = leaf.prev_sibling
@@ -1690,10 +1647,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
 
             prevp_parent = prevp.parent
             assert prevp_parent is not None
-            if (
-                prevp.type == token.COLON
-                and prevp_parent.type in {syms.subscript, syms.sliceop}
-            ):
+            if prevp.type == token.COLON and prevp_parent.type in {
+                syms.subscript, syms.sliceop
+            }:
                 return NO
 
             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
@@ -2013,6 +1969,8 @@ def right_hand_split(
     If the split was by optional parentheses, attempt splitting without them, too.
     `omit` is a collection of closing bracket IDs that shouldn't be considered for
     this split.
+
+    Note: running this function modifies `bracket_depth` on the leaves of `line`.
     """
     head = Line(depth=line.depth)
     body = Line(depth=line.depth + 1, inside_brackets=True)
@@ -2039,8 +1997,9 @@ def right_hand_split(
     # Since body is a new indent level, remove spurious leading whitespace.
     if body_leaves:
         normalize_prefix(body_leaves[0], inside_brackets=True)
-    elif not head_leaves:
-        # No `head` and no `body` means the split failed. `tail` has all content.
+    if not head_leaves:
+        # No `head` means the split failed. Either `tail` has all content or
+        # the matching `opening_bracket` wasn't available on `line` anymore.
         raise CannotSplit("No brackets found")
 
     # Build the new lines.
@@ -2058,19 +2017,27 @@ def right_hand_split(
         # the closing bracket is an optional paren
         and closing_bracket.type == token.RPAR
         and not closing_bracket.value
-        # there are no delimiters or standalone comments in the body
-        and not body.bracket_tracker.delimiters
+        # there are no standalone comments in the body
         and not line.contains_standalone_comments(0)
         # and it's not an import (optional parens are the only thing we can split
         # on in this case; attempting a split without them is a waste of time)
         and not line.is_import
     ):
         omit = {id(closing_bracket), *omit}
-        try:
-            yield from right_hand_split(line, py36=py36, omit=omit)
-            return
-        except CannotSplit:
-            pass
+        delimiter_count = body.bracket_tracker.delimiter_count_with_priority()
+        if (
+            delimiter_count == 0
+            or delimiter_count == 1
+            and (
+                body.leaves[0].type in OPENING_BRACKETS
+                or body.leaves[-1].type in CLOSING_BRACKETS
+            )
+        ):
+            try:
+                yield from right_hand_split(line, py36=py36, omit=omit)
+                return
+            except CannotSplit:
+                pass
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
@@ -2132,11 +2099,9 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
     except IndexError:
         raise CannotSplit("Line empty")
 
-    delimiters = line.bracket_tracker.delimiters
+    bt = line.bracket_tracker
     try:
-        delimiter_priority = line.bracket_tracker.max_delimiter_priority(
-            exclude={id(last_leaf)}
-        )
+        delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
     except ValueError:
         raise CannotSplit("No delimiters found")
 
@@ -2162,12 +2127,11 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
-        if (
-            leaf.bracket_depth == lowest_depth
-            and is_vararg(leaf, within=VARARGS_PARENTS)
+        if leaf.bracket_depth == lowest_depth and is_vararg(
+            leaf, within=VARARGS_PARENTS
         ):
             trailing_comma_safe = trailing_comma_safe and py36
-        leaf_priority = delimiters.get(id(leaf))
+        leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
 
@@ -2472,6 +2436,35 @@ def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
     return p.type in within
 
 
+def is_stub_suite(node: Node) -> bool:
+    """Return True if `node` is a suite with a stub body."""
+    if (
+        len(node.children) != 4
+        or node.children[0].type != token.NEWLINE
+        or node.children[1].type != token.INDENT
+        or node.children[3].type != token.DEDENT
+    ):
+        return False
+
+    return is_stub_body(node.children[2])
+
+
+def is_stub_body(node: LN) -> bool:
+    """Return True if `node` is a simple statement containing an ellipsis."""
+    if not isinstance(node, Node) or node.type != syms.simple_stmt:
+        return False
+
+    if len(node.children) != 2:
+        return False
+
+    child = node.children[0]
+    return (
+        child.type == syms.atom
+        and len(child.children) == 3
+        and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
+    )
+
+
 def max_delimiter_priority_in_atom(node: LN) -> int:
     """Return maximum delimiter priority inside `node`.
 
@@ -2842,7 +2835,7 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     )
 
 
-def cancel(tasks: List[asyncio.Task]) -> None:
+def cancel(tasks: Iterable[asyncio.Task]) -> None:
     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
     err("Aborted!")
     for task in tasks: