]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

respect magic trailing commas in return types (#3916)
authorJohn Litborn <11260241+jakkdl@users.noreply.github.com>
Wed, 4 Oct 2023 23:42:35 +0000 (01:42 +0200)
committerGitHub <noreply@github.com>
Wed, 4 Oct 2023 23:42:35 +0000 (16:42 -0700)
CHANGES.md
src/black/linegen.py
src/black/mode.py
tests/data/preview/return_annotation_brackets_string.py
tests/data/preview_py_310/funcdef_return_type_trailing_comma.py [new file with mode: 0644]
tests/data/simple_cases/return_annotation_brackets.py

index 5e518497c9220559fc937f42bb534497a805bb37..888824ee055e71dd9f3c7c2d6c15bd48742f652a 100644 (file)
@@ -18,6 +18,7 @@
 
 - Long type hints are now wrapped in parentheses and properly indented when split across
   multiple lines (#3899)
+- Magic trailing commas are now respected in return types. (#3916)
 
 ### Configuration
 
index 9ddd4619f691275a8a1ce9b70d4bce7348596281..bdc4ee54ab285351588659f458f804aa212573c2 100644 (file)
@@ -573,7 +573,7 @@ def transform_line(
             transformers = [string_merge, string_paren_strip]
         else:
             transformers = []
-    elif line.is_def:
+    elif line.is_def and not should_split_funcdef_with_rhs(line, mode):
         transformers = [left_hand_split]
     else:
 
@@ -652,6 +652,40 @@ def transform_line(
         yield line
 
 
+def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:
+    """If a funcdef has a magic trailing comma in the return type, then we should first
+    split the line with rhs to respect the comma.
+    """
+    if Preview.respect_magic_trailing_comma_in_return_type not in mode:
+        return False
+
+    return_type_leaves: List[Leaf] = []
+    in_return_type = False
+
+    for leaf in line.leaves:
+        if leaf.type == token.COLON:
+            in_return_type = False
+        if in_return_type:
+            return_type_leaves.append(leaf)
+        if leaf.type == token.RARROW:
+            in_return_type = True
+
+    # using `bracket_split_build_line` will mess with whitespace, so we duplicate a
+    # couple lines from it.
+    result = Line(mode=line.mode, depth=line.depth)
+    leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)
+    for leaf in return_type_leaves:
+        result.append(
+            leaf,
+            preformatted=True,
+            track_bracket=id(leaf) in leaves_to_track,
+        )
+
+    # we could also return true if the line is too long, and the return type is longer
+    # than the param list. Or if `should_split_rhs` returns True.
+    return result.magic_trailing_comma is not None
+
+
 class _BracketSplitComponent(Enum):
     head = auto()
     body = auto()
index f44a821bcd0fc961f7fc59c206473aa4966bd616..30c5d2f1b2f1577a95606220973bffb628ccffed 100644 (file)
@@ -181,6 +181,7 @@ class Preview(Enum):
     string_processing = auto()
     parenthesize_conditional_expressions = auto()
     parenthesize_long_type_hints = auto()
+    respect_magic_trailing_comma_in_return_type = auto()
     skip_magic_trailing_comma_in_subscript = auto()
     wrap_long_dict_values_in_parens = auto()
     wrap_multiple_context_managers_in_parens = auto()
index 6978829fd5c8759eac46eb2e3241f5365182fe26..9148bd045bc8ade46c801f4b1316c82d29d9b12b 100644 (file)
@@ -2,6 +2,10 @@
 def frobnicate() -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
     pass
 
+# splitting the string breaks if there's any parameters
+def frobnicate(a) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
+    pass
+
 # output
 
 # Long string example
@@ -10,3 +14,10 @@ def frobnicate() -> (
     " list[ThisIsTrulyUnreasonablyExtremelyLongClassName]"
 ):
     pass
+
+
+# splitting the string breaks if there's any parameters
+def frobnicate(
+    a,
+) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
+    pass
diff --git a/tests/data/preview_py_310/funcdef_return_type_trailing_comma.py b/tests/data/preview_py_310/funcdef_return_type_trailing_comma.py
new file mode 100644 (file)
index 0000000..15db772
--- /dev/null
@@ -0,0 +1,300 @@
+# normal, short, function definition
+def foo(a, b) -> tuple[int, float]: ...
+
+
+# normal, short, function definition w/o return type
+def foo(a, b): ...
+
+
+# no splitting
+def foo(a: A, b: B) -> list[p, q]:
+    pass
+
+
+# magic trailing comma in param list
+def foo(a, b,): ...
+
+
+# magic trailing comma in nested params in param list
+def foo(a, b: tuple[int, float,]): ...
+
+
+# magic trailing comma in return type, no params
+def a() -> tuple[
+    a,
+    b,
+]: ...
+
+
+# magic trailing comma in return type, params
+def foo(a: A, b: B) -> list[
+    p,
+    q,
+]:
+    pass
+
+
+# magic trailing comma in param list and in return type
+def foo(
+    a: a,
+    b: b,
+) -> list[
+    a,
+    a,
+]:
+    pass
+
+
+# long function definition, param list is longer
+def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
+    bbbbbbbbbbbbbbbbbb,
+) -> cccccccccccccccccccccccccccccc: ...
+
+
+# long function definition, return type is longer
+# this should maybe split on rhs?
+def aaaaaaaaaaaaaaaaa(bbbbbbbbbbbbbbbbbb) -> list[
+    Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd
+]: ...
+
+
+# long return type, no param list
+def foo() -> list[
+    Loooooooooooooooooooooooooooooooooooong,
+    Loooooooooooooooooooong,
+    Looooooooooooong,
+]: ...
+
+
+# long function name, no param list, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
+    pass
+
+
+# long function name, no param list
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
+    list[int, float]
+): ...
+
+
+# long function name, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
+    a, b
+): ...
+
+
+# unskippable type hint (??)
+def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]:  # type: ignore
+    pass
+
+
+def foo(a) -> list[
+    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+]:  # abpedeifnore
+    pass
+
+def foo(a, b: list[Bad],): ... # type: ignore
+
+# don't lose any comments (no magic)
+def foo( # 1
+    a, # 2
+    b) -> list[ # 3
+               a, # 4
+               b]: # 5
+        ... # 6
+
+
+# don't lose any comments (param list magic)
+def foo( # 1
+    a, # 2
+    b,) -> list[ # 3
+               a, # 4
+               b]: # 5
+        ... # 6
+
+
+# don't lose any comments (return type magic)
+def foo( # 1
+    a, # 2
+    b) -> list[ # 3
+               a, # 4
+               b,]: # 5
+        ... # 6
+
+
+# don't lose any comments (both magic)
+def foo( # 1
+    a, # 2
+    b,) -> list[ # 3
+               a, # 4
+               b,]: # 5
+        ... # 6
+
+# real life example
+def SimplePyFn(
+    context: hl.GeneratorContext,
+    buffer_input: Buffer[UInt8, 2],
+    func_input: Buffer[Int32, 2],
+    float_arg: Scalar[Float32],
+    offset: int = 0,
+) -> tuple[
+    Buffer[UInt8, 2],
+    Buffer[UInt8, 2],
+]: ...
+# output
+# normal, short, function definition
+def foo(a, b) -> tuple[int, float]: ...
+
+
+# normal, short, function definition w/o return type
+def foo(a, b): ...
+
+
+# no splitting
+def foo(a: A, b: B) -> list[p, q]:
+    pass
+
+
+# magic trailing comma in param list
+def foo(
+    a,
+    b,
+): ...
+
+
+# magic trailing comma in nested params in param list
+def foo(
+    a,
+    b: tuple[
+        int,
+        float,
+    ],
+): ...
+
+
+# magic trailing comma in return type, no params
+def a() -> tuple[
+    a,
+    b,
+]: ...
+
+
+# magic trailing comma in return type, params
+def foo(a: A, b: B) -> list[
+    p,
+    q,
+]:
+    pass
+
+
+# magic trailing comma in param list and in return type
+def foo(
+    a: a,
+    b: b,
+) -> list[
+    a,
+    a,
+]:
+    pass
+
+
+# long function definition, param list is longer
+def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
+    bbbbbbbbbbbbbbbbbb,
+) -> cccccccccccccccccccccccccccccc: ...
+
+
+# long function definition, return type is longer
+# this should maybe split on rhs?
+def aaaaaaaaaaaaaaaaa(
+    bbbbbbbbbbbbbbbbbb,
+) -> list[Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd]: ...
+
+
+# long return type, no param list
+def foo() -> list[
+    Loooooooooooooooooooooooooooooooooooong,
+    Loooooooooooooooooooong,
+    Looooooooooooong,
+]: ...
+
+
+# long function name, no param list, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
+    pass
+
+
+# long function name, no param list
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
+    list[int, float]
+): ...
+
+
+# long function name, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
+    a, b
+): ...
+
+
+# unskippable type hint (??)
+def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]:  # type: ignore
+    pass
+
+
+def foo(
+    a,
+) -> list[
+    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+]:  # abpedeifnore
+    pass
+
+
+def foo(
+    a,
+    b: list[Bad],
+): ...  # type: ignore
+
+
+# don't lose any comments (no magic)
+def foo(a, b) -> list[a, b]:  # 1  # 2  # 3  # 4  # 5
+    ...  # 6
+
+
+# don't lose any comments (param list magic)
+def foo(  # 1
+    a,  # 2
+    b,
+) -> list[a, b]:  # 3  # 4  # 5
+    ...  # 6
+
+
+# don't lose any comments (return type magic)
+def foo(a, b) -> list[  # 1  # 2  # 3
+    a,  # 4
+    b,
+]:  # 5
+    ...  # 6
+
+
+# don't lose any comments (both magic)
+def foo(  # 1
+    a,  # 2
+    b,
+) -> list[  # 3
+    a,  # 4
+    b,
+]:  # 5
+    ...  # 6
+
+
+# real life example
+def SimplePyFn(
+    context: hl.GeneratorContext,
+    buffer_input: Buffer[UInt8, 2],
+    func_input: Buffer[Int32, 2],
+    float_arg: Scalar[Float32],
+    offset: int = 0,
+) -> tuple[
+    Buffer[UInt8, 2],
+    Buffer[UInt8, 2],
+]: ...
index 265c30220d8c8e5c9ed22a0ce47552a1ad9cef49..8509ecdb92c0afd94d92d7d65393ad7332d2e978 100644 (file)
@@ -87,6 +87,11 @@ def foo() -> tuple[loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo
 def foo() -> tuple[int, int, int,]:
     return 2
 
+# Magic trailing comma example, with params
+# this is broken - the trailing comma is transferred to the param list. Fixed in preview
+def foo(a,b) -> tuple[int, int, int,]:
+    return 2
+
 # output
 # Control
 def double(a: int) -> int:
@@ -208,3 +213,11 @@ def foo() -> (
     ]
 ):
     return 2
+
+
+# Magic trailing comma example, with params
+# this is broken - the trailing comma is transferred to the param list. Fixed in preview
+def foo(
+    a, b
+) -> tuple[int, int, int,]:
+    return 2