Skip to content

Commit

Permalink
Don't add default None when default is ...
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Sep 18, 2023
1 parent b08f445 commit 5306219
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
20 changes: 12 additions & 8 deletions bump_pydantic/codemods/add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,22 @@ def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAss
if self.inside_base_model and self.should_add_none:
if updated_node.value is None:
updated_node = updated_node.with_changes(value=cst.Name("None"))
# TODO: Should accept `pydantic.Field` as well.
elif m.matches(updated_node.value, m.Call(func=m.Name("Field"))):
assert isinstance(updated_node.value, cst.Call)
if updated_node.value.args:
arg = updated_node.value.args[0]
if (arg.keyword is None or arg.keyword.value == "default") and m.matches(arg.value, m.Ellipsis()):
args = updated_node.value.args
if args:
# NOTE: It has a "default" value as positional argument. Nothing to do.
if args[0].keyword is None:
...
# NOTE: It has a "default" or "default_factory" keyword argument. Nothing to do.
elif any(arg.keyword and arg.keyword.value in ("default", "default_factory") for arg in args):
...
else:
updated_node = updated_node.with_changes(
value=updated_node.value.with_changes(
args=[arg.with_changes(value=cst.Name("None")), *updated_node.value.args[1:]]
)
value=updated_node.value.with_changes(args=[cst.Arg(value=cst.Name("None")), *args])
)
# This is the case where `Field` is called without any arguments e.g. `Field()`.

# NOTE: This is the case where `Field` is called without any arguments e.g. `Field()`.
else:
updated_node = updated_node.with_changes(
value=updated_node.value.with_changes(args=[cst.Arg(value=cst.Name("None"))]) # type: ignore
Expand Down
16 changes: 14 additions & 2 deletions tests/integration/cases/add_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
" g: Optional[int] = Field()",
" h: Optional[int] = Field(...)",
" i: Optional[int] = Field(default_factory=lambda: None)",
" j: Optional[int] = ...",
" k: Optional[int] = None",
" l: Optional[int] = Field(lt=10, default=None)",
" m: Optional[int] = Field(lt=10)",
" n: Optional[int] = Field(default=...)",
" o: Optional[int] = Field(default=None)",
],
),
expected=File(
Expand All @@ -38,10 +44,16 @@
" c: Union[int, None] = None",
" d: Any = None",
" e: Dict[str, str]",
" f: Optional[int] = Field(None, lt=10)",
" f: Optional[int] = Field(..., lt=10)",
" g: Optional[int] = Field(None)",
" h: Optional[int] = Field(None)",
" h: Optional[int] = Field(...)",
" i: Optional[int] = Field(default_factory=lambda: None)",
" j: Optional[int] = ...",
" k: Optional[int] = None",
" l: Optional[int] = Field(lt=10, default=None)",
" m: Optional[int] = Field(None, lt=10)",
" n: Optional[int] = Field(default=...)",
" o: Optional[int] = Field(default=None)",
],
),
)
Expand Down

0 comments on commit 5306219

Please sign in to comment.