From 9394de150ebf0adc426523f46dc08e8b2b2b0b63 Mon Sep 17 00:00:00 2001 From: dylanjblack <38996120+dylanjblack@users.noreply.github.com> Date: Sat, 15 Jun 2019 14:49:49 +1000 Subject: [PATCH] Fix trailing comma for function with one arg (#880) (#891) Modified maybe_remove_trailing_comma to remove trailing commas for typedarglists (in addition to arglists), and updated line split logic to ensure that all lines in a function definition that contain only one arg have a trailing comma. --- black.py | 15 +++++++++++---- tests/data/function_trailing_comma.py | 14 ++++++++++++++ tests/test_black.py | 8 ++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 tests/data/function_trailing_comma.py diff --git a/black.py b/black.py index 635eba2..8318674 100644 --- a/black.py +++ b/black.py @@ -1352,7 +1352,10 @@ class Line: bracket_depth = leaf.bracket_depth if bracket_depth == depth and leaf.type == token.COMMA: commas += 1 - if leaf.parent and leaf.parent.type == syms.arglist: + if leaf.parent and leaf.parent.type in { + syms.arglist, + syms.typedargslist, + }: commas += 1 break @@ -2488,9 +2491,13 @@ def bracket_split_build_line( if leaves: # Since body is a new indent level, remove spurious leading whitespace. normalize_prefix(leaves[0], inside_brackets=True) - # Ensure a trailing comma for imports, but be careful not to add one after - # any comments. - if original.is_import: + # Ensure a trailing comma for imports and standalone function arguments, but + # be careful not to add one after any comments. + no_commas = original.is_def and not any( + l.type == token.COMMA for l in leaves + ) + + if original.is_import or no_commas: for i in range(len(leaves) - 1, -1, -1): if leaves[i].type == STANDALONE_COMMENT: continue diff --git a/tests/data/function_trailing_comma.py b/tests/data/function_trailing_comma.py new file mode 100644 index 0000000..29fd99b --- /dev/null +++ b/tests/data/function_trailing_comma.py @@ -0,0 +1,14 @@ +def f(a,): + ... + +def f(a:int=1,): + ... + +# output + +def f(a): + ... + + +def f(a: int = 1): + ... diff --git a/tests/test_black.py b/tests/test_black.py index 88c03d0..828b3e4 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -264,6 +264,14 @@ class BlackTestCase(unittest.TestCase): black.assert_equivalent(source, actual) black.assert_stable(source, actual, black.FileMode()) + @patch("black.dump_to_file", dump_to_stderr) + def test_function_trailing_comma(self) -> None: + source, expected = read_data("function_trailing_comma") + actual = fs(source) + self.assertFormatEqual(expected, actual) + black.assert_equivalent(source, actual) + black.assert_stable(source, actual, black.FileMode()) + @patch("black.dump_to_file", dump_to_stderr) def test_expression(self) -> None: source, expected = read_data("expression") -- 2.39.5