Skip to content

Commit

Permalink
[WIP] dataclass typing
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>

[WIP] dataclass typing

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Mar 18, 2024
1 parent 64b8468 commit c175ae0
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,20 @@ def my_wf(in1: int, in2: int) -> int:
t = native_types[k]
try:
if type(v) is Promise:
v = resolve_attr_path_in_promise(v)
v = resolve_attr_path_in_promise(v, t)
result[k] = TypeEngine.to_literal(ctx, v, t, var.type)
except TypeTransformerFailedError as exc:
raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc

return result


def resolve_attr_path_in_promise(p: Promise) -> Promise:
def resolve_attr_path_in_promise(p: Promise, t: type) -> Promise:
"""
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
This is for local execution only. The remote execution will be resolved in flytepropeller.
"""

curr_val = p.val
curr_val = p.val #cast(t, p.val)

used = 0

Expand Down Expand Up @@ -130,9 +129,10 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
if type(curr_val.value) is _literals_models.Scalar and type(curr_val.value.value) is _struct.Struct:
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
new_st = int(new_st) if t==int else new_st
literal_type = TypeEngine.to_literal_type(t)
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, t, literal_type)

p._val = curr_val
return p
Expand Down

0 comments on commit c175ae0

Please sign in to comment.