diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 123cb4a0ef..f79ddc3082 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1445,6 +1445,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp python_type = get_underlying_type(python_type) found_res = False + is_ambiguous = False res = None res_type = None for i in range(len(get_args(python_type))): @@ -1454,13 +1455,15 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp res = trans.to_literal(ctx, python_val, t, expected.union_type.variants[i]) res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) if found_res: - # Should really never happen, sanity check - raise TypeError("Ambiguous choice of variant for union type") + is_ambiguous = True found_res = True - except (TypeTransformerFailedError, AttributeError, ValueError, AssertionError) as e: + except Exception as e: logger.debug(f"Failed to convert from {python_val} to {t}", e) continue + if is_ambiguous: + raise TypeError("Ambiguous choice of variant for union type") + if found_res: return Literal(scalar=Scalar(union=Union(value=res, stored_type=res_type))) @@ -1477,6 +1480,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: union_tag = union_type.structure.tag found_res = False + is_ambiguous = False + cur_transformer = "" res = None res_tag = None for v in get_args(expected_python_type): @@ -1494,25 +1499,27 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: assert lv.scalar.union is not None # type checker res = trans.to_python_value(ctx, lv.scalar.union.value, v) - res_tag = trans.name if found_res: - raise TypeError( - "Ambiguous choice of variant for union type. " - + f"Both {res_tag} and {trans.name} transformers match" - ) - found_res = True + is_ambiguous = True + cur_transformer = trans.name + break else: res = trans.to_python_value(ctx, lv, v) if found_res: - raise TypeError( - "Ambiguous choice of variant for union type. " - + f"Both {res_tag} and {trans.name} transformers match" - ) - res_tag = trans.name - found_res = True - except (TypeTransformerFailedError, AttributeError) as e: + is_ambiguous = True + cur_transformer = trans.name + break + res_tag = trans.name + found_res = True + except Exception as e: logger.debug(f"Failed to convert from {lv} to {v}", e) + if is_ambiguous: + raise TypeError( + "Ambiguous choice of variant for union type. " + + f"Both {res_tag} and {cur_transformer} transformers match" + ) + if found_res: return res diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 227be2d0ff..3ce3570b49 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -2034,6 +2034,20 @@ def test_schema_in_dataclass(): assert o == ot +def test_union_in_dataclass(): + schema = TestSchema() + df = pd.DataFrame(data={"some_str": ["a", "b", "c"]}) + schema.open().write(df) + o = Result(result=InnerResult(number=1, schema=schema), schema=schema) + ctx = FlyteContext.current_context() + tf = UnionTransformer() + pt = typing.Union[Result, InnerResult] + lt = tf.get_literal_type(pt) + lv = tf.to_literal(ctx, o, pt, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=pt) + return o == ot + + @dataclass class InnerResult_dataclassjsonmixin(DataClassJSONMixin): number: int