]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/linegen.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

chore: Set permissions for GitHub actions (#3043)
[etc/vim.git] / src / black / linegen.py
index 8a28c3901bbeb90dd68806f4e0adaac7241ab861..91fdeef8f2f94a3e1b7b8cb244b19723beac9b85 100644 (file)
@@ -3,7 +3,7 @@ Generating lines of code.
 """
 from functools import partial, wraps
 import sys
 """
 from functools import partial, wraps
 import sys
-from typing import Collection, Iterator, List, Optional, Set, Union
+from typing import Collection, Iterator, List, Optional, Set, Union, cast
 
 from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT
 from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS
 
 from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT
 from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS
@@ -144,6 +144,33 @@ class LineGenerator(Visitor[Line]):
 
             yield from self.visit(child)
 
 
             yield from self.visit(child)
 
+    def visit_funcdef(self, node: Node) -> Iterator[Line]:
+        """Visit function definition."""
+        if Preview.annotation_parens not in self.mode:
+            yield from self.visit_stmt(node, keywords={"def"}, parens=set())
+        else:
+            yield from self.line()
+
+            # Remove redundant brackets around return type annotation.
+            is_return_annotation = False
+            for child in node.children:
+                if child.type == token.RARROW:
+                    is_return_annotation = True
+                elif is_return_annotation:
+                    if child.type == syms.atom and child.children[0].type == token.LPAR:
+                        if maybe_make_parens_invisible_in_atom(
+                            child,
+                            parent=node,
+                            remove_brackets_around_comma=False,
+                        ):
+                            wrap_in_parentheses(node, child, visible=False)
+                    else:
+                        wrap_in_parentheses(node, child, visible=False)
+                    is_return_annotation = False
+
+            for child in node.children:
+                yield from self.visit(child)
+
     def visit_match_case(self, node: Node) -> Iterator[Line]:
         """Visit either a match or case statement."""
         normalize_invisible_parens(node, parens_after=set(), preview=self.mode.preview)
     def visit_match_case(self, node: Node) -> Iterator[Line]:
         """Visit either a match or case statement."""
         normalize_invisible_parens(node, parens_after=set(), preview=self.mode.preview)
@@ -226,6 +253,9 @@ class LineGenerator(Visitor[Line]):
             ):
                 wrap_in_parentheses(node, leaf)
 
             ):
                 wrap_in_parentheses(node, leaf)
 
+        if Preview.remove_redundant_parens in self.mode:
+            remove_await_parens(node)
+
         yield from self.visit_default(node)
 
     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
         yield from self.visit_default(node)
 
     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
@@ -322,10 +352,10 @@ class LineGenerator(Visitor[Line]):
             self.visit_except_clause = partial(
                 v, keywords={"except"}, parens={"except"}
             )
             self.visit_except_clause = partial(
                 v, keywords={"except"}, parens={"except"}
             )
+            self.visit_with_stmt = partial(v, keywords={"with"}, parens={"with"})
         else:
             self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
         else:
             self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
-        self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
-        self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
+            self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
@@ -477,7 +507,10 @@ def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator
             current_leaves is body_leaves
             and leaf.type in CLOSING_BRACKETS
             and leaf.opening_bracket is matching_bracket
             current_leaves is body_leaves
             and leaf.type in CLOSING_BRACKETS
             and leaf.opening_bracket is matching_bracket
+            and isinstance(matching_bracket, Leaf)
         ):
         ):
+            ensure_visible(leaf)
+            ensure_visible(matching_bracket)
             current_leaves = tail_leaves if body_leaves else head_leaves
         current_leaves.append(leaf)
         if current_leaves is head_leaves:
             current_leaves = tail_leaves if body_leaves else head_leaves
         current_leaves.append(leaf)
         if current_leaves is head_leaves:
@@ -845,11 +878,26 @@ def normalize_invisible_parens(
             check_lpar = True
 
         if check_lpar:
             check_lpar = True
 
         if check_lpar:
-            if child.type == syms.atom:
+            if (
+                preview
+                and child.type == syms.atom
+                and node.type == syms.for_stmt
+                and isinstance(child.prev_sibling, Leaf)
+                and child.prev_sibling.type == token.NAME
+                and child.prev_sibling.value == "for"
+            ):
+                if maybe_make_parens_invisible_in_atom(
+                    child,
+                    parent=node,
+                    remove_brackets_around_comma=True,
+                ):
+                    wrap_in_parentheses(node, child, visible=False)
+            elif preview and isinstance(child, Node) and node.type == syms.with_stmt:
+                remove_with_parens(child, node)
+            elif child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(
                     child,
                     parent=node,
                 if maybe_make_parens_invisible_in_atom(
                     child,
                     parent=node,
-                    preview=preview,
                 ):
                     wrap_in_parentheses(node, child, visible=False)
             elif is_one_tuple(child):
                 ):
                     wrap_in_parentheses(node, child, visible=False)
             elif is_one_tuple(child):
@@ -867,42 +915,127 @@ def normalize_invisible_parens(
                     node.insert_child(index, Leaf(token.LPAR, ""))
                     node.append_child(Leaf(token.RPAR, ""))
                 break
                     node.insert_child(index, Leaf(token.LPAR, ""))
                     node.append_child(Leaf(token.RPAR, ""))
                 break
+            elif (
+                index == 1
+                and child.type == token.STAR
+                and node.type == syms.except_clause
+            ):
+                # In except* (PEP 654), the star is actually part of
+                # of the keyword. So we need to skip the insertion of
+                # invisible parentheses to work more precisely.
+                continue
 
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
                 wrap_in_parentheses(node, child, visible=False)
 
 
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
                 wrap_in_parentheses(node, child, visible=False)
 
-        check_lpar = isinstance(child, Leaf) and child.value in parens_after
+        comma_check = child.type == token.COMMA if preview else False
+
+        check_lpar = isinstance(child, Leaf) and (
+            child.value in parens_after or comma_check
+        )
+
+
+def remove_await_parens(node: Node) -> None:
+    if node.children[0].type == token.AWAIT and len(node.children) > 1:
+        if (
+            node.children[1].type == syms.atom
+            and node.children[1].children[0].type == token.LPAR
+        ):
+            if maybe_make_parens_invisible_in_atom(
+                node.children[1],
+                parent=node,
+                remove_brackets_around_comma=True,
+            ):
+                wrap_in_parentheses(node, node.children[1], visible=False)
+
+            # Since await is an expression we shouldn't remove
+            # brackets in cases where this would change
+            # the AST due to operator precedence.
+            # Therefore we only aim to remove brackets around
+            # power nodes that aren't also await expressions themselves.
+            # https://peps.python.org/pep-0492/#updated-operator-precedence-table
+            # N.B. We've still removed any redundant nested brackets though :)
+            opening_bracket = cast(Leaf, node.children[1].children[0])
+            closing_bracket = cast(Leaf, node.children[1].children[-1])
+            bracket_contents = cast(Node, node.children[1].children[1])
+            if bracket_contents.type != syms.power:
+                ensure_visible(opening_bracket)
+                ensure_visible(closing_bracket)
+            elif (
+                bracket_contents.type == syms.power
+                and bracket_contents.children[0].type == token.AWAIT
+            ):
+                ensure_visible(opening_bracket)
+                ensure_visible(closing_bracket)
+                # If we are in a nested await then recurse down.
+                remove_await_parens(bracket_contents)
+
+
+def remove_with_parens(node: Node, parent: Node) -> None:
+    """Recursively hide optional parens in `with` statements."""
+    # Removing all unnecessary parentheses in with statements in one pass is a tad
+    # complex as different variations of bracketed statements result in pretty
+    # different parse trees:
+    #
+    # with (open("file")) as f:                       # this is an asexpr_test
+    #     ...
+    #
+    # with (open("file") as f):                       # this is an atom containing an
+    #     ...                                         # asexpr_test
+    #
+    # with (open("file")) as f, (open("file")) as f:  # this is asexpr_test, COMMA,
+    #     ...                                         # asexpr_test
+    #
+    # with (open("file") as f, open("file") as f):    # an atom containing a
+    #     ...                                         # testlist_gexp which then
+    #                                                 # contains multiple asexpr_test(s)
+    if node.type == syms.atom:
+        if maybe_make_parens_invisible_in_atom(
+            node,
+            parent=parent,
+            remove_brackets_around_comma=True,
+        ):
+            wrap_in_parentheses(parent, node, visible=False)
+        if isinstance(node.children[1], Node):
+            remove_with_parens(node.children[1], node)
+    elif node.type == syms.testlist_gexp:
+        for child in node.children:
+            if isinstance(child, Node):
+                remove_with_parens(child, node)
+    elif node.type == syms.asexpr_test and not any(
+        leaf.type == token.COLONEQUAL for leaf in node.leaves()
+    ):
+        if maybe_make_parens_invisible_in_atom(
+            node.children[0],
+            parent=node,
+            remove_brackets_around_comma=True,
+        ):
+            wrap_in_parentheses(node, node.children[0], visible=False)
 
 
 def maybe_make_parens_invisible_in_atom(
     node: LN,
     parent: LN,
 
 
 def maybe_make_parens_invisible_in_atom(
     node: LN,
     parent: LN,
-    preview: bool = False,
+    remove_brackets_around_comma: bool = False,
 ) -> bool:
     """If it's safe, make the parens in the atom `node` invisible, recursively.
     Additionally, remove repeated, adjacent invisible parens from the atom `node`
     as they are redundant.
 
     Returns whether the node should itself be wrapped in invisible parentheses.
 ) -> bool:
     """If it's safe, make the parens in the atom `node` invisible, recursively.
     Additionally, remove repeated, adjacent invisible parens from the atom `node`
     as they are redundant.
 
     Returns whether the node should itself be wrapped in invisible parentheses.
-
     """
     """
-    if (
-        preview
-        and parent.type == syms.for_stmt
-        and isinstance(node.prev_sibling, Leaf)
-        and node.prev_sibling.type == token.NAME
-        and node.prev_sibling.value == "for"
-    ):
-        for_stmt_check = False
-    else:
-        for_stmt_check = True
-
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
         or is_one_tuple(node)
         or (is_yield(node) and parent.type != syms.expr_stmt)
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
         or is_one_tuple(node)
         or (is_yield(node) and parent.type != syms.expr_stmt)
-        or (max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY and for_stmt_check)
+        or (
+            # This condition tries to prevent removing non-optional brackets
+            # around a tuple, however, can be a bit overzealous so we provide
+            # and option to skip this check for `for` and `with` statements.
+            not remove_brackets_around_comma
+            and max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
+        )
     ):
         return False
 
     ):
         return False
 
@@ -925,7 +1058,11 @@ def maybe_make_parens_invisible_in_atom(
         # make parentheses invisible
         first.value = ""
         last.value = ""
         # make parentheses invisible
         first.value = ""
         last.value = ""
-        maybe_make_parens_invisible_in_atom(middle, parent=parent, preview=preview)
+        maybe_make_parens_invisible_in_atom(
+            middle,
+            parent=parent,
+            remove_brackets_around_comma=remove_brackets_around_comma,
+        )
 
         if is_atom_with_invisible_parens(middle):
             # Strip the invisible parens from `middle` by replacing
 
         if is_atom_with_invisible_parens(middle):
             # Strip the invisible parens from `middle` by replacing