From 36078bc83f24dcd5f74e021a105429595a3fd63c Mon Sep 17 00:00:00 2001 From: John Litborn <11260241+jakkdl@users.noreply.github.com> Date: Thu, 5 Oct 2023 01:42:35 +0200 Subject: [PATCH] respect magic trailing commas in return types (#3916) --- CHANGES.md | 1 + src/black/linegen.py | 36 ++- src/black/mode.py | 1 + .../return_annotation_brackets_string.py | 11 + .../funcdef_return_type_trailing_comma.py | 300 ++++++++++++++++++ .../return_annotation_brackets.py | 13 + 6 files changed, 361 insertions(+), 1 deletion(-) create mode 100644 tests/data/preview_py_310/funcdef_return_type_trailing_comma.py diff --git a/CHANGES.md b/CHANGES.md index 5e51849..888824e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -18,6 +18,7 @@ - Long type hints are now wrapped in parentheses and properly indented when split across multiple lines (#3899) +- Magic trailing commas are now respected in return types. (#3916) ### Configuration diff --git a/src/black/linegen.py b/src/black/linegen.py index 9ddd461..bdc4ee5 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -573,7 +573,7 @@ def transform_line( transformers = [string_merge, string_paren_strip] else: transformers = [] - elif line.is_def: + elif line.is_def and not should_split_funcdef_with_rhs(line, mode): transformers = [left_hand_split] else: @@ -652,6 +652,40 @@ def transform_line( yield line +def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool: + """If a funcdef has a magic trailing comma in the return type, then we should first + split the line with rhs to respect the comma. + """ + if Preview.respect_magic_trailing_comma_in_return_type not in mode: + return False + + return_type_leaves: List[Leaf] = [] + in_return_type = False + + for leaf in line.leaves: + if leaf.type == token.COLON: + in_return_type = False + if in_return_type: + return_type_leaves.append(leaf) + if leaf.type == token.RARROW: + in_return_type = True + + # using `bracket_split_build_line` will mess with whitespace, so we duplicate a + # couple lines from it. + result = Line(mode=line.mode, depth=line.depth) + leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves) + for leaf in return_type_leaves: + result.append( + leaf, + preformatted=True, + track_bracket=id(leaf) in leaves_to_track, + ) + + # we could also return true if the line is too long, and the return type is longer + # than the param list. Or if `should_split_rhs` returns True. + return result.magic_trailing_comma is not None + + class _BracketSplitComponent(Enum): head = auto() body = auto() diff --git a/src/black/mode.py b/src/black/mode.py index f44a821..30c5d2f 100644 --- a/src/black/mode.py +++ b/src/black/mode.py @@ -181,6 +181,7 @@ class Preview(Enum): string_processing = auto() parenthesize_conditional_expressions = auto() parenthesize_long_type_hints = auto() + respect_magic_trailing_comma_in_return_type = auto() skip_magic_trailing_comma_in_subscript = auto() wrap_long_dict_values_in_parens = auto() wrap_multiple_context_managers_in_parens = auto() diff --git a/tests/data/preview/return_annotation_brackets_string.py b/tests/data/preview/return_annotation_brackets_string.py index 6978829..9148bd0 100644 --- a/tests/data/preview/return_annotation_brackets_string.py +++ b/tests/data/preview/return_annotation_brackets_string.py @@ -2,6 +2,10 @@ def frobnicate() -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]": pass +# splitting the string breaks if there's any parameters +def frobnicate(a) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]": + pass + # output # Long string example @@ -10,3 +14,10 @@ def frobnicate() -> ( " list[ThisIsTrulyUnreasonablyExtremelyLongClassName]" ): pass + + +# splitting the string breaks if there's any parameters +def frobnicate( + a, +) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]": + pass diff --git a/tests/data/preview_py_310/funcdef_return_type_trailing_comma.py b/tests/data/preview_py_310/funcdef_return_type_trailing_comma.py new file mode 100644 index 0000000..15db772 --- /dev/null +++ b/tests/data/preview_py_310/funcdef_return_type_trailing_comma.py @@ -0,0 +1,300 @@ +# normal, short, function definition +def foo(a, b) -> tuple[int, float]: ... + + +# normal, short, function definition w/o return type +def foo(a, b): ... + + +# no splitting +def foo(a: A, b: B) -> list[p, q]: + pass + + +# magic trailing comma in param list +def foo(a, b,): ... + + +# magic trailing comma in nested params in param list +def foo(a, b: tuple[int, float,]): ... + + +# magic trailing comma in return type, no params +def a() -> tuple[ + a, + b, +]: ... + + +# magic trailing comma in return type, params +def foo(a: A, b: B) -> list[ + p, + q, +]: + pass + + +# magic trailing comma in param list and in return type +def foo( + a: a, + b: b, +) -> list[ + a, + a, +]: + pass + + +# long function definition, param list is longer +def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa( + bbbbbbbbbbbbbbbbbb, +) -> cccccccccccccccccccccccccccccc: ... + + +# long function definition, return type is longer +# this should maybe split on rhs? +def aaaaaaaaaaaaaaaaa(bbbbbbbbbbbbbbbbbb) -> list[ + Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd +]: ... + + +# long return type, no param list +def foo() -> list[ + Loooooooooooooooooooooooooooooooooooong, + Loooooooooooooooooooong, + Looooooooooooong, +]: ... + + +# long function name, no param list, no return value +def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(): + pass + + +# long function name, no param list +def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> ( + list[int, float] +): ... + + +# long function name, no return value +def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong( + a, b +): ... + + +# unskippable type hint (??) +def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore + pass + + +def foo(a) -> list[ + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +]: # abpedeifnore + pass + +def foo(a, b: list[Bad],): ... # type: ignore + +# don't lose any comments (no magic) +def foo( # 1 + a, # 2 + b) -> list[ # 3 + a, # 4 + b]: # 5 + ... # 6 + + +# don't lose any comments (param list magic) +def foo( # 1 + a, # 2 + b,) -> list[ # 3 + a, # 4 + b]: # 5 + ... # 6 + + +# don't lose any comments (return type magic) +def foo( # 1 + a, # 2 + b) -> list[ # 3 + a, # 4 + b,]: # 5 + ... # 6 + + +# don't lose any comments (both magic) +def foo( # 1 + a, # 2 + b,) -> list[ # 3 + a, # 4 + b,]: # 5 + ... # 6 + +# real life example +def SimplePyFn( + context: hl.GeneratorContext, + buffer_input: Buffer[UInt8, 2], + func_input: Buffer[Int32, 2], + float_arg: Scalar[Float32], + offset: int = 0, +) -> tuple[ + Buffer[UInt8, 2], + Buffer[UInt8, 2], +]: ... +# output +# normal, short, function definition +def foo(a, b) -> tuple[int, float]: ... + + +# normal, short, function definition w/o return type +def foo(a, b): ... + + +# no splitting +def foo(a: A, b: B) -> list[p, q]: + pass + + +# magic trailing comma in param list +def foo( + a, + b, +): ... + + +# magic trailing comma in nested params in param list +def foo( + a, + b: tuple[ + int, + float, + ], +): ... + + +# magic trailing comma in return type, no params +def a() -> tuple[ + a, + b, +]: ... + + +# magic trailing comma in return type, params +def foo(a: A, b: B) -> list[ + p, + q, +]: + pass + + +# magic trailing comma in param list and in return type +def foo( + a: a, + b: b, +) -> list[ + a, + a, +]: + pass + + +# long function definition, param list is longer +def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa( + bbbbbbbbbbbbbbbbbb, +) -> cccccccccccccccccccccccccccccc: ... + + +# long function definition, return type is longer +# this should maybe split on rhs? +def aaaaaaaaaaaaaaaaa( + bbbbbbbbbbbbbbbbbb, +) -> list[Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd]: ... + + +# long return type, no param list +def foo() -> list[ + Loooooooooooooooooooooooooooooooooooong, + Loooooooooooooooooooong, + Looooooooooooong, +]: ... + + +# long function name, no param list, no return value +def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(): + pass + + +# long function name, no param list +def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> ( + list[int, float] +): ... + + +# long function name, no return value +def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong( + a, b +): ... + + +# unskippable type hint (??) +def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore + pass + + +def foo( + a, +) -> list[ + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +]: # abpedeifnore + pass + + +def foo( + a, + b: list[Bad], +): ... # type: ignore + + +# don't lose any comments (no magic) +def foo(a, b) -> list[a, b]: # 1 # 2 # 3 # 4 # 5 + ... # 6 + + +# don't lose any comments (param list magic) +def foo( # 1 + a, # 2 + b, +) -> list[a, b]: # 3 # 4 # 5 + ... # 6 + + +# don't lose any comments (return type magic) +def foo(a, b) -> list[ # 1 # 2 # 3 + a, # 4 + b, +]: # 5 + ... # 6 + + +# don't lose any comments (both magic) +def foo( # 1 + a, # 2 + b, +) -> list[ # 3 + a, # 4 + b, +]: # 5 + ... # 6 + + +# real life example +def SimplePyFn( + context: hl.GeneratorContext, + buffer_input: Buffer[UInt8, 2], + func_input: Buffer[Int32, 2], + float_arg: Scalar[Float32], + offset: int = 0, +) -> tuple[ + Buffer[UInt8, 2], + Buffer[UInt8, 2], +]: ... diff --git a/tests/data/simple_cases/return_annotation_brackets.py b/tests/data/simple_cases/return_annotation_brackets.py index 265c302..8509ecd 100644 --- a/tests/data/simple_cases/return_annotation_brackets.py +++ b/tests/data/simple_cases/return_annotation_brackets.py @@ -87,6 +87,11 @@ def foo() -> tuple[loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo def foo() -> tuple[int, int, int,]: return 2 +# Magic trailing comma example, with params +# this is broken - the trailing comma is transferred to the param list. Fixed in preview +def foo(a,b) -> tuple[int, int, int,]: + return 2 + # output # Control def double(a: int) -> int: @@ -208,3 +213,11 @@ def foo() -> ( ] ): return 2 + + +# Magic trailing comma example, with params +# this is broken - the trailing comma is transferred to the param list. Fixed in preview +def foo( + a, b +) -> tuple[int, int, int,]: + return 2 -- 2.39.2