-
Notifications
You must be signed in to change notification settings - Fork 301
/
schema.py
101 lines (82 loc) · 4.13 KB
/
schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import typing
from typing import Type
from flytekit import FlyteContext, lazy_module
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType
from flytekit.types.schema import FlyteSchema, SchemaFormat, SchemaOpenMode
from flytekit.types.schema.types import FlyteSchemaTransformer
from flytekit.types.schema.types_pandas import PandasSchemaWriter
pandas = lazy_module("pandas")
pandera = lazy_module("pandera")
T = typing.TypeVar("T")
class PanderaTransformer(TypeTransformer[pandera.typing.DataFrame]):
_SUPPORTED_TYPES: typing.Dict[type, SchemaType.SchemaColumn.SchemaColumnType] = (
FlyteSchemaTransformer._SUPPORTED_TYPES
)
def __init__(self):
super().__init__("Pandera Transformer", pandera.typing.DataFrame) # type: ignore
def _pandera_schema(self, t: Type[pandera.typing.DataFrame]):
try:
type_args = typing.get_args(t)
except AttributeError:
# for python < 3.8
type_args = getattr(t, "__args__", None)
if type_args:
schema_model, *_ = type_args
schema = schema_model.to_schema()
else:
schema = pandera.DataFrameSchema() # type: ignore
return schema
@staticmethod
def _get_pandas_type(pandera_dtype: pandera.dtypes.DataType):
return pandera_dtype.type.type
def _get_col_dtypes(self, t: Type[pandera.typing.DataFrame]):
return {k: self._get_pandas_type(v.dtype) for k, v in self._pandera_schema(t).columns.items()}
def _get_schema_type(self, t: Type[pandera.typing.DataFrame]) -> SchemaType:
converted_cols: typing.List[SchemaType.SchemaColumn] = []
for k, col in self._pandera_schema(t).columns.items():
pandas_type = self._get_pandas_type(col.dtype)
if pandas_type not in self._SUPPORTED_TYPES:
raise AssertionError(f"type {pandas_type} is currently not supported by the flytekit-pandera plugin")
converted_cols.append(SchemaType.SchemaColumn(name=k, type=self._SUPPORTED_TYPES[pandas_type]))
return SchemaType(columns=converted_cols)
def get_literal_type(self, t: Type[pandera.typing.DataFrame]) -> LiteralType:
return LiteralType(schema=self._get_schema_type(t))
def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, (t, pandas.DataFrame)):
raise TypeError(f"Type of Val '{v}' is not an instance of {t}")
def to_literal(
self,
ctx: FlyteContext,
python_val: pandas.DataFrame,
python_type: Type[pandera.typing.DataFrame],
expected: LiteralType,
) -> Literal:
if isinstance(python_val, pandas.DataFrame):
local_dir = ctx.file_access.get_random_local_directory()
w = PandasSchemaWriter(
local_dir=local_dir, cols=self._get_col_dtypes(python_type), fmt=SchemaFormat.PARQUET
)
w.write(self._pandera_schema(python_type)(python_val))
remote_path = ctx.file_access.put_raw_data(local_dir)
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type))))
else:
raise AssertionError(
f"Only Pandas Dataframe object can be returned from a task, returned object type {type(python_val)}"
)
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[pandera.typing.DataFrame]
) -> pandera.typing.DataFrame:
if not (lv and lv.scalar and lv.scalar.schema):
raise AssertionError("Can only convert a literal schema to a pandera schema")
def downloader(x, y):
ctx.file_access.get_data(x, y, is_multipart=True)
df = FlyteSchema(
local_path=ctx.file_access.get_random_local_directory(),
remote_path=lv.scalar.schema.uri,
downloader=downloader,
supported_mode=SchemaOpenMode.READ,
)
return self._pandera_schema(expected_python_type)(df.open().all())
TypeEngine.register(PanderaTransformer())