diff --git a/black.py b/black.py
index 635eba207cd..7b72fa10cef 100644
--- a/black.py
+++ b/black.py
@@ -1352,7 +1352,10 @@ def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
             bracket_depth = leaf.bracket_depth
             if bracket_depth == depth and leaf.type == token.COMMA:
                 commas += 1
-                if leaf.parent and leaf.parent.type == syms.arglist:
+                if leaf.parent and leaf.parent.type in {
+                    syms.arglist,
+                    syms.typedargslist,
+                }:
                     commas += 1
                     break
 
@@ -2488,9 +2491,12 @@ def bracket_split_build_line(
         if leaves:
             # Since body is a new indent level, remove spurious leading whitespace.
             normalize_prefix(leaves[0], inside_brackets=True)
-            # Ensure a trailing comma for imports, but be careful not to add one after
-            # any comments.
-            if original.is_import:
+            # Ensure a trailing comma for imports and standalone function arguments, but
+            # be careful not to add one after any comments.
+            no_commas = original.is_def and not [
+                l for l in leaves if l.type == token.COMMA
+            ]
+            if original.is_import or no_commas:
                 for i in range(len(leaves) - 1, -1, -1):
                     if leaves[i].type == STANDALONE_COMMENT:
                         continue
diff --git a/tests/data/function3.py b/tests/data/function3.py
new file mode 100644
index 00000000000..29fd99b7d91
--- /dev/null
+++ b/tests/data/function3.py
@@ -0,0 +1,14 @@
+def f(a,):
+    ...
+
+def f(a:int=1,):
+    ...
+
+# output
+
+def f(a):
+    ...
+
+
+def f(a: int = 1):
+    ...
diff --git a/tests/test_black.py b/tests/test_black.py
index 88c03d05fb8..4f0fb16b442 100644
--- a/tests/test_black.py
+++ b/tests/test_black.py
@@ -264,6 +264,14 @@ def test_function2(self) -> None:
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, black.FileMode())
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_function3(self) -> None:
+        source, expected = read_data("function3")
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, black.FileMode())
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_expression(self) -> None:
         source, expected = read_data("expression")