X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/c160e4b7ce30c661ac4f2dfa5038becf1b8c5c33..bb588073ab286a9f1f8d839ab2cebe13011dd22c:/src/black/linegen.py?ds=sidebyside

diff --git a/src/black/linegen.py b/src/black/linegen.py
index 507e860..d12ca39 100644
--- a/src/black/linegen.py
+++ b/src/black/linegen.py
@@ -1,6 +1,7 @@
 """
 Generating lines of code.
 """
+
 import sys
 from dataclasses import replace
 from enum import Enum, auto
@@ -397,6 +398,24 @@ class LineGenerator(Visitor[Line]):
             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
         yield from self.visit_default(node)
 
+    def visit_tname(self, node: Node) -> Iterator[Line]:
+        """
+        Add potential parentheses around types in function parameter lists to be made
+        into real parentheses in case the type hint is too long to fit on a line
+        Examples:
+        def foo(a: int, b: float = 7): ...
+
+        ->
+
+        def foo(a: (int), b: (float) = 7): ...
+        """
+        if Preview.parenthesize_long_type_hints in self.mode:
+            assert len(node.children) == 3
+            if maybe_make_parens_invisible_in_atom(node.children[2], parent=node):
+                wrap_in_parentheses(node, node.children[2], visible=False)
+
+        yield from self.visit_default(node)
+
     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
         if Preview.hex_codes_in_unicode_sequences in self.mode:
             normalize_unicode_escape_sequences(leaf)
@@ -498,7 +517,14 @@ class LineGenerator(Visitor[Line]):
         self.visit_except_clause = partial(v, keywords={"except"}, parens={"except"})
         self.visit_with_stmt = partial(v, keywords={"with"}, parens={"with"})
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
-        self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
+
+        # When this is moved out of preview, add ":" directly to ASSIGNMENTS in nodes.py
+        if Preview.parenthesize_long_type_hints in self.mode:
+            assignments = ASSIGNMENTS | {":"}
+        else:
+            assignments = ASSIGNMENTS
+        self.visit_expr_stmt = partial(v, keywords=Ø, parens=assignments)
+
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
@@ -510,6 +536,17 @@ class LineGenerator(Visitor[Line]):
         self.visit_case_block = self.visit_match_case
 
 
+def _hugging_power_ops_line_to_string(
+    line: Line,
+    features: Collection[Feature],
+    mode: Mode,
+) -> Optional[str]:
+    try:
+        return line_to_string(next(hug_power_op(line, features, mode)))
+    except CannotTransform:
+        return None
+
+
 def transform_line(
     line: Line, mode: Mode, features: Collection[Feature] = ()
 ) -> Iterator[Line]:
@@ -525,6 +562,14 @@ def transform_line(
 
     line_str = line_to_string(line)
 
+    # We need the line string when power operators are hugging to determine if we should
+    # split the line. Default to line_str, if no power operator are present on the line.
+    line_str_hugging_power_ops = (
+        (_hugging_power_ops_line_to_string(line, features, mode) or line_str)
+        if Preview.fix_power_op_line_length in mode
+        else line_str
+    )
+
     ll = mode.line_length
     sn = mode.string_normalization
     string_merge = StringMerger(ll, sn)
@@ -538,7 +583,7 @@ def transform_line(
         and not line.should_split_rhs
         and not line.magic_trailing_comma
         and (
-            is_line_short_enough(line, mode=mode, line_str=line_str)
+            is_line_short_enough(line, mode=mode, line_str=line_str_hugging_power_ops)
             or line.contains_unsplittable_type_ignore()
         )
         and not (line.inside_brackets and line.contains_standalone_comments())
@@ -548,7 +593,7 @@ def transform_line(
             transformers = [string_merge, string_paren_strip]
         else:
             transformers = []
-    elif line.is_def:
+    elif line.is_def and not should_split_funcdef_with_rhs(line, mode):
         transformers = [left_hand_split]
     else:
 
@@ -627,6 +672,40 @@ def transform_line(
         yield line
 
 
+def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:
+    """If a funcdef has a magic trailing comma in the return type, then we should first
+    split the line with rhs to respect the comma.
+    """
+    if Preview.respect_magic_trailing_comma_in_return_type not in mode:
+        return False
+
+    return_type_leaves: List[Leaf] = []
+    in_return_type = False
+
+    for leaf in line.leaves:
+        if leaf.type == token.COLON:
+            in_return_type = False
+        if in_return_type:
+            return_type_leaves.append(leaf)
+        if leaf.type == token.RARROW:
+            in_return_type = True
+
+    # using `bracket_split_build_line` will mess with whitespace, so we duplicate a
+    # couple lines from it.
+    result = Line(mode=line.mode, depth=line.depth)
+    leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)
+    for leaf in return_type_leaves:
+        result.append(
+            leaf,
+            preformatted=True,
+            track_bracket=id(leaf) in leaves_to_track,
+        )
+
+    # we could also return true if the line is too long, and the return type is longer
+    # than the param list. Or if `should_split_rhs` returns True.
+    return result.magic_trailing_comma is not None
+
+
 class _BracketSplitComponent(Enum):
     head = auto()
     body = auto()
@@ -1368,7 +1447,7 @@ def maybe_make_parens_invisible_in_atom(
     Returns whether the node should itself be wrapped in invisible parentheses.
     """
     if (
-        node.type != syms.atom
+        node.type not in (syms.atom, syms.expr)
         or is_empty_tuple(node)
         or is_one_tuple(node)
         or (is_yield(node) and parent.type != syms.expr_stmt)
@@ -1392,6 +1471,7 @@ def maybe_make_parens_invisible_in_atom(
             syms.except_clause,
             syms.funcdef,
             syms.with_stmt,
+            syms.tname,
             # these ones aren't useful to end users, but they do please fuzzers
             syms.for_stmt,
             syms.del_stmt,