From 47c66def557f23b25ce660a1ebc32d4ba9ccae72 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 10 Jul 2023 11:36:26 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Refactor=20`Annotated[...,=20Field(?= =?UTF-8?q?)]`=20(#66)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bump_pydantic/codemods/field.py | 22 +++++++++++++++++++--- tests/unit/test_field.py | 17 +++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/bump_pydantic/codemods/field.py b/bump_pydantic/codemods/field.py index e6d647f..0a4692a 100644 --- a/bump_pydantic/codemods/field.py +++ b/bump_pydantic/codemods/field.py @@ -43,6 +43,20 @@ ] ) +ANN_ASSIGN_WITH_FIELD = m.AnnAssign( + value=m.Call(func=m.Name("Field")), +) | m.AnnAssign( + annotation=m.Annotation( + annotation=m.Subscript( + slice=[ + m.ZeroOrMore(), + m.SubscriptElement(slice=m.Index(value=m.Call(func=m.Name("Field")))), + m.ZeroOrMore(), + ] + ) + ) +) + class FieldCodemod(VisitorBasedCodemodCommand): def __init__(self, context: CodemodContext) -> None: @@ -60,12 +74,12 @@ def leave_field_import(self, original_node: cst.Module, updated_node: cst.Module self.has_field_import = False return updated_node - @m.visit(m.AnnAssign(value=m.Call(func=m.Name("Field")))) + @m.visit(ANN_ASSIGN_WITH_FIELD) def visit_field_assign(self, node: cst.AnnAssign) -> None: self.inside_field_assign = True self._const: Union[cst.Arg, None] = None - @m.leave(m.AnnAssign(value=m.Call(func=m.Name("Field")))) + @m.leave(ANN_ASSIGN_WITH_FIELD) def leave_field_assign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign: self.inside_field_assign = False @@ -124,10 +138,12 @@ def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> c source = textwrap.dedent( """ + from typing import Annotated + from pydantic import BaseModel, Field class A(BaseModel): - a: List[str] = Field(..., description="My description", min_items=1) + a: Annotated[List[str], Field(..., description="My description", min_items=1)] """ ) console.print(source) diff --git a/tests/unit/test_field.py b/tests/unit/test_field.py index 62e02b5..33a34a1 100644 --- a/tests/unit/test_field.py +++ b/tests/unit/test_field.py @@ -119,3 +119,20 @@ class Settings(BaseSettings): strawberry: int = Field(..., frozen=False) """ self.assertCodemod(before, after) + + def test_annotated_field(self) -> None: + before = """ + from pydantic import BaseModel, Field + from typing import Annotated + + class Potato(BaseModel): + potato: Annotated[List[int], Field(..., min_items=1, max_items=10)] + """ + after = """ + from pydantic import BaseModel, Field + from typing import Annotated + + class Potato(BaseModel): + potato: Annotated[List[int], Field(..., min_length=1, max_length=10)] + """ + self.assertCodemod(before, after)