From 5fb5cc8c2bd5a0bb1359fb69cdb705b55afade52 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=C5=81ukasz=20Langa?= Date: Thu, 15 Mar 2018 19:25:23 -0700 Subject: [PATCH] Only use trailing commas in function signatures when it's safe Trailing commas after * or ** in a function signature are only safe for Python 3.6 code. So now Black checks whether the file was already Python 3.6 to begin with. If so, trailing commas are used in such cases. Otherwise, they're not. When * and ** don't appear in a function signature, the trailing comma is always safe. Fixes #8 --- README.md | 5 ++++ black.py | 62 +++++++++++++++++++++++++++++++++++++++------ tests/expression.py | 18 ++++++------- tests/fstring.py | 5 ++++ tests/function.py | 4 ++- tests/test_black.py | 26 +++++++++++++++++++ 6 files changed, 101 insertions(+), 19 deletions(-) create mode 100644 tests/fstring.py diff --git a/README.md b/README.md index b3e6985..5a6825f 100644 --- a/README.md +++ b/README.md @@ -260,6 +260,11 @@ You can still try but prepare to be disappointed. * added `--check` +* only put trailing commas in function signatures and calls if it's + safe to do so. If the file is Python 3.6+ it's always safe, otherwise + only safe if there are no `*args` or `**kwargs` used in the signature + or call. (#8) + * fixed invalid spacing of dots in relative imports (#6, #13) * fixed invalid splitting after comma on unpacked variables in for-loops diff --git a/black.py b/black.py index 774d91d..0a9d3ea 100644 --- a/black.py +++ b/black.py @@ -7,6 +7,7 @@ import keyword import os from pathlib import Path import tokenize +import sys from typing import ( Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union ) @@ -192,6 +193,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent: comments: List[Line] = [] lines = LineGenerator() elt = EmptyLineTracker() + py36 = is_python36(src_node) empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -204,7 +206,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent: for comment in comments: dst_contents += str(comment) comments = [] - for line in split_line(current_line, line_length=line_length): + for line in split_line(current_line, line_length=line_length, py36=py36): dst_contents += str(line) else: comments.append(current_line) @@ -1108,13 +1110,18 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: yield Leaf(STANDALONE_COMMENT, line) -def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]: +def split_line( + line: Line, line_length: int, inner: bool = False, py36: bool = False +) -> Iterator[Line]: """Splits a `line` into potentially many lines. They should fit in the allotted `line_length` but might not be able to. `inner` signifies that there were a pair of brackets somewhere around the current `line`, possibly transitively. This means we can fallback to splitting by delimiters if the LHS/RHS don't yield any results. + + If `py36` is True, splitting may generate syntax that is only compatible + with Python 3.6 and later. """ line_str = str(line).strip('\n') if len(line_str) <= line_length and '\n' not in line_str: @@ -1137,11 +1144,13 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li # split altogether. result: List[Line] = [] try: - for l in split_func(line): + for l in split_func(line, py36=py36): if str(l).strip('\n') == line_str: raise CannotSplit("Split function returned an unchanged result") - result.extend(split_line(l, line_length=line_length, inner=True)) + result.extend( + split_line(l, line_length=line_length, inner=True, py36=py36) + ) except CannotSplit as cs: continue @@ -1153,7 +1162,7 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li yield line -def left_hand_split(line: Line) -> Iterator[Line]: +def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. @@ -1208,7 +1217,7 @@ def left_hand_split(line: Line) -> Iterator[Line]: yield result -def right_hand_split(line: Line) -> Iterator[Line]: +def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the last matching bracket pair.""" head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1259,10 +1268,12 @@ def right_hand_split(line: Line) -> Iterator[Line]: yield result -def delimiter_split(line: Line) -> Iterator[Line]: +def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split according to delimiters of the highest priority. This kind of split doesn't increase indentation. + If `py36` is True, the split will add trailing commas also in function + signatures that contain * and **. """ try: last_leaf = line.leaves[-1] @@ -1276,11 +1287,20 @@ def delimiter_split(line: Line) -> Iterator[Line]: raise CannotSplit("No delimiters found") current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + lowest_depth = sys.maxsize + trailing_comma_safe = True for leaf in line.leaves: current_line.append(leaf, preformatted=True) comment_after = line.comments.get(id(leaf)) if comment_after: current_line.append(comment_after, preformatted=True) + lowest_depth = min(lowest_depth, leaf.bracket_depth) + if ( + leaf.bracket_depth == lowest_depth and # type: ignore + leaf.type == token.STAR or + leaf.type == token.DOUBLESTAR + ): + trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: normalize_prefix(current_line.leaves[0]) @@ -1290,7 +1310,8 @@ def delimiter_split(line: Line) -> Iterator[Line]: if current_line: if ( delimiter_priority == COMMA_PRIORITY and - current_line.leaves[-1].type != token.COMMA + current_line.leaves[-1].type != token.COMMA and + trailing_comma_safe ): current_line.append(Leaf(token.COMMA, ',')) normalize_prefix(current_line.leaves[0]) @@ -1325,6 +1346,31 @@ def normalize_prefix(leaf: Leaf) -> None: leaf.prefix = '' +def is_python36(node: Node) -> bool: + """Returns True if the current file is using Python 3.6+ features. + + Currently looking for: + - f-strings; and + - trailing commas after * or ** in function signatures. + """ + for n in node.pre_order(): + if n.type == token.STRING: + assert isinstance(n, Leaf) + if n.value[:2] in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}: + return True + + elif ( + n.type == syms.typedargslist and + n.children and + n.children[-1].type == token.COMMA + ): + for ch in n.children: + if ch.type == token.STAR or ch.type == token.DOUBLESTAR: + return True + + return False + + PYTHON_EXTENSIONS = {'.py'} BLACKLISTED_DIRECTORIES = { 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv' diff --git a/tests/expression.py b/tests/expression.py index 59e4211..a3c810e 100644 --- a/tests/expression.py +++ b/tests/expression.py @@ -71,6 +71,7 @@ call(arg) call(kwarg='hey') call(arg, kwarg='hey') call(arg, another, kwarg='hey', **kwargs) +call(this_is_a_very_long_variable_which_will_force_a_delimiter_split, arg, another, kwarg='hey', **kwargs) # note: no trailing comma pre-3.6 lukasz.langa.pl call.me(maybe) 1 .real @@ -88,11 +89,6 @@ slice[:-1] slice[1:] slice[::-1] (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None) -f'f-string without formatted values is just a string' -f'{{NOT a formatted value}}' -f'some f-string with {a} {few():.2f} {formatted.values!r}' -f"{f'{nested} inner'} outer" -f'space between opening braces: { {a for a in (1, 2, 3)}}' {'2.7': dead, '3.7': long_live or die_hard} {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] @@ -200,6 +196,13 @@ call(arg) call(kwarg='hey') call(arg, kwarg='hey') call(arg, another, kwarg='hey', **kwargs) +call( + this_is_a_very_long_variable_which_will_force_a_delimiter_split, + arg, + another, + kwarg='hey', + **kwargs +) # note: no trailing comma pre-3.6 lukasz.langa.pl call.me(maybe) 1 .real @@ -217,11 +220,6 @@ slice[:-1] slice[1:] slice[::-1] (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None) -f'f-string without formatted values is just a string' -f'{{NOT a formatted value}}' -f'some f-string with {a} {few():.2f} {formatted.values!r}' -f"{f'{nested} inner'} outer" -f'space between opening braces: { {a for a in (1, 2, 3)}}' {'2.7': dead, '3.7': long_live or die_hard} {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] diff --git a/tests/fstring.py b/tests/fstring.py new file mode 100644 index 0000000..6b821be --- /dev/null +++ b/tests/fstring.py @@ -0,0 +1,5 @@ +f'f-string without formatted values is just a string' +f'{{NOT a formatted value}}' +f'some f-string with {a} {few():.2f} {formatted.values!r}' +f"{f'{nested} inner'} outer" +f'space between opening braces: { {a for a in (1, 2, 3)}}' diff --git a/tests/function.py b/tests/function.py index 858b042..abe2200 100644 --- a/tests/function.py +++ b/tests/function.py @@ -6,7 +6,7 @@ from third_party import X, Y, Z from library import some_connection, \ some_decorator - +f'trigger 3.6 mode' def func_no_args(): a; b; c if True: raise RuntimeError @@ -71,6 +71,8 @@ from third_party import X, Y, Z from library import some_connection, some_decorator +f'trigger 3.6 mode' + def func_no_args(): a diff --git a/tests/test_black.py b/tests/test_black.py index 223c907..1dda5fc 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -108,6 +108,14 @@ class BlackTestCase(unittest.TestCase): black.assert_equivalent(source, actual) black.assert_stable(source, actual, line_length=ll) + @patch("black.dump_to_file", dump_to_stderr) + def test_fstring(self) -> None: + source, expected = read_data('fstring') + actual = fs(source) + self.assertFormatEqual(expected, actual) + black.assert_equivalent(source, actual) + black.assert_stable(source, actual, line_length=ll) + @patch("black.dump_to_file", dump_to_stderr) def test_comments(self) -> None: source, expected = read_data('comments') @@ -215,6 +223,24 @@ class BlackTestCase(unittest.TestCase): ) self.assertEqual(report.return_code, 123) + def test_is_python36(self): + node = black.lib2to3_parse("def f(*, arg): ...\n") + self.assertFalse(black.is_python36(node)) + node = black.lib2to3_parse("def f(*, arg,): ...\n") + self.assertTrue(black.is_python36(node)) + node = black.lib2to3_parse("def f(*, arg): f'string'\n") + self.assertTrue(black.is_python36(node)) + source, expected = read_data('function') + node = black.lib2to3_parse(source) + self.assertTrue(black.is_python36(node)) + node = black.lib2to3_parse(expected) + self.assertTrue(black.is_python36(node)) + source, expected = read_data('expression') + node = black.lib2to3_parse(source) + self.assertFalse(black.is_python36(node)) + node = black.lib2to3_parse(expected) + self.assertFalse(black.is_python36(node)) + if __name__ == '__main__': unittest.main() -- 2.39.5