Skip to content

Commit

Permalink
🐛 Use Annotated even when con* functions are wrapped on `cst.Subscr…
Browse files Browse the repository at this point in the history
…ipt` (#110)
  • Loading branch information
Kludex authored Jul 26, 2023
1 parent 5248663 commit 7493d23
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 30 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ Into:

```py
from typing import Generic, TypeVar
from pydantic import BaseModel

T = TypeVar('T')

Expand Down Expand Up @@ -217,7 +218,7 @@ Into:
```py
from typing import List

from pydantic import RootModel
from pydantic import RootModel, BaseModel

class User(BaseModel):
age: int
Expand Down
2 changes: 1 addition & 1 deletion bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Rule(str, Enum):
BP007 = "BP007"
"""Replace `@validator` with `@field_validator`."""
BP008 = "BP008"
"""Replace `constr(<args>)` with `Annotated[str, StringConstraints(<args>)`."""
"""Replace `con*` functions by `Annotated` versions."""


def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
Expand Down
32 changes: 9 additions & 23 deletions bump_pydantic/codemods/con_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,18 @@
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

CONSTR_CALL = m.Call(func=m.Name("constr") | m.Attribute(value=m.Name("pydantic"), attr=m.Name("constr")))
ANN_ASSIGN_CONSTR_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CONSTR_CALL))


CON_NUMBER_CALL = m.OneOf(
*[
m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name)))
for name in ("conint", "confloat", "condecimal", "conbytes")
]
)
ANN_ASSIGN_CON_NUMBER_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_NUMBER_CALL))

CON_COLLECTION_CALL = m.OneOf(
*[
m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name)))
for name in ("conlist", "conset", "confrozenset")
]
)
ANN_ASSIGN_COLLECTION_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_COLLECTION_CALL))

MAP_FUNC_TO_TYPE = {
"constr": "str",
Expand All @@ -48,19 +42,14 @@ class ConFuncCallCommand(VisitorBasedCodemodCommand):
def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

@m.leave(ANN_ASSIGN_CONSTR_CALL | ANN_ASSIGN_CON_NUMBER_CALL | ANN_ASSIGN_COLLECTION_CALL)
def leave_ann_assign_constr_call(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
annotation = cast(cst.Call, original_node.annotation.annotation)
if m.matches(annotation.func, m.Name()):
func_name = cast(str, annotation.func.value) # type: ignore
@m.leave(CON_NUMBER_CALL | CON_COLLECTION_CALL | CONSTR_CALL)
def leave_annotation_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Subscript:
if m.matches(original_node.func, m.Name()):
func_name = cast(str, original_node.func.value) # type: ignore
else:
func_name = cast(str, annotation.func.attr.value) # type: ignore
func_name = cast(str, original_node.func.attr.value) # type: ignore
type_name = MAP_FUNC_TO_TYPE[func_name]

# TODO: When FastAPI supports Pydantic 2.0.4+, remove the conditional below.
if func_name == "constr":
return updated_node

needed_import = MAP_TYPE_TO_NEEDED_IMPORT.get(type_name)
if needed_import is not None:
AddImportsVisitor.add_needed_import(context=self.context, **needed_import) # type: ignore[arg-type]
Expand All @@ -76,23 +65,20 @@ def leave_ann_assign_constr_call(self, original_node: cst.AnnAssign, updated_nod
slice_value = cst.Index(value=cst.Name(type_name))

AddImportsVisitor.add_needed_import(context=self.context, module="typing_extensions", obj="Annotated")
annotated = cst.Subscript(
return cst.Subscript(
value=cst.Name("Annotated"),
slice=[
cst.SubscriptElement(slice=slice_value),
cst.SubscriptElement(slice=cst.Index(value=updated_node.annotation.annotation)),
cst.SubscriptElement(slice=cst.Index(value=updated_node)),
],
)
annotation = cst.Annotation(annotation=annotated) # type: ignore[assignment]
return updated_node.with_changes(annotation=annotation)

# TODO: When FastAPI supports Pydantic 2.0.4+, remove the comments below.
@m.leave(CONSTR_CALL)
def leave_constr_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
self._remove_import(original_node.func)
# AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="StringConstraints")
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="StringConstraints")
return updated_node.with_changes(
# func=cst.Name("StringConstraints"),
func=cst.Name("StringConstraints"),
args=[
arg if arg.keyword and arg.keyword.value != "regex" else arg.with_changes(keyword=cst.Name("pattern"))
for arg in updated_node.args
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/cases/con_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,29 @@
" e: condecimal(gt=0, lt=10)",
" f: confloat(gt=0, lt=10)",
" g: conset(int, min_items=1, max_items=10)",
" h: Optional[conint(ge=1, le=4294967295)] = None",
" i: dict[str, condecimal(max_digits=10, decimal_places=2)]",
],
),
expected=File(
"con_func.py",
content=[
"from pydantic import Field, BaseModel, constr",
"from pydantic import Field, StringConstraints, BaseModel",
"from decimal import Decimal",
"from typing import List, Set",
"from typing_extensions import Annotated",
"",
"",
"class Potato(BaseModel):",
" a: constr(pattern='[a-z]+')",
" a: Annotated[str, StringConstraints(pattern='[a-z]+')]",
" b: Annotated[List[int], Field(min_length=1, max_length=10)]",
" c: Annotated[int, Field(gt=0, lt=10)]",
" d: Annotated[bytes, Field(min_length=1, max_length=10)]",
" e: Annotated[Decimal, Field(gt=0, lt=10)]",
" f: Annotated[float, Field(gt=0, lt=10)]",
" g: Annotated[Set[int], Field(min_length=1, max_length=10)]",
" h: Optional[Annotated[int, Field(ge=1, le=4294967295)]] = None",
" i: dict[str, Annotated[Decimal, Field(max_digits=10, decimal_places=2)]]",
],
),
)
Expand Down
21 changes: 18 additions & 3 deletions tests/unit/test_con_func.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from libcst.codemod import CodemodTest

from bump_pydantic.codemods.con_func import ConFuncCallCommand
Expand All @@ -9,7 +8,6 @@ class TestFieldCommand(CodemodTest):

maxDiff = None

@pytest.mark.xfail(reason="Annotated is not supported yet!")
def test_constr_to_annotated(self) -> None:
before = """
from pydantic import BaseModel, constr
Expand All @@ -26,7 +24,6 @@ class Potato(BaseModel):
"""
self.assertCodemod(before, after)

@pytest.mark.xfail(reason="Annotated is not supported yet!")
def test_pydantic_constr_to_annotated(self) -> None:
before = """
import pydantic
Expand Down Expand Up @@ -76,3 +73,21 @@ class Potato(BaseModel):
potato: Annotated[int, Field(ge=0, le=100)]
"""
self.assertCodemod(before, after)

def test_conint_to_optional_annotated(self) -> None:
before = """
from typing import Optional
from pydantic import BaseModel, conint
class Potato(BaseModel):
potato: Optional[conint(ge=0, le=100)]
"""
after = """
from typing import Optional
from pydantic import Field, BaseModel
from typing_extensions import Annotated
class Potato(BaseModel):
potato: Optional[Annotated[int, Field(ge=0, le=100)]]
"""
self.assertCodemod(before, after)

0 comments on commit 7493d23

Please sign in to comment.