Skip to content

Commit

Permalink
✨ Replace validator by field_validator and root_validator by `m…
Browse files Browse the repository at this point in the history
…odel_validator` (#40)
  • Loading branch information
Kludex authored Jun 29, 2023
1 parent 61279ba commit a4a7c7d
Show file tree
Hide file tree
Showing 5 changed files with 664 additions and 1 deletion.
115 changes: 115 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ Bump Pydantic is a tool to help you migrate your code from Pydantic V1 to V2.
- [BP003: Replace `Config` class by `model_config`](#bp003-replace-config-class-by-model_config)
- [BP005: Replace `GenericModel` by `BaseModel`](#bp005-replace-genericmodel-by-basemodel)
- [BP006: Replace `__root__` by `RootModel`](#bp006-replace-__root__-by-rootmodel)
- [BP007: Replace decorators](#bp007-replace-decorators)
- [BP008: Replace `const=True` by `Literal`](#bp008-replace-consttrue-by-literal)
- [BP009: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`](#bp009-replace-pydanticparse_obj_as-by-pydantictypeadapter)
- [License](#license)

---
Expand Down Expand Up @@ -240,6 +243,118 @@ class Users(RootModel[List[User]]):
pass
```

### BP007: Replace decorators

- ✅ Replace `@validator` by `@field_validator`.
- ✅ Replace `@root_validator` by `@model_validator`.

The following code will be transformed:

```py
from pydantic import BaseModel, validator, root_validator


class User(BaseModel):
name: str

@validator('name', pre=True)
def validate_name(cls, v):
return v

@root_validator(pre=True)
def validate_root(cls, values):
return values
```

Into:

```py
from pydantic import BaseModel, field_validator, model_validator


class User(BaseModel):
name: str

@field_validator('name', mode='before')
def validate_name(cls, v):
return v

@model_validator(mode='before')
def validate_root(cls, values):
return values
```

### BP008: Replace `const=True` by `Literal`

- ✅ Replace `field: Enum = Field(Enum.VALUE, const=True)` by `field: Literal[Enum.VALUE] = Enum.VALUE`.

The following code will be transformed:

```py
from enum import Enum

from pydantic import BaseModel, Field


class User(BaseModel):
name: Enum = Field(Enum.VALUE, const=True)
```

Into:

```py
from enum import Enum

from pydantic import BaseModel, Field


class User(BaseModel):
name: Literal[Enum.VALUE] = Enum.VALUE
```

### BP009: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`

- ✅ Replace `pydantic.parse_obj_as(T, obj)` to `pydantic.TypeAdapter(T).validate_python(obj)`.


The following code will be transformed:

```py
from typing import List

from pydantic import BaseModel, parse_obj_as


class User(BaseModel):
name: str


class Users(BaseModel):
users: List[User]


users = parse_obj_as(Users, {'users': [{'name': 'John'}]})
```

Into:

```py
from typing import List

from pydantic import BaseModel, TypeAdapter


class User(BaseModel):
name: str


class Users(BaseModel):
users: List[User]


users = TypeAdapter(Users).validate_python({'users': [{'name': 'John'}]})
```

---

## License
Expand Down
6 changes: 6 additions & 0 deletions bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from bump_pydantic.codemods.replace_generic_model import ReplaceGenericModelCommand
from bump_pydantic.codemods.replace_imports import ReplaceImportsCodemod
from bump_pydantic.codemods.root_model import RootModelCommand
from bump_pydantic.codemods.validator import ValidatorCodemod


class Rule(str, Enum):
Expand All @@ -25,6 +26,8 @@ class Rule(str, Enum):
"""Replace `GenericModel` with `BaseModel`."""
BP006 = "BP006"
"""Replace `BaseModel.__root__ = T` with `RootModel[T]`."""
BP007 = "BP007"
"""Replace `@validator` with `@field_validator`."""


def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
Expand All @@ -48,6 +51,9 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]
if Rule.BP006 not in disabled:
codemods.append(RootModelCommand)

if Rule.BP007 not in disabled:
codemods.append(ValidatorCodemod)

# Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods
190 changes: 190 additions & 0 deletions bump_pydantic/codemods/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from typing import List

import libcst as cst
from libcst import matchers as m
from libcst._nodes.module import Module
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

PREFIX_COMMENT = "# TODO[pydantic]: "
REFACTOR_COMMENT = f"{PREFIX_COMMENT}We couldn't refactor the `{{old_name}}`, please replace it by `{{new_name}}` manually." # noqa: E501
VALIDATOR_COMMENT = REFACTOR_COMMENT.format(old_name="validator", new_name="field_validator")
ROOT_VALIDATOR_COMMENT = REFACTOR_COMMENT.format(old_name="root_validator", new_name="model_validator")
CHECK_LINK_COMMENT = "# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information."

IMPORT_VALIDATOR = m.Module(
body=[
m.ZeroOrMore(),
m.SimpleStatementLine(
body=[
m.ZeroOrMore(),
m.ImportFrom(
module=m.Name("pydantic"),
names=[
m.ZeroOrMore(),
m.ImportAlias(name=m.Name("validator")),
m.ZeroOrMore(),
],
),
m.ZeroOrMore(),
],
),
m.ZeroOrMore(),
]
)
VALIDATOR_DECORATOR = m.Decorator(decorator=m.Call(func=m.Name("validator")))
VALIDATOR_FUNCTION = m.FunctionDef(decorators=[m.ZeroOrMore(), VALIDATOR_DECORATOR, m.ZeroOrMore()])

IMPORT_ROOT_VALIDATOR = m.Module(
body=[
m.ZeroOrMore(),
m.SimpleStatementLine(
body=[
m.ZeroOrMore(),
m.ImportFrom(
module=m.Name("pydantic"),
names=[
m.ZeroOrMore(),
m.ImportAlias(name=m.Name("root_validator")),
m.ZeroOrMore(),
],
),
m.ZeroOrMore(),
],
),
m.ZeroOrMore(),
]
)
ROOT_VALIDATOR_DECORATOR = m.Decorator(decorator=m.Call(func=m.Name("root_validator")))
ROOT_VALIDATOR_FUNCTION = m.FunctionDef(decorators=[m.ZeroOrMore(), ROOT_VALIDATOR_DECORATOR, m.ZeroOrMore()])


class ValidatorCodemod(VisitorBasedCodemodCommand):
def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

self._import_pydantic_validator = self._import_pydantic_root_validator = False
self._already_modified = False
self._should_add_comment = False
self._args: List[cst.Arg] = []

@m.visit(IMPORT_VALIDATOR)
def visit_import_validator(self, node: cst.CSTNode) -> None:
self._import_pydantic_validator = True
self._import_pydantic_root_validator = True

def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
self._import_pydantic_validator = False
self._import_pydantic_root_validator = False
return updated_node

@m.visit(VALIDATOR_DECORATOR | ROOT_VALIDATOR_DECORATOR)
def visit_validator_decorator(self, node: cst.Decorator) -> None:
if m.matches(node.decorator, m.Call()):
for arg in node.decorator.args: # type: ignore[attr-defined]
pre_false = m.Arg(keyword=m.Name("pre"), value=m.Name("False"))
pre_true = m.Arg(keyword=m.Name("pre"), value=m.Name("True"))
if m.matches(arg, m.Arg(keyword=m.Name("allow_reuse")) | pre_false):
continue
if m.matches(arg, pre_true):
self._args.append(arg.with_changes(keyword=cst.Name("mode"), value=cst.SimpleString('"before"')))
elif m.matches(arg.keyword, m.Name(value=m.MatchIfTrue(lambda v: v in ("each_item", "always")))):
self._should_add_comment = True
else:
# The `check_fields` kw-argument and all positional arguments can be just copied.
self._args.append(arg)
else:
"""This only happens for `@validator`, not with `@validator()`. The parenthesis makes it not be a `Call`"""
self._should_add_comment = True

# Removes the trailing comma on the last argument e.g.
# `@validator(allow_reuse=True, )` -> `@validator(allow_reuse=True)`
if self._args:
self._args[-1] = self._args[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)

@m.visit(VALIDATOR_FUNCTION)
def visit_validator_func(self, node: cst.FunctionDef) -> None:
# We are only able to refactor the `@validator` when the function has only `cls` and `v` as arguments.
if len(node.params.params) > 2:
self._should_add_comment = True

@m.leave(ROOT_VALIDATOR_DECORATOR)
def leave_root_validator_func(self, original_node: cst.Decorator, updated_node: cst.Decorator) -> cst.Decorator:
for line in updated_node.leading_lines:
if m.matches(line, m.EmptyLine(comment=m.Comment(value=CHECK_LINK_COMMENT))):
return updated_node

if self._should_add_comment:
return self._decorator_with_leading_comment(updated_node, ROOT_VALIDATOR_COMMENT)

return self._replace_validators(updated_node, "root_validator", "model_validator")

@m.leave(VALIDATOR_DECORATOR)
def leave_validator_decorator(self, original_node: cst.Decorator, updated_node: cst.Decorator) -> cst.Decorator:
for line in updated_node.leading_lines:
if m.matches(line, m.EmptyLine(comment=m.Comment(value=CHECK_LINK_COMMENT))):
return updated_node

if self._should_add_comment:
return self._decorator_with_leading_comment(updated_node, VALIDATOR_COMMENT)

return self._replace_validators(updated_node, "validator", "field_validator")

@m.leave(VALIDATOR_FUNCTION | ROOT_VALIDATOR_FUNCTION)
def leave_validator_func(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self._args = []
if self._should_add_comment:
self._should_add_comment = False
return updated_node
classmethod_decorator = cst.Decorator(decorator=cst.Name("classmethod"))
return updated_node.with_changes(decorators=[*updated_node.decorators, classmethod_decorator])

def _decorator_with_leading_comment(self, node: cst.Decorator, comment: str) -> cst.Decorator:
return node.with_changes(
leading_lines=[
*node.leading_lines,
cst.EmptyLine(comment=cst.Comment(value=(comment))),
cst.EmptyLine(comment=cst.Comment(value=(CHECK_LINK_COMMENT))),
]
)

def _replace_validators(self, node: cst.Decorator, old_name: str, new_name: str) -> cst.Decorator:
RemoveImportsVisitor.remove_unused_import(self.context, "pydantic", old_name)
AddImportsVisitor.add_needed_import(self.context, "pydantic", new_name)
decorator = node.decorator.with_changes(func=cst.Name(new_name), args=self._args)
return node.with_changes(decorator=decorator)


if __name__ == "__main__":
import textwrap

from rich.console import Console

console = Console()

source = textwrap.dedent(
"""
from pydantic import BaseModel, validator
class Foo(BaseModel):
bar: str
@validator("bar", pre=True, always=True)
def bar_validator(cls, v):
return v
"""
)
console.print(source)
console.print("=" * 80)

mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = ValidatorCodemod(context=context)
# console.print(mod)

mod = wrapper.visit(command)
wrapper = cst.MetadataWrapper(mod)
command = AddImportsVisitor(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
console.print(mod.code)
2 changes: 1 addition & 1 deletion tests/unit/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from bump_pydantic.codemods.field import FieldCodemod


class TestReplaceConfigCommand(CodemodTest):
class TestFieldCommand(CodemodTest):
TRANSFORM = FieldCodemod

maxDiff = None
Expand Down
Loading

0 comments on commit a4a7c7d

Please sign in to comment.