transformers = [string_merge, string_paren_strip]
else:
transformers = []
- elif line.is_def:
+ elif line.is_def and not should_split_funcdef_with_rhs(line, mode):
transformers = [left_hand_split]
else:
yield line
+def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:
+ """If a funcdef has a magic trailing comma in the return type, then we should first
+ split the line with rhs to respect the comma.
+ """
+ if Preview.respect_magic_trailing_comma_in_return_type not in mode:
+ return False
+
+ return_type_leaves: List[Leaf] = []
+ in_return_type = False
+
+ for leaf in line.leaves:
+ if leaf.type == token.COLON:
+ in_return_type = False
+ if in_return_type:
+ return_type_leaves.append(leaf)
+ if leaf.type == token.RARROW:
+ in_return_type = True
+
+ # using `bracket_split_build_line` will mess with whitespace, so we duplicate a
+ # couple lines from it.
+ result = Line(mode=line.mode, depth=line.depth)
+ leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)
+ for leaf in return_type_leaves:
+ result.append(
+ leaf,
+ preformatted=True,
+ track_bracket=id(leaf) in leaves_to_track,
+ )
+
+ # we could also return true if the line is too long, and the return type is longer
+ # than the param list. Or if `should_split_rhs` returns True.
+ return result.magic_trailing_comma is not None
+
+
class _BracketSplitComponent(Enum):
head = auto()
body = auto()
--- /dev/null
+# normal, short, function definition
+def foo(a, b) -> tuple[int, float]: ...
+
+
+# normal, short, function definition w/o return type
+def foo(a, b): ...
+
+
+# no splitting
+def foo(a: A, b: B) -> list[p, q]:
+ pass
+
+
+# magic trailing comma in param list
+def foo(a, b,): ...
+
+
+# magic trailing comma in nested params in param list
+def foo(a, b: tuple[int, float,]): ...
+
+
+# magic trailing comma in return type, no params
+def a() -> tuple[
+ a,
+ b,
+]: ...
+
+
+# magic trailing comma in return type, params
+def foo(a: A, b: B) -> list[
+ p,
+ q,
+]:
+ pass
+
+
+# magic trailing comma in param list and in return type
+def foo(
+ a: a,
+ b: b,
+) -> list[
+ a,
+ a,
+]:
+ pass
+
+
+# long function definition, param list is longer
+def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
+ bbbbbbbbbbbbbbbbbb,
+) -> cccccccccccccccccccccccccccccc: ...
+
+
+# long function definition, return type is longer
+# this should maybe split on rhs?
+def aaaaaaaaaaaaaaaaa(bbbbbbbbbbbbbbbbbb) -> list[
+ Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd
+]: ...
+
+
+# long return type, no param list
+def foo() -> list[
+ Loooooooooooooooooooooooooooooooooooong,
+ Loooooooooooooooooooong,
+ Looooooooooooong,
+]: ...
+
+
+# long function name, no param list, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
+ pass
+
+
+# long function name, no param list
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
+ list[int, float]
+): ...
+
+
+# long function name, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
+ a, b
+): ...
+
+
+# unskippable type hint (??)
+def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
+ pass
+
+
+def foo(a) -> list[
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+]: # abpedeifnore
+ pass
+
+def foo(a, b: list[Bad],): ... # type: ignore
+
+# don't lose any comments (no magic)
+def foo( # 1
+ a, # 2
+ b) -> list[ # 3
+ a, # 4
+ b]: # 5
+ ... # 6
+
+
+# don't lose any comments (param list magic)
+def foo( # 1
+ a, # 2
+ b,) -> list[ # 3
+ a, # 4
+ b]: # 5
+ ... # 6
+
+
+# don't lose any comments (return type magic)
+def foo( # 1
+ a, # 2
+ b) -> list[ # 3
+ a, # 4
+ b,]: # 5
+ ... # 6
+
+
+# don't lose any comments (both magic)
+def foo( # 1
+ a, # 2
+ b,) -> list[ # 3
+ a, # 4
+ b,]: # 5
+ ... # 6
+
+# real life example
+def SimplePyFn(
+ context: hl.GeneratorContext,
+ buffer_input: Buffer[UInt8, 2],
+ func_input: Buffer[Int32, 2],
+ float_arg: Scalar[Float32],
+ offset: int = 0,
+) -> tuple[
+ Buffer[UInt8, 2],
+ Buffer[UInt8, 2],
+]: ...
+# output
+# normal, short, function definition
+def foo(a, b) -> tuple[int, float]: ...
+
+
+# normal, short, function definition w/o return type
+def foo(a, b): ...
+
+
+# no splitting
+def foo(a: A, b: B) -> list[p, q]:
+ pass
+
+
+# magic trailing comma in param list
+def foo(
+ a,
+ b,
+): ...
+
+
+# magic trailing comma in nested params in param list
+def foo(
+ a,
+ b: tuple[
+ int,
+ float,
+ ],
+): ...
+
+
+# magic trailing comma in return type, no params
+def a() -> tuple[
+ a,
+ b,
+]: ...
+
+
+# magic trailing comma in return type, params
+def foo(a: A, b: B) -> list[
+ p,
+ q,
+]:
+ pass
+
+
+# magic trailing comma in param list and in return type
+def foo(
+ a: a,
+ b: b,
+) -> list[
+ a,
+ a,
+]:
+ pass
+
+
+# long function definition, param list is longer
+def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
+ bbbbbbbbbbbbbbbbbb,
+) -> cccccccccccccccccccccccccccccc: ...
+
+
+# long function definition, return type is longer
+# this should maybe split on rhs?
+def aaaaaaaaaaaaaaaaa(
+ bbbbbbbbbbbbbbbbbb,
+) -> list[Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd]: ...
+
+
+# long return type, no param list
+def foo() -> list[
+ Loooooooooooooooooooooooooooooooooooong,
+ Loooooooooooooooooooong,
+ Looooooooooooong,
+]: ...
+
+
+# long function name, no param list, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
+ pass
+
+
+# long function name, no param list
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
+ list[int, float]
+): ...
+
+
+# long function name, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
+ a, b
+): ...
+
+
+# unskippable type hint (??)
+def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
+ pass
+
+
+def foo(
+ a,
+) -> list[
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+]: # abpedeifnore
+ pass
+
+
+def foo(
+ a,
+ b: list[Bad],
+): ... # type: ignore
+
+
+# don't lose any comments (no magic)
+def foo(a, b) -> list[a, b]: # 1 # 2 # 3 # 4 # 5
+ ... # 6
+
+
+# don't lose any comments (param list magic)
+def foo( # 1
+ a, # 2
+ b,
+) -> list[a, b]: # 3 # 4 # 5
+ ... # 6
+
+
+# don't lose any comments (return type magic)
+def foo(a, b) -> list[ # 1 # 2 # 3
+ a, # 4
+ b,
+]: # 5
+ ... # 6
+
+
+# don't lose any comments (both magic)
+def foo( # 1
+ a, # 2
+ b,
+) -> list[ # 3
+ a, # 4
+ b,
+]: # 5
+ ... # 6
+
+
+# real life example
+def SimplePyFn(
+ context: hl.GeneratorContext,
+ buffer_input: Buffer[UInt8, 2],
+ func_input: Buffer[Int32, 2],
+ float_arg: Scalar[Float32],
+ offset: int = 0,
+) -> tuple[
+ Buffer[UInt8, 2],
+ Buffer[UInt8, 2],
+]: ...