]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/linegen.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix parser bug where "type" was misinterpreted as a keyword inside a match (#3950)
[etc/vim.git] / src / black / linegen.py
index 0091cbb3bd1d3195d3bc06fed11aeac280b987bf..d12ca39d0372972cef62043b230fa5785d67bd7f 100644 (file)
@@ -1,6 +1,7 @@
 """
 Generating lines of code.
 """
 """
 Generating lines of code.
 """
+
 import sys
 from dataclasses import replace
 from enum import Enum, auto
 import sys
 from dataclasses import replace
 from enum import Enum, auto
@@ -49,6 +50,7 @@ from black.nodes import (
     is_stub_body,
     is_stub_suite,
     is_tuple_containing_walrus,
     is_stub_body,
     is_stub_suite,
     is_tuple_containing_walrus,
+    is_type_ignore_comment_string,
     is_vararg,
     is_walrus_assignment,
     is_yield,
     is_vararg,
     is_walrus_assignment,
     is_yield,
@@ -280,7 +282,9 @@ class LineGenerator(Visitor[Line]):
 
     def visit_suite(self, node: Node) -> Iterator[Line]:
         """Visit a suite."""
 
     def visit_suite(self, node: Node) -> Iterator[Line]:
         """Visit a suite."""
-        if self.mode.is_pyi and is_stub_suite(node):
+        if (
+            self.mode.is_pyi or Preview.dummy_implementations in self.mode
+        ) and is_stub_suite(node):
             yield from self.visit(node.children[2])
         else:
             yield from self.visit_default(node)
             yield from self.visit(node.children[2])
         else:
             yield from self.visit_default(node)
@@ -295,7 +299,9 @@ class LineGenerator(Visitor[Line]):
 
         is_suite_like = node.parent and node.parent.type in STATEMENT
         if is_suite_like:
 
         is_suite_like = node.parent and node.parent.type in STATEMENT
         if is_suite_like:
-            if self.mode.is_pyi and is_stub_body(node):
+            if (
+                self.mode.is_pyi or Preview.dummy_implementations in self.mode
+            ) and is_stub_body(node):
                 yield from self.visit_default(node)
             else:
                 yield from self.line(+1)
                 yield from self.visit_default(node)
             else:
                 yield from self.line(+1)
@@ -304,7 +310,7 @@ class LineGenerator(Visitor[Line]):
 
         else:
             if (
 
         else:
             if (
-                not self.mode.is_pyi
+                not (self.mode.is_pyi or Preview.dummy_implementations in self.mode)
                 or not node.parent
                 or not is_stub_suite(node.parent)
             ):
                 or not node.parent
                 or not is_stub_suite(node.parent)
             ):
@@ -392,6 +398,24 @@ class LineGenerator(Visitor[Line]):
             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
         yield from self.visit_default(node)
 
             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
         yield from self.visit_default(node)
 
+    def visit_tname(self, node: Node) -> Iterator[Line]:
+        """
+        Add potential parentheses around types in function parameter lists to be made
+        into real parentheses in case the type hint is too long to fit on a line
+        Examples:
+        def foo(a: int, b: float = 7): ...
+
+        ->
+
+        def foo(a: (int), b: (float) = 7): ...
+        """
+        if Preview.parenthesize_long_type_hints in self.mode:
+            assert len(node.children) == 3
+            if maybe_make_parens_invisible_in_atom(node.children[2], parent=node):
+                wrap_in_parentheses(node, node.children[2], visible=False)
+
+        yield from self.visit_default(node)
+
     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
         if Preview.hex_codes_in_unicode_sequences in self.mode:
             normalize_unicode_escape_sequences(leaf)
     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
         if Preview.hex_codes_in_unicode_sequences in self.mode:
             normalize_unicode_escape_sequences(leaf)
@@ -493,7 +517,14 @@ class LineGenerator(Visitor[Line]):
         self.visit_except_clause = partial(v, keywords={"except"}, parens={"except"})
         self.visit_with_stmt = partial(v, keywords={"with"}, parens={"with"})
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
         self.visit_except_clause = partial(v, keywords={"except"}, parens={"except"})
         self.visit_with_stmt = partial(v, keywords={"with"}, parens={"with"})
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
-        self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
+
+        # When this is moved out of preview, add ":" directly to ASSIGNMENTS in nodes.py
+        if Preview.parenthesize_long_type_hints in self.mode:
+            assignments = ASSIGNMENTS | {":"}
+        else:
+            assignments = ASSIGNMENTS
+        self.visit_expr_stmt = partial(v, keywords=Ø, parens=assignments)
+
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
@@ -505,6 +536,17 @@ class LineGenerator(Visitor[Line]):
         self.visit_case_block = self.visit_match_case
 
 
         self.visit_case_block = self.visit_match_case
 
 
+def _hugging_power_ops_line_to_string(
+    line: Line,
+    features: Collection[Feature],
+    mode: Mode,
+) -> Optional[str]:
+    try:
+        return line_to_string(next(hug_power_op(line, features, mode)))
+    except CannotTransform:
+        return None
+
+
 def transform_line(
     line: Line, mode: Mode, features: Collection[Feature] = ()
 ) -> Iterator[Line]:
 def transform_line(
     line: Line, mode: Mode, features: Collection[Feature] = ()
 ) -> Iterator[Line]:
@@ -520,6 +562,14 @@ def transform_line(
 
     line_str = line_to_string(line)
 
 
     line_str = line_to_string(line)
 
+    # We need the line string when power operators are hugging to determine if we should
+    # split the line. Default to line_str, if no power operator are present on the line.
+    line_str_hugging_power_ops = (
+        (_hugging_power_ops_line_to_string(line, features, mode) or line_str)
+        if Preview.fix_power_op_line_length in mode
+        else line_str
+    )
+
     ll = mode.line_length
     sn = mode.string_normalization
     string_merge = StringMerger(ll, sn)
     ll = mode.line_length
     sn = mode.string_normalization
     string_merge = StringMerger(ll, sn)
@@ -533,7 +583,7 @@ def transform_line(
         and not line.should_split_rhs
         and not line.magic_trailing_comma
         and (
         and not line.should_split_rhs
         and not line.magic_trailing_comma
         and (
-            is_line_short_enough(line, mode=mode, line_str=line_str)
+            is_line_short_enough(line, mode=mode, line_str=line_str_hugging_power_ops)
             or line.contains_unsplittable_type_ignore()
         )
         and not (line.inside_brackets and line.contains_standalone_comments())
             or line.contains_unsplittable_type_ignore()
         )
         and not (line.inside_brackets and line.contains_standalone_comments())
@@ -543,7 +593,7 @@ def transform_line(
             transformers = [string_merge, string_paren_strip]
         else:
             transformers = []
             transformers = [string_merge, string_paren_strip]
         else:
             transformers = []
-    elif line.is_def:
+    elif line.is_def and not should_split_funcdef_with_rhs(line, mode):
         transformers = [left_hand_split]
     else:
 
         transformers = [left_hand_split]
     else:
 
@@ -622,6 +672,40 @@ def transform_line(
         yield line
 
 
         yield line
 
 
+def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:
+    """If a funcdef has a magic trailing comma in the return type, then we should first
+    split the line with rhs to respect the comma.
+    """
+    if Preview.respect_magic_trailing_comma_in_return_type not in mode:
+        return False
+
+    return_type_leaves: List[Leaf] = []
+    in_return_type = False
+
+    for leaf in line.leaves:
+        if leaf.type == token.COLON:
+            in_return_type = False
+        if in_return_type:
+            return_type_leaves.append(leaf)
+        if leaf.type == token.RARROW:
+            in_return_type = True
+
+    # using `bracket_split_build_line` will mess with whitespace, so we duplicate a
+    # couple lines from it.
+    result = Line(mode=line.mode, depth=line.depth)
+    leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)
+    for leaf in return_type_leaves:
+        result.append(
+            leaf,
+            preformatted=True,
+            track_bracket=id(leaf) in leaves_to_track,
+        )
+
+    # we could also return true if the line is too long, and the return type is longer
+    # than the param list. Or if `should_split_rhs` returns True.
+    return result.magic_trailing_comma is not None
+
+
 class _BracketSplitComponent(Enum):
     head = auto()
     body = auto()
 class _BracketSplitComponent(Enum):
     head = auto()
     body = auto()
@@ -918,6 +1002,13 @@ def bracket_split_build_line(
                     )
                     if isinstance(node, Node) and isinstance(node.prev_sibling, Leaf)
                 )
                     )
                     if isinstance(node, Node) and isinstance(node.prev_sibling, Leaf)
                 )
+                # Except the false negatives above for PEP 604 unions where we
+                # can't add the comma.
+                and not (
+                    leaves[0].parent
+                    and leaves[0].parent.next_sibling
+                    and leaves[0].parent.next_sibling.type == token.VBAR
+                )
             )
 
             if original.is_import or no_commas:
             )
 
             if original.is_import or no_commas:
@@ -1356,7 +1447,7 @@ def maybe_make_parens_invisible_in_atom(
     Returns whether the node should itself be wrapped in invisible parentheses.
     """
     if (
     Returns whether the node should itself be wrapped in invisible parentheses.
     """
     if (
-        node.type != syms.atom
+        node.type not in (syms.atom, syms.expr)
         or is_empty_tuple(node)
         or is_one_tuple(node)
         or (is_yield(node) and parent.type != syms.expr_stmt)
         or is_empty_tuple(node)
         or is_one_tuple(node)
         or (is_yield(node) and parent.type != syms.expr_stmt)
@@ -1380,6 +1471,7 @@ def maybe_make_parens_invisible_in_atom(
             syms.except_clause,
             syms.funcdef,
             syms.with_stmt,
             syms.except_clause,
             syms.funcdef,
             syms.with_stmt,
+            syms.tname,
             # these ones aren't useful to end users, but they do please fuzzers
             syms.for_stmt,
             syms.del_stmt,
             # these ones aren't useful to end users, but they do please fuzzers
             syms.for_stmt,
             syms.del_stmt,
@@ -1392,8 +1484,13 @@ def maybe_make_parens_invisible_in_atom(
     if is_lpar_token(first) and is_rpar_token(last):
         middle = node.children[1]
         # make parentheses invisible
     if is_lpar_token(first) and is_rpar_token(last):
         middle = node.children[1]
         # make parentheses invisible
-        first.value = ""
-        last.value = ""
+        if (
+            # If the prefix of `middle` includes a type comment with
+            # ignore annotation, then we do not remove the parentheses
+            not is_type_ignore_comment_string(middle.prefix.strip())
+        ):
+            first.value = ""
+            last.value = ""
         maybe_make_parens_invisible_in_atom(
             middle,
             parent=parent,
         maybe_make_parens_invisible_in_atom(
             middle,
             parent=parent,