From f5cd70dd053e6f3d4aaf5b90d9c4b28f32c0980a Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Thu, 14 Mar 2024 23:20:34 +0800 Subject: [PATCH] StructuredDataset: add recursive `flatten_dict()` Signed-off-by: Austin Liu wip Signed-off-by: Austin Liu fmt Signed-off-by: Austin Liu fix Signed-off-by: Austin Liu fix Signed-off-by: Austin Liu fix Signed-off-by: Austin Liu --- flytekit/core/base_task.py | 14 ++++++++--- .../types/structured/structured_dataset.py | 23 +++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 1842a9957f..f57719c6f9 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -17,7 +17,6 @@ """ import asyncio -import collections import datetime import inspect import warnings @@ -71,7 +70,7 @@ UNSET_CARD = "_uc" -def kwtypes(**kwargs) -> OrderedDict[str, Type]: +def kwtypes(*args, **kwargs) -> OrderedDict[str, Type]: """ This is a small helper function to convert the keyword arguments to an OrderedDict of types. @@ -79,7 +78,16 @@ def kwtypes(**kwargs) -> OrderedDict[str, Type]: kwtypes(a=int, b=str) """ - d = collections.OrderedDict() + d = OrderedDict() + for arg in args: + # handle positional arguments: dataclass + if hasattr(arg, "__annotations__"): + dm = vars(arg) + d.update(dm["__annotations__"]) + # handle positional arguments: dict + elif isinstance(arg, dict): + d.update(arg) + # handle named arguments for k, v in kwargs.items(): d[k] = v return d diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 1d7af31404..9e5e56734d 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -820,13 +820,36 @@ def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: return type_models.LiteralType(map_value_type=self._get_dataset_column_literal_type(t.__args__[1])) raise AssertionError(f"type {t} is currently not supported by StructuredDataset") + def flatten_dict(self, nested_dict): + result = {} + + def _flatten(sub_dict, parent_key=""): + for key, value in sub_dict.items(): + current_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, dict): + return _flatten(value, current_key) + elif hasattr(value, "__dataclass_fields__"): + fields = getattr(value, "__dataclass_fields__") + d = {k: v.type for k, v in fields.items()} + return _flatten(d, current_key) + else: + result[current_key] = value + return result + + return _flatten(sub_dict=nested_dict) + def _convert_ordered_dict_of_columns_to_list( self, column_map: typing.Optional[typing.OrderedDict[str, Type]] ) -> typing.List[StructuredDatasetType.DatasetColumn]: converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = [] if column_map is None or len(column_map) == 0: return converted_cols + flat_column_map = {} for k, v in column_map.items(): + d = dict() + d[k] = v + flat_column_map.update(self.flatten_dict(d)) + for k, v in flat_column_map.items(): lt = self._get_dataset_column_literal_type(v) converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt)) return converted_cols