]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Split the TRAILING_COMMA feature (#763)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Mon, 25 Mar 2019 15:22:02 +0000 (08:22 -0700)
committerGitHub <noreply@github.com>
Mon, 25 Mar 2019 15:22:02 +0000 (08:22 -0700)
black.py
tests/test_black.py

index 2dee826d1b1d482e3d3313e4d4878e1e1cce5a5b..0ea9b9ab0c40cac91b7782673f55e978caa21979 100644 (file)
--- a/black.py
+++ b/black.py
@@ -68,7 +68,7 @@ LeafID = int
 Priority = int
 Index = int
 LN = Union[Leaf, Node]
 Priority = int
 Index = int
 LN = Union[Leaf, Node]
-SplitFunc = Callable[["Line", bool], Iterator["Line"]]
+SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
 Timestamp = float
 FileSize = int
 CacheInfo = Tuple[Timestamp, FileSize]
 Timestamp = float
 FileSize = int
 CacheInfo = Tuple[Timestamp, FileSize]
@@ -133,31 +133,35 @@ class Feature(Enum):
     UNICODE_LITERALS = 1
     F_STRINGS = 2
     NUMERIC_UNDERSCORES = 3
     UNICODE_LITERALS = 1
     F_STRINGS = 2
     NUMERIC_UNDERSCORES = 3
-    TRAILING_COMMA = 4
+    TRAILING_COMMA_IN_CALL = 4
+    TRAILING_COMMA_IN_DEF = 5
 
 
 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
     TargetVersion.PY27: set(),
     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
 
 
 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
     TargetVersion.PY27: set(),
     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
-    TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
+    TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
     TargetVersion.PY36: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
     TargetVersion.PY36: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
-        Feature.TRAILING_COMMA,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
     },
     TargetVersion.PY37: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
     },
     TargetVersion.PY37: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
-        Feature.TRAILING_COMMA,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
     },
     TargetVersion.PY38: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
     },
     TargetVersion.PY38: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
-        Feature.TRAILING_COMMA,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
     },
 }
 
     },
 }
 
@@ -683,6 +687,11 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line()
     after = 0
     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line()
     after = 0
+    split_line_features = {
+        feature
+        for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
+        if supports_feature(versions, feature)
+    }
     for current_line in lines.visit(src_node):
         for _ in range(after):
             dst_contents += str(empty_line)
     for current_line in lines.visit(src_node):
         for _ in range(after):
             dst_contents += str(empty_line)
@@ -690,9 +699,7 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
         for _ in range(before):
             dst_contents += str(empty_line)
         for line in split_line(
         for _ in range(before):
             dst_contents += str(empty_line)
         for line in split_line(
-            current_line,
-            line_length=mode.line_length,
-            supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
+            current_line, line_length=mode.line_length, features=split_line_features
         ):
             dst_contents += str(line)
     return dst_contents
         ):
             dst_contents += str(line)
     return dst_contents
@@ -2158,7 +2165,7 @@ def split_line(
     line: Line,
     line_length: int,
     inner: bool = False,
     line: Line,
     line_length: int,
     inner: bool = False,
-    supports_trailing_commas: bool = False,
+    features: Collection[Feature] = (),
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
@@ -2167,7 +2174,7 @@ def split_line(
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
-    If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
+    `features` are syntactical features that may be used in the output.
     """
     if line.is_comment:
         yield line
     """
     if line.is_comment:
         yield line
@@ -2188,13 +2195,9 @@ def split_line(
         split_funcs = [left_hand_split]
     else:
 
         split_funcs = [left_hand_split]
     else:
 
-        def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
+        def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
             for omit in generate_trailers_to_omit(line, line_length):
             for omit in generate_trailers_to_omit(line, line_length):
-                lines = list(
-                    right_hand_split(
-                        line, line_length, supports_trailing_commas, omit=omit
-                    )
-                )
+                lines = list(right_hand_split(line, line_length, features, omit=omit))
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
@@ -2202,7 +2205,7 @@ def split_line(
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
-            yield from right_hand_split(line, line_length, supports_trailing_commas)
+            yield from right_hand_split(line, line_length, features=features)
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
@@ -2214,16 +2217,13 @@ def split_line(
         # split altogether.
         result: List[Line] = []
         try:
         # split altogether.
         result: List[Line] = []
         try:
-            for l in split_func(line, supports_trailing_commas):
+            for l in split_func(line, features):
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
                     split_line(
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
                     split_line(
-                        l,
-                        line_length=line_length,
-                        inner=True,
-                        supports_trailing_commas=supports_trailing_commas,
+                        l, line_length=line_length, inner=True, features=features
                     )
                 )
         except CannotSplit:
                     )
                 )
         except CannotSplit:
@@ -2237,9 +2237,7 @@ def split_line(
         yield line
 
 
         yield line
 
 
-def left_hand_split(
-    line: Line, supports_trailing_commas: bool = False
-) -> Iterator[Line]:
+def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
@@ -2278,7 +2276,7 @@ def left_hand_split(
 def right_hand_split(
     line: Line,
     line_length: int,
 def right_hand_split(
     line: Line,
     line_length: int,
-    supports_trailing_commas: bool = False,
+    features: Collection[Feature] = (),
     omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
     omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
@@ -2337,12 +2335,7 @@ def right_hand_split(
     ):
         omit = {id(closing_bracket), *omit}
         try:
     ):
         omit = {id(closing_bracket), *omit}
         try:
-            yield from right_hand_split(
-                line,
-                line_length,
-                supports_trailing_commas=supports_trailing_commas,
-                omit=omit,
-            )
+            yield from right_hand_split(line, line_length, features=features, omit=omit)
             return
 
         except CannotSplit:
             return
 
         except CannotSplit:
@@ -2431,10 +2424,8 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
     """
 
     @wraps(split_func)
     """
 
     @wraps(split_func)
-    def split_wrapper(
-        line: Line, supports_trailing_commas: bool = False
-    ) -> Iterator[Line]:
-        for l in split_func(line, supports_trailing_commas):
+    def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
+        for l in split_func(line, features):
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
@@ -2442,13 +2433,11 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
 
 
 @dont_increase_indentation
 
 
 @dont_increase_indentation
-def delimiter_split(
-    line: Line, supports_trailing_commas: bool = False
-) -> Iterator[Line]:
+def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split according to delimiters of the highest priority.
 
     """Split according to delimiters of the highest priority.
 
-    If `supports_trailing_commas` is True, the split will add trailing commas
-    also in function signatures that contain `*` and `**`.
+    If the appropriate Features are given, the split will add trailing commas
+    also in function signatures and calls that contain `*` and `**`.
     """
     try:
         last_leaf = line.leaves[-1]
     """
     try:
         last_leaf = line.leaves[-1]
@@ -2487,10 +2476,16 @@ def delimiter_split(
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
-        if leaf.bracket_depth == lowest_depth and is_vararg(
-            leaf, within=VARARGS_PARENTS
-        ):
-            trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
+        if leaf.bracket_depth == lowest_depth:
+            if is_vararg(leaf, within={syms.typedargslist}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
+                )
+            elif is_vararg(leaf, within={syms.arglist, syms.argument}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
+                )
+
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
@@ -2509,7 +2504,7 @@ def delimiter_split(
 
 @dont_increase_indentation
 def standalone_comment_split(
 
 @dont_increase_indentation
 def standalone_comment_split(
-    line: Line, supports_trailing_commas: bool = False
+    line: Line, features: Collection[Feature] = ()
 ) -> Iterator[Line]:
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
 ) -> Iterator[Line]:
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
@@ -3059,14 +3054,19 @@ def get_features_used(node: Node) -> Set[Feature]:
             and n.children
             and n.children[-1].type == token.COMMA
         ):
             and n.children
             and n.children[-1].type == token.COMMA
         ):
+            if n.type == syms.typedargslist:
+                feature = Feature.TRAILING_COMMA_IN_DEF
+            else:
+                feature = Feature.TRAILING_COMMA_IN_CALL
+
             for ch in n.children:
                 if ch.type in STARS:
             for ch in n.children:
                 if ch.type in STARS:
-                    features.add(Feature.TRAILING_COMMA)
+                    features.add(feature)
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
-                            features.add(Feature.TRAILING_COMMA)
+                            features.add(feature)
 
     return features
 
 
     return features
 
index a3e2ff863833c2e2185c3452cdbe57ff29ec4e08..3e0b8a189bfcaa051f4056cc6246a6b56dd73a15 100644 (file)
@@ -857,7 +857,11 @@ class BlackTestCase(unittest.TestCase):
         node = black.lib2to3_parse("def f(*, arg): ...\n")
         self.assertEqual(black.get_features_used(node), set())
         node = black.lib2to3_parse("def f(*, arg,): ...\n")
         node = black.lib2to3_parse("def f(*, arg): ...\n")
         self.assertEqual(black.get_features_used(node), set())
         node = black.lib2to3_parse("def f(*, arg,): ...\n")
-        self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
+        self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
+        node = black.lib2to3_parse("f(*arg,)\n")
+        self.assertEqual(
+            black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
+        )
         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
         node = black.lib2to3_parse("123_456\n")
         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
         node = black.lib2to3_parse("123_456\n")
@@ -866,13 +870,14 @@ class BlackTestCase(unittest.TestCase):
         self.assertEqual(black.get_features_used(node), set())
         source, expected = read_data("function")
         node = black.lib2to3_parse(source)
         self.assertEqual(black.get_features_used(node), set())
         source, expected = read_data("function")
         node = black.lib2to3_parse(source)
-        self.assertEqual(
-            black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
-        )
+        expected_features = {
+            Feature.TRAILING_COMMA_IN_CALL,
+            Feature.TRAILING_COMMA_IN_DEF,
+            Feature.F_STRINGS,
+        }
+        self.assertEqual(black.get_features_used(node), expected_features)
         node = black.lib2to3_parse(expected)
         node = black.lib2to3_parse(expected)
-        self.assertEqual(
-            black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
-        )
+        self.assertEqual(black.get_features_used(node), expected_features)
         source, expected = read_data("expression")
         node = black.lib2to3_parse(source)
         self.assertEqual(black.get_features_used(node), set())
         source, expected = read_data("expression")
         node = black.lib2to3_parse(source)
         self.assertEqual(black.get_features_used(node), set())
@@ -1524,8 +1529,8 @@ class BlackTestCase(unittest.TestCase):
 
             await check("3.6", 200)
             await check("py3.6", 200)
 
             await check("3.6", 200)
             await check("py3.6", 200)
-            await check("3.5,3.7", 200)
-            await check("3.5,py3.7", 200)
+            await check("3.6,3.7", 200)
+            await check("3.6,py3.7", 200)
 
             await check("2", 204)
             await check("2.7", 204)
 
             await check("2", 204)
             await check("2.7", 204)