Skip to content

Commit

Permalink
Fix union transformer (#2024)
Browse files Browse the repository at this point in the history
* Fix union transformer

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* fix test

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Dec 5, 2023
1 parent 5c6802c commit 1e30eb1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
39 changes: 23 additions & 16 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
python_type = get_underlying_type(python_type)

found_res = False
is_ambiguous = False
res = None
res_type = None
for i in range(len(get_args(python_type))):
Expand All @@ -1454,13 +1455,15 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
res = trans.to_literal(ctx, python_val, t, expected.union_type.variants[i])
res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name)
if found_res:
# Should really never happen, sanity check
raise TypeError("Ambiguous choice of variant for union type")
is_ambiguous = True
found_res = True
except (TypeTransformerFailedError, AttributeError, ValueError, AssertionError) as e:
except Exception as e:
logger.debug(f"Failed to convert from {python_val} to {t}", e)
continue

if is_ambiguous:
raise TypeError("Ambiguous choice of variant for union type")

if found_res:
return Literal(scalar=Scalar(union=Union(value=res, stored_type=res_type)))

Expand All @@ -1477,6 +1480,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
union_tag = union_type.structure.tag

found_res = False
is_ambiguous = False
cur_transformer = ""
res = None
res_tag = None
for v in get_args(expected_python_type):
Expand All @@ -1494,25 +1499,27 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
assert lv.scalar.union is not None # type checker

res = trans.to_python_value(ctx, lv.scalar.union.value, v)
res_tag = trans.name
if found_res:
raise TypeError(
"Ambiguous choice of variant for union type. "
+ f"Both {res_tag} and {trans.name} transformers match"
)
found_res = True
is_ambiguous = True
cur_transformer = trans.name
break
else:
res = trans.to_python_value(ctx, lv, v)
if found_res:
raise TypeError(
"Ambiguous choice of variant for union type. "
+ f"Both {res_tag} and {trans.name} transformers match"
)
res_tag = trans.name
found_res = True
except (TypeTransformerFailedError, AttributeError) as e:
is_ambiguous = True
cur_transformer = trans.name
break
res_tag = trans.name
found_res = True
except Exception as e:
logger.debug(f"Failed to convert from {lv} to {v}", e)

if is_ambiguous:
raise TypeError(
"Ambiguous choice of variant for union type. "
+ f"Both {res_tag} and {cur_transformer} transformers match"
)

if found_res:
return res

Expand Down
14 changes: 14 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,6 +2034,20 @@ def test_schema_in_dataclass():
assert o == ot


def test_union_in_dataclass():
schema = TestSchema()
df = pd.DataFrame(data={"some_str": ["a", "b", "c"]})
schema.open().write(df)
o = Result(result=InnerResult(number=1, schema=schema), schema=schema)
ctx = FlyteContext.current_context()
tf = UnionTransformer()
pt = typing.Union[Result, InnerResult]
lt = tf.get_literal_type(pt)
lv = tf.to_literal(ctx, o, pt, lt)
ot = tf.to_python_value(ctx, lv=lv, expected_python_type=pt)
return o == ot


@dataclass
class InnerResult_dataclassjsonmixin(DataClassJSONMixin):
number: int
Expand Down

0 comments on commit 1e30eb1

Please sign in to comment.