forked from flyteorg/flytekit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pandera integration using new plugin system (flyteorg#354)
- Loading branch information
1 parent
cb46788
commit b1cd377
Showing
7 changed files
with
198 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .schema import PanderaTransformer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import typing | ||
from typing import Type | ||
|
||
import pandas | ||
import pandera | ||
|
||
from flytekit import FlyteContext | ||
from flytekit.extend import TypeEngine, TypeTransformer | ||
from flytekit.models.literals import Literal, Scalar, Schema | ||
from flytekit.models.types import LiteralType, SchemaType | ||
from flytekit.types.schema import FlyteSchema, PandasSchemaWriter, SchemaFormat, SchemaOpenMode | ||
from flytekit.types.schema.types import FlyteSchemaTransformer | ||
|
||
|
||
class PanderaTransformer(TypeTransformer[pandera.typing.DataFrame]): | ||
_SUPPORTED_TYPES: typing.Dict[ | ||
type, SchemaType.SchemaColumn.SchemaColumnType | ||
] = FlyteSchemaTransformer._SUPPORTED_TYPES | ||
|
||
class EmptySchema(pandera.SchemaModel): | ||
pass | ||
|
||
def __init__(self): | ||
super().__init__("Pandera Transformer", pandera.typing.DataFrame) | ||
|
||
def _pandera_schema(self, t: Type[pandera.typing.DataFrame]): | ||
try: | ||
type_args = typing.get_args(t) | ||
except AttributeError: | ||
# for python < 3.8 | ||
type_args = getattr(t, "__args__", None) | ||
|
||
if type_args: | ||
schema_model, *_ = type_args | ||
else: | ||
schema_model = self.EmptySchema | ||
return schema_model.to_schema() | ||
|
||
def _get_col_dtypes(self, t: Type[pandera.typing.DataFrame]): | ||
return {k: v.pandas_dtype for k, v in self._pandera_schema(t).columns.items()} | ||
|
||
def _get_schema_type(self, t: Type[pandera.typing.DataFrame]) -> SchemaType: | ||
converted_cols: typing.List[SchemaType.SchemaColumn] = [] | ||
for k, col in self._pandera_schema(t).columns.items(): | ||
if col.pandas_dtype not in self._SUPPORTED_TYPES: | ||
raise AssertionError(f"type {v} is currently not supported by the pandera schema") | ||
converted_cols.append(SchemaType.SchemaColumn(name=k, type=self._SUPPORTED_TYPES[col.pandas_dtype])) | ||
return SchemaType(columns=converted_cols) | ||
|
||
def get_literal_type(self, t: Type[pandera.typing.DataFrame]) -> LiteralType: | ||
return LiteralType(schema=self._get_schema_type(t)) | ||
|
||
def to_literal( | ||
self, | ||
ctx: FlyteContext, | ||
python_val: pandas.DataFrame, | ||
python_type: Type[pandera.typing.DataFrame], | ||
expected: LiteralType, | ||
) -> Literal: | ||
if isinstance(python_val, pandas.DataFrame): | ||
local_dir = ctx.file_access.get_random_local_directory() | ||
w = PandasSchemaWriter( | ||
local_dir=local_dir, cols=self._get_col_dtypes(python_type), fmt=SchemaFormat.PARQUET | ||
) | ||
w.write(python_val) | ||
remote_path = ctx.file_access.get_random_remote_directory() | ||
ctx.file_access.put_data(local_dir, remote_path, is_multipart=True) | ||
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type)))) | ||
else: | ||
raise AssertionError( | ||
f"Only Pandas Dataframe object can be returned from a task, returned object type {type(python_val)}" | ||
) | ||
|
||
def to_python_value( | ||
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[pandera.typing.DataFrame] | ||
) -> pandera.typing.DataFrame: | ||
if not (lv and lv.scalar and lv.scalar.schema): | ||
raise AssertionError("Can only covert a literal schema to a pandera schema") | ||
|
||
def downloader(x, y): | ||
ctx.file_access.download_directory(x, y) | ||
|
||
df = FlyteSchema( | ||
local_path=ctx.file_access.get_random_local_directory(), | ||
remote_path=lv.scalar.schema.uri, | ||
downloader=downloader, | ||
supported_mode=SchemaOpenMode.READ, | ||
) | ||
return self._pandera_schema(expected_python_type)(df.open().all()) | ||
|
||
|
||
TypeEngine.register(PanderaTransformer()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from setuptools import setup | ||
|
||
PLUGIN_NAME = "pandera" | ||
|
||
microlib_name = f"flytekitplugins-{PLUGIN_NAME}" | ||
|
||
plugin_requires = ["flytekit>=0.16.0b6,<1.0.0", "pandera>=0.6.1"] | ||
|
||
setup( | ||
name=microlib_name, | ||
version="0.1.0", | ||
author="flyteorg", | ||
author_email="[email protected]", | ||
description="Pandera plugin for flytekit", | ||
namespace_packages=["flytekitplugins"], | ||
packages=[f"flytekitplugins.{PLUGIN_NAME}"], | ||
install_requires=plugin_requires, | ||
license="apache2", | ||
python_requires=">=3.7", | ||
classifiers=[ | ||
"Intended Audience :: Science/Research", | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python :: 3.7", | ||
"Programming Language :: Python :: 3.8", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Software Development", | ||
"Topic :: Software Development :: Libraries", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import pandas | ||
import pandera | ||
import pytest | ||
from flytekitplugins.pandera import schema # noqa: F401 | ||
|
||
from flytekit import task, workflow | ||
|
||
|
||
def test_pandera_dataframe_type_hints(): | ||
class InSchema(pandera.SchemaModel): | ||
col1: pandera.typing.Series[int] | ||
col2: pandera.typing.Series[float] | ||
|
||
class IntermediateSchema(InSchema): | ||
col3: pandera.typing.Series[float] | ||
|
||
@pandera.dataframe_check | ||
@classmethod | ||
def col3_check(cls, df: pandera.typing.DataFrame) -> pandera.typing.Series[bool]: | ||
return df["col3"] == df["col1"] * df["col2"] | ||
|
||
class OutSchema(IntermediateSchema): | ||
col4: pandera.typing.Series[str] | ||
|
||
@task | ||
def transform1(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[IntermediateSchema]: | ||
return df.assign(col3=df["col1"] * df["col2"]) | ||
|
||
@task | ||
def transform2(df: pandera.typing.DataFrame[IntermediateSchema]) -> pandera.typing.DataFrame[OutSchema]: | ||
return df.assign(col4="foo") | ||
|
||
@workflow | ||
def my_wf() -> pandera.typing.DataFrame[OutSchema]: | ||
df = pandas.DataFrame({"col1": [1, 2, 3], "col2": [10.0, 11.0, 12.0]}) | ||
return transform2(df=transform1(df=df)) | ||
|
||
@workflow | ||
def invalid_wf() -> pandera.typing.DataFrame[OutSchema]: | ||
df = pandas.DataFrame({"col1": [1, 2, 3], "col2": list("abc")}) | ||
return transform2(df=transform1(df=df)) | ||
|
||
result = my_wf() | ||
assert isinstance(result, pandas.DataFrame) | ||
|
||
# raise error at runtime on invalid types | ||
with pytest.raises(pandera.errors.SchemaError): | ||
invalid_wf() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data", | ||
[ | ||
pandas.DataFrame({"col1": [1, 2, 3]}), | ||
pandas.DataFrame({"col1": [1, 2, 3], "col2": list("abc")}), | ||
pandas.DataFrame(), | ||
], | ||
) | ||
def test_pandera_dataframe_no_schema_model(data): | ||
@task | ||
def transform(df: pandera.typing.DataFrame) -> pandera.typing.DataFrame: | ||
return df | ||
|
||
@workflow | ||
def my_wf(df: pandera.typing.DataFrame) -> pandera.typing.DataFrame: | ||
return transform(df=df) | ||
|
||
result = my_wf(df=data) | ||
assert isinstance(result, pandas.DataFrame) |