diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 336ffbdad6..2277b18027 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -111,7 +111,6 @@ class PickleParamType(click.ParamType): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: - uri = FlyteContextManager.current_context().file_access.get_random_local_path() with open(uri, "w+b") as outfile: cloudpickle.dump(value, outfile) @@ -119,7 +118,6 @@ def convert( class DateTimeType(click.DateTime): - _NOW_FMT = "now" _ADDITONAL_FORMATS = [_NOW_FMT] @@ -276,7 +274,6 @@ def get_uri_for_dir( def convert_to_structured_dataset( self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: Directory ) -> Literal: - uri = self.get_uri_for_dir(ctx, value, "00000.parquet") lit = Literal( @@ -338,7 +335,7 @@ def convert_to_union( python_val = converter._click_type.convert(value, param, ctx) literal = converter.convert_to_literal(ctx, param, python_val) return Literal(scalar=Scalar(union=Union(literal, variant))) - except (Exception or AttributeError) as e: + except Exception or AttributeError as e: logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") @@ -399,7 +396,10 @@ def convert_to_struct( Convert the loaded json object to a Flyte Literal struct type. """ if type(value) != self._python_type: - o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) + if is_pydantic_basemodel(self._python_type): + o = self._python_type.parse_raw(json.dumps(value)) # type: ignore + else: + o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) else: o = value return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) @@ -446,6 +446,15 @@ def convert(self, ctx, param, value) -> typing.Union[Literal, typing.Any]: raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e +def is_pydantic_basemodel(python_type: typing.Type) -> bool: + try: + import pydantic + except ImportError: + return False + else: + return issubclass(python_type, pydantic.BaseModel) + + def to_click_option( ctx: click.Context, flyte_ctx: FlyteContext, diff --git a/plugins/flytekit-pydantic/README.md b/plugins/flytekit-pydantic/README.md new file mode 100644 index 0000000000..8eb7267100 --- /dev/null +++ b/plugins/flytekit-pydantic/README.md @@ -0,0 +1,28 @@ +# Flytekit Pydantic Plugin + +Pydantic is a data validation and settings management library that uses Python type annotations to enforce type hints at runtime and provide user-friendly errors when data is invalid. Pydantic models are classes that inherit from `pydantic.BaseModel` and are used to define the structure and validation of data using Python type annotations. + +The plugin adds type support for pydantic models. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-pydantic +``` + + +## Type Example +```python +from pydantic import BaseModel + + +class TrainConfig(BaseModel): + lr: float = 1e-3 + batch_size: int = 32 + files: List[FlyteFile] + directories: List[FlyteDirectory] + +@task +def train(cfg: TrainConfig): + ... +``` diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py new file mode 100644 index 0000000000..23e7e341bd --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py @@ -0,0 +1,4 @@ +from .basemodel_transformer import BaseModelTransformer +from .deserialization import set_validators_on_supported_flyte_types as _set_validators_on_supported_flyte_types + +_set_validators_on_supported_flyte_types() # enables you to use flytekit.types in pydantic model diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py new file mode 100644 index 0000000000..325da8e500 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -0,0 +1,67 @@ +"""Serializes & deserializes the pydantic basemodels """ + +from typing import Dict, Type + +import pydantic +from google.protobuf import json_format +from typing_extensions import Annotated + +from flytekit import FlyteContext +from flytekit.core import type_engine +from flytekit.models import literals, types + +from . import deserialization, serialization + +BaseModelLiterals = Annotated[ + Dict[str, literals.Literal], + """ + BaseModel serialized to a LiteralMap consisting of: + 1) the basemodel json with placeholders for flyte types + 2) mapping from placeholders to serialized flyte type values in the object store + """, +] + + +class BaseModelTransformer(type_engine.TypeTransformer[pydantic.BaseModel]): + _TYPE_INFO = types.LiteralType(simple=types.SimpleType.STRUCT) + + def __init__(self): + """Construct pydantic.BaseModelTransformer.""" + super().__init__(name="basemodel-transform", t=pydantic.BaseModel) + + def get_literal_type(self, t: Type[pydantic.BaseModel]) -> types.LiteralType: + return types.LiteralType(simple=types.SimpleType.STRUCT) + + def to_literal( + self, + ctx: FlyteContext, + python_val: pydantic.BaseModel, + python_type: Type[pydantic.BaseModel], + expected: types.LiteralType, + ) -> literals.Literal: + """Convert a given ``pydantic.BaseModel`` to the Literal representation.""" + return serialization.serialize_basemodel(python_val) + + def to_python_value( + self, + ctx: FlyteContext, + lv: literals.Literal, + expected_python_type: Type[pydantic.BaseModel], + ) -> pydantic.BaseModel: + """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" + basemodel_literals: BaseModelLiterals = lv.map.literals + basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(basemodel_literals) + with deserialization.PydanticDeserializationLiteralStore.attach( + basemodel_literals[serialization.OBJECTS_KEY].map + ): + return expected_python_type.parse_raw(basemodel_json_w_placeholders) + + +def read_basemodel_json_from_literalmap(lv: BaseModelLiterals) -> serialization.SerializedBaseModel: + basemodel_literal: literals.Literal = lv[serialization.BASEMODEL_JSON_KEY] + basemodel_json_w_placeholders = json_format.MessageToJson(basemodel_literal.scalar.generic) + assert isinstance(basemodel_json_w_placeholders, str) + return basemodel_json_w_placeholders + + +type_engine.TypeEngine.register(BaseModelTransformer()) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py new file mode 100644 index 0000000000..238e78c84d --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py @@ -0,0 +1,31 @@ +import builtins +import datetime +import typing +from typing import Set + +import numpy +import pyarrow +from typing_extensions import Annotated + +from flytekit.core import type_engine + +MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES: Set[str] = {m.__name__ for m in [builtins, typing, datetime, pyarrow, numpy]} + + +def include_in_flyte_types(t: type) -> bool: + if t is None: + return False + object_module = t.__module__ + if any(object_module.startswith(module) for module in MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES): + return False + return True + + +type_engine.TypeEngine.lazy_import_transformers() # loads all transformers +PYDANTIC_SUPPORTED_FLYTE_TYPES = tuple( + filter(include_in_flyte_types, type_engine.TypeEngine.get_available_transformers()) +) + +# this is the UUID placeholder that is set in the serialized basemodel JSON, connecting that field to +# the literal map that holds the actual object that needs to be deserialized (w/ protobuf) +LiteralObjID = Annotated[str, "Key for unique object in literal map."] diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py new file mode 100644 index 0000000000..24fe5afc1e --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -0,0 +1,145 @@ +import contextlib +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type, TypeVar, Union, cast + +import pydantic +from flytekitplugins.pydantic import commons, serialization + +from flytekit.core import context_manager, type_engine +from flytekit.models import literals +from flytekit.types import directory, file + +# this field is used by pydantic to get the validator method +PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_validators__.__name__ +PythonType = TypeVar("PythonType") # target type of the deserialization + + +class PydanticDeserializationLiteralStore: + """ + The purpose of this class is to provide a context manager that can be used to deserialize a basemodel from a + literal map. + + Because pydantic validators are fixed when subclassing a BaseModel, this object is a singleton that + serves as a namespace that can be set with the attach_to_literalmap context manager for the time that + a basemodel is being deserialized. The validators are then accessing this namespace for the flyteobj + placeholders that it is trying to deserialize. + """ + + literal_store: Optional[serialization.LiteralStore] = None # attachement point for the literal map + + def __init__(self) -> None: + raise Exception("This class should not be instantiated") + + def __init_subclass__(cls) -> None: + raise Exception("This class should not be subclassed") + + @classmethod + @contextlib.contextmanager + def attach(cls, literal_map: literals.LiteralMap) -> Generator[None, None, None]: + """ + Read a literal map and populate the object store from it. + + This can be used as a context manager to attach to a literal map for the duration of a deserialization + Note that this is not threadsafe, and designed to manage a single deserialization at a time. + """ + assert not cls.is_attached(), "can only be attached to one literal map at a time." + try: + cls.literal_store = literal_map.literals + yield + finally: + cls.literal_store = None + + @classmethod + def contains(cls, item: commons.LiteralObjID) -> bool: + assert cls.is_attached(), "can only check for existence of a literal when attached to a literal map" + assert cls.literal_store is not None + return item in cls.literal_store + + @classmethod + def is_attached(cls) -> bool: + return cls.literal_store is not None + + @classmethod + def get_python_object( + cls, identifier: commons.LiteralObjID, expected_type: Type[PythonType] + ) -> Optional[PythonType]: + """Deserialize a flyte literal and return the python object.""" + if not cls.is_attached(): + raise Exception("Must attach to a literal map before deserializing") + literal = cls.literal_store[identifier] # type: ignore + python_object = deserialize_flyte_literal(literal, expected_type) + return python_object + + +def set_validators_on_supported_flyte_types() -> None: + """ + Set pydantic validator for the flyte types supported by this plugin. + """ + for flyte_type in commons.PYDANTIC_SUPPORTED_FLYTE_TYPES: + setattr(flyte_type, PYDANTIC_VALIDATOR_METHOD_NAME, add_flyte_validators_for_type(flyte_type)) + + +def add_flyte_validators_for_type( + flyte_obj_type: Type[type_engine.T], +) -> Callable[[Any], Iterator[Callable[[Any], type_engine.T]]]: + """ + Add flyte deserialisation validators to a type. + """ + + previous_validators = cast( + Iterator[Callable[[Any], type_engine.T]], + getattr(flyte_obj_type, PYDANTIC_VALIDATOR_METHOD_NAME, lambda *_: [])(), + ) + + def validator(object_uid_maybe: Union[commons.LiteralObjID, Any]) -> Union[type_engine.T, Any]: + """Partial of deserialize_flyte_literal with the object_type fixed""" + if not PydanticDeserializationLiteralStore.is_attached(): + return object_uid_maybe # this validator should only trigger when we are deserializeing + if not isinstance(object_uid_maybe, str): + return object_uid_maybe # object uids are strings and we dont want to trigger on other types + if not PydanticDeserializationLiteralStore.contains(object_uid_maybe): + return object_uid_maybe # final safety check to make sure that the object uid is in the literal map + return PydanticDeserializationLiteralStore.get_python_object(object_uid_maybe, flyte_obj_type) + + def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], type_engine.T]]: + """Generator that returns validators.""" + yield validator + yield from previous_validators + yield from ADDITIONAL_FLYTETYPE_VALIDATORS.get(flyte_obj_type, []) + + return validator_generator + + +def validate_flytefile(flytefile: Union[str, file.FlyteFile]) -> file.FlyteFile: + """Validate a flytefile (i.e. deserialize).""" + if isinstance(flytefile, file.FlyteFile): + return flytefile + if isinstance(flytefile, str): # when e.g. initializing from config + return file.FlyteFile(flytefile) + else: + raise ValueError(f"Invalid type for flytefile: {type(flytefile)}") + + +def validate_flytedir(flytedir: Union[str, directory.FlyteDirectory]) -> directory.FlyteDirectory: + """Validate a flytedir (i.e. deserialize).""" + if isinstance(flytedir, directory.FlyteDirectory): + return flytedir + if isinstance(flytedir, str): # when e.g. initializing from config + return directory.FlyteDirectory(flytedir) + else: + raise ValueError(f"Invalid type for flytedir: {type(flytedir)}") + + +ADDITIONAL_FLYTETYPE_VALIDATORS: Dict[Type, List[Callable[[Any], Any]]] = { + file.FlyteFile: [validate_flytefile], + directory.FlyteDirectory: [validate_flytedir], +} + + +def deserialize_flyte_literal( + flyteobj_literal: literals.Literal, python_type: Type[PythonType] +) -> Optional[PythonType]: + """Deserialize a Flyte Literal into the python object instance.""" + ctx = context_manager.FlyteContext.current_context() + transformer = type_engine.TypeEngine.get_transformer(python_type) + python_obj = transformer.to_python_value(ctx, flyteobj_literal, python_type) + return python_obj diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py new file mode 100644 index 0000000000..cd5b149fd9 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -0,0 +1,115 @@ +""" +Logic for serializing a basemodel to a literalmap that can be passed between flyte tasks. + +The serialization process is as follows: + +1. Serialize the basemodel to json, replacing all flyte types with unique placeholder strings +2. Serialize the flyte types to separate literals and store them in the flyte object store (a singleton object) +3. Return a literal map with the json and the flyte object store represented as a literalmap {placeholder: flyte type} + +""" +import uuid +from typing import Any, Dict, Union, cast + +import pydantic +from google.protobuf import json_format, struct_pb2 +from typing_extensions import Annotated + +from flytekit.core import context_manager, type_engine +from flytekit.models import literals + +from . import commons + +BASEMODEL_JSON_KEY = "BaseModel JSON" +OBJECTS_KEY = "Serialized Flyte Objects" + +SerializedBaseModel = Annotated[str, "A pydantic BaseModel that has been serialized with placeholders for Flyte types."] + +ObjectStoreID = Annotated[str, "Key for unique literalmap of a serialized basemodel."] +LiteralObjID = Annotated[str, "Key for unique object in literal map."] +LiteralStore = Annotated[Dict[LiteralObjID, literals.Literal], "uid to literals for a serialized BaseModel"] + + +class BaseModelFlyteObjectStore: + """ + This class is an intermediate store for python objects that are being serialized/deserialized. + + On serialization of a basemodel, flyte objects are serialized and stored in this object store. + """ + + def __init__(self) -> None: + self.literal_store: LiteralStore = dict() + + def register_python_object(self, python_object: object) -> LiteralObjID: + """Serialize to literal and return a unique identifier.""" + serialized_item = serialize_to_flyte_literal(python_object) + identifier = make_identifier_for_serializeable(python_object) + assert identifier not in self.literal_store + self.literal_store[identifier] = serialized_item + return identifier + + def to_literal(self) -> literals.Literal: + """Convert the object store to a literal map.""" + return literals.Literal(map=literals.LiteralMap(literals=self.literal_store)) + + +def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.Literal: + """ + Serializes a given pydantic BaseModel instance into a LiteralMap. + The BaseModel is first serialized into a JSON format, where all Flyte types are replaced with unique placeholder strings. + The Flyte Types are serialized into separate Flyte literals + """ + store = BaseModelFlyteObjectStore() + basemodel_literal = serialize_basemodel_to_literal(basemodel, store) + basemodel_literalmap = literals.LiteralMap( + { + BASEMODEL_JSON_KEY: basemodel_literal, # json with flyte types replaced with placeholders + OBJECTS_KEY: store.to_literal(), # flyte type-engine serialized types + } + ) + literal = literals.Literal(map=basemodel_literalmap) # type: ignore + return literal + + +def serialize_basemodel_to_literal( + basemodel: pydantic.BaseModel, + flyteobject_store: BaseModelFlyteObjectStore, +) -> literals.Literal: + """ + Serialize a pydantic BaseModel to json and protobuf, separating out the Flyte types into a separate store. + On deserialization, the store is used to reconstruct the Flyte types. + """ + + def encoder(obj: Any) -> Union[str, commons.LiteralObjID]: + if isinstance(obj, commons.PYDANTIC_SUPPORTED_FLYTE_TYPES): + return flyteobject_store.register_python_object(obj) + return basemodel.__json_encoder__(obj) + + basemodel_json = basemodel.json(encoder=encoder) + return make_literal_from_json(basemodel_json) + + +def serialize_to_flyte_literal(python_obj: object) -> literals.Literal: + """ + Use the Flyte TypeEngine to serialize a python object to a Flyte Literal. + """ + python_type = type(python_obj) + ctx = context_manager.FlyteContextManager().current_context() + literal_type = type_engine.TypeEngine.to_literal_type(python_type) + literal_obj = type_engine.TypeEngine.to_literal(ctx, python_obj, python_type, literal_type) + return literal_obj + + +def make_literal_from_json(json: str) -> literals.Literal: + """ + Converts the json representation of a pydantic BaseModel to a Flyte Literal. + """ + return literals.Literal(scalar=literals.Scalar(generic=json_format.Parse(json, struct_pb2.Struct()))) # type: ignore + + +def make_identifier_for_serializeable(python_type: object) -> LiteralObjID: + """ + Create a unique identifier for a python object. + """ + unique_id = f"{type(python_type).__name__}_{uuid.uuid4().hex}" + return cast(LiteralObjID, unique_id) diff --git a/plugins/flytekit-pydantic/requirements.in b/plugins/flytekit-pydantic/requirements.in new file mode 100644 index 0000000000..44f25884d7 --- /dev/null +++ b/plugins/flytekit-pydantic/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-pydantic diff --git a/plugins/flytekit-pydantic/requirements.txt b/plugins/flytekit-pydantic/requirements.txt new file mode 100644 index 0000000000..68acf7008a --- /dev/null +++ b/plugins/flytekit-pydantic/requirements.txt @@ -0,0 +1,347 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-pydantic + # via -r requirements.in +adal==1.2.7 + # via azure-datalake-store +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp +arrow==1.2.3 + # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.52 + # via adlfs +azure-identity==1.12.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs +binaryornot==0.4.4 + # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth +certifi==2022.12.7 + # via + # kubernetes + # requests +cffi==1.15.1 + # via + # azure-datalake-store + # cryptography +chardet==5.1.0 + # via binaryornot +charset-normalizer==3.1.0 + # via + # aiohttp + # requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.2.1 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.14 + # via flytekit +cryptography==40.0.2 + # via + # adal + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via gcsfs +deprecated==1.2.13 + # via flytekit +diskcache==5.6.1 + # via flytekit +docker==6.0.1 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +flyteidl==1.3.20 + # via flytekit +flytekit==1.5.0 + # via flytekitplugins-pydantic +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.4.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.4.0 + # via flytekit +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.17.3 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.9.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage +googleapis-common-protos==1.59.0 + # via + # flyteidl + # flytekit + # google-api-core + # grpcio-status +grpcio==1.54.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.54.0 + # via flytekit +idna==3.4 + # via + # requests + # yarl +importlib-metadata==6.6.0 + # via + # flytekit + # keyring +isodate==0.6.1 + # via azure-storage-blob +jaraco-classes==3.2.3 + # via keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +jmespath==1.0.1 + # via botocore +joblib==1.2.0 + # via flytekit +keyring==23.13.1 + # via flytekit +kubernetes==26.1.0 + # via flytekit +markupsafe==2.1.2 + # via jinja2 +marshmallow==3.19.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +more-itertools==9.1.0 + # via jaraco-classes +msal==1.22.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy-extensions==1.0.0 + # via typing-inspect +natsort==8.3.1 + # via flytekit +numpy==1.24.3 + # via + # flytekit + # pandas + # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.1 + # via + # docker + # marshmallow +pandas==1.5.3 + # via flytekit +portalocker==2.7.0 + # via msal-extensions +protobuf==4.22.3 + # via + # flyteidl + # google-api-core + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +pyarrow==10.0.1 + # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth +pycparser==2.21 + # via cffi +pydantic==1.10.7 + # via flytekitplugins-pydantic +pyjwt[crypto]==2.6.0 + # via + # adal + # msal +pyopenssl==23.1.1 + # via flytekit +python-dateutil==2.8.2 + # via + # adal + # arrow + # botocore + # croniter + # flytekit + # kubernetes + # pandas +python-json-logger==2.0.7 + # via flytekit +python-slugify==8.0.1 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2023.3 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # kubernetes + # responses +regex==2023.5.5 + # via docker-image-py +requests==2.30.0 + # via + # adal + # azure-core + # azure-datalake-store + # cookiecutter + # docker + # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib + # responses +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes +responses==0.23.1 + # via flytekit +rsa==4.9 + # via google-auth +s3fs==2023.4.0 + # via flytekit +six==1.16.0 + # via + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes + # python-dateutil +smmap==5.0.0 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +types-pyyaml==6.0.12.9 + # via responses +typing-extensions==4.5.0 + # via + # azure-core + # azure-storage-blob + # flytekit + # pydantic + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.15 + # via + # botocore + # docker + # flytekit + # kubernetes + # requests + # responses +websocket-client==1.5.1 + # via + # docker + # kubernetes +wheel==0.40.0 + # via flytekit +wrapt==1.15.0 + # via + # aiobotocore + # deprecated + # flytekit +yarl==1.9.2 + # via aiohttp +zipp==3.15.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-pydantic/setup.py b/plugins/flytekit-pydantic/setup.py new file mode 100644 index 0000000000..313c574dd1 --- /dev/null +++ b/plugins/flytekit-pydantic/setup.py @@ -0,0 +1,40 @@ +from setuptools import setup + +PLUGIN_NAME = "pydantic" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.7.0b0,<2.0.0", "pydantic"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Plugin adding type support for Pydantic models", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-pydantic", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-pydantic/tests/folder/test_file1.txt b/plugins/flytekit-pydantic/tests/folder/test_file1.txt new file mode 100644 index 0000000000..257cc5642c --- /dev/null +++ b/plugins/flytekit-pydantic/tests/folder/test_file1.txt @@ -0,0 +1 @@ +foo diff --git a/plugins/flytekit-pydantic/tests/folder/test_file2.txt b/plugins/flytekit-pydantic/tests/folder/test_file2.txt new file mode 100644 index 0000000000..5716ca5987 --- /dev/null +++ b/plugins/flytekit-pydantic/tests/folder/test_file2.txt @@ -0,0 +1 @@ +bar diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py new file mode 100644 index 0000000000..3c02dcb3f1 --- /dev/null +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -0,0 +1,296 @@ +import datetime as dt +import os +import pathlib +from typing import Any, Dict, List, Optional, Type, Union + +import pandas as pd +import pytest +from flyteidl.core.types_pb2 import SimpleType +from flytekitplugins.pydantic import BaseModelTransformer +from flytekitplugins.pydantic.commons import PYDANTIC_SUPPORTED_FLYTE_TYPES +from pydantic import BaseModel, Extra + +import flytekit +from flytekit.core import context_manager +from flytekit.core.type_engine import TypeEngine +from flytekit.types import directory +from flytekit.types.file import file + + +class TrainConfig(BaseModel): + """Config BaseModel for testing purposes.""" + + batch_size: int = 32 + lr: float = 1e-3 + loss: str = "cross_entropy" + + class Config: + extra = Extra.forbid + + +class Config(BaseModel): + """Config BaseModel for testing purposes with an optional type hint.""" + + model_config: Optional[Union[Dict[str, TrainConfig], TrainConfig]] = TrainConfig() + + +class ConfigWithDatetime(BaseModel): + """Config BaseModel for testing purposes with datetime type hint.""" + + datetime: dt.datetime = dt.datetime.now() + + +class NestedConfig(BaseModel): + """Nested config BaseModel for testing purposes.""" + + files: "ConfigWithFlyteFiles" + dirs: "ConfigWithFlyteDirs" + df: "ConfigWithPandasDataFrame" + datetime: "ConfigWithDatetime" = ConfigWithDatetime() + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, NestedConfig) and all( + getattr(self, attr) == getattr(__value, attr) for attr in ["files", "dirs", "df", "datetime"] + ) + + +class ConfigRequired(BaseModel): + """Config BaseModel for testing purposes with required attribute.""" + + model_config: Union[Dict[str, TrainConfig], TrainConfig] + + +class ConfigWithFlyteFiles(BaseModel): + """Config BaseModel for testing purposes with flytekit.files.FlyteFile type hint.""" + + flytefiles: List[file.FlyteFile] + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ConfigWithFlyteFiles) and all( + pathlib.Path(self_file).read_text() == pathlib.Path(other_file).read_text() + for self_file, other_file in zip(self.flytefiles, __value.flytefiles) + ) + + +class ConfigWithFlyteDirs(BaseModel): + """Config BaseModel for testing purposes with flytekit.directory.FlyteDirectory type hint.""" + + flytedirs: List[directory.FlyteDirectory] + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ConfigWithFlyteDirs) and all( + os.listdir(self_dir) == os.listdir(other_dir) + for self_dir, other_dir in zip(self.flytedirs, __value.flytedirs) + ) + + +class ConfigWithPandasDataFrame(BaseModel): + """Config BaseModel for testing purposes with pandas.DataFrame type hint.""" + + df: pd.DataFrame + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ConfigWithPandasDataFrame) and self.df.equals(__value.df) + + +class ChildConfig(Config): + """Child class config BaseModel for testing purposes.""" + + d: List[int] = [1, 2, 3] + + +NestedConfig.update_forward_refs() + + +@pytest.mark.parametrize( + "python_type,kwargs", + [ + (Config, {}), + (ConfigRequired, {"model_config": TrainConfig()}), + (TrainConfig, {}), + (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), + (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), + ( + NestedConfig, + { + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, + }, + ), + ], +) +def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): + """Test that a (de-)serialization roundtrip results in the identical BaseModel.""" + + ctx = context_manager.FlyteContextManager().current_context() + + type_transformer = BaseModelTransformer() + + python_value = python_type(**kwargs) + + literal_value = type_transformer.to_literal( + ctx, + python_value, + python_type, + type_transformer.get_literal_type(python_value), + ) + + reconstructed_value = type_transformer.to_python_value(ctx, literal_value, type(python_value)) + + assert reconstructed_value == python_value + + +@pytest.mark.parametrize( + "config_type,kwargs", + [ + (Config, {"model_config": {"foo": TrainConfig(loss="mse")}}), + (ConfigRequired, {"model_config": {"foo": TrainConfig(loss="mse")}}), + (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), + (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), + ( + NestedConfig, + { + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, + }, + ), + ], +) +def test_pass_to_workflow(config_type: Type, kwargs: Dict[str, Any]): + """Test passing a BaseModel instance to a workflow works.""" + cfg = config_type(**kwargs) + + @flytekit.task + def train(cfg: config_type) -> config_type: + return cfg + + @flytekit.workflow + def wf(cfg: config_type) -> config_type: + return train(cfg=cfg) + + returned_cfg = wf(cfg=cfg) # type: ignore + + assert returned_cfg == cfg + # TODO these assertions are not valid for all types + + +@pytest.mark.parametrize( + "kwargs", + [ + {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + ], +) +def test_flytefiles_in_wf(kwargs: Dict[str, Any]): + """Test passing a BaseModel instance to a workflow works.""" + cfg = ConfigWithFlyteFiles(**kwargs) + + @flytekit.task + def read(cfg: ConfigWithFlyteFiles) -> str: + with open(cfg.flytefiles[0], "r") as f: + return f.read() + + @flytekit.workflow + def wf(cfg: ConfigWithFlyteFiles) -> str: + return read(cfg=cfg) # type: ignore + + string = wf(cfg=cfg) + assert string in {"foo\n", "bar\n"} # type: ignore + + +@pytest.mark.parametrize( + "kwargs", + [ + {"flytedirs": ["tests/folder/"]}, + ], +) +def test_flytedirs_in_wf(kwargs: Dict[str, Any]): + """Test passing a BaseModel instance to a workflow works.""" + cfg = ConfigWithFlyteDirs(**kwargs) + + @flytekit.task + def listdir(cfg: ConfigWithFlyteDirs) -> List[str]: + return os.listdir(cfg.flytedirs[0]) + + @flytekit.workflow + def wf(cfg: ConfigWithFlyteDirs) -> List[str]: + return listdir(cfg=cfg) # type: ignore + + dirs = wf(cfg=cfg) + assert len(dirs) == 2 # type: ignore + + +def test_double_config_in_wf(): + """Test passing a BaseModel instance to a workflow works.""" + cfg1 = TrainConfig(batch_size=13) + cfg2 = TrainConfig(batch_size=31) + + @flytekit.task + def are_different(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: + return cfg1 != cfg2 + + @flytekit.workflow + def wf(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: + return are_different(cfg1=cfg1, cfg2=cfg2) # type: ignore + + assert wf(cfg1=cfg1, cfg2=cfg2), wf(cfg1=cfg1, cfg2=cfg2) # type: ignore + + +@pytest.mark.parametrize( + "python_type,config_kwargs", + [ + (Config, {}), + (ConfigRequired, {"model_config": TrainConfig()}), + (TrainConfig, {}), + (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), + (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), + ( + NestedConfig, + { + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, + }, + ), + ], +) +def test_dynamic(python_type: Type[BaseModel], config_kwargs: Dict[str, Any]): + config_instance = python_type(**config_kwargs) + + @flytekit.task + def train(cfg: BaseModel): + print(cfg) + + @flytekit.dynamic(cache=True, cache_version="0.3") + def sub_wf(cfg: BaseModel): + train(cfg=cfg) + + @flytekit.workflow + def wf(): + sub_wf(cfg=config_instance) + + wf() + + +def test_supported(): + assert len(PYDANTIC_SUPPORTED_FLYTE_TYPES) == 9 + + +def test_single_df(): + ctx = context_manager.FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(ConfigWithPandasDataFrame) + assert lt.simple == SimpleType.STRUCT + + pyd = ConfigWithPandasDataFrame(df=pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) + lit = TypeEngine.to_literal(ctx, pyd, ConfigWithPandasDataFrame, lt) + assert lit.map is not None + offloaded_keys = list(lit.map.literals["Serialized Flyte Objects"].map.literals.keys()) + assert len(offloaded_keys) == 1 + assert ( + lit.map.literals["Serialized Flyte Objects"].map.literals[offloaded_keys[0]].scalar.structured_dataset + is not None + )