From ad471d327c11152319aa85e6fe81c84bb6fca24e Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Sat, 6 Apr 2024 12:42:20 -0700 Subject: [PATCH] Do not duplicate classmethod when replacing validators --- bump_pydantic/codemods/validator.py | 6 ++++-- tests/unit/test_validator.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/bump_pydantic/codemods/validator.py b/bump_pydantic/codemods/validator.py index 2a771b9..969058c 100644 --- a/bump_pydantic/codemods/validator.py +++ b/bump_pydantic/codemods/validator.py @@ -142,8 +142,10 @@ def leave_validator_func(self, original_node: cst.FunctionDef, updated_node: cst self._should_add_comment = False return updated_node - classmethod_decorator = cst.Decorator(decorator=cst.Name("classmethod")) - return updated_node.with_changes(decorators=[*updated_node.decorators, classmethod_decorator]) + if not any(m.matches(d, m.Decorator(decorator=m.Name("classmethod"))) for d in updated_node.decorators): + classmethod_decorator = cst.Decorator(decorator=cst.Name("classmethod")) + updated_node = updated_node.with_changes(decorators=[*updated_node.decorators, classmethod_decorator]) + return updated_node def _decorator_with_leading_comment(self, node: cst.Decorator, comment: str) -> cst.Decorator: return node.with_changes( diff --git a/tests/unit/test_validator.py b/tests/unit/test_validator.py index 2a8e289..12afacb 100644 --- a/tests/unit/test_validator.py +++ b/tests/unit/test_validator.py @@ -265,6 +265,39 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]: from pydantic import field_validator + class Potato(BaseModel): + name: str + dialect: str + + @field_validator("name", "dialect") + @classmethod + def _string_validator(cls, v: t.Any) -> t.Optional[str]: + if isinstance(v, exp.Expression): + return v.name.lower() + return str(v).lower() if v is not None else None + """ + self.assertCodemod(before, after) + + def test_replace_validator_with_existing_classmethod(self) -> None: + before = """ + from pydantic import validator + + + class Potato(BaseModel): + name: str + dialect: str + + @validator("name", "dialect") + @classmethod + def _string_validator(cls, v: t.Any) -> t.Optional[str]: + if isinstance(v, exp.Expression): + return v.name.lower() + return str(v).lower() if v is not None else None + """ + after = """ + from pydantic import field_validator + + class Potato(BaseModel): name: str dialect: str