Priority = int
 Index = int
 LN = Union[Leaf, Node]
-SplitFunc = Callable[["Line", bool], Iterator["Line"]]
+SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
 Timestamp = float
 FileSize = int
 CacheInfo = Tuple[Timestamp, FileSize]
     UNICODE_LITERALS = 1
     F_STRINGS = 2
     NUMERIC_UNDERSCORES = 3
-    TRAILING_COMMA = 4
+    TRAILING_COMMA_IN_CALL = 4
+    TRAILING_COMMA_IN_DEF = 5
 
 
 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
     TargetVersion.PY27: set(),
     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
-    TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
+    TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
     TargetVersion.PY36: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
-        Feature.TRAILING_COMMA,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
     },
     TargetVersion.PY37: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
-        Feature.TRAILING_COMMA,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
     },
     TargetVersion.PY38: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
-        Feature.TRAILING_COMMA,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
     },
 }
 
     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line()
     after = 0
+    split_line_features = {
+        feature
+        for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
+        if supports_feature(versions, feature)
+    }
     for current_line in lines.visit(src_node):
         for _ in range(after):
             dst_contents += str(empty_line)
         for _ in range(before):
             dst_contents += str(empty_line)
         for line in split_line(
-            current_line,
-            line_length=mode.line_length,
-            supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
+            current_line, line_length=mode.line_length, features=split_line_features
         ):
             dst_contents += str(line)
     return dst_contents
     line: Line,
     line_length: int,
     inner: bool = False,
-    supports_trailing_commas: bool = False,
+    features: Collection[Feature] = (),
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
-    If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
+    `features` are syntactical features that may be used in the output.
     """
     if line.is_comment:
         yield line
         split_funcs = [left_hand_split]
     else:
 
-        def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
+        def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
             for omit in generate_trailers_to_omit(line, line_length):
-                lines = list(
-                    right_hand_split(
-                        line, line_length, supports_trailing_commas, omit=omit
-                    )
-                )
+                lines = list(right_hand_split(line, line_length, features, omit=omit))
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
-            yield from right_hand_split(line, line_length, supports_trailing_commas)
+            yield from right_hand_split(line, line_length, features=features)
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
         # split altogether.
         result: List[Line] = []
         try:
-            for l in split_func(line, supports_trailing_commas):
+            for l in split_func(line, features):
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
                     split_line(
-                        l,
-                        line_length=line_length,
-                        inner=True,
-                        supports_trailing_commas=supports_trailing_commas,
+                        l, line_length=line_length, inner=True, features=features
                     )
                 )
         except CannotSplit:
         yield line
 
 
-def left_hand_split(
-    line: Line, supports_trailing_commas: bool = False
-) -> Iterator[Line]:
+def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
 def right_hand_split(
     line: Line,
     line_length: int,
-    supports_trailing_commas: bool = False,
+    features: Collection[Feature] = (),
     omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
     ):
         omit = {id(closing_bracket), *omit}
         try:
-            yield from right_hand_split(
-                line,
-                line_length,
-                supports_trailing_commas=supports_trailing_commas,
-                omit=omit,
-            )
+            yield from right_hand_split(line, line_length, features=features, omit=omit)
             return
 
         except CannotSplit:
     """
 
     @wraps(split_func)
-    def split_wrapper(
-        line: Line, supports_trailing_commas: bool = False
-    ) -> Iterator[Line]:
-        for l in split_func(line, supports_trailing_commas):
+    def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
+        for l in split_func(line, features):
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
 
 
 @dont_increase_indentation
-def delimiter_split(
-    line: Line, supports_trailing_commas: bool = False
-) -> Iterator[Line]:
+def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split according to delimiters of the highest priority.
 
-    If `supports_trailing_commas` is True, the split will add trailing commas
-    also in function signatures that contain `*` and `**`.
+    If the appropriate Features are given, the split will add trailing commas
+    also in function signatures and calls that contain `*` and `**`.
     """
     try:
         last_leaf = line.leaves[-1]
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
-        if leaf.bracket_depth == lowest_depth and is_vararg(
-            leaf, within=VARARGS_PARENTS
-        ):
-            trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
+        if leaf.bracket_depth == lowest_depth:
+            if is_vararg(leaf, within={syms.typedargslist}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
+                )
+            elif is_vararg(leaf, within={syms.arglist, syms.argument}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
+                )
+
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
 
 @dont_increase_indentation
 def standalone_comment_split(
-    line: Line, supports_trailing_commas: bool = False
+    line: Line, features: Collection[Feature] = ()
 ) -> Iterator[Line]:
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
             and n.children
             and n.children[-1].type == token.COMMA
         ):
+            if n.type == syms.typedargslist:
+                feature = Feature.TRAILING_COMMA_IN_DEF
+            else:
+                feature = Feature.TRAILING_COMMA_IN_CALL
+
             for ch in n.children:
                 if ch.type in STARS:
-                    features.add(Feature.TRAILING_COMMA)
+                    features.add(feature)
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
-                            features.add(Feature.TRAILING_COMMA)
+                            features.add(feature)
 
     return features
 
 
         node = black.lib2to3_parse("def f(*, arg): ...\n")
         self.assertEqual(black.get_features_used(node), set())
         node = black.lib2to3_parse("def f(*, arg,): ...\n")
-        self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
+        self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
+        node = black.lib2to3_parse("f(*arg,)\n")
+        self.assertEqual(
+            black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
+        )
         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
         node = black.lib2to3_parse("123_456\n")
         self.assertEqual(black.get_features_used(node), set())
         source, expected = read_data("function")
         node = black.lib2to3_parse(source)
-        self.assertEqual(
-            black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
-        )
+        expected_features = {
+            Feature.TRAILING_COMMA_IN_CALL,
+            Feature.TRAILING_COMMA_IN_DEF,
+            Feature.F_STRINGS,
+        }
+        self.assertEqual(black.get_features_used(node), expected_features)
         node = black.lib2to3_parse(expected)
-        self.assertEqual(
-            black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
-        )
+        self.assertEqual(black.get_features_used(node), expected_features)
         source, expected = read_data("expression")
         node = black.lib2to3_parse(source)
         self.assertEqual(black.get_features_used(node), set())
 
             await check("3.6", 200)
             await check("py3.6", 200)
-            await check("3.5,3.7", 200)
-            await check("3.5,py3.7", 200)
+            await check("3.6,3.7", 200)
+            await check("3.6,py3.7", 200)
 
             await check("2", 204)
             await check("2.7", 204)