Skip to content

Commit

Permalink
[flytekit] Support attribute access on promises (flyteorg#1825)
Browse files Browse the repository at this point in the history
Support attribute access on promises

In the following workflow:

- `basic_workflow` contains trivial examples to access output attributes
- `failed_workflow` contains examples that causes exception (e.g. out of range)
- `advanced_workflow` contains examples with more complex attribute access

```python
from dataclasses import dataclass
from typing import Dict, List, NamedTuple

from dataclasses_json import dataclass_json

from flytekit import WorkflowFailurePolicy, task, workflow


@dataclass_json
@DataClass
class foo:
    a: str


@task
def t1() -> (List[str], Dict[str, str], foo):
    return ["a", "b"], {"a": "b"}, foo(a="b")


@task
def t2(a: str) -> str:
    print("a", a)
    return a


@task
def t3() -> (Dict[str, List[str]], List[Dict[str, str]], Dict[str, foo]):
    return {"a": ["b"]}, [{"a": "b"}], {"a": foo(a="b")}


@task
def t4(a: List[str]):
    print("a", a)


@task
def t5(a: Dict[str, str]):
    print("a", a)


@workflow
def basic_workflow():
    l, d, f = t1()
    t2(a=l[0])
    t2(a=d["a"])
    t2(a=f.a)


@workflow(
    # The workflow doesn't fail when one of the nodes fails but other nodes are still executable
    failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE
)
def failed_workflow():
    # This workflow is supposed to fail due to exceptions
    l, d, f = t1()
    t2(a=l[100])
    t2(a=d["b"])
    t2(a=f.b)


@workflow
def advanced_workflow():
    dl, ld, dd = t3()
    t2(a=dl["a"][0])
    t2(a=ld[0]["a"])
    t2(a=dd["a"].a)

    t4(a=dl["a"])
    t5(a=ld[0])
```

Signed-off-by: byhsu <[email protected]>
  • Loading branch information
ByronHsu authored and ringohoffman committed Nov 24, 2023
1 parent 8b2b975 commit 88fe671
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 18 deletions.
6 changes: 2 additions & 4 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,8 @@ filelock==3.12.4
flask==2.3.3
# via mlflow
flatbuffers==23.5.26
# via
# tensorflow
# tf2onnx
flyteidl==1.5.17
# via tensorflow
flyteidl==1.10.0
# via flytekit
fonttools==4.42.1
# via matplotlib
Expand Down
144 changes: 140 additions & 4 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import collections
import inspect
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast

from google.protobuf import struct_pb2 as _struct
from typing_extensions import Protocol, get_args

from flytekit.core import constants as _common_constants
Expand All @@ -22,6 +24,7 @@
from flytekit.core.node import Node
from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.exceptions import user as _user_exceptions
from flytekit.exceptions.user import FlytePromiseAttributeResolveException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literals_models
Expand Down Expand Up @@ -79,13 +82,72 @@ def my_wf(in1: int, in2: int) -> int:
var = flyte_interface_types[k]
t = native_types[k]
try:
if type(v) is Promise:
v = resolve_attr_path_in_promise(v)
result[k] = TypeEngine.to_literal(ctx, v, t, var.type)
except TypeTransformerFailedError as exc:
raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc

return result


def resolve_attr_path_in_promise(p: Promise) -> Promise:
"""
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
This is for local execution only. The remote execution will be resolved in flytepropeller.
"""

curr_val = p.val

used = 0

for attr in p.attr_path:
# If current value is Flyte literal collection (list) or map (dictionary), use [] to resolve
if (
type(curr_val.value) is _literals_models.LiteralMap
or type(curr_val.value) is _literals_models.LiteralCollection
):
if type(attr) == str and attr not in curr_val.value.literals:
raise FlytePromiseAttributeResolveException(
f"Failed to resolve attribute path {p.attr_path} in promise {p},"
f" attribute {attr} not found in {curr_val.value.literals.keys()}"
)

if type(attr) == int and attr >= len(curr_val.value.literals):
raise FlytePromiseAttributeResolveException(
f"Failed to resolve attribute path {p.attr_path} in promise {p},"
f" index {attr} out of range {len(curr_val.value.literals)}"
)

curr_val = curr_val.value.literals[attr]
used += 1
# Scalar is always the leaf. There can't be a collection or map in a scalar.
if type(curr_val.value) is _literals_models.Scalar:
break

# If the current value is a dataclass, resolve the dataclass with the remaining path
if type(curr_val.value) is _literals_models.Scalar and type(curr_val.value.value) is _struct.Struct:
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)

p._val = curr_val
return p


def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct:
curr_val = st
for attr in attr_path:
if attr not in curr_val:
raise FlytePromiseAttributeResolveException(
f"Failed to resolve attribute path {attr_path} in struct {curr_val}, attribute {attr} not found"
)
curr_val = curr_val[attr]
return curr_val


def get_primitive_val(prim: Primitive) -> Any:
for value in [
prim.integer,
Expand Down Expand Up @@ -303,6 +365,8 @@ def __init__(self, var: str, val: Union[NodeOutput, _literals_models.Literal]):
self._var = var
self._promise_ready = True
self._val = val
self._ref = None
self._attr_path: List[Union[str, int]] = []
if val and isinstance(val, NodeOutput):
self._ref = val
self._promise_ready = False
Expand Down Expand Up @@ -347,7 +411,7 @@ def ref(self) -> NodeOutput:
"""
If the promise is NOT READY / Incomplete, then it maps to the origin node that owns the promise
"""
return self._ref
return self._ref # type: ignore

@property
def var(self) -> str:
Expand All @@ -356,6 +420,14 @@ def var(self) -> str:
"""
return self._var

@property
def attr_path(self) -> List[Union[str, int]]:
"""
The attribute path the promise will be resolved with.
:rtype: List[Union[str, int]]
"""
return self._attr_path

def eval(self) -> Any:
if not self._promise_ready or self._val is None:
raise ValueError("Cannot Eval with incomplete promises")
Expand Down Expand Up @@ -414,11 +486,64 @@ def with_overrides(self, *args, **kwargs):
def __repr__(self):
if self._promise_ready:
return f"Resolved({self._var}={self._val})"
return f"Promise(node:{self.ref.node_id}.{self._var})"
return f"Promise(node:{self.ref.node_id}.{self._var}.{self.attr_path})"

def __str__(self):
return str(self.__repr__())

def deepcopy(self) -> Promise:
new_promise = Promise(var=self.var, val=self.val)
new_promise._promise_ready = self._promise_ready
new_promise._ref = self._ref
new_promise._attr_path = deepcopy(self._attr_path)
return new_promise

def __getitem__(self, key) -> Promise:
"""
When we use [] to access the attribute on the promise, for example
```
@workflow
def wf():
o = t1()
t2(x=o["a"][0])
```
The attribute keys are appended on the promise and a new promise is returned with the updated attribute path.
We don't modify the original promise because it might be used in other places as well.
"""

return self._append_attr(key)

def __getattr__(self, key) -> Promise:
"""
When we use . to access the attribute on the promise, for example
```
@workflow
def wf():
o = t1()
t2(o.a.b)
```
The attribute keys are appended on the promise and a new promise is returned with the updated attribute path.
We don't modify the original promise because it might be used in other places as well.
"""

return self._append_attr(key)

def _append_attr(self, key) -> Promise:
new_promise = self.deepcopy()

# The attr_path on the promise is for local_execute
new_promise._attr_path.append(key)

if new_promise.ref is not None:
# The attr_path on the ref node is for remote execute
new_promise._ref = new_promise.ref.with_attr(key)

return new_promise


def create_native_named_tuple(
ctx: FlyteContext,
Expand Down Expand Up @@ -710,13 +835,16 @@ def __repr__(self):


class NodeOutput(type_models.OutputReference):
def __init__(self, node: Node, var: str):
def __init__(self, node: Node, var: str, attr_path: Optional[List[Union[str, int]]] = None):
"""
:param node:
:param var: The name of the variable this NodeOutput references
"""
if attr_path is None:
attr_path = []

self._node = node
super(NodeOutput, self).__init__(self._node.id, var)
super(NodeOutput, self).__init__(self._node.id, var, attr_path)

@property
def node_id(self):
Expand All @@ -736,6 +864,14 @@ def __repr__(self) -> str:
s = f"Node({self.node if self.node.id is not None else None}:{self.var})"
return s

def deepcopy(self) -> NodeOutput:
return NodeOutput(node=self.node, var=self.var, attr_path=deepcopy(self._attr_path))

def with_attr(self, key) -> NodeOutput:
new_node_output = self.deepcopy()
new_node_output._attr_path.append(key)
return new_node_output


class SupportsNodeCreation(Protocol):
@property
Expand Down
18 changes: 17 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,23 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
f"evaluation doesn't work with json dataclasses"
)

return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema)
# Recursively construct the dataclass_type which contains the literal type of each field
literal_type = {}

# Get the type of each field from dataclass
for field in t.__dataclass_fields__.values(): # type: ignore
try:
literal_type[field.name] = TypeEngine.to_literal_type(field.type)
except Exception as e:
logger.warning(
"Field {} of type {} cannot be converted to a literal type. Error: {}".format(
field.name, field.type, e
)
)

ts = TypeStructure(tag="", dataclass_type=literal_type)

return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if not dataclasses.is_dataclass(python_val):
Expand Down
4 changes: 4 additions & 0 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ class FlyteInvalidInputException(FlyteUserException):
def __init__(self, request: typing.Any):
self.request = request
super().__init__()


class FlytePromiseAttributeResolveException(FlyteAssertion):
_ERROR_CODE = "USER:PromiseAttributeResolveError"
48 changes: 43 additions & 5 deletions flytekit/models/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json as _json
import typing
from typing import Dict

from flyteidl.core import types_pb2 as _types_pb2
from google.protobuf import json_format as _json_format
Expand Down Expand Up @@ -127,21 +128,34 @@ class TypeStructure(_common.FlyteIdlEntity):
Models _types_pb2.TypeStructure
"""

def __init__(self, tag: str):
def __init__(self, tag: str, dataclass_type: Dict[str, "LiteralType"] = None):
self._tag = tag
self._dataclass_type = dataclass_type

@property
def tag(self) -> str:
return self._tag

@property
def dataclass_type(self) -> Dict[str, "LiteralType"]:
return self._dataclass_type

def to_flyte_idl(self) -> _types_pb2.TypeStructure:
return _types_pb2.TypeStructure(
tag=self._tag,
dataclass_type={k: v.to_flyte_idl() for k, v in self._dataclass_type.items()}
if self._dataclass_type is not None
else None,
)

@classmethod
def from_flyte_idl(cls, proto: _types_pb2.TypeStructure):
return cls(tag=proto.tag)
return cls(
tag=proto.tag,
dataclass_type={k: LiteralType.from_flyte_idl(v) for k, v in proto.dataclass_type.items()}
if proto.dataclass_type is not None
else None,
)


class StructuredDatasetType(_common.FlyteIdlEntity):
Expand Down Expand Up @@ -396,16 +410,18 @@ def from_flyte_idl(cls, proto):


class OutputReference(_common.FlyteIdlEntity):
def __init__(self, node_id, var):
def __init__(self, node_id, var, attr_path: typing.List[typing.Union[str, int]] = None):
"""
A reference to an output produced by a node. The type can be retrieved -and validated- from
the underlying interface of the node.
:param Text node_id: Node id must exist at the graph layer.
:param Text var: Variable name must refer to an output variable for the node.
:param List[Union[str, int]] attr_path: The attribute path the promise will be resolved with.
"""
self._node_id = node_id
self._var = var
self._attr_path = attr_path if attr_path is not None else []

@property
def node_id(self):
Expand All @@ -423,6 +439,14 @@ def var(self):
"""
return self._var

@property
def attr_path(self) -> typing.List[typing.Union[str, int]]:
"""
The attribute path the promise will be resolved with.
:rtype: list[union[str, int]]
"""
return self._attr_path

@var.setter
def var(self, var_name):
self._var = var_name
Expand All @@ -431,15 +455,29 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.core.types.OutputReference
"""
return _types_pb2.OutputReference(node_id=self.node_id, var=self.var)
return _types_pb2.OutputReference(
node_id=self.node_id,
var=self.var,
attr_path=[
_types_pb2.PromiseAttribute(
string_value=p if type(p) == str else None,
int_value=p if type(p) == int else None,
)
for p in self._attr_path
],
)

@classmethod
def from_flyte_idl(cls, pb2_object):
"""
:param flyteidl.core.types.OutputReference pb2_object:
:rtype: OutputReference
"""
return cls(node_id=pb2_object.node_id, var=pb2_object.var)
return cls(
node_id=pb2_object.node_id,
var=pb2_object.var,
attr_path=[p.string_value or p.int_value for p in pb2_object.attr_path],
)


class Error(_common.FlyteIdlEntity):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
install_requires=[
"googleapis-common-protos>=1.57",
"flyteidl>=1.5.16",
"flyteidl>=1.10.0",
"wheel>=0.30.0,<1.0.0",
"pandas>=1.0.0,<2.0.0",
"pyarrow>=4.0.0,<11.0.0",
Expand Down
Loading

0 comments on commit 88fe671

Please sign in to comment.