Skip to content

Commit

Permalink
Add escape for scalars to binding during union handling (#2460)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Jun 5, 2024
1 parent 57ee143 commit 72af7c6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
9 changes: 9 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,15 @@ def binding_data_from_python_std(
)

elif t_value is not None and expected_literal_type.union_type is not None:
# If the value is not a container type, then we can directly convert it to a scalar in the Union case.
# This pushes the handling of the Union types to the type engine.
if not isinstance(t_value, list) and not isinstance(t_value, dict):
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar
return _literals_models.BindingData(scalar=scalar)

# If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is
# akin to what the Type Engine does when it finds a Union type (see the UnionTransformer), but we can't rely on
# that in this case, because of the mix and match of realized values, and Promises.
for i in range(len(expected_literal_type.union_type.variants)):
try:
lt_type = expected_literal_type.union_type.variants[i]
Expand Down
17 changes: 17 additions & 0 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
from flytekit.core.promise import (
Promise,
VoidPromise,
binding_data_from_python_std,
create_and_link_node,
create_and_link_node_from_remote,
resolve_attr_path_in_promise,
translate_inputs_to_literals,
)
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException
from flytekit.models.types import LiteralType, SimpleType, TypeStructure
from flytekit.types.pickle.pickle import BatchSize


Expand Down Expand Up @@ -234,3 +236,18 @@ class Foo:
# exception
with pytest.raises(FlytePromiseAttributeResolveException):
tgt_promise = resolve_attr_path_in_promise(src_promise["c"])


def test_prom_with_union_literals():
ctx = FlyteContextManager.current_context()
pt = typing.Union[str, int]
lt = TypeEngine.to_literal_type(pt)
assert lt.union_type.variants == [
LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")),
LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")),
]

bd = binding_data_from_python_std(ctx, lt, 3, pt, [])
assert bd.scalar.union.stored_type.structure.tag == "int"
bd = binding_data_from_python_std(ctx, lt, "hello", pt, [])
assert bd.scalar.union.stored_type.structure.tag == "str"
36 changes: 36 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,42 @@ def test_annotated_union_type():
assert v == "hello"


def test_union_type_simple():
pt = typing.Union[str, int]
lt = TypeEngine.to_literal_type(pt)
assert lt.union_type.variants == [
LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")),
LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")),
]
ctx = FlyteContextManager.current_context()
lv = TypeEngine.to_literal(ctx, 3, pt, lt)
assert lv.scalar.union is not None
assert lv.scalar.union.stored_type.structure.tag == "int"
assert lv.scalar.union.stored_type.structure.dataclass_type is None


def test_union_containers():
pt = typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int]
lt = TypeEngine.to_literal_type(pt)

list_of_maps_of_list_ints = [
{"first_map_a": [42], "first_map_b": [42, 2]},
{
"second_map_c": [33],
"second_map_d": [9, 99],
},
]
map_of_list_ints = {
"ll_1": [1, 23, 3],
"ll_2": [4, 5, 6],
}
ctx = FlyteContextManager.current_context()
lv = TypeEngine.to_literal(ctx, list_of_maps_of_list_ints, pt, lt)
assert lv.scalar.union.stored_type.structure.tag == "Typed List"
lv = TypeEngine.to_literal(ctx, map_of_list_ints, pt, lt)
assert lv.scalar.union.stored_type.structure.tag == "Python Dictionary"


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_optional_type():
pt = typing.Optional[int]
Expand Down

0 comments on commit 72af7c6

Please sign in to comment.