Skip to content

Commit

Permalink
🐛 Do not duplicate comments on validator replacement (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Jun 30, 2023
1 parent 1571fd3 commit 3508bfb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
16 changes: 10 additions & 6 deletions bump_pydantic/codemods/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, context: CodemodContext) -> None:
self._import_pydantic_validator = self._import_pydantic_root_validator = False
self._already_modified = False
self._should_add_comment = False
self._has_comment = False
self._args: List[cst.Arg] = []

@m.visit(IMPORT_VALIDATOR)
Expand Down Expand Up @@ -104,15 +105,17 @@ def visit_validator_decorator(self, node: cst.Decorator) -> None:

@m.visit(VALIDATOR_FUNCTION)
def visit_validator_func(self, node: cst.FunctionDef) -> None:
for line in node.leading_lines:
if m.matches(line, m.EmptyLine(comment=m.Comment(value=CHECK_LINK_COMMENT))):
self._has_comment = True
# We are only able to refactor the `@validator` when the function has only `cls` and `v` as arguments.
if len(node.params.params) > 2:
self._should_add_comment = True

@m.leave(ROOT_VALIDATOR_DECORATOR)
def leave_root_validator_func(self, original_node: cst.Decorator, updated_node: cst.Decorator) -> cst.Decorator:
for line in updated_node.leading_lines:
if m.matches(line, m.EmptyLine(comment=m.Comment(value=CHECK_LINK_COMMENT))):
return updated_node
if self._has_comment:
return updated_node

if self._should_add_comment:
return self._decorator_with_leading_comment(updated_node, ROOT_VALIDATOR_COMMENT)
Expand All @@ -121,9 +124,8 @@ def leave_root_validator_func(self, original_node: cst.Decorator, updated_node:

@m.leave(VALIDATOR_DECORATOR)
def leave_validator_decorator(self, original_node: cst.Decorator, updated_node: cst.Decorator) -> cst.Decorator:
for line in updated_node.leading_lines:
if m.matches(line, m.EmptyLine(comment=m.Comment(value=CHECK_LINK_COMMENT))):
return updated_node
if self._has_comment:
return updated_node

if self._should_add_comment:
return self._decorator_with_leading_comment(updated_node, VALIDATOR_COMMENT)
Expand All @@ -133,9 +135,11 @@ def leave_validator_decorator(self, original_node: cst.Decorator, updated_node:
@m.leave(VALIDATOR_FUNCTION | ROOT_VALIDATOR_FUNCTION)
def leave_validator_func(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self._args = []
self._has_comment = False
if self._should_add_comment:
self._should_add_comment = False
return updated_node

classmethod_decorator = cst.Decorator(decorator=cst.Name("classmethod"))
return updated_node.with_changes(decorators=[*updated_node.decorators, classmethod_decorator])

Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,24 @@ def _normalize_fields(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
return values
"""
self.assertCodemod(before, after)

def test_noop_comment(self) -> None:
code = """
import typing as t
from pydantic import BaseModel, validator
class Potato(BaseModel):
name: str
dialect: str
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("name", "dialect")
def _string_validator(cls, v: t.Any, values: t.Dict[str, t.Any], **kwargs) -> t.Optional[str]:
if isinstance(v, exp.Expression):
return v.name.lower()
return str(v).lower() if v is not None else None
"""
self.assertCodemod(code, code)

0 comments on commit 3508bfb

Please sign in to comment.