From 9f3c1dd4bceaf016af02773d2e3f70ff1022145a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 26 Jul 2023 15:54:25 +0100 Subject: [PATCH 1/4] =?UTF-8?q?=F0=9F=90=9B=20Use=20`Annotated`=20even=20w?= =?UTF-8?q?hen=20con*=20functions=20are=20wrapped=20on=20`cst.Subscript`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bump_pydantic/codemods/__init__.py | 2 +- bump_pydantic/codemods/con_func.py | 32 ++++++++------------------ tests/integration/cases/con_func.py | 6 +++-- tests/unit/test_con_func.py | 35 ++++++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 29 deletions(-) diff --git a/bump_pydantic/codemods/__init__.py b/bump_pydantic/codemods/__init__.py index e2e6f74..790a3d8 100644 --- a/bump_pydantic/codemods/__init__.py +++ b/bump_pydantic/codemods/__init__.py @@ -30,7 +30,7 @@ class Rule(str, Enum): BP007 = "BP007" """Replace `@validator` with `@field_validator`.""" BP008 = "BP008" - """Replace `constr()` with `Annotated[str, StringConstraints()`.""" + """Replace `con*` functions by `Annotated` versions.""" def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]: diff --git a/bump_pydantic/codemods/con_func.py b/bump_pydantic/codemods/con_func.py index 5e42033..17f4171 100644 --- a/bump_pydantic/codemods/con_func.py +++ b/bump_pydantic/codemods/con_func.py @@ -6,24 +6,18 @@ from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor CONSTR_CALL = m.Call(func=m.Name("constr") | m.Attribute(value=m.Name("pydantic"), attr=m.Name("constr"))) -ANN_ASSIGN_CONSTR_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CONSTR_CALL)) - - CON_NUMBER_CALL = m.OneOf( *[ m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name))) for name in ("conint", "confloat", "condecimal", "conbytes") ] ) -ANN_ASSIGN_CON_NUMBER_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_NUMBER_CALL)) - CON_COLLECTION_CALL = m.OneOf( *[ m.Call(func=m.Name(name) | m.Attribute(value=m.Name("pydantic"), attr=m.Name(name))) for name in ("conlist", "conset", "confrozenset") ] ) -ANN_ASSIGN_COLLECTION_CALL = m.AnnAssign(annotation=m.Annotation(annotation=CON_COLLECTION_CALL)) MAP_FUNC_TO_TYPE = { "constr": "str", @@ -48,19 +42,14 @@ class ConFuncCallCommand(VisitorBasedCodemodCommand): def __init__(self, context: CodemodContext) -> None: super().__init__(context) - @m.leave(ANN_ASSIGN_CONSTR_CALL | ANN_ASSIGN_CON_NUMBER_CALL | ANN_ASSIGN_COLLECTION_CALL) - def leave_ann_assign_constr_call(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign: - annotation = cast(cst.Call, original_node.annotation.annotation) - if m.matches(annotation.func, m.Name()): - func_name = cast(str, annotation.func.value) # type: ignore + @m.leave(CON_NUMBER_CALL | CON_COLLECTION_CALL | CONSTR_CALL) + def leave_annotation_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Subscript: + if m.matches(original_node.func, m.Name()): + func_name = cast(str, original_node.func.value) # type: ignore else: - func_name = cast(str, annotation.func.attr.value) # type: ignore + func_name = cast(str, original_node.func.attr.value) # type: ignore type_name = MAP_FUNC_TO_TYPE[func_name] - # TODO: When FastAPI supports Pydantic 2.0.4+, remove the conditional below. - if func_name == "constr": - return updated_node - needed_import = MAP_TYPE_TO_NEEDED_IMPORT.get(type_name) if needed_import is not None: AddImportsVisitor.add_needed_import(context=self.context, **needed_import) # type: ignore[arg-type] @@ -76,23 +65,20 @@ def leave_ann_assign_constr_call(self, original_node: cst.AnnAssign, updated_nod slice_value = cst.Index(value=cst.Name(type_name)) AddImportsVisitor.add_needed_import(context=self.context, module="typing_extensions", obj="Annotated") - annotated = cst.Subscript( + return cst.Subscript( value=cst.Name("Annotated"), slice=[ cst.SubscriptElement(slice=slice_value), - cst.SubscriptElement(slice=cst.Index(value=updated_node.annotation.annotation)), + cst.SubscriptElement(slice=cst.Index(value=updated_node)), ], ) - annotation = cst.Annotation(annotation=annotated) # type: ignore[assignment] - return updated_node.with_changes(annotation=annotation) - # TODO: When FastAPI supports Pydantic 2.0.4+, remove the comments below. @m.leave(CONSTR_CALL) def leave_constr_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: self._remove_import(original_node.func) - # AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="StringConstraints") + AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="StringConstraints") return updated_node.with_changes( - # func=cst.Name("StringConstraints"), + func=cst.Name("StringConstraints"), args=[ arg if arg.keyword and arg.keyword.value != "regex" else arg.with_changes(keyword=cst.Name("pattern")) for arg in updated_node.args diff --git a/tests/integration/cases/con_func.py b/tests/integration/cases/con_func.py index 1bab613..d86ce1c 100644 --- a/tests/integration/cases/con_func.py +++ b/tests/integration/cases/con_func.py @@ -18,25 +18,27 @@ " e: condecimal(gt=0, lt=10)", " f: confloat(gt=0, lt=10)", " g: conset(int, min_items=1, max_items=10)", + " h: Optional[conint(ge=1, le=4294967295)] = None", ], ), expected=File( "con_func.py", content=[ - "from pydantic import Field, BaseModel, constr", + "from pydantic import Field, StringConstraints, BaseModel", "from decimal import Decimal", "from typing import List, Set", "from typing_extensions import Annotated", "", "", "class Potato(BaseModel):", - " a: constr(pattern='[a-z]+')", + " a: Annotated[str, StringConstraints(pattern='[a-z]+')]", " b: Annotated[List[int], Field(min_length=1, max_length=10)]", " c: Annotated[int, Field(gt=0, lt=10)]", " d: Annotated[bytes, Field(min_length=1, max_length=10)]", " e: Annotated[Decimal, Field(gt=0, lt=10)]", " f: Annotated[float, Field(gt=0, lt=10)]", " g: Annotated[Set[int], Field(min_length=1, max_length=10)]", + " h: Optional[Annotated[int, Field(ge=1, le=4294967295)]] = None", ], ), ) diff --git a/tests/unit/test_con_func.py b/tests/unit/test_con_func.py index 831f20b..aed0c6c 100644 --- a/tests/unit/test_con_func.py +++ b/tests/unit/test_con_func.py @@ -1,4 +1,4 @@ -import pytest +import libcst as cst from libcst.codemod import CodemodTest from bump_pydantic.codemods.con_func import ConFuncCallCommand @@ -9,7 +9,6 @@ class TestFieldCommand(CodemodTest): maxDiff = None - @pytest.mark.xfail(reason="Annotated is not supported yet!") def test_constr_to_annotated(self) -> None: before = """ from pydantic import BaseModel, constr @@ -26,7 +25,6 @@ class Potato(BaseModel): """ self.assertCodemod(before, after) - @pytest.mark.xfail(reason="Annotated is not supported yet!") def test_pydantic_constr_to_annotated(self) -> None: before = """ import pydantic @@ -76,3 +74,34 @@ class Potato(BaseModel): potato: Annotated[int, Field(ge=0, le=100)] """ self.assertCodemod(before, after) + + def test_conint_to_optional_annotated(self) -> None: + before = """ + from typing import Optional + from pydantic import BaseModel, conint + + class Potato(BaseModel): + potato: Optional[conint(ge=0, le=100)] + """ + import textwrap + + from rich.pretty import pprint + + pprint(cst.parse_module(textwrap.dedent(before))) + another_before = """ + from typing import Optional + from pydantic import BaseModel, conint + + class Potato(BaseModel): + potato: conint(ge=0, le=100) + """ + pprint(cst.parse_module(textwrap.dedent(another_before))) + after = """ + from typing import Optional + from pydantic import Field, BaseModel + from typing_extensions import Annotated + + class Potato(BaseModel): + potato: Optional[Annotated[int, Field(ge=0, le=100)]] + """ + self.assertCodemod(before, after) From 1a4dd5782a1536992b8a2b9e236c1ff00b4aeb10 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 26 Jul 2023 15:56:21 +0100 Subject: [PATCH 2/4] Add test for decimal --- tests/integration/cases/con_func.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/cases/con_func.py b/tests/integration/cases/con_func.py index d86ce1c..631e7ac 100644 --- a/tests/integration/cases/con_func.py +++ b/tests/integration/cases/con_func.py @@ -19,6 +19,7 @@ " f: confloat(gt=0, lt=10)", " g: conset(int, min_items=1, max_items=10)", " h: Optional[conint(ge=1, le=4294967295)] = None", + " i: dict[str, condecimal(max_digits=10, decimal_places=2)]", ], ), expected=File( @@ -39,6 +40,7 @@ " f: Annotated[float, Field(gt=0, lt=10)]", " g: Annotated[Set[int], Field(min_length=1, max_length=10)]", " h: Optional[Annotated[int, Field(ge=1, le=4294967295)]] = None", + " i: dict[str, Annotated[Decimal, Field(max_digits=10, decimal_places=2)]]", ], ), ) From 691a895e8b57c5f45a9ad490f0909d9e1af9e2e8 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 26 Jul 2023 15:57:17 +0100 Subject: [PATCH 3/4] Add missing imports on README --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bb7b45e..f12c76e 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,7 @@ Into: ```py from typing import Generic, TypeVar +from pydantic import BaseModel T = TypeVar('T') @@ -217,7 +218,7 @@ Into: ```py from typing import List -from pydantic import RootModel +from pydantic import RootModel, BaseModel class User(BaseModel): age: int From d247a49a8a0f0bf1c0414934069444857c3c595d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 26 Jul 2023 15:58:10 +0100 Subject: [PATCH 4/4] Remove print on tests --- tests/unit/test_con_func.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/unit/test_con_func.py b/tests/unit/test_con_func.py index aed0c6c..72d5689 100644 --- a/tests/unit/test_con_func.py +++ b/tests/unit/test_con_func.py @@ -1,4 +1,3 @@ -import libcst as cst from libcst.codemod import CodemodTest from bump_pydantic.codemods.con_func import ConFuncCallCommand @@ -83,19 +82,6 @@ def test_conint_to_optional_annotated(self) -> None: class Potato(BaseModel): potato: Optional[conint(ge=0, le=100)] """ - import textwrap - - from rich.pretty import pprint - - pprint(cst.parse_module(textwrap.dedent(before))) - another_before = """ - from typing import Optional - from pydantic import BaseModel, conint - - class Potato(BaseModel): - potato: conint(ge=0, le=100) - """ - pprint(cst.parse_module(textwrap.dedent(another_before))) after = """ from typing import Optional from pydantic import Field, BaseModel