Skip to content

Commit

Permalink
Handle list subtype that cannot be inferred
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Oct 28, 2023
1 parent 4d5e1b8 commit 30a25f1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
4 changes: 2 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,10 @@ def binding_data_from_python_std(
)

elif isinstance(t_value, list):
sub_type: type = ListTransformer.get_sub_type(t_value_type)
sub_type: Optional[type] = ListTransformer.get_sub_type_or_none(t_value_type)

Check warning on line 573 in flytekit/core/promise.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/promise.py#L573

Added line #L573 was not covered by tests
collection = _literals_models.BindingDataCollection(
bindings=[
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type, nodes)
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes)
for t in t_value
]
)
Expand Down
12 changes: 11 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,16 @@ def get_sub_type(t: Type[T]) -> Type[T]:
"""
Return the generic Type T of the List
"""
if (sub_type := ListTransformer.get_sub_type_or_none(t)) is not None:
return sub_type

raise ValueError("Only generic univariate typing.List[T] type is supported.")

Check warning on line 1114 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1114

Added line #L1114 was not covered by tests

@staticmethod
def get_sub_type_or_none(t: Type[T]) -> Optional[Type[T]]:
"""
Return the generic Type T of the List, or None if the generic type cannot be inferred
"""
if hasattr(t, "__origin__"):
# Handle annotation on list generic, eg:
# Annotated[typing.List[int], 'foo']
Expand All @@ -1117,7 +1127,7 @@ def get_sub_type(t: Type[T]) -> Type[T]:
if getattr(t, "__origin__") is list and hasattr(t, "__args__"):
return getattr(t, "__args__")[0]

raise ValueError("Only generic univariate typing.List[T] type is supported.")
return None

Check warning on line 1130 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1130

Added line #L1130 was not covered by tests

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2334,3 +2334,11 @@ class Datum(DataClassJSONMixin):
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert datum_mashumaro.x == pv.x
assert datum_mashumaro.y.value == pv.y


def test_ListTransformer_get_sub_type():
assert ListTransformer.get_sub_type_or_none(typing.List[str]) is str


def test_ListTransformer_get_sub_type_as_none():
assert ListTransformer.get_sub_type_or_none(type([])) is None

0 comments on commit 30a25f1

Please sign in to comment.