diff --git a/black.py b/black.py index 635eba207cd..7b72fa10cef 100644 --- a/black.py +++ b/black.py @@ -1352,7 +1352,10 @@ def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: bracket_depth = leaf.bracket_depth if bracket_depth == depth and leaf.type == token.COMMA: commas += 1 - if leaf.parent and leaf.parent.type == syms.arglist: + if leaf.parent and leaf.parent.type in { + syms.arglist, + syms.typedargslist, + }: commas += 1 break @@ -2488,9 +2491,12 @@ def bracket_split_build_line( if leaves: # Since body is a new indent level, remove spurious leading whitespace. normalize_prefix(leaves[0], inside_brackets=True) - # Ensure a trailing comma for imports, but be careful not to add one after - # any comments. - if original.is_import: + # Ensure a trailing comma for imports and standalone function arguments, but + # be careful not to add one after any comments. + no_commas = original.is_def and not [ + l for l in leaves if l.type == token.COMMA + ] + if original.is_import or no_commas: for i in range(len(leaves) - 1, -1, -1): if leaves[i].type == STANDALONE_COMMENT: continue diff --git a/tests/data/function3.py b/tests/data/function3.py new file mode 100644 index 00000000000..29fd99b7d91 --- /dev/null +++ b/tests/data/function3.py @@ -0,0 +1,14 @@ +def f(a,): + ... + +def f(a:int=1,): + ... + +# output + +def f(a): + ... + + +def f(a: int = 1): + ... diff --git a/tests/test_black.py b/tests/test_black.py index 88c03d05fb8..4f0fb16b442 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -264,6 +264,14 @@ def test_function2(self) -> None: black.assert_equivalent(source, actual) black.assert_stable(source, actual, black.FileMode()) + @patch("black.dump_to_file", dump_to_stderr) + def test_function3(self) -> None: + source, expected = read_data("function3") + actual = fs(source) + self.assertFormatEqual(expected, actual) + black.assert_equivalent(source, actual) + black.assert_stable(source, actual, black.FileMode()) + @patch("black.dump_to_file", dump_to_stderr) def test_expression(self) -> None: source, expected = read_data("expression")