From f69f7a3550509416921e9ecef4f02de259c92099 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Thu, 14 Mar 2024 23:20:34 +0800 Subject: [PATCH 01/11] StructuredDataset: add recursive `flatten_dict()` Signed-off-by: Austin Liu wip Signed-off-by: Austin Liu fmt Signed-off-by: Austin Liu fix Signed-off-by: Austin Liu fix Signed-off-by: Austin Liu fix Signed-off-by: Austin Liu --- flytekit/core/base_task.py | 14 ++++++++--- .../types/structured/structured_dataset.py | 23 +++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 7411fd635e..62886672ba 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -17,7 +17,6 @@ """ import asyncio -import collections import datetime import inspect import warnings @@ -84,7 +83,7 @@ UNSET_CARD = "_uc" -def kwtypes(**kwargs) -> OrderedDict[str, Type]: +def kwtypes(*args, **kwargs) -> OrderedDict[str, Type]: """ This is a small helper function to convert the keyword arguments to an OrderedDict of types. @@ -92,7 +91,16 @@ def kwtypes(**kwargs) -> OrderedDict[str, Type]: kwtypes(a=int, b=str) """ - d = collections.OrderedDict() + d = OrderedDict() + for arg in args: + # handle positional arguments: dataclass + if hasattr(arg, "__annotations__"): + dm = vars(arg) + d.update(dm["__annotations__"]) + # handle positional arguments: dict + elif isinstance(arg, dict): + d.update(arg) + # handle named arguments for k, v in kwargs.items(): d[k] = v return d diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 8faed9ff45..0cff8fb8c1 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -820,13 +820,36 @@ def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: return type_models.LiteralType(map_value_type=self._get_dataset_column_literal_type(t.__args__[1])) raise AssertionError(f"type {t} is currently not supported by StructuredDataset") + def flatten_dict(self, nested_dict): + result = {} + + def _flatten(sub_dict, parent_key=""): + for key, value in sub_dict.items(): + current_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, dict): + return _flatten(value, current_key) + elif hasattr(value, "__dataclass_fields__"): + fields = getattr(value, "__dataclass_fields__") + d = {k: v.type for k, v in fields.items()} + return _flatten(d, current_key) + else: + result[current_key] = value + return result + + return _flatten(sub_dict=nested_dict) + def _convert_ordered_dict_of_columns_to_list( self, column_map: typing.Optional[typing.OrderedDict[str, Type]] ) -> typing.List[StructuredDatasetType.DatasetColumn]: converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = [] if column_map is None or len(column_map) == 0: return converted_cols + flat_column_map = {} for k, v in column_map.items(): + d = dict() + d[k] = v + flat_column_map.update(self.flatten_dict(d)) + for k, v in flat_column_map.items(): lt = self._get_dataset_column_literal_type(v) converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt)) return converted_cols From 15600ff6006ae4394ea0f153f4960fe2fe06a197 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 15 Apr 2024 22:13:54 +0800 Subject: [PATCH 02/11] add `levels_wf` as integration test Signed-off-by: Austin Liu fmt Signed-off-by: Austin Liu --- .../integration/remote/test_remote.py | 17 +- .../workflows/basic/structured_datasets.py | 164 ++++++++++++++++++ 2 files changed, 171 insertions(+), 10 deletions(-) create mode 100644 tests/flytekit/integration/remote/workflows/basic/structured_datasets.py diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index f23fc061d9..ad0358874a 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -55,16 +55,7 @@ def run(file_name, wf_name, *args): [ "pyflyte", "--verbose", - "-c", - CONFIG, "run", - "--remote", - "--image", - IMAGE, - "--project", - PROJECT, - "--domain", - DOMAIN, MODULE_PATH / file_name, wf_name, *args, @@ -74,10 +65,16 @@ def run(file_name, wf_name, *args): # test child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2. -def test_remote_run(): +def test_remote_run_child_workflow(): run("child_workflow.py", "parent_wf", "--a", "3") +# test structured_datasets.levels_wf. +# TODO: assert `execution.outputs == ['age', 'level2']`. +def test_remote_run_structured_datasets(): + run("structured_datasets.py", "levels_wf") + + def test_fetch_execute_launch_plan(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) flyte_launch_plan = remote.fetch_launch_plan(name="basic.hello_world.my_wf", version=VERSION) diff --git a/tests/flytekit/integration/remote/workflows/basic/structured_datasets.py b/tests/flytekit/integration/remote/workflows/basic/structured_datasets.py new file mode 100644 index 0000000000..4e95766392 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/structured_datasets.py @@ -0,0 +1,164 @@ +from dataclasses import dataclass +from typing import Annotated + +import pandas as pd + +from flytekit import ImageSpec, StructuredDataset, kwtypes, task, workflow + +## Add `@task(container_image=image)` if want to test in remote mode. Remove `git_commit_sha` after merged. +## Add `GOOGLE_APPLICATION_CREDENTIALS` if wanna test `google-cloud-bigquery`. +flytekit_dev_version = "https://github.com/austin362667/flytekit.git@f5cd70dd053e6f3d4aaf5b90d9c4b28f32c0980a" +image = ImageSpec( + packages=[ + "pandas", + # "google-cloud-bigquery", + # "google-cloud-bigquery-storage", + # "flytekitplugins-bigquery==1.11.0", + f"git+{flytekit_dev_version}", + ], + apt_packages=["git"], + # source_root="./keys", + # env={"GOOGLE_APPLICATION_CREDENTIALS": "./gcp-service-account.json"}, + platform="linux/arm64", + registry="localhost:30000", +) + + +## Case 1. +data = [ + { + "company": "XYZ pvt ltd", + "location": "London", + "info": {"president": "Rakesh Kapoor", "contacts": {"email": "contact@xyz.com", "tel": "9876543210"}}, + }, + { + "company": "ABC pvt ltd", + "location": "USA", + "info": {"president": "Kapoor Rakesh", "contacts": {"email": "contact@abc.com", "tel": "0123456789"}}, + }, +] + + +@dataclass +class ContactsField: + email: str + tel: str + + +@dataclass +class InfoField: + president: str + contacts: ContactsField + + +@dataclass +class CompanyField: + location: str + info: InfoField + company: str + + +MyArgDataset = Annotated[StructuredDataset, kwtypes(company=str)] +MyDictDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str}})] +MyDictListDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str, "email": str}})] +MyTopDataClassDataset = Annotated[StructuredDataset, kwtypes(CompanyField)] +MySecondDataClassDataset = Annotated[StructuredDataset, kwtypes(info=InfoField)] +MyNestedDataClassDataset = Annotated[StructuredDataset, kwtypes(info=kwtypes(contacts=ContactsField))] + + +@task(container_image=image) +def create_bq_table() -> StructuredDataset: + df = pd.json_normalize(data, max_level=0) + print("original dataframe: \n", df) + + # Enable one of GCP `uri` below if you want. You can replace `uri` with your own google cloud endpoints. + return StructuredDataset( + dataframe=df, + # uri= "gs://flyte_austin362667_bucket/nested_types" + # uri= "bq://flyte-austin362667-gcp:dataset.nested_type" + ) + + +@task(container_image=image) +def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyArgDataset dataframe: \n", t) + return t + + +@task(container_image=image) +def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyDictDataset dataframe: \n", t) + return t + + +@task(container_image=image) +def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyDictListDataset dataframe: \n", t) + return t + + +@task(container_image=image) +def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyTopDataClassDataset dataframe: \n", t) + return t + + +@task(container_image=image) +def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MySecondDataClassDataset dataframe: \n", t) + return t + + +@task(container_image=image) +def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyNestedDataClassDataset dataframe: \n", t) + return t + + +@workflow +def contacts_wf(): + sd = create_bq_table() + print_table_by_arg(sd=sd) + print_table_by_dict(sd=sd) + print_table_by_list_dict(sd=sd) + print_table_by_top_dataclass(sd=sd) + print_table_by_second_dataclass(sd=sd) + print_table_by_nested_dataclass(sd=sd) + return + + +## Case 2. +@dataclass +class Levels: + # level1: str + level2: str + + +Schema = Annotated[StructuredDataset, kwtypes(age=int, levels=Levels)] + + +@task(container_image=image) +def mytask_w() -> StructuredDataset: + df = pd.DataFrame({"age": [1, 2], "levels": [{"level1": "1", "level2": "2"}, {"level1": "2", "level2": "4"}]}) + return StructuredDataset(dataframe=df, uri=None) + + +# Should only show `level2` string. +@task(container_image=image) +def mytask_r(sd: Schema) -> list[str]: + t = sd.open(pd.DataFrame).all() + print("dataframe: \n", t) + return t.columns.tolist() + + +@workflow +def levels_wf(): + sd = mytask_w() + mytask_r(sd=sd) + return From bf892f7e634ace6e853c59166cc336b05f5893cc Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Tue, 16 Apr 2024 00:02:06 +0800 Subject: [PATCH 03/11] add structured_dataset unit tests Signed-off-by: Austin Liu --- .../integration/remote/test_remote.py | 15 +- .../workflows/basic/structured_datasets.py | 164 ------------------ .../test_structured_dataset_workflow.py | 88 ++++++++++ 3 files changed, 97 insertions(+), 170 deletions(-) delete mode 100644 tests/flytekit/integration/remote/workflows/basic/structured_datasets.py diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index ad0358874a..0f14e12422 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -55,7 +55,16 @@ def run(file_name, wf_name, *args): [ "pyflyte", "--verbose", + "-c", + CONFIG, "run", + "--remote", + "--image", + IMAGE, + "--project", + PROJECT, + "--domain", + DOMAIN, MODULE_PATH / file_name, wf_name, *args, @@ -69,12 +78,6 @@ def test_remote_run_child_workflow(): run("child_workflow.py", "parent_wf", "--a", "3") -# test structured_datasets.levels_wf. -# TODO: assert `execution.outputs == ['age', 'level2']`. -def test_remote_run_structured_datasets(): - run("structured_datasets.py", "levels_wf") - - def test_fetch_execute_launch_plan(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) flyte_launch_plan = remote.fetch_launch_plan(name="basic.hello_world.my_wf", version=VERSION) diff --git a/tests/flytekit/integration/remote/workflows/basic/structured_datasets.py b/tests/flytekit/integration/remote/workflows/basic/structured_datasets.py deleted file mode 100644 index 4e95766392..0000000000 --- a/tests/flytekit/integration/remote/workflows/basic/structured_datasets.py +++ /dev/null @@ -1,164 +0,0 @@ -from dataclasses import dataclass -from typing import Annotated - -import pandas as pd - -from flytekit import ImageSpec, StructuredDataset, kwtypes, task, workflow - -## Add `@task(container_image=image)` if want to test in remote mode. Remove `git_commit_sha` after merged. -## Add `GOOGLE_APPLICATION_CREDENTIALS` if wanna test `google-cloud-bigquery`. -flytekit_dev_version = "https://github.com/austin362667/flytekit.git@f5cd70dd053e6f3d4aaf5b90d9c4b28f32c0980a" -image = ImageSpec( - packages=[ - "pandas", - # "google-cloud-bigquery", - # "google-cloud-bigquery-storage", - # "flytekitplugins-bigquery==1.11.0", - f"git+{flytekit_dev_version}", - ], - apt_packages=["git"], - # source_root="./keys", - # env={"GOOGLE_APPLICATION_CREDENTIALS": "./gcp-service-account.json"}, - platform="linux/arm64", - registry="localhost:30000", -) - - -## Case 1. -data = [ - { - "company": "XYZ pvt ltd", - "location": "London", - "info": {"president": "Rakesh Kapoor", "contacts": {"email": "contact@xyz.com", "tel": "9876543210"}}, - }, - { - "company": "ABC pvt ltd", - "location": "USA", - "info": {"president": "Kapoor Rakesh", "contacts": {"email": "contact@abc.com", "tel": "0123456789"}}, - }, -] - - -@dataclass -class ContactsField: - email: str - tel: str - - -@dataclass -class InfoField: - president: str - contacts: ContactsField - - -@dataclass -class CompanyField: - location: str - info: InfoField - company: str - - -MyArgDataset = Annotated[StructuredDataset, kwtypes(company=str)] -MyDictDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str}})] -MyDictListDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str, "email": str}})] -MyTopDataClassDataset = Annotated[StructuredDataset, kwtypes(CompanyField)] -MySecondDataClassDataset = Annotated[StructuredDataset, kwtypes(info=InfoField)] -MyNestedDataClassDataset = Annotated[StructuredDataset, kwtypes(info=kwtypes(contacts=ContactsField))] - - -@task(container_image=image) -def create_bq_table() -> StructuredDataset: - df = pd.json_normalize(data, max_level=0) - print("original dataframe: \n", df) - - # Enable one of GCP `uri` below if you want. You can replace `uri` with your own google cloud endpoints. - return StructuredDataset( - dataframe=df, - # uri= "gs://flyte_austin362667_bucket/nested_types" - # uri= "bq://flyte-austin362667-gcp:dataset.nested_type" - ) - - -@task(container_image=image) -def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame: - t = sd.open(pd.DataFrame).all() - print("MyArgDataset dataframe: \n", t) - return t - - -@task(container_image=image) -def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame: - t = sd.open(pd.DataFrame).all() - print("MyDictDataset dataframe: \n", t) - return t - - -@task(container_image=image) -def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame: - t = sd.open(pd.DataFrame).all() - print("MyDictListDataset dataframe: \n", t) - return t - - -@task(container_image=image) -def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame: - t = sd.open(pd.DataFrame).all() - print("MyTopDataClassDataset dataframe: \n", t) - return t - - -@task(container_image=image) -def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame: - t = sd.open(pd.DataFrame).all() - print("MySecondDataClassDataset dataframe: \n", t) - return t - - -@task(container_image=image) -def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame: - t = sd.open(pd.DataFrame).all() - print("MyNestedDataClassDataset dataframe: \n", t) - return t - - -@workflow -def contacts_wf(): - sd = create_bq_table() - print_table_by_arg(sd=sd) - print_table_by_dict(sd=sd) - print_table_by_list_dict(sd=sd) - print_table_by_top_dataclass(sd=sd) - print_table_by_second_dataclass(sd=sd) - print_table_by_nested_dataclass(sd=sd) - return - - -## Case 2. -@dataclass -class Levels: - # level1: str - level2: str - - -Schema = Annotated[StructuredDataset, kwtypes(age=int, levels=Levels)] - - -@task(container_image=image) -def mytask_w() -> StructuredDataset: - df = pd.DataFrame({"age": [1, 2], "levels": [{"level1": "1", "level2": "2"}, {"level1": "2", "level2": "4"}]}) - return StructuredDataset(dataframe=df, uri=None) - - -# Should only show `level2` string. -@task(container_image=image) -def mytask_r(sd: Schema) -> list[str]: - t = sd.open(pd.DataFrame).all() - print("dataframe: \n", t) - return t.columns.tolist() - - -@workflow -def levels_wf(): - sd = mytask_w() - mytask_r(sd=sd) - return diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index 3b0bf96e7a..8309119839 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -1,5 +1,6 @@ import os import typing +from dataclasses import dataclass import numpy as np import pyarrow as pa @@ -28,7 +29,16 @@ NUMPY_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory() BQ_PATH = "bq://flyte-dataset:flyte.table" + +@dataclass +class MyCols: + Name: str + Age: int + + my_cols = kwtypes(Name=str, Age=int) +my_dataclass_cols = kwtypes(MyCols) +my_dict_cols = kwtypes({"Name": str, "Age": int}) fields = [("Name", pa.string()), ("Age", pa.int32())] arrow_schema = pa.schema(fields) pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) @@ -157,6 +167,18 @@ def t4(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: return dataset.open(pd.DataFrame).all() +@task +def t4a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame: + # s3 (parquet) -> pandas -> s3 (parquet) + return dataset.open(pd.DataFrame).all() + + +@task +def t4b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame: + # s3 (parquet) -> pandas -> s3 (parquet) + return dataset.open(pd.DataFrame).all() + + @task def t5(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: # s3 (parquet) -> pandas -> bq @@ -170,6 +192,20 @@ def t6(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: return df +@task +def t6a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame: + # bq -> pandas -> s3 (parquet) + df = dataset.open(pd.DataFrame).all() + return df + + +@task +def t6b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame: + # bq -> pandas -> s3 (parquet) + df = dataset.open(pd.DataFrame).all() + return df + + @task def t7( df1: pd.DataFrame, df2: pd.DataFrame @@ -193,6 +229,20 @@ def t8a(dataframe: pa.Table) -> pa.Table: return dataframe +@task +def t8b(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dataclass_cols]: + # Arrow table -> s3 (parquet) + print(dataframe.columns) + return StructuredDataset(dataframe=dataframe) + + +@task +def t8c(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dict_cols]: + # Arrow table -> s3 (parquet) + print(dataframe.columns) + return StructuredDataset(dataframe=dataframe) + + @task def t9(dataframe: np.ndarray) -> Annotated[StructuredDataset, my_cols]: # numpy -> Arrow table -> s3 (parquet) @@ -206,6 +256,20 @@ def t10(dataset: Annotated[StructuredDataset, my_cols]) -> np.ndarray: return np_array +@task +def t10a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> np.ndarray: + # s3 (parquet) -> Arrow table -> numpy + np_array = dataset.open(np.ndarray).all() + return np_array + + +@task +def t10b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> np.ndarray: + # s3 (parquet) -> Arrow table -> numpy + np_array = dataset.open(np.ndarray).all() + return np_array + + StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler()) StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler()) @@ -223,6 +287,20 @@ def t12(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: return df +@task +def t12a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame: + # csv -> pandas + df = dataset.open(pd.DataFrame).all() + return df + + +@task +def t12b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame: + # csv -> pandas + df = dataset.open(pd.DataFrame).all() + return df + + @task def generate_pandas() -> pd.DataFrame: return pd_df @@ -249,15 +327,25 @@ def wf(): t3(dataset=StructuredDataset(uri=PANDAS_PATH)) t3a(dataset=StructuredDataset(uri=PANDAS_PATH)) t4(dataset=StructuredDataset(uri=PANDAS_PATH)) + t4a(dataset=StructuredDataset(uri=PANDAS_PATH)) + t4b(dataset=StructuredDataset(uri=PANDAS_PATH)) t5(dataframe=df) t6(dataset=StructuredDataset(uri=BQ_PATH)) + t6a(dataset=StructuredDataset(uri=BQ_PATH)) + t6b(dataset=StructuredDataset(uri=BQ_PATH)) t7(df1=df, df2=df) t8(dataframe=arrow_df) t8a(dataframe=arrow_df) + t8b(dataframe=arrow_df) + t8c(dataframe=arrow_df) t9(dataframe=np_array) t10(dataset=StructuredDataset(uri=NUMPY_PATH)) + t10a(dataset=StructuredDataset(uri=NUMPY_PATH)) + t10b(dataset=StructuredDataset(uri=NUMPY_PATH)) t11(dataframe=df) t12(dataset=StructuredDataset(uri=PANDAS_PATH)) + t12a(dataset=StructuredDataset(uri=PANDAS_PATH)) + t12b(dataset=StructuredDataset(uri=PANDAS_PATH)) def test_structured_dataset_wf(): From 774c16eb41b5a439b970b7dee230f76a992cca63 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Wed, 17 Apr 2024 16:27:18 +0800 Subject: [PATCH 04/11] make `kwtypes()` only accepts named args Signed-off-by: Austin Liu --- flytekit/core/base_task.py | 12 +- .../types/structured/structured_dataset.py | 76 +++-- .../test_structured_dataset_workflow.py | 4 +- ...tured_dataset_workflow_with_nested_type.py | 303 ++++++++++++++++++ 4 files changed, 351 insertions(+), 44 deletions(-) create mode 100644 tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 62886672ba..db68222ed1 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -83,7 +83,7 @@ UNSET_CARD = "_uc" -def kwtypes(*args, **kwargs) -> OrderedDict[str, Type]: +def kwtypes(**kwargs) -> OrderedDict[str, Type]: """ This is a small helper function to convert the keyword arguments to an OrderedDict of types. @@ -92,15 +92,7 @@ def kwtypes(*args, **kwargs) -> OrderedDict[str, Type]: kwtypes(a=int, b=str) """ d = OrderedDict() - for arg in args: - # handle positional arguments: dataclass - if hasattr(arg, "__annotations__"): - dm = vars(arg) - d.update(dm["__annotations__"]) - # handle positional arguments: dict - elif isinstance(arg, dict): - d.update(arg) - # handle named arguments + # only handle named arguments for k, v in kwargs.items(): d[k] = v return d diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 0cff8fb8c1..abde7baf36 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -114,6 +114,29 @@ def iter(self) -> Generator[DF, None, None]: ) +# flat the nested column map recursively +def flatten_dict(nested_dict) -> typing.Dict: + result = {} + + def _flatten(sub_dict, parent_key=""): + for key, value in sub_dict.items(): + current_key = f"{parent_key}.{key}" if parent_key else key + # handle sub `dict` + if isinstance(value, dict): + return _flatten(value, current_key) + # handle sub `dataclass` + elif hasattr(value, "__dataclass_fields__"): + fields = getattr(value, "__dataclass_fields__") + d = {k: v.type for k, v in fields.items()} + return _flatten(d, current_key) + # already flattened + else: + result[current_key] = value + return result + + return _flatten(sub_dict=nested_dict) + + def extract_cols_and_format( t: typing.Any, ) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional["pa.lib.Schema"]]: @@ -142,18 +165,29 @@ def extract_cols_and_format( if get_origin(t) is Annotated: base_type, *annotate_args = get_args(t) for aa in annotate_args: - if isinstance(aa, StructuredDatasetFormat): + d = collections.OrderedDict() + # handle `dataclass` argument: + if hasattr(aa, "__annotations__"): + dm = vars(aa) + d.update(dm["__annotations__"]) + # handle `dict` argument: + elif isinstance(aa, dict): + d.update(aa) + # handle defaults: + else: + d = aa + if isinstance(d, StructuredDatasetFormat): if fmt != "": - raise ValueError(f"A format was already specified {fmt}, cannot use {aa}") - fmt = aa - elif isinstance(aa, collections.OrderedDict): + raise ValueError(f"A format was already specified {fmt}, cannot use {d}") + fmt = d + elif isinstance(d, collections.OrderedDict): if ordered_dict_cols is not None: - raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}") - ordered_dict_cols = aa - elif isinstance(aa, pa.lib.Schema): + raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {d}") + ordered_dict_cols = d + elif isinstance(d, pa.lib.Schema): if pa_schema is not None: - raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}") - pa_schema = aa + raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {d}") + pa_schema = d return base_type, ordered_dict_cols, fmt, pa_schema # We return None as the format instead of parquet or something because the transformer engine may find @@ -820,35 +854,13 @@ def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: return type_models.LiteralType(map_value_type=self._get_dataset_column_literal_type(t.__args__[1])) raise AssertionError(f"type {t} is currently not supported by StructuredDataset") - def flatten_dict(self, nested_dict): - result = {} - - def _flatten(sub_dict, parent_key=""): - for key, value in sub_dict.items(): - current_key = f"{parent_key}.{key}" if parent_key else key - if isinstance(value, dict): - return _flatten(value, current_key) - elif hasattr(value, "__dataclass_fields__"): - fields = getattr(value, "__dataclass_fields__") - d = {k: v.type for k, v in fields.items()} - return _flatten(d, current_key) - else: - result[current_key] = value - return result - - return _flatten(sub_dict=nested_dict) - def _convert_ordered_dict_of_columns_to_list( self, column_map: typing.Optional[typing.OrderedDict[str, Type]] ) -> typing.List[StructuredDatasetType.DatasetColumn]: converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = [] if column_map is None or len(column_map) == 0: return converted_cols - flat_column_map = {} - for k, v in column_map.items(): - d = dict() - d[k] = v - flat_column_map.update(self.flatten_dict(d)) + flat_column_map = flatten_dict(column_map) for k, v in flat_column_map.items(): lt = self._get_dataset_column_literal_type(v) converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt)) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index 8309119839..91fa72b526 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -37,8 +37,8 @@ class MyCols: my_cols = kwtypes(Name=str, Age=int) -my_dataclass_cols = kwtypes(MyCols) -my_dict_cols = kwtypes({"Name": str, "Age": int}) +my_dataclass_cols = MyCols +my_dict_cols = {"Name": str, "Age": int} fields = [("Name", pa.string()), ("Age", pa.int32())] arrow_schema = pa.schema(fields) pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py new file mode 100644 index 0000000000..1a9c7e1c74 --- /dev/null +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py @@ -0,0 +1,303 @@ +import os +import typing +from dataclasses import dataclass +from tabulate import tabulate +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from typing_extensions import Annotated + +from flytekit import FlyteContext, FlyteContextManager, kwtypes, task, workflow +from flytekit.models import literals +from flytekit.models.literals import StructuredDatasetMetadata +from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.basic_dfs import CSVToPandasDecodingHandler, PandasToCSVEncodingHandler +from flytekit.types.structured.structured_dataset import ( + CSV, + DF, + PARQUET, + StructuredDataset, + StructuredDatasetDecoder, + StructuredDatasetEncoder, + StructuredDatasetTransformerEngine, +) +from flytekit import ImageSpec, StructuredDataset, kwtypes, task, workflow + +pd = pytest.importorskip("pandas") + +PANDAS_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory() +NUMPY_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory() +BQ_PATH = "bq://flyte-dataset:flyte.table" + + +@dataclass +class MyCols: + Name: str + Age: int + + +my_cols = kwtypes(Name=str, Age=int) +my_dataclass_cols = MyCols +my_dict_cols = {"Name": str, "Age": int} +fields = [("Name", pa.string()), ("Age", pa.int32())] +arrow_schema = pa.schema(fields) +pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + +class MockBQEncodingHandlers(StructuredDatasetEncoder): + def __init__(self): + super().__init__(pd.DataFrame, "bq", "") + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + return literals.StructuredDataset( + uri="bq://bucket/key", metadata=StructuredDatasetMetadata(structured_dataset_type) + ) + + +class MockBQDecodingHandlers(StructuredDatasetDecoder): + def __init__(self): + super().__init__(pd.DataFrame, "bq", "") + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> pd.DataFrame: + return pd_df + + +StructuredDatasetTransformerEngine.register(MockBQEncodingHandlers(), False, True) +StructuredDatasetTransformerEngine.register(MockBQDecodingHandlers(), False, True) + + +class NumpyRenderer: + """ + The Polars DataFrame summary statistics are rendered as an HTML table. + """ + + def to_html(self, array: np.ndarray) -> str: + return pd.DataFrame(array).describe().to_html() + + +@pytest.fixture(autouse=True) +def numpy_type(): + class NumpyEncodingHandlers(StructuredDatasetEncoder): + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + path = typing.cast(str, structured_dataset.uri) + if not path: + path = ctx.file_access.join( + ctx.file_access.raw_output_prefix, + ctx.file_access.get_random_string(), + ) + df = typing.cast(np.ndarray, structured_dataset.dataframe) + name = ["col" + str(i) for i in range(len(df))] + table = pa.Table.from_arrays(df, name) + local_dir = ctx.file_access.get_random_local_directory() + local_path = os.path.join(local_dir, f"{0:05}") + pq.write_table(table, local_path) + ctx.file_access.upload_directory(local_dir, path) + structured_dataset_type.format = PARQUET + return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + class NumpyDecodingHandlers(StructuredDatasetDecoder): + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> typing.Union[DF, typing.Generator[DF, None, None]]: + path = flyte_value.uri + local_dir = ctx.file_access.get_random_local_directory() + ctx.file_access.get_data(path, local_dir, is_multipart=True) + table = pq.read_table(local_dir) + return table.to_pandas().to_numpy() + + StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray)) + StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray)) + StructuredDatasetTransformerEngine.register_renderer(np.ndarray, NumpyRenderer()) + + +StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler()) +StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler()) + + + +## Case 1. +data = [ + { + "company": "XYZ pvt ltd", + "location": "London", + "info": {"president": "Rakesh Kapoor", "contacts": {"email": "contact@xyz.com", "tel": "9876543210"}}, + }, + { + "company": "ABC pvt ltd", + "location": "USA", + "info": {"president": "Kapoor Rakesh", "contacts": {"email": "contact@abc.com", "tel": "0123456789"}}, + }, +] + + +@dataclass +class ContactsField: + email: str + tel: str + + +@dataclass +class InfoField: + president: str + contacts: ContactsField + + +@dataclass +class CompanyField: + location: str + info: InfoField + company: str + + +MyArgDataset = Annotated[StructuredDataset, kwtypes(company=str)] +MyDictDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str}})] +MyDictListDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str, "email": str}})] +MyTopDataClassDataset = Annotated[StructuredDataset, CompanyField] +MyTopDictDataset = Annotated[StructuredDataset, {"company": str, "location": str}] +MySecondDataClassDataset = Annotated[StructuredDataset, kwtypes(info=InfoField)] +MyNestedDataClassDataset = Annotated[StructuredDataset, kwtypes(info=kwtypes(contacts=ContactsField))] + +@task() +def create_pd_table() -> StructuredDataset: + df = pd.json_normalize(data, max_level=0) + print("original dataframe: \n", tabulate(df, headers='keys', tablefmt='psql')) + + return StructuredDataset( + dataframe=df, + uri=PANDAS_PATH + ) + +@task() +def create_bq_table() -> StructuredDataset: + df = pd.json_normalize(data, max_level=0) + print("original dataframe: \n", tabulate(df, headers='keys', tablefmt='psql')) + + # Enable one of GCP `uri` below if you want. You can replace `uri` with your own google cloud endpoints. + return StructuredDataset( + dataframe=df, + uri=BQ_PATH + ) + +@task() +def create_np_table() -> StructuredDataset: + df = pd.json_normalize(data, max_level=0) + print("original dataframe: \n", tabulate(df, headers='keys', tablefmt='psql')) + + return StructuredDataset( + dataframe=df, + uri=NUMPY_PATH + ) + +@task() +def create_ar_table() -> StructuredDataset: + df = pa.Table.from_pandas(pd.json_normalize(data, max_level=0)) + print("original dataframe: \n", tabulate(df, headers='keys', tablefmt='psql')) + + return StructuredDataset( + dataframe=df, + ) + + +@task() +def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyArgDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + return t + + +@task() +def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyDictDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + return t + + +@task() +def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyDictListDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + return t + + +@task() +def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyTopDataClassDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + return t + +@task() +def print_table_by_top_dict(sd: MyTopDictDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyTopDictDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + return t + +@task() +def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MySecondDataClassDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + return t + + +@task() +def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyNestedDataClassDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + return t + +@workflow +def wf(): + pd_sd = create_pd_table() + print_table_by_arg(sd=pd_sd) + print_table_by_dict(sd=pd_sd) + print_table_by_list_dict(sd=pd_sd) + print_table_by_top_dataclass(sd=pd_sd) + print_table_by_top_dict(sd=pd_sd) + print_table_by_second_dataclass(sd=pd_sd) + print_table_by_nested_dataclass(sd=pd_sd) + bq_sd = create_pd_table() + print_table_by_arg(sd=bq_sd) + print_table_by_dict(sd=bq_sd) + print_table_by_list_dict(sd=bq_sd) + print_table_by_top_dataclass(sd=bq_sd) + print_table_by_top_dict(sd=bq_sd) + print_table_by_second_dataclass(sd=bq_sd) + print_table_by_nested_dataclass(sd=bq_sd) + np_sd = create_pd_table() + print_table_by_arg(sd=np_sd) + print_table_by_dict(sd=np_sd) + print_table_by_list_dict(sd=np_sd) + print_table_by_top_dataclass(sd=np_sd) + print_table_by_top_dict(sd=np_sd) + print_table_by_second_dataclass(sd=np_sd) + print_table_by_nested_dataclass(sd=np_sd) + ar_sd = create_pd_table() + print_table_by_arg(sd=ar_sd) + print_table_by_dict(sd=ar_sd) + print_table_by_list_dict(sd=ar_sd) + print_table_by_top_dataclass(sd=ar_sd) + print_table_by_top_dict(sd=ar_sd) + print_table_by_second_dataclass(sd=ar_sd) + print_table_by_nested_dataclass(sd=ar_sd) + return + +def test_structured_dataset_wf(): + wf() From 601939cd43d7844ea52b6e762459061fb2d4807a Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Fri, 19 Apr 2024 12:13:28 +0800 Subject: [PATCH 05/11] lint Signed-off-by: Austin Liu --- ...tured_dataset_workflow_with_nested_type.py | 54 +++++++++---------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py index 1a9c7e1c74..23c63aa53d 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py @@ -1,28 +1,26 @@ import os import typing from dataclasses import dataclass -from tabulate import tabulate + import numpy as np import pyarrow as pa import pyarrow.parquet as pq import pytest +from tabulate import tabulate from typing_extensions import Annotated -from flytekit import FlyteContext, FlyteContextManager, kwtypes, task, workflow +from flytekit import FlyteContext, FlyteContextManager, StructuredDataset, kwtypes, task, workflow from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.basic_dfs import CSVToPandasDecodingHandler, PandasToCSVEncodingHandler from flytekit.types.structured.structured_dataset import ( - CSV, DF, PARQUET, - StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, StructuredDatasetTransformerEngine, ) -from flytekit import ImageSpec, StructuredDataset, kwtypes, task, workflow pd = pytest.importorskip("pandas") @@ -133,7 +131,6 @@ def decode( StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler()) - ## Case 1. data = [ { @@ -176,41 +173,36 @@ class CompanyField: MySecondDataClassDataset = Annotated[StructuredDataset, kwtypes(info=InfoField)] MyNestedDataClassDataset = Annotated[StructuredDataset, kwtypes(info=kwtypes(contacts=ContactsField))] + @task() def create_pd_table() -> StructuredDataset: df = pd.json_normalize(data, max_level=0) - print("original dataframe: \n", tabulate(df, headers='keys', tablefmt='psql')) + print("original dataframe: \n", tabulate(df, headers="keys", tablefmt="psql")) + + return StructuredDataset(dataframe=df, uri=PANDAS_PATH) - return StructuredDataset( - dataframe=df, - uri=PANDAS_PATH - ) @task() def create_bq_table() -> StructuredDataset: df = pd.json_normalize(data, max_level=0) - print("original dataframe: \n", tabulate(df, headers='keys', tablefmt='psql')) + print("original dataframe: \n", tabulate(df, headers="keys", tablefmt="psql")) # Enable one of GCP `uri` below if you want. You can replace `uri` with your own google cloud endpoints. - return StructuredDataset( - dataframe=df, - uri=BQ_PATH - ) + return StructuredDataset(dataframe=df, uri=BQ_PATH) + @task() def create_np_table() -> StructuredDataset: df = pd.json_normalize(data, max_level=0) - print("original dataframe: \n", tabulate(df, headers='keys', tablefmt='psql')) + print("original dataframe: \n", tabulate(df, headers="keys", tablefmt="psql")) + + return StructuredDataset(dataframe=df, uri=NUMPY_PATH) - return StructuredDataset( - dataframe=df, - uri=NUMPY_PATH - ) @task() def create_ar_table() -> StructuredDataset: df = pa.Table.from_pandas(pd.json_normalize(data, max_level=0)) - print("original dataframe: \n", tabulate(df, headers='keys', tablefmt='psql')) + print("original dataframe: \n", tabulate(df, headers="keys", tablefmt="psql")) return StructuredDataset( dataframe=df, @@ -220,49 +212,52 @@ def create_ar_table() -> StructuredDataset: @task() def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyArgDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + print("MyArgDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) return t @task() def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyDictDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + print("MyDictDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) return t @task() def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyDictListDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + print("MyDictListDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) return t @task() def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyTopDataClassDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + print("MyTopDataClassDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) return t + @task() def print_table_by_top_dict(sd: MyTopDictDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyTopDictDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + print("MyTopDictDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) return t + @task() def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MySecondDataClassDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + print("MySecondDataClassDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) return t @task() def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyNestedDataClassDataset dataframe: \n", tabulate(t, headers='keys', tablefmt='psql')) + print("MyNestedDataClassDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) return t + @workflow def wf(): pd_sd = create_pd_table() @@ -299,5 +294,6 @@ def wf(): print_table_by_nested_dataclass(sd=ar_sd) return + def test_structured_dataset_wf(): wf() From facdeb86213475cf347256cb63e6dd7a322f0bdd Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Sun, 21 Apr 2024 15:57:29 +0800 Subject: [PATCH 06/11] remove tabulate Signed-off-by: Austin Liu --- ...tured_dataset_workflow_with_nested_type.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py index 23c63aa53d..709466ca82 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py @@ -6,7 +6,6 @@ import pyarrow as pa import pyarrow.parquet as pq import pytest -from tabulate import tabulate from typing_extensions import Annotated from flytekit import FlyteContext, FlyteContextManager, StructuredDataset, kwtypes, task, workflow @@ -177,7 +176,7 @@ class CompanyField: @task() def create_pd_table() -> StructuredDataset: df = pd.json_normalize(data, max_level=0) - print("original dataframe: \n", tabulate(df, headers="keys", tablefmt="psql")) + print("original dataframe: \n", df) return StructuredDataset(dataframe=df, uri=PANDAS_PATH) @@ -185,7 +184,7 @@ def create_pd_table() -> StructuredDataset: @task() def create_bq_table() -> StructuredDataset: df = pd.json_normalize(data, max_level=0) - print("original dataframe: \n", tabulate(df, headers="keys", tablefmt="psql")) + print("original dataframe: \n", df) # Enable one of GCP `uri` below if you want. You can replace `uri` with your own google cloud endpoints. return StructuredDataset(dataframe=df, uri=BQ_PATH) @@ -194,7 +193,7 @@ def create_bq_table() -> StructuredDataset: @task() def create_np_table() -> StructuredDataset: df = pd.json_normalize(data, max_level=0) - print("original dataframe: \n", tabulate(df, headers="keys", tablefmt="psql")) + print("original dataframe: \n", df) return StructuredDataset(dataframe=df, uri=NUMPY_PATH) @@ -202,7 +201,7 @@ def create_np_table() -> StructuredDataset: @task() def create_ar_table() -> StructuredDataset: df = pa.Table.from_pandas(pd.json_normalize(data, max_level=0)) - print("original dataframe: \n", tabulate(df, headers="keys", tablefmt="psql")) + print("original dataframe: \n", df) return StructuredDataset( dataframe=df, @@ -212,49 +211,49 @@ def create_ar_table() -> StructuredDataset: @task() def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyArgDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) + print("MyArgDataset dataframe: \n", t) return t @task() def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyDictDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) + print("MyDictDataset dataframe: \n", t) return t @task() def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyDictListDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) + print("MyDictListDataset dataframe: \n", t) return t @task() def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyTopDataClassDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) + print("MyTopDataClassDataset dataframe: \n", t) return t @task() def print_table_by_top_dict(sd: MyTopDictDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyTopDictDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) + print("MyTopDictDataset dataframe: \n", t) return t @task() def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MySecondDataClassDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) + print("MySecondDataClassDataset dataframe: \n", t) return t @task() def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame: t = sd.open(pd.DataFrame).all() - print("MyNestedDataClassDataset dataframe: \n", tabulate(t, headers="keys", tablefmt="psql")) + print("MyNestedDataClassDataset dataframe: \n", t) return t From ff6b58ff0a5d63deb2ae52c0b1088af39973fc28 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Fri, 26 Apr 2024 17:14:49 +0800 Subject: [PATCH 07/11] refine `flatten_dict()` recursion Signed-off-by: Austin Liu --- .../types/structured/structured_dataset.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index abde7baf36..e0e484c908 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -115,26 +115,22 @@ def iter(self) -> Generator[DF, None, None]: # flat the nested column map recursively -def flatten_dict(nested_dict) -> typing.Dict: +def flatten_dict(sub_dict, parent_key="") -> typing.Dict: result = {} - - def _flatten(sub_dict, parent_key=""): - for key, value in sub_dict.items(): - current_key = f"{parent_key}.{key}" if parent_key else key - # handle sub `dict` - if isinstance(value, dict): - return _flatten(value, current_key) - # handle sub `dataclass` - elif hasattr(value, "__dataclass_fields__"): - fields = getattr(value, "__dataclass_fields__") - d = {k: v.type for k, v in fields.items()} - return _flatten(d, current_key) - # already flattened - else: - result[current_key] = value - return result - - return _flatten(sub_dict=nested_dict) + for key, value in sub_dict.items(): + current_key = f"{parent_key}.{key}" if parent_key else key + # handle sub `dict` + if isinstance(value, dict): + result.update(flatten_dict(sub_dict=value, parent_key=current_key)) + # handle sub `dataclass` + elif hasattr(value, "__dataclass_fields__"): + fields = getattr(value, "__dataclass_fields__") + d = {k: v.type for k, v in fields.items()} + result.update(flatten_dict(sub_dict=d, parent_key=current_key)) + # already flattened + else: + result[current_key] = value + return result def extract_cols_and_format( From 153f70926ee79d462f96ca8c829c47883e3b43c9 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Fri, 26 Apr 2024 17:19:42 +0800 Subject: [PATCH 08/11] resolve conflicts Signed-off-by: Austin Liu --- tests/flytekit/integration/remote/test_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 0f14e12422..f23fc061d9 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -74,7 +74,7 @@ def run(file_name, wf_name, *args): # test child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2. -def test_remote_run_child_workflow(): +def test_remote_run(): run("child_workflow.py", "parent_wf", "--a", "3") From eaac66c0342b99c13470ffe19a1e34b1ad0f439f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 30 Apr 2024 10:12:36 +0800 Subject: [PATCH 09/11] nit Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 4 +-- .../types/structured/structured_dataset.py | 35 ++++++++----------- ...tured_dataset_workflow_with_nested_type.py | 1 - 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index db68222ed1..7411fd635e 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -17,6 +17,7 @@ """ import asyncio +import collections import datetime import inspect import warnings @@ -91,8 +92,7 @@ def kwtypes(**kwargs) -> OrderedDict[str, Type]: kwtypes(a=int, b=str) """ - d = OrderedDict() - # only handle named arguments + d = collections.OrderedDict() for k, v in kwargs.items(): d[k] = v return d diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index e0e484c908..99f0d1a8c1 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -4,7 +4,7 @@ import types import typing from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field, is_dataclass from typing import Dict, Generator, Optional, Type, Union import _datetime @@ -119,15 +119,13 @@ def flatten_dict(sub_dict, parent_key="") -> typing.Dict: result = {} for key, value in sub_dict.items(): current_key = f"{parent_key}.{key}" if parent_key else key - # handle sub `dict` if isinstance(value, dict): result.update(flatten_dict(sub_dict=value, parent_key=current_key)) # handle sub `dataclass` - elif hasattr(value, "__dataclass_fields__"): + elif is_dataclass(value): fields = getattr(value, "__dataclass_fields__") d = {k: v.type for k, v in fields.items()} result.update(flatten_dict(sub_dict=d, parent_key=current_key)) - # already flattened else: result[current_key] = value return result @@ -161,29 +159,26 @@ def extract_cols_and_format( if get_origin(t) is Annotated: base_type, *annotate_args = get_args(t) for aa in annotate_args: - d = collections.OrderedDict() - # handle `dataclass` argument: - if hasattr(aa, "__annotations__"): - dm = vars(aa) - d.update(dm["__annotations__"]) - # handle `dict` argument: + if is_dataclass(aa): + d = collections.OrderedDict() + d.update(asdict(aa)) + ordered_dict_cols = d elif isinstance(aa, dict): + d = collections.OrderedDict() d.update(aa) - # handle defaults: - else: - d = aa - if isinstance(d, StructuredDatasetFormat): + ordered_dict_cols = d + elif isinstance(aa, StructuredDatasetFormat): if fmt != "": - raise ValueError(f"A format was already specified {fmt}, cannot use {d}") - fmt = d - elif isinstance(d, collections.OrderedDict): + raise ValueError(f"A format was already specified {fmt}, cannot use {aa}") + fmt = aa + elif isinstance(aa, collections.OrderedDict): if ordered_dict_cols is not None: raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {d}") - ordered_dict_cols = d - elif isinstance(d, pa.lib.Schema): + ordered_dict_cols = aa + elif isinstance(aa, pa.lib.Schema): if pa_schema is not None: raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {d}") - pa_schema = d + pa_schema = aa return base_type, ordered_dict_cols, fmt, pa_schema # We return None as the format instead of parquet or something because the transformer engine may find diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py index 709466ca82..49e8ac04ee 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py @@ -291,7 +291,6 @@ def wf(): print_table_by_top_dict(sd=ar_sd) print_table_by_second_dataclass(sd=ar_sd) print_table_by_nested_dataclass(sd=ar_sd) - return def test_structured_dataset_wf(): From 98d7902e2ff2cda90da315829a492dfc730054b5 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 30 Apr 2024 10:23:54 +0800 Subject: [PATCH 10/11] nit Signed-off-by: Kevin Su --- flytekit/types/structured/structured_dataset.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 99f0d1a8c1..d6dc6b49e5 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -4,7 +4,7 @@ import types import typing from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass, field, is_dataclass +from dataclasses import dataclass, field, is_dataclass from typing import Dict, Generator, Optional, Type, Union import _datetime @@ -115,13 +115,12 @@ def iter(self) -> Generator[DF, None, None]: # flat the nested column map recursively -def flatten_dict(sub_dict, parent_key="") -> typing.Dict: +def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict: result = {} for key, value in sub_dict.items(): current_key = f"{parent_key}.{key}" if parent_key else key if isinstance(value, dict): result.update(flatten_dict(sub_dict=value, parent_key=current_key)) - # handle sub `dataclass` elif is_dataclass(value): fields = getattr(value, "__dataclass_fields__") d = {k: v.type for k, v in fields.items()} @@ -159,9 +158,11 @@ def extract_cols_and_format( if get_origin(t) is Annotated: base_type, *annotate_args = get_args(t) for aa in annotate_args: - if is_dataclass(aa): + if hasattr(aa, "__annotations__"): + # handle dataclass argument d = collections.OrderedDict() - d.update(asdict(aa)) + dm = vars(aa) + d.update(dm["__annotations__"]) ordered_dict_cols = d elif isinstance(aa, dict): d = collections.OrderedDict() @@ -173,11 +174,11 @@ def extract_cols_and_format( fmt = aa elif isinstance(aa, collections.OrderedDict): if ordered_dict_cols is not None: - raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {d}") + raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}") ordered_dict_cols = aa elif isinstance(aa, pa.lib.Schema): if pa_schema is not None: - raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {d}") + raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}") pa_schema = aa return base_type, ordered_dict_cols, fmt, pa_schema From a6e469a2f3467ae572a62b89b49ad88696159069 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Tue, 30 Apr 2024 14:04:19 +0800 Subject: [PATCH 11/11] clean up Signed-off-by: Austin Liu --- ...tured_dataset_workflow_with_nested_type.py | 120 +----------------- 1 file changed, 1 insertion(+), 119 deletions(-) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py index 49e8ac04ee..62c0f6d651 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py @@ -1,25 +1,10 @@ -import os -import typing from dataclasses import dataclass -import numpy as np import pyarrow as pa -import pyarrow.parquet as pq import pytest from typing_extensions import Annotated -from flytekit import FlyteContext, FlyteContextManager, StructuredDataset, kwtypes, task, workflow -from flytekit.models import literals -from flytekit.models.literals import StructuredDatasetMetadata -from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.basic_dfs import CSVToPandasDecodingHandler, PandasToCSVEncodingHandler -from flytekit.types.structured.structured_dataset import ( - DF, - PARQUET, - StructuredDatasetDecoder, - StructuredDatasetEncoder, - StructuredDatasetTransformerEngine, -) +from flytekit import FlyteContextManager, StructuredDataset, kwtypes, task, workflow pd = pytest.importorskip("pandas") @@ -28,109 +13,6 @@ BQ_PATH = "bq://flyte-dataset:flyte.table" -@dataclass -class MyCols: - Name: str - Age: int - - -my_cols = kwtypes(Name=str, Age=int) -my_dataclass_cols = MyCols -my_dict_cols = {"Name": str, "Age": int} -fields = [("Name", pa.string()), ("Age", pa.int32())] -arrow_schema = pa.schema(fields) -pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - - -class MockBQEncodingHandlers(StructuredDatasetEncoder): - def __init__(self): - super().__init__(pd.DataFrame, "bq", "") - - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - return literals.StructuredDataset( - uri="bq://bucket/key", metadata=StructuredDatasetMetadata(structured_dataset_type) - ) - - -class MockBQDecodingHandlers(StructuredDatasetDecoder): - def __init__(self): - super().__init__(pd.DataFrame, "bq", "") - - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pd.DataFrame: - return pd_df - - -StructuredDatasetTransformerEngine.register(MockBQEncodingHandlers(), False, True) -StructuredDatasetTransformerEngine.register(MockBQDecodingHandlers(), False, True) - - -class NumpyRenderer: - """ - The Polars DataFrame summary statistics are rendered as an HTML table. - """ - - def to_html(self, array: np.ndarray) -> str: - return pd.DataFrame(array).describe().to_html() - - -@pytest.fixture(autouse=True) -def numpy_type(): - class NumpyEncodingHandlers(StructuredDatasetEncoder): - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - path = typing.cast(str, structured_dataset.uri) - if not path: - path = ctx.file_access.join( - ctx.file_access.raw_output_prefix, - ctx.file_access.get_random_string(), - ) - df = typing.cast(np.ndarray, structured_dataset.dataframe) - name = ["col" + str(i) for i in range(len(df))] - table = pa.Table.from_arrays(df, name) - local_dir = ctx.file_access.get_random_local_directory() - local_path = os.path.join(local_dir, f"{0:05}") - pq.write_table(table, local_path) - ctx.file_access.upload_directory(local_dir, path) - structured_dataset_type.format = PARQUET - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) - - class NumpyDecodingHandlers(StructuredDatasetDecoder): - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> typing.Union[DF, typing.Generator[DF, None, None]]: - path = flyte_value.uri - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(path, local_dir, is_multipart=True) - table = pq.read_table(local_dir) - return table.to_pandas().to_numpy() - - StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray)) - StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray)) - StructuredDatasetTransformerEngine.register_renderer(np.ndarray, NumpyRenderer()) - - -StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler()) -StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler()) - - -## Case 1. data = [ { "company": "XYZ pvt ltd",