Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotated StructuredDataset: support nested_types #2252

Merged
33 changes: 30 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import types
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass, field, is_dataclass
from typing import Dict, Generator, Optional, Type, Union

import _datetime
Expand Down Expand Up @@ -114,6 +114,22 @@ def iter(self) -> Generator[DF, None, None]:
)


# flat the nested column map recursively
def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict:
result = {}
for key, value in sub_dict.items():
current_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
result.update(flatten_dict(sub_dict=value, parent_key=current_key))
elif is_dataclass(value):
fields = getattr(value, "__dataclass_fields__")
d = {k: v.type for k, v in fields.items()}
result.update(flatten_dict(sub_dict=d, parent_key=current_key))
else:
result[current_key] = value
return result


def extract_cols_and_format(
t: typing.Any,
) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional["pa.lib.Schema"]]:
Expand Down Expand Up @@ -142,7 +158,17 @@ def extract_cols_and_format(
if get_origin(t) is Annotated:
base_type, *annotate_args = get_args(t)
for aa in annotate_args:
if isinstance(aa, StructuredDatasetFormat):
if hasattr(aa, "__annotations__"):
# handle dataclass argument
d = collections.OrderedDict()
dm = vars(aa)
d.update(dm["__annotations__"])
ordered_dict_cols = d
elif isinstance(aa, dict):
d = collections.OrderedDict()
d.update(aa)
ordered_dict_cols = d
elif isinstance(aa, StructuredDatasetFormat):
if fmt != "":
raise ValueError(f"A format was already specified {fmt}, cannot use {aa}")
fmt = aa
Expand Down Expand Up @@ -826,7 +852,8 @@ def _convert_ordered_dict_of_columns_to_list(
converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = []
if column_map is None or len(column_map) == 0:
return converted_cols
for k, v in column_map.items():
flat_column_map = flatten_dict(column_map)
for k, v in flat_column_map.items():
lt = self._get_dataset_column_literal_type(v)
converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt))
return converted_cols
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import typing
from dataclasses import dataclass

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -28,7 +29,16 @@
NUMPY_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory()
BQ_PATH = "bq://flyte-dataset:flyte.table"


@dataclass
class MyCols:
Name: str
Age: int


my_cols = kwtypes(Name=str, Age=int)
my_dataclass_cols = MyCols
my_dict_cols = {"Name": str, "Age": int}
fields = [("Name", pa.string()), ("Age", pa.int32())]
arrow_schema = pa.schema(fields)
pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add more nested dataframes to cover extreme test cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just add your example to the unit test? we can add a new file (test_structured_dataset_workflow_with_nested_type.py) to tests/flytekit/unit/types/structured_dataset

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Expand Down Expand Up @@ -157,6 +167,18 @@ def t4(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return dataset.open(pd.DataFrame).all()


@task
def t4a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame:
# s3 (parquet) -> pandas -> s3 (parquet)
return dataset.open(pd.DataFrame).all()


@task
def t4b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame:
# s3 (parquet) -> pandas -> s3 (parquet)
return dataset.open(pd.DataFrame).all()


@task
def t5(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]:
# s3 (parquet) -> pandas -> bq
Expand All @@ -170,6 +192,20 @@ def t6(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return df


@task
def t6a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame:
# bq -> pandas -> s3 (parquet)
df = dataset.open(pd.DataFrame).all()
return df


@task
def t6b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame:
# bq -> pandas -> s3 (parquet)
df = dataset.open(pd.DataFrame).all()
return df


@task
def t7(
df1: pd.DataFrame, df2: pd.DataFrame
Expand All @@ -193,6 +229,20 @@ def t8a(dataframe: pa.Table) -> pa.Table:
return dataframe


@task
def t8b(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dataclass_cols]:
# Arrow table -> s3 (parquet)
print(dataframe.columns)
return StructuredDataset(dataframe=dataframe)


@task
def t8c(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dict_cols]:
# Arrow table -> s3 (parquet)
print(dataframe.columns)
return StructuredDataset(dataframe=dataframe)


@task
def t9(dataframe: np.ndarray) -> Annotated[StructuredDataset, my_cols]:
# numpy -> Arrow table -> s3 (parquet)
Expand All @@ -206,6 +256,20 @@ def t10(dataset: Annotated[StructuredDataset, my_cols]) -> np.ndarray:
return np_array


@task
def t10a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> np.ndarray:
# s3 (parquet) -> Arrow table -> numpy
np_array = dataset.open(np.ndarray).all()
return np_array


@task
def t10b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> np.ndarray:
# s3 (parquet) -> Arrow table -> numpy
np_array = dataset.open(np.ndarray).all()
return np_array


StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler())
StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler())

Expand All @@ -223,6 +287,20 @@ def t12(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return df


@task
def t12a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame:
# csv -> pandas
df = dataset.open(pd.DataFrame).all()
return df


@task
def t12b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame:
# csv -> pandas
df = dataset.open(pd.DataFrame).all()
return df


@task
def generate_pandas() -> pd.DataFrame:
return pd_df
Expand All @@ -249,15 +327,25 @@ def wf():
t3(dataset=StructuredDataset(uri=PANDAS_PATH))
t3a(dataset=StructuredDataset(uri=PANDAS_PATH))
t4(dataset=StructuredDataset(uri=PANDAS_PATH))
t4a(dataset=StructuredDataset(uri=PANDAS_PATH))
t4b(dataset=StructuredDataset(uri=PANDAS_PATH))
t5(dataframe=df)
t6(dataset=StructuredDataset(uri=BQ_PATH))
t6a(dataset=StructuredDataset(uri=BQ_PATH))
t6b(dataset=StructuredDataset(uri=BQ_PATH))
t7(df1=df, df2=df)
t8(dataframe=arrow_df)
t8a(dataframe=arrow_df)
t8b(dataframe=arrow_df)
t8c(dataframe=arrow_df)
t9(dataframe=np_array)
t10(dataset=StructuredDataset(uri=NUMPY_PATH))
t10a(dataset=StructuredDataset(uri=NUMPY_PATH))
t10b(dataset=StructuredDataset(uri=NUMPY_PATH))
t11(dataframe=df)
t12(dataset=StructuredDataset(uri=PANDAS_PATH))
t12a(dataset=StructuredDataset(uri=PANDAS_PATH))
t12b(dataset=StructuredDataset(uri=PANDAS_PATH))


def test_structured_dataset_wf():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from dataclasses import dataclass

import pyarrow as pa
import pytest
from typing_extensions import Annotated

from flytekit import FlyteContextManager, StructuredDataset, kwtypes, task, workflow

pd = pytest.importorskip("pandas")

PANDAS_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory()
NUMPY_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory()
BQ_PATH = "bq://flyte-dataset:flyte.table"


data = [
{
"company": "XYZ pvt ltd",
"location": "London",
"info": {"president": "Rakesh Kapoor", "contacts": {"email": "[email protected]", "tel": "9876543210"}},
},
{
"company": "ABC pvt ltd",
"location": "USA",
"info": {"president": "Kapoor Rakesh", "contacts": {"email": "[email protected]", "tel": "0123456789"}},
},
]


@dataclass
class ContactsField:
email: str
tel: str


@dataclass
class InfoField:
president: str
contacts: ContactsField


@dataclass
class CompanyField:
location: str
info: InfoField
company: str


MyArgDataset = Annotated[StructuredDataset, kwtypes(company=str)]
MyDictDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str}})]
MyDictListDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str, "email": str}})]
MyTopDataClassDataset = Annotated[StructuredDataset, CompanyField]
MyTopDictDataset = Annotated[StructuredDataset, {"company": str, "location": str}]
MySecondDataClassDataset = Annotated[StructuredDataset, kwtypes(info=InfoField)]
MyNestedDataClassDataset = Annotated[StructuredDataset, kwtypes(info=kwtypes(contacts=ContactsField))]


@task()
def create_pd_table() -> StructuredDataset:
df = pd.json_normalize(data, max_level=0)
print("original dataframe: \n", df)

return StructuredDataset(dataframe=df, uri=PANDAS_PATH)


@task()
def create_bq_table() -> StructuredDataset:
df = pd.json_normalize(data, max_level=0)
print("original dataframe: \n", df)

# Enable one of GCP `uri` below if you want. You can replace `uri` with your own google cloud endpoints.
return StructuredDataset(dataframe=df, uri=BQ_PATH)


@task()
def create_np_table() -> StructuredDataset:
df = pd.json_normalize(data, max_level=0)
print("original dataframe: \n", df)

return StructuredDataset(dataframe=df, uri=NUMPY_PATH)


@task()
def create_ar_table() -> StructuredDataset:
df = pa.Table.from_pandas(pd.json_normalize(data, max_level=0))
print("original dataframe: \n", df)

return StructuredDataset(
dataframe=df,
)


@task()
def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyArgDataset dataframe: \n", t)
return t


@task()
def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyDictDataset dataframe: \n", t)
return t


@task()
def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyDictListDataset dataframe: \n", t)
return t


@task()
def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyTopDataClassDataset dataframe: \n", t)
return t


@task()
def print_table_by_top_dict(sd: MyTopDictDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyTopDictDataset dataframe: \n", t)
return t


@task()
def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MySecondDataClassDataset dataframe: \n", t)
return t


@task()
def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyNestedDataClassDataset dataframe: \n", t)
return t


@workflow
def wf():
pd_sd = create_pd_table()
print_table_by_arg(sd=pd_sd)
print_table_by_dict(sd=pd_sd)
print_table_by_list_dict(sd=pd_sd)
print_table_by_top_dataclass(sd=pd_sd)
print_table_by_top_dict(sd=pd_sd)
print_table_by_second_dataclass(sd=pd_sd)
print_table_by_nested_dataclass(sd=pd_sd)
bq_sd = create_pd_table()
print_table_by_arg(sd=bq_sd)
print_table_by_dict(sd=bq_sd)
print_table_by_list_dict(sd=bq_sd)
print_table_by_top_dataclass(sd=bq_sd)
print_table_by_top_dict(sd=bq_sd)
print_table_by_second_dataclass(sd=bq_sd)
print_table_by_nested_dataclass(sd=bq_sd)
np_sd = create_pd_table()
print_table_by_arg(sd=np_sd)
print_table_by_dict(sd=np_sd)
print_table_by_list_dict(sd=np_sd)
print_table_by_top_dataclass(sd=np_sd)
print_table_by_top_dict(sd=np_sd)
print_table_by_second_dataclass(sd=np_sd)
print_table_by_nested_dataclass(sd=np_sd)
ar_sd = create_pd_table()
print_table_by_arg(sd=ar_sd)
print_table_by_dict(sd=ar_sd)
print_table_by_list_dict(sd=ar_sd)
print_table_by_top_dataclass(sd=ar_sd)
print_table_by_top_dict(sd=ar_sd)
print_table_by_second_dataclass(sd=ar_sd)
print_table_by_nested_dataclass(sd=ar_sd)


def test_structured_dataset_wf():
wf()
Loading