diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 220fc3fb89..ba326ec27e 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -697,6 +697,15 @@ def binding_data_from_python_std( ) elif t_value is not None and expected_literal_type.union_type is not None: + # If the value is not a container type, then we can directly convert it to a scalar in the Union case. + # This pushes the handling of the Union types to the type engine. + if not isinstance(t_value, list) and not isinstance(t_value, dict): + scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar + return _literals_models.BindingData(scalar=scalar) + + # If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is + # akin to what the Type Engine does when it finds a Union type (see the UnionTransformer), but we can't rely on + # that in this case, because of the mix and match of realized values, and Promises. for i in range(len(expected_literal_type.union_type.variants)): try: lt_type = expected_literal_type.union_type.variants[i] diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 74f3db99e3..e022c875e0 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -13,6 +13,7 @@ from flytekit.core.promise import ( Promise, VoidPromise, + binding_data_from_python_std, create_and_link_node, create_and_link_node_from_remote, resolve_attr_path_in_promise, @@ -20,6 +21,7 @@ ) from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException +from flytekit.models.types import LiteralType, SimpleType, TypeStructure from flytekit.types.pickle.pickle import BatchSize @@ -234,3 +236,18 @@ class Foo: # exception with pytest.raises(FlytePromiseAttributeResolveException): tgt_promise = resolve_attr_path_in_promise(src_promise["c"]) + + +def test_prom_with_union_literals(): + ctx = FlyteContextManager.current_context() + pt = typing.Union[str, int] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + ] + + bd = binding_data_from_python_std(ctx, lt, 3, pt, []) + assert bd.scalar.union.stored_type.structure.tag == "int" + bd = binding_data_from_python_std(ctx, lt, "hello", pt, []) + assert bd.scalar.union.stored_type.structure.tag == "str" diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index d4db4f34fe..0546b9dc7a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1759,6 +1759,42 @@ def test_annotated_union_type(): assert v == "hello" +def test_union_type_simple(): + pt = typing.Union[str, int] + lt = TypeEngine.to_literal_type(pt) + assert lt.union_type.variants == [ + LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), + LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), + ] + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, 3, pt, lt) + assert lv.scalar.union is not None + assert lv.scalar.union.stored_type.structure.tag == "int" + assert lv.scalar.union.stored_type.structure.dataclass_type is None + + +def test_union_containers(): + pt = typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int] + lt = TypeEngine.to_literal_type(pt) + + list_of_maps_of_list_ints = [ + {"first_map_a": [42], "first_map_b": [42, 2]}, + { + "second_map_c": [33], + "second_map_d": [9, 99], + }, + ] + map_of_list_ints = { + "ll_1": [1, 23, 3], + "ll_2": [4, 5, 6], + } + ctx = FlyteContextManager.current_context() + lv = TypeEngine.to_literal(ctx, list_of_maps_of_list_ints, pt, lt) + assert lv.scalar.union.stored_type.structure.tag == "Typed List" + lv = TypeEngine.to_literal(ctx, map_of_list_ints, pt, lt) + assert lv.scalar.union.stored_type.structure.tag == "Python Dictionary" + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") def test_optional_type(): pt = typing.Optional[int]