Skip to content

Commit

Permalink
StructuredDataset: add recursive flatten_dict()
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>

wip

Signed-off-by: Austin Liu <[email protected]>

fmt

Signed-off-by: Austin Liu <[email protected]>

fix

Signed-off-by: Austin Liu <[email protected]>

fix

Signed-off-by: Austin Liu <[email protected]>

fix

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Mar 16, 2024
1 parent 64b8468 commit f5cd70d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
14 changes: 11 additions & 3 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import asyncio
import collections
import datetime
import inspect
import warnings
Expand Down Expand Up @@ -71,15 +70,24 @@
UNSET_CARD = "_uc"


def kwtypes(**kwargs) -> OrderedDict[str, Type]:
def kwtypes(*args, **kwargs) -> OrderedDict[str, Type]:
"""
This is a small helper function to convert the keyword arguments to an OrderedDict of types.
.. code-block:: python
kwtypes(a=int, b=str)
"""
d = collections.OrderedDict()
d = OrderedDict()
for arg in args:
# handle positional arguments: dataclass
if hasattr(arg, "__annotations__"):
dm = vars(arg)
d.update(dm["__annotations__"])
# handle positional arguments: dict
elif isinstance(arg, dict):
d.update(arg)
# handle named arguments
for k, v in kwargs.items():
d[k] = v
return d
Expand Down
23 changes: 23 additions & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,13 +820,36 @@ def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType:
return type_models.LiteralType(map_value_type=self._get_dataset_column_literal_type(t.__args__[1]))
raise AssertionError(f"type {t} is currently not supported by StructuredDataset")

def flatten_dict(self, nested_dict):
result = {}

def _flatten(sub_dict, parent_key=""):
for key, value in sub_dict.items():
current_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
return _flatten(value, current_key)
elif hasattr(value, "__dataclass_fields__"):
fields = getattr(value, "__dataclass_fields__")
d = {k: v.type for k, v in fields.items()}
return _flatten(d, current_key)
else:
result[current_key] = value
return result

return _flatten(sub_dict=nested_dict)

def _convert_ordered_dict_of_columns_to_list(
self, column_map: typing.Optional[typing.OrderedDict[str, Type]]
) -> typing.List[StructuredDatasetType.DatasetColumn]:
converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = []
if column_map is None or len(column_map) == 0:
return converted_cols
flat_column_map = {}
for k, v in column_map.items():
d = dict()
d[k] = v
flat_column_map.update(self.flatten_dict(d))
for k, v in flat_column_map.items():
lt = self._get_dataset_column_literal_type(v)
converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt))
return converted_cols
Expand Down

0 comments on commit f5cd70d

Please sign in to comment.