Skip to content

Commit

Permalink
✨ Refactor Annotated[..., Field()] (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Jul 10, 2023
1 parent 6fc078f commit 47c66de
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
22 changes: 19 additions & 3 deletions bump_pydantic/codemods/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@
]
)

ANN_ASSIGN_WITH_FIELD = m.AnnAssign(
value=m.Call(func=m.Name("Field")),
) | m.AnnAssign(
annotation=m.Annotation(
annotation=m.Subscript(
slice=[
m.ZeroOrMore(),
m.SubscriptElement(slice=m.Index(value=m.Call(func=m.Name("Field")))),
m.ZeroOrMore(),
]
)
)
)


class FieldCodemod(VisitorBasedCodemodCommand):
def __init__(self, context: CodemodContext) -> None:
Expand All @@ -60,12 +74,12 @@ def leave_field_import(self, original_node: cst.Module, updated_node: cst.Module
self.has_field_import = False
return updated_node

@m.visit(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
@m.visit(ANN_ASSIGN_WITH_FIELD)
def visit_field_assign(self, node: cst.AnnAssign) -> None:
self.inside_field_assign = True
self._const: Union[cst.Arg, None] = None

@m.leave(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
@m.leave(ANN_ASSIGN_WITH_FIELD)
def leave_field_assign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
self.inside_field_assign = False

Expand Down Expand Up @@ -124,10 +138,12 @@ def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> c

source = textwrap.dedent(
"""
from typing import Annotated
from pydantic import BaseModel, Field
class A(BaseModel):
a: List[str] = Field(..., description="My description", min_items=1)
a: Annotated[List[str], Field(..., description="My description", min_items=1)]
"""
)
console.print(source)
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,20 @@ class Settings(BaseSettings):
strawberry: int = Field(..., frozen=False)
"""
self.assertCodemod(before, after)

def test_annotated_field(self) -> None:
before = """
from pydantic import BaseModel, Field
from typing import Annotated
class Potato(BaseModel):
potato: Annotated[List[int], Field(..., min_items=1, max_items=10)]
"""
after = """
from pydantic import BaseModel, Field
from typing import Annotated
class Potato(BaseModel):
potato: Annotated[List[int], Field(..., min_length=1, max_length=10)]
"""
self.assertCodemod(before, after)

0 comments on commit 47c66de

Please sign in to comment.