diff --git a/bump_pydantic/codemods/validator.py b/bump_pydantic/codemods/validator.py index a2fe7fc..6709ce5 100644 --- a/bump_pydantic/codemods/validator.py +++ b/bump_pydantic/codemods/validator.py @@ -140,6 +140,13 @@ def leave_validator_func(self, original_node: cst.FunctionDef, updated_node: cst self._should_add_comment = False return updated_node + # Check if a classmethod decorator already exists + if any( + m.matches(decorator, m.Decorator(decorator=m.Name("classmethod"))) for decorator in updated_node.decorators + ): + return updated_node # If it already exists, return the node as is + + # If it doesn't exist, add the classmethod decorator classmethod_decorator = cst.Decorator(decorator=cst.Name("classmethod")) return updated_node.with_changes(decorators=[*updated_node.decorators, classmethod_decorator]) diff --git a/tests/integration/cases/replace_validator.py b/tests/integration/cases/replace_validator.py index 82009b5..dc5616a 100644 --- a/tests/integration/cases/replace_validator.py +++ b/tests/integration/cases/replace_validator.py @@ -45,6 +45,51 @@ ], ), ), + Case( + name="Replace validator with existing classmethod decorator", + source=File( + "replace_validator_existing_classmethod.py", + content=[ + "from pydantic import BaseModel, validator, root_validator", + "", + "", + "class A(BaseModel):", + " a: int", + " b: str", + "", + " @validator('a')", + " @classmethod", + " def validate_a(cls, v):", + " return v + 1", + "", + " @root_validator()", + " @classmethod", + " def validate_b(cls, values):", + " return values", + ], + ), + expected=File( + "replace_validator_existing_classmethod.py", + content=[ + "from pydantic import field_validator, model_validator, BaseModel", + "", + "", + "class A(BaseModel):", + " a: int", + " b: str", + "", + " @field_validator('a')", + " @classmethod", + " def validate_a(cls, v):", + " return v + 1", + "", + " @model_validator()", + " @classmethod", + " def validate_b(cls, values):", + " return values", + ], + ), + ), Case( name="Replace validator with pre=True", source=File( diff --git a/tests/unit/test_validator.py b/tests/unit/test_validator.py index 2a8e289..f6d9501 100644 --- a/tests/unit/test_validator.py +++ b/tests/unit/test_validator.py @@ -278,6 +278,51 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]: """ self.assertCodemod(before, after) + def test_replace_validator_existing_classmethod_decorator(self) -> None: + before = """ + from pydantic import validator + + + class Potato(BaseModel): + name: str + dialect: str + + @root_validator(pre=True, allow_reuse=True) + @classmethod + def _normalize_fields(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + if "gateways" not in values and "gateway" in values: + values["gateways"] = values.pop("gateway") + + @validator("name", "dialect", pre=False) + @classmethod + def _string_validator(cls, v: t.Any) -> t.Optional[str]: + if isinstance(v, exp.Expression): + return v.name.lower() + return str(v).lower() if v is not None else None + """ + after = """ + from pydantic import field_validator, model_validator + + + class Potato(BaseModel): + name: str + dialect: str + + @model_validator(mode="before") + @classmethod + def _normalize_fields(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + if "gateways" not in values and "gateway" in values: + values["gateways"] = values.pop("gateway") + + @field_validator("name", "dialect") + @classmethod + def _string_validator(cls, v: t.Any) -> t.Optional[str]: + if isinstance(v, exp.Expression): + return v.name.lower() + return str(v).lower() if v is not None else None + """ + self.assertCodemod(before, after) + @pytest.mark.xfail(reason="Not implemented yet") def test_import_pydantic(self) -> None: before = """