Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotated StructuredDataset: support nested_types #2252

Merged
14 changes: 11 additions & 3 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import asyncio
import collections
import datetime
import inspect
import warnings
Expand Down Expand Up @@ -84,15 +83,24 @@
UNSET_CARD = "_uc"


def kwtypes(**kwargs) -> OrderedDict[str, Type]:
def kwtypes(*args, **kwargs) -> OrderedDict[str, Type]:
"""
This is a small helper function to convert the keyword arguments to an OrderedDict of types.

.. code-block:: python

kwtypes(a=int, b=str)
"""
d = collections.OrderedDict()
d = OrderedDict()
for arg in args:
# handle positional arguments: dataclass
if hasattr(arg, "__annotations__"):
dm = vars(arg)
d.update(dm["__annotations__"])
# handle positional arguments: dict
elif isinstance(arg, dict):
d.update(arg)
# handle named arguments
for k, v in kwargs.items():
d[k] = v
return d
Expand Down
23 changes: 23 additions & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,13 +820,36 @@
return type_models.LiteralType(map_value_type=self._get_dataset_column_literal_type(t.__args__[1]))
raise AssertionError(f"type {t} is currently not supported by StructuredDataset")

def flatten_dict(self, nested_dict):
austin362667 marked this conversation as resolved.
Show resolved Hide resolved
result = {}
austin362667 marked this conversation as resolved.
Show resolved Hide resolved

def _flatten(sub_dict, parent_key=""):
austin362667 marked this conversation as resolved.
Show resolved Hide resolved
for key, value in sub_dict.items():
current_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
return _flatten(value, current_key)

Check warning on line 830 in flytekit/types/structured/structured_dataset.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/structured/structured_dataset.py#L830

Added line #L830 was not covered by tests
elif hasattr(value, "__dataclass_fields__"):
fields = getattr(value, "__dataclass_fields__")
d = {k: v.type for k, v in fields.items()}
return _flatten(d, current_key)

Check warning on line 834 in flytekit/types/structured/structured_dataset.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/structured/structured_dataset.py#L832-L834

Added lines #L832 - L834 were not covered by tests
else:
result[current_key] = value
return result

return _flatten(sub_dict=nested_dict)

def _convert_ordered_dict_of_columns_to_list(
self, column_map: typing.Optional[typing.OrderedDict[str, Type]]
) -> typing.List[StructuredDatasetType.DatasetColumn]:
converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = []
if column_map is None or len(column_map) == 0:
return converted_cols
flat_column_map = {}
for k, v in column_map.items():
d = dict()
d[k] = v
flat_column_map.update(self.flatten_dict(d))
austin362667 marked this conversation as resolved.
Show resolved Hide resolved
for k, v in flat_column_map.items():
lt = self._get_dataset_column_literal_type(v)
converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt))
return converted_cols
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run(file_name, wf_name, *args):


# test child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2.
def test_remote_run():
def test_remote_run_child_workflow():
run("child_workflow.py", "parent_wf", "--a", "3")


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import typing
from dataclasses import dataclass

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -28,7 +29,16 @@
NUMPY_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory()
BQ_PATH = "bq://flyte-dataset:flyte.table"


@dataclass
class MyCols:
Name: str
Age: int


my_cols = kwtypes(Name=str, Age=int)
my_dataclass_cols = kwtypes(MyCols)
my_dict_cols = kwtypes({"Name": str, "Age": int})
fields = [("Name", pa.string()), ("Age", pa.int32())]
arrow_schema = pa.schema(fields)
pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add more nested dataframes to cover extreme test cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just add your example to the unit test? we can add a new file (test_structured_dataset_workflow_with_nested_type.py) to tests/flytekit/unit/types/structured_dataset

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Expand Down Expand Up @@ -157,6 +167,18 @@ def t4(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return dataset.open(pd.DataFrame).all()


@task
def t4a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame:
# s3 (parquet) -> pandas -> s3 (parquet)
return dataset.open(pd.DataFrame).all()


@task
def t4b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame:
# s3 (parquet) -> pandas -> s3 (parquet)
return dataset.open(pd.DataFrame).all()


@task
def t5(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]:
# s3 (parquet) -> pandas -> bq
Expand All @@ -170,6 +192,20 @@ def t6(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return df


@task
def t6a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame:
# bq -> pandas -> s3 (parquet)
df = dataset.open(pd.DataFrame).all()
return df


@task
def t6b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame:
# bq -> pandas -> s3 (parquet)
df = dataset.open(pd.DataFrame).all()
return df


@task
def t7(
df1: pd.DataFrame, df2: pd.DataFrame
Expand All @@ -193,6 +229,20 @@ def t8a(dataframe: pa.Table) -> pa.Table:
return dataframe


@task
def t8b(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dataclass_cols]:
# Arrow table -> s3 (parquet)
print(dataframe.columns)
return StructuredDataset(dataframe=dataframe)


@task
def t8c(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dict_cols]:
# Arrow table -> s3 (parquet)
print(dataframe.columns)
return StructuredDataset(dataframe=dataframe)


@task
def t9(dataframe: np.ndarray) -> Annotated[StructuredDataset, my_cols]:
# numpy -> Arrow table -> s3 (parquet)
Expand All @@ -206,6 +256,20 @@ def t10(dataset: Annotated[StructuredDataset, my_cols]) -> np.ndarray:
return np_array


@task
def t10a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> np.ndarray:
# s3 (parquet) -> Arrow table -> numpy
np_array = dataset.open(np.ndarray).all()
return np_array


@task
def t10b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> np.ndarray:
# s3 (parquet) -> Arrow table -> numpy
np_array = dataset.open(np.ndarray).all()
return np_array


StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler())
StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler())

Expand All @@ -223,6 +287,20 @@ def t12(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return df


@task
def t12a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame:
# csv -> pandas
df = dataset.open(pd.DataFrame).all()
return df


@task
def t12b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame:
# csv -> pandas
df = dataset.open(pd.DataFrame).all()
return df


@task
def generate_pandas() -> pd.DataFrame:
return pd_df
Expand All @@ -249,15 +327,25 @@ def wf():
t3(dataset=StructuredDataset(uri=PANDAS_PATH))
t3a(dataset=StructuredDataset(uri=PANDAS_PATH))
t4(dataset=StructuredDataset(uri=PANDAS_PATH))
t4a(dataset=StructuredDataset(uri=PANDAS_PATH))
t4b(dataset=StructuredDataset(uri=PANDAS_PATH))
t5(dataframe=df)
t6(dataset=StructuredDataset(uri=BQ_PATH))
t6a(dataset=StructuredDataset(uri=BQ_PATH))
t6b(dataset=StructuredDataset(uri=BQ_PATH))
t7(df1=df, df2=df)
t8(dataframe=arrow_df)
t8a(dataframe=arrow_df)
t8b(dataframe=arrow_df)
t8c(dataframe=arrow_df)
t9(dataframe=np_array)
t10(dataset=StructuredDataset(uri=NUMPY_PATH))
t10a(dataset=StructuredDataset(uri=NUMPY_PATH))
t10b(dataset=StructuredDataset(uri=NUMPY_PATH))
t11(dataframe=df)
t12(dataset=StructuredDataset(uri=PANDAS_PATH))
t12a(dataset=StructuredDataset(uri=PANDAS_PATH))
t12b(dataset=StructuredDataset(uri=PANDAS_PATH))


def test_structured_dataset_wf():
Expand Down
Loading