Skip to content

Commit

Permalink
Handle list subtype that cannot be inferred (flyteorg#1907)
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
Signed-off-by: Rafael Raposo <[email protected]>
  • Loading branch information
honnix authored and RRap0so committed Dec 15, 2023
1 parent 1b4dc0e commit 32042e7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
4 changes: 2 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,10 +695,10 @@ def binding_data_from_python_std(
)

elif isinstance(t_value, list):
sub_type: type = ListTransformer.get_sub_type(t_value_type)
sub_type: Optional[type] = ListTransformer.get_sub_type_or_none(t_value_type)
collection = _literals_models.BindingDataCollection(
bindings=[
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type, nodes)
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes)
for t in t_value
]
)
Expand Down
12 changes: 11 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,16 @@ def get_sub_type(t: Type[T]) -> Type[T]:
"""
Return the generic Type T of the List
"""
if (sub_type := ListTransformer.get_sub_type_or_none(t)) is not None:
return sub_type

raise ValueError("Only generic univariate typing.List[T] type is supported.")

@staticmethod
def get_sub_type_or_none(t: Type[T]) -> Optional[Type[T]]:
"""
Return the generic Type T of the List, or None if the generic type cannot be inferred
"""
if hasattr(t, "__origin__"):
# Handle annotation on list generic, eg:
# Annotated[typing.List[int], 'foo']
Expand All @@ -1133,7 +1143,7 @@ def get_sub_type(t: Type[T]) -> Type[T]:
if getattr(t, "__origin__") is list and hasattr(t, "__args__"):
return getattr(t, "__args__")[0]

raise ValueError("Only generic univariate typing.List[T] type is supported.")
return None

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2334,3 +2334,11 @@ class Datum(DataClassJSONMixin):
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert datum_mashumaro.x == pv.x
assert datum_mashumaro.y.value == pv.y


def test_ListTransformer_get_sub_type():
assert ListTransformer.get_sub_type_or_none(typing.List[str]) is str


def test_ListTransformer_get_sub_type_as_none():
assert ListTransformer.get_sub_type_or_none(type([])) is None

0 comments on commit 32042e7

Please sign in to comment.