Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add pydantic plugin with BaseModel type transformer #1620

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions plugins/flytekit-pydantic/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Flytekit Pydantic Plugin

Pydantic is a data validation and settings management library that uses Python type annotations to enforce type hints at runtime and provide user-friendly errors when data is invalid. Pydantic models are classes that inherit from `pydantic.BaseModel` and are used to define the structure and validation of data using Python type annotations.

The plugin adds type support for pydantic models.

To install the plugin, run the following command:

```bash
pip install flytekitplugins-pydantic
```


## Type Example
```python
from pydantic import BaseModel
import flytekitplugins.pydantic


class Config(BaseModel):
lr: float = 1e-3
batch_size: int = 32


@task
def train(cfg: Config):
...
```
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .schema import BaseModelTransformer
51 changes: 51 additions & 0 deletions plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Type

from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct
from pydantic import BaseModel

from flytekit import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.literals import Literal, Scalar
from flytekit.models.types import LiteralType, SimpleType


class BaseModelTransformer(TypeTransformer[BaseModel]):
_TYPE_INFO = LiteralType(simple=SimpleType.STRUCT)

def __init__(self):
"""Construct BaseModelTransformer."""
super().__init__(name="basemodel-transform", t=BaseModel)

def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT)

def to_literal(
self,
ctx: FlyteContext,
python_val: BaseModel,
python_type: Type[BaseModel],
expected: LiteralType,
) -> Literal:
"""This method is used to convert from given python type object pydantic ``BaseModel`` to the Literal representation."""
s = Struct()

s.update({"schema": python_val.schema(), "data": python_val.dict()})

return Literal(scalar=Scalar(generic=s))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel:
"""Re-hydrate the pydantic BaseModel object from Flyte Literal value."""
base_model = MessageToDict(lv.scalar.generic)
schema = base_model["schema"]
data = base_model["data"]

if (expected_schema := expected_python_type.schema()) != schema:
raise TypeTransformerFailedError(
f"The schema `{expected_schema}` of the expected python type {expected_python_type} is not equal to the received schema `{schema}`."
)

return expected_python_type.parse_obj(data)


TypeEngine.register(BaseModelTransformer())
2 changes: 2 additions & 0 deletions plugins/flytekit-pydantic/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.
-e file:.#egg=flytekitplugins-pydantic
Loading