]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Only use trailing commas in function signatures when it's safe
authorŁukasz Langa <lukasz@langa.pl>
Fri, 16 Mar 2018 02:25:23 +0000 (19:25 -0700)
committerLukasz Langa <ambv@fb.com>
Fri, 16 Mar 2018 02:44:09 +0000 (19:44 -0700)
Trailing commas after * or ** in a function signature are only safe for Python 3.6
code.  So now Black checks whether the file was already Python 3.6 to begin
with.  If so, trailing commas are used in such cases.  Otherwise, they're not.

When * and ** don't appear in a function signature, the trailing comma is
always safe.

Fixes #8

README.md
black.py
tests/expression.py
tests/fstring.py [new file with mode: 0644]
tests/function.py
tests/test_black.py

index b3e69851965c409d0ec1c0ab24c712f597bde368..5a6825f3cc99167ef2565b928840309e3b69c13f 100644 (file)
--- a/README.md
+++ b/README.md
@@ -260,6 +260,11 @@ You can still try but prepare to be disappointed.
 
 * added `--check`
 
 
 * added `--check`
 
+* only put trailing commas in function signatures and calls if it's
+  safe to do so. If the file is Python 3.6+ it's always safe, otherwise
+  only safe if there are no `*args` or `**kwargs` used in the signature
+  or call. (#8)
+
 * fixed invalid spacing of dots in relative imports (#6, #13)
 
 * fixed invalid splitting after comma on unpacked variables in for-loops
 * fixed invalid spacing of dots in relative imports (#6, #13)
 
 * fixed invalid splitting after comma on unpacked variables in for-loops
index 774d91dc2a7f64ec2dbc252a55e0fe051d5c429c..0a9d3eae46fa1a548f1eb3c9674fc7be3ba6812b 100644 (file)
--- a/black.py
+++ b/black.py
@@ -7,6 +7,7 @@ import keyword
 import os
 from pathlib import Path
 import tokenize
 import os
 from pathlib import Path
 import tokenize
+import sys
 from typing import (
     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
 )
 from typing import (
     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
 )
@@ -192,6 +193,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
     comments: List[Line] = []
     lines = LineGenerator()
     elt = EmptyLineTracker()
     comments: List[Line] = []
     lines = LineGenerator()
     elt = EmptyLineTracker()
+    py36 = is_python36(src_node)
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -204,7 +206,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
             for comment in comments:
                 dst_contents += str(comment)
             comments = []
             for comment in comments:
                 dst_contents += str(comment)
             comments = []
-            for line in split_line(current_line, line_length=line_length):
+            for line in split_line(current_line, line_length=line_length, py36=py36):
                 dst_contents += str(line)
         else:
             comments.append(current_line)
                 dst_contents += str(line)
         else:
             comments.append(current_line)
@@ -1108,13 +1110,18 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
         yield Leaf(STANDALONE_COMMENT, line)
 
 
         yield Leaf(STANDALONE_COMMENT, line)
 
 
-def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]:
+def split_line(
+    line: Line, line_length: int, inner: bool = False, py36: bool = False
+) -> Iterator[Line]:
     """Splits a `line` into potentially many lines.
 
     They should fit in the allotted `line_length` but might not be able to.
     `inner` signifies that there were a pair of brackets somewhere around the
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
     """Splits a `line` into potentially many lines.
 
     They should fit in the allotted `line_length` but might not be able to.
     `inner` signifies that there were a pair of brackets somewhere around the
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
+
+    If `py36` is True, splitting may generate syntax that is only compatible
+    with Python 3.6 and later.
     """
     line_str = str(line).strip('\n')
     if len(line_str) <= line_length and '\n' not in line_str:
     """
     line_str = str(line).strip('\n')
     if len(line_str) <= line_length and '\n' not in line_str:
@@ -1137,11 +1144,13 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li
         # split altogether.
         result: List[Line] = []
         try:
         # split altogether.
         result: List[Line] = []
         try:
-            for l in split_func(line):
+            for l in split_func(line, py36=py36):
                 if str(l).strip('\n') == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 if str(l).strip('\n') == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
-                result.extend(split_line(l, line_length=line_length, inner=True))
+                result.extend(
+                    split_line(l, line_length=line_length, inner=True, py36=py36)
+                )
         except CannotSplit as cs:
             continue
 
         except CannotSplit as cs:
             continue
 
@@ -1153,7 +1162,7 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li
         yield line
 
 
         yield line
 
 
-def left_hand_split(line: Line) -> Iterator[Line]:
+def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
@@ -1208,7 +1217,7 @@ def left_hand_split(line: Line) -> Iterator[Line]:
             yield result
 
 
             yield result
 
 
-def right_hand_split(line: Line) -> Iterator[Line]:
+def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair."""
     head = Line(depth=line.depth)
     body = Line(depth=line.depth + 1, inside_brackets=True)
     """Split line into many lines, starting with the last matching bracket pair."""
     head = Line(depth=line.depth)
     body = Line(depth=line.depth + 1, inside_brackets=True)
@@ -1259,10 +1268,12 @@ def right_hand_split(line: Line) -> Iterator[Line]:
             yield result
 
 
             yield result
 
 
-def delimiter_split(line: Line) -> Iterator[Line]:
+def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
     """Split according to delimiters of the highest priority.
 
     This kind of split doesn't increase indentation.
     """Split according to delimiters of the highest priority.
 
     This kind of split doesn't increase indentation.
+    If `py36` is True, the split will add trailing commas also in function
+    signatures that contain * and **.
     """
     try:
         last_leaf = line.leaves[-1]
     """
     try:
         last_leaf = line.leaves[-1]
@@ -1276,11 +1287,20 @@ def delimiter_split(line: Line) -> Iterator[Line]:
         raise CannotSplit("No delimiters found")
 
     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
         raise CannotSplit("No delimiters found")
 
     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+    lowest_depth = sys.maxsize
+    trailing_comma_safe = True
     for leaf in line.leaves:
         current_line.append(leaf, preformatted=True)
         comment_after = line.comments.get(id(leaf))
         if comment_after:
             current_line.append(comment_after, preformatted=True)
     for leaf in line.leaves:
         current_line.append(leaf, preformatted=True)
         comment_after = line.comments.get(id(leaf))
         if comment_after:
             current_line.append(comment_after, preformatted=True)
+        lowest_depth = min(lowest_depth, leaf.bracket_depth)
+        if (
+            leaf.bracket_depth == lowest_depth and  # type: ignore
+            leaf.type == token.STAR or
+            leaf.type == token.DOUBLESTAR
+        ):
+            trailing_comma_safe = trailing_comma_safe and py36
         leaf_priority = delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             normalize_prefix(current_line.leaves[0])
         leaf_priority = delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             normalize_prefix(current_line.leaves[0])
@@ -1290,7 +1310,8 @@ def delimiter_split(line: Line) -> Iterator[Line]:
     if current_line:
         if (
             delimiter_priority == COMMA_PRIORITY and
     if current_line:
         if (
             delimiter_priority == COMMA_PRIORITY and
-            current_line.leaves[-1].type != token.COMMA
+            current_line.leaves[-1].type != token.COMMA and
+            trailing_comma_safe
         ):
             current_line.append(Leaf(token.COMMA, ','))
         normalize_prefix(current_line.leaves[0])
         ):
             current_line.append(Leaf(token.COMMA, ','))
         normalize_prefix(current_line.leaves[0])
@@ -1325,6 +1346,31 @@ def normalize_prefix(leaf: Leaf) -> None:
     leaf.prefix = ''
 
 
     leaf.prefix = ''
 
 
+def is_python36(node: Node) -> bool:
+    """Returns True if the current file is using Python 3.6+ features.
+
+    Currently looking for:
+    - f-strings; and
+    - trailing commas after * or ** in function signatures.
+    """
+    for n in node.pre_order():
+        if n.type == token.STRING:
+            assert isinstance(n, Leaf)
+            if n.value[:2] in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
+                return True
+
+        elif (
+            n.type == syms.typedargslist and
+            n.children and
+            n.children[-1].type == token.COMMA
+        ):
+            for ch in n.children:
+                if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
+                    return True
+
+    return False
+
+
 PYTHON_EXTENSIONS = {'.py'}
 BLACKLISTED_DIRECTORIES = {
     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
 PYTHON_EXTENSIONS = {'.py'}
 BLACKLISTED_DIRECTORIES = {
     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
index 59e4211c8886a46fe4f4e5434ce2a086dc47ba8f..a3c810ee9f92883ff9248551670a80a9d9a1b3bb 100644 (file)
@@ -71,6 +71,7 @@ call(arg)
 call(kwarg='hey')
 call(arg, kwarg='hey')
 call(arg, another, kwarg='hey', **kwargs)
 call(kwarg='hey')
 call(arg, kwarg='hey')
 call(arg, another, kwarg='hey', **kwargs)
+call(this_is_a_very_long_variable_which_will_force_a_delimiter_split, arg, another, kwarg='hey', **kwargs)  # note: no trailing comma pre-3.6
 lukasz.langa.pl
 call.me(maybe)
 1 .real
 lukasz.langa.pl
 call.me(maybe)
 1 .real
@@ -88,11 +89,6 @@ slice[:-1]
 slice[1:]
 slice[::-1]
 (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
 slice[1:]
 slice[::-1]
 (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
-f'f-string without formatted values is just a string'
-f'{{NOT a formatted value}}'
-f'some f-string with {a} {few():.2f} {formatted.values!r}'
-f"{f'{nested} inner'} outer"
-f'space between opening braces: { {a for a in (1, 2, 3)}}'
 {'2.7': dead, '3.7': long_live or die_hard}
 {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'}
 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
 {'2.7': dead, '3.7': long_live or die_hard}
 {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'}
 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
@@ -200,6 +196,13 @@ call(arg)
 call(kwarg='hey')
 call(arg, kwarg='hey')
 call(arg, another, kwarg='hey', **kwargs)
 call(kwarg='hey')
 call(arg, kwarg='hey')
 call(arg, another, kwarg='hey', **kwargs)
+call(
+    this_is_a_very_long_variable_which_will_force_a_delimiter_split,
+    arg,
+    another,
+    kwarg='hey',
+    **kwargs
+)  # note: no trailing comma pre-3.6
 lukasz.langa.pl
 call.me(maybe)
 1 .real
 lukasz.langa.pl
 call.me(maybe)
 1 .real
@@ -217,11 +220,6 @@ slice[:-1]
 slice[1:]
 slice[::-1]
 (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
 slice[1:]
 slice[::-1]
 (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
-f'f-string without formatted values is just a string'
-f'{{NOT a formatted value}}'
-f'some f-string with {a} {few():.2f} {formatted.values!r}'
-f"{f'{nested} inner'} outer"
-f'space between opening braces: { {a for a in (1, 2, 3)}}'
 {'2.7': dead, '3.7': long_live or die_hard}
 {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'}
 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
 {'2.7': dead, '3.7': long_live or die_hard}
 {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'}
 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
diff --git a/tests/fstring.py b/tests/fstring.py
new file mode 100644 (file)
index 0000000..6b821be
--- /dev/null
@@ -0,0 +1,5 @@
+f'f-string without formatted values is just a string'
+f'{{NOT a formatted value}}'
+f'some f-string with {a} {few():.2f} {formatted.values!r}'
+f"{f'{nested} inner'} outer"
+f'space between opening braces: { {a for a in (1, 2, 3)}}'
index 858b042a18db32050041b92e4addef6b4f3b40d5..abe2200e3792205088f7e9ab31a5b6f4c410b78c 100644 (file)
@@ -6,7 +6,7 @@ from third_party import X, Y, Z
 
 from library import some_connection, \
                     some_decorator
 
 from library import some_connection, \
                     some_decorator
-
+f'trigger 3.6 mode'
 def func_no_args():
   a; b; c
   if True: raise RuntimeError
 def func_no_args():
   a; b; c
   if True: raise RuntimeError
@@ -71,6 +71,8 @@ from third_party import X, Y, Z
 
 from library import some_connection, some_decorator
 
 
 from library import some_connection, some_decorator
 
+f'trigger 3.6 mode'
+
 
 def func_no_args():
     a
 
 def func_no_args():
     a
index 223c907217059b50c3a822c54f2649adf7d85118..1dda5fc963de828384e656d31894cbf379ac01c0 100644 (file)
@@ -108,6 +108,14 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_fstring(self) -> None:
+        source, expected = read_data('fstring')
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, line_length=ll)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_comments(self) -> None:
         source, expected = read_data('comments')
     @patch("black.dump_to_file", dump_to_stderr)
     def test_comments(self) -> None:
         source, expected = read_data('comments')
@@ -215,6 +223,24 @@ class BlackTestCase(unittest.TestCase):
             )
             self.assertEqual(report.return_code, 123)
 
             )
             self.assertEqual(report.return_code, 123)
 
+    def test_is_python36(self):
+        node = black.lib2to3_parse("def f(*, arg): ...\n")
+        self.assertFalse(black.is_python36(node))
+        node = black.lib2to3_parse("def f(*, arg,): ...\n")
+        self.assertTrue(black.is_python36(node))
+        node = black.lib2to3_parse("def f(*, arg): f'string'\n")
+        self.assertTrue(black.is_python36(node))
+        source, expected = read_data('function')
+        node = black.lib2to3_parse(source)
+        self.assertTrue(black.is_python36(node))
+        node = black.lib2to3_parse(expected)
+        self.assertTrue(black.is_python36(node))
+        source, expected = read_data('expression')
+        node = black.lib2to3_parse(source)
+        self.assertFalse(black.is_python36(node))
+        node = black.lib2to3_parse(expected)
+        self.assertFalse(black.is_python36(node))
+
 
 if __name__ == '__main__':
     unittest.main()
 
 if __name__ == '__main__':
     unittest.main()