Skip to content

Commit

Permalink
Pandera integration using new plugin system (flyteorg#354)
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicBboy authored Feb 6, 2021
1 parent cb46788 commit b1cd377
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 2 deletions.
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from flytekit.core.task import reference_task, task
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.loggers import logger
from flytekit.types import schema

__version__ = "develop"

Expand Down
2 changes: 1 addition & 1 deletion plugins/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Flyte plugins are structured as micro-libs and can be authored in an
independent repository. The plugins maintained by the core team are maintained
in this repository

## (Refer to this Blog to understand the idea of microlibs)[https://medium.com/@jherreras/python-microlibs-5be9461ad979]
## [Refer to this Blog to understand the idea of microlibs](https://medium.com/@jherreras/python-microlibs-5be9461ad979)

## Conventions
All plugins should expose a library in the format **flytekitplugins-{}**, where
Expand Down
1 change: 1 addition & 0 deletions plugins/pandera/flytekitplugins/pandera/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .schema import PanderaTransformer
92 changes: 92 additions & 0 deletions plugins/pandera/flytekitplugins/pandera/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import typing
from typing import Type

import pandas
import pandera

from flytekit import FlyteContext
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType
from flytekit.types.schema import FlyteSchema, PandasSchemaWriter, SchemaFormat, SchemaOpenMode
from flytekit.types.schema.types import FlyteSchemaTransformer


class PanderaTransformer(TypeTransformer[pandera.typing.DataFrame]):
_SUPPORTED_TYPES: typing.Dict[
type, SchemaType.SchemaColumn.SchemaColumnType
] = FlyteSchemaTransformer._SUPPORTED_TYPES

class EmptySchema(pandera.SchemaModel):
pass

def __init__(self):
super().__init__("Pandera Transformer", pandera.typing.DataFrame)

def _pandera_schema(self, t: Type[pandera.typing.DataFrame]):
try:
type_args = typing.get_args(t)
except AttributeError:
# for python < 3.8
type_args = getattr(t, "__args__", None)

if type_args:
schema_model, *_ = type_args
else:
schema_model = self.EmptySchema
return schema_model.to_schema()

def _get_col_dtypes(self, t: Type[pandera.typing.DataFrame]):
return {k: v.pandas_dtype for k, v in self._pandera_schema(t).columns.items()}

def _get_schema_type(self, t: Type[pandera.typing.DataFrame]) -> SchemaType:
converted_cols: typing.List[SchemaType.SchemaColumn] = []
for k, col in self._pandera_schema(t).columns.items():
if col.pandas_dtype not in self._SUPPORTED_TYPES:
raise AssertionError(f"type {v} is currently not supported by the pandera schema")
converted_cols.append(SchemaType.SchemaColumn(name=k, type=self._SUPPORTED_TYPES[col.pandas_dtype]))
return SchemaType(columns=converted_cols)

def get_literal_type(self, t: Type[pandera.typing.DataFrame]) -> LiteralType:
return LiteralType(schema=self._get_schema_type(t))

def to_literal(
self,
ctx: FlyteContext,
python_val: pandas.DataFrame,
python_type: Type[pandera.typing.DataFrame],
expected: LiteralType,
) -> Literal:
if isinstance(python_val, pandas.DataFrame):
local_dir = ctx.file_access.get_random_local_directory()
w = PandasSchemaWriter(
local_dir=local_dir, cols=self._get_col_dtypes(python_type), fmt=SchemaFormat.PARQUET
)
w.write(python_val)
remote_path = ctx.file_access.get_random_remote_directory()
ctx.file_access.put_data(local_dir, remote_path, is_multipart=True)
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type))))
else:
raise AssertionError(
f"Only Pandas Dataframe object can be returned from a task, returned object type {type(python_val)}"
)

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[pandera.typing.DataFrame]
) -> pandera.typing.DataFrame:
if not (lv and lv.scalar and lv.scalar.schema):
raise AssertionError("Can only covert a literal schema to a pandera schema")

def downloader(x, y):
ctx.file_access.download_directory(x, y)

df = FlyteSchema(
local_path=ctx.file_access.get_random_local_directory(),
remote_path=lv.scalar.schema.uri,
downloader=downloader,
supported_mode=SchemaOpenMode.READ,
)
return self._pandera_schema(expected_python_type)(df.open().all())


TypeEngine.register(PanderaTransformer())
32 changes: 32 additions & 0 deletions plugins/pandera/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from setuptools import setup

PLUGIN_NAME = "pandera"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=0.16.0b6,<1.0.0", "pandera>=0.6.1"]

setup(
name=microlib_name,
version="0.1.0",
author="flyteorg",
author_email="[email protected]",
description="Pandera plugin for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.7",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
3 changes: 2 additions & 1 deletion plugins/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"flytekitplugins-spark": "spark",
"flytekitplugins-pod": "pod",
"flytekitplugins-kfpytorch": "kfpytorch",
"flytekitplugins-aws": "aws",
"flytekitplugins-awssagemaker": "awssagemaker",
"flytekitplugins-kftensorflow": "kftensorflow",
"flytekitplugins-pandera": "pandera",
}


Expand Down
69 changes: 69 additions & 0 deletions plugins/tests/pandera/test_wf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pandas
import pandera
import pytest
from flytekitplugins.pandera import schema # noqa: F401

from flytekit import task, workflow


def test_pandera_dataframe_type_hints():
class InSchema(pandera.SchemaModel):
col1: pandera.typing.Series[int]
col2: pandera.typing.Series[float]

class IntermediateSchema(InSchema):
col3: pandera.typing.Series[float]

@pandera.dataframe_check
@classmethod
def col3_check(cls, df: pandera.typing.DataFrame) -> pandera.typing.Series[bool]:
return df["col3"] == df["col1"] * df["col2"]

class OutSchema(IntermediateSchema):
col4: pandera.typing.Series[str]

@task
def transform1(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[IntermediateSchema]:
return df.assign(col3=df["col1"] * df["col2"])

@task
def transform2(df: pandera.typing.DataFrame[IntermediateSchema]) -> pandera.typing.DataFrame[OutSchema]:
return df.assign(col4="foo")

@workflow
def my_wf() -> pandera.typing.DataFrame[OutSchema]:
df = pandas.DataFrame({"col1": [1, 2, 3], "col2": [10.0, 11.0, 12.0]})
return transform2(df=transform1(df=df))

@workflow
def invalid_wf() -> pandera.typing.DataFrame[OutSchema]:
df = pandas.DataFrame({"col1": [1, 2, 3], "col2": list("abc")})
return transform2(df=transform1(df=df))

result = my_wf()
assert isinstance(result, pandas.DataFrame)

# raise error at runtime on invalid types
with pytest.raises(pandera.errors.SchemaError):
invalid_wf()


@pytest.mark.parametrize(
"data",
[
pandas.DataFrame({"col1": [1, 2, 3]}),
pandas.DataFrame({"col1": [1, 2, 3], "col2": list("abc")}),
pandas.DataFrame(),
],
)
def test_pandera_dataframe_no_schema_model(data):
@task
def transform(df: pandera.typing.DataFrame) -> pandera.typing.DataFrame:
return df

@workflow
def my_wf(df: pandera.typing.DataFrame) -> pandera.typing.DataFrame:
return transform(df=df)

result = my_wf(df=data)
assert isinstance(result, pandas.DataFrame)

0 comments on commit b1cd377

Please sign in to comment.