From 88fe6716744f8ceb4ca862e547a85f6c85680219 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Tue, 31 Oct 2023 18:12:24 -0700 Subject: [PATCH] [flytekit] Support attribute access on promises (#1825) Support attribute access on promises In the following workflow: - `basic_workflow` contains trivial examples to access output attributes - `failed_workflow` contains examples that causes exception (e.g. out of range) - `advanced_workflow` contains examples with more complex attribute access ```python from dataclasses import dataclass from typing import Dict, List, NamedTuple from dataclasses_json import dataclass_json from flytekit import WorkflowFailurePolicy, task, workflow @dataclass_json @dataclass class foo: a: str @task def t1() -> (List[str], Dict[str, str], foo): return ["a", "b"], {"a": "b"}, foo(a="b") @task def t2(a: str) -> str: print("a", a) return a @task def t3() -> (Dict[str, List[str]], List[Dict[str, str]], Dict[str, foo]): return {"a": ["b"]}, [{"a": "b"}], {"a": foo(a="b")} @task def t4(a: List[str]): print("a", a) @task def t5(a: Dict[str, str]): print("a", a) @workflow def basic_workflow(): l, d, f = t1() t2(a=l[0]) t2(a=d["a"]) t2(a=f.a) @workflow( # The workflow doesn't fail when one of the nodes fails but other nodes are still executable failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE ) def failed_workflow(): # This workflow is supposed to fail due to exceptions l, d, f = t1() t2(a=l[100]) t2(a=d["b"]) t2(a=f.b) @workflow def advanced_workflow(): dl, ld, dd = t3() t2(a=dl["a"][0]) t2(a=ld[0]["a"]) t2(a=dd["a"].a) t4(a=dl["a"]) t5(a=ld[0]) ``` Signed-off-by: byhsu --- doc-requirements.txt | 6 +- flytekit/core/promise.py | 144 ++++++++++++++++++++++- flytekit/core/type_engine.py | 18 ++- flytekit/exceptions/user.py | 4 + flytekit/models/types.py | 48 +++++++- setup.py | 2 +- tests/flytekit/unit/core/test_promise.py | 69 ++++++++++- tests/flytekit/unit/models/test_types.py | 11 ++ 8 files changed, 284 insertions(+), 18 deletions(-) diff --git a/doc-requirements.txt b/doc-requirements.txt index c35ba0c62e6..f6728147918 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -272,10 +272,8 @@ filelock==3.12.4 flask==2.3.3 # via mlflow flatbuffers==23.5.26 - # via - # tensorflow - # tf2onnx -flyteidl==1.5.17 + # via tensorflow +flyteidl==1.10.0 # via flytekit fonttools==4.42.1 # via matplotlib diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 02c724e3a1a..ffb92797666 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -2,9 +2,11 @@ import collections import inspect +from copy import deepcopy from enum import Enum from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast +from google.protobuf import struct_pb2 as _struct from typing_extensions import Protocol, get_args from flytekit.core import constants as _common_constants @@ -22,6 +24,7 @@ from flytekit.core.node import Node from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError from flytekit.exceptions import user as _user_exceptions +from flytekit.exceptions.user import FlytePromiseAttributeResolveException from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literals_models @@ -79,6 +82,8 @@ def my_wf(in1: int, in2: int) -> int: var = flyte_interface_types[k] t = native_types[k] try: + if type(v) is Promise: + v = resolve_attr_path_in_promise(v) result[k] = TypeEngine.to_literal(ctx, v, t, var.type) except TypeTransformerFailedError as exc: raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc @@ -86,6 +91,63 @@ def my_wf(in1: int, in2: int) -> int: return result +def resolve_attr_path_in_promise(p: Promise) -> Promise: + """ + resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value + This is for local execution only. The remote execution will be resolved in flytepropeller. + """ + + curr_val = p.val + + used = 0 + + for attr in p.attr_path: + # If current value is Flyte literal collection (list) or map (dictionary), use [] to resolve + if ( + type(curr_val.value) is _literals_models.LiteralMap + or type(curr_val.value) is _literals_models.LiteralCollection + ): + if type(attr) == str and attr not in curr_val.value.literals: + raise FlytePromiseAttributeResolveException( + f"Failed to resolve attribute path {p.attr_path} in promise {p}," + f" attribute {attr} not found in {curr_val.value.literals.keys()}" + ) + + if type(attr) == int and attr >= len(curr_val.value.literals): + raise FlytePromiseAttributeResolveException( + f"Failed to resolve attribute path {p.attr_path} in promise {p}," + f" index {attr} out of range {len(curr_val.value.literals)}" + ) + + curr_val = curr_val.value.literals[attr] + used += 1 + # Scalar is always the leaf. There can't be a collection or map in a scalar. + if type(curr_val.value) is _literals_models.Scalar: + break + + # If the current value is a dataclass, resolve the dataclass with the remaining path + if type(curr_val.value) is _literals_models.Scalar and type(curr_val.value.value) is _struct.Struct: + st = curr_val.value.value + new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:]) + literal_type = TypeEngine.to_literal_type(type(new_st)) + # Reconstruct the resolved result to flyte literal (because the resolved result might not be struct) + curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type) + + p._val = curr_val + return p + + +def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct: + curr_val = st + for attr in attr_path: + if attr not in curr_val: + raise FlytePromiseAttributeResolveException( + f"Failed to resolve attribute path {attr_path} in struct {curr_val}, attribute {attr} not found" + ) + curr_val = curr_val[attr] + return curr_val + + def get_primitive_val(prim: Primitive) -> Any: for value in [ prim.integer, @@ -303,6 +365,8 @@ def __init__(self, var: str, val: Union[NodeOutput, _literals_models.Literal]): self._var = var self._promise_ready = True self._val = val + self._ref = None + self._attr_path: List[Union[str, int]] = [] if val and isinstance(val, NodeOutput): self._ref = val self._promise_ready = False @@ -347,7 +411,7 @@ def ref(self) -> NodeOutput: """ If the promise is NOT READY / Incomplete, then it maps to the origin node that owns the promise """ - return self._ref + return self._ref # type: ignore @property def var(self) -> str: @@ -356,6 +420,14 @@ def var(self) -> str: """ return self._var + @property + def attr_path(self) -> List[Union[str, int]]: + """ + The attribute path the promise will be resolved with. + :rtype: List[Union[str, int]] + """ + return self._attr_path + def eval(self) -> Any: if not self._promise_ready or self._val is None: raise ValueError("Cannot Eval with incomplete promises") @@ -414,11 +486,64 @@ def with_overrides(self, *args, **kwargs): def __repr__(self): if self._promise_ready: return f"Resolved({self._var}={self._val})" - return f"Promise(node:{self.ref.node_id}.{self._var})" + return f"Promise(node:{self.ref.node_id}.{self._var}.{self.attr_path})" def __str__(self): return str(self.__repr__()) + def deepcopy(self) -> Promise: + new_promise = Promise(var=self.var, val=self.val) + new_promise._promise_ready = self._promise_ready + new_promise._ref = self._ref + new_promise._attr_path = deepcopy(self._attr_path) + return new_promise + + def __getitem__(self, key) -> Promise: + """ + When we use [] to access the attribute on the promise, for example + + ``` + @workflow + def wf(): + o = t1() + t2(x=o["a"][0]) + ``` + + The attribute keys are appended on the promise and a new promise is returned with the updated attribute path. + We don't modify the original promise because it might be used in other places as well. + """ + + return self._append_attr(key) + + def __getattr__(self, key) -> Promise: + """ + When we use . to access the attribute on the promise, for example + + ``` + @workflow + def wf(): + o = t1() + t2(o.a.b) + ``` + + The attribute keys are appended on the promise and a new promise is returned with the updated attribute path. + We don't modify the original promise because it might be used in other places as well. + """ + + return self._append_attr(key) + + def _append_attr(self, key) -> Promise: + new_promise = self.deepcopy() + + # The attr_path on the promise is for local_execute + new_promise._attr_path.append(key) + + if new_promise.ref is not None: + # The attr_path on the ref node is for remote execute + new_promise._ref = new_promise.ref.with_attr(key) + + return new_promise + def create_native_named_tuple( ctx: FlyteContext, @@ -710,13 +835,16 @@ def __repr__(self): class NodeOutput(type_models.OutputReference): - def __init__(self, node: Node, var: str): + def __init__(self, node: Node, var: str, attr_path: Optional[List[Union[str, int]]] = None): """ :param node: :param var: The name of the variable this NodeOutput references """ + if attr_path is None: + attr_path = [] + self._node = node - super(NodeOutput, self).__init__(self._node.id, var) + super(NodeOutput, self).__init__(self._node.id, var, attr_path) @property def node_id(self): @@ -736,6 +864,14 @@ def __repr__(self) -> str: s = f"Node({self.node if self.node.id is not None else None}:{self.var})" return s + def deepcopy(self) -> NodeOutput: + return NodeOutput(node=self.node, var=self.var, attr_path=deepcopy(self._attr_path)) + + def with_attr(self, key) -> NodeOutput: + new_node_output = self.deepcopy() + new_node_output._attr_path.append(key) + return new_node_output + class SupportsNodeCreation(Protocol): @property diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d190c52cdf3..cca5bd50d5a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -405,7 +405,23 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: f"evaluation doesn't work with json dataclasses" ) - return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema) + # Recursively construct the dataclass_type which contains the literal type of each field + literal_type = {} + + # Get the type of each field from dataclass + for field in t.__dataclass_fields__.values(): # type: ignore + try: + literal_type[field.name] = TypeEngine.to_literal_type(field.type) + except Exception as e: + logger.warning( + "Field {} of type {} cannot be converted to a literal type. Error: {}".format( + field.name, field.type, e + ) + ) + + ts = TypeStructure(tag="", dataclass_type=literal_type) + + return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if not dataclasses.is_dataclass(python_val): diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index 0bb5e6c03b0..91eb0842730 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -94,3 +94,7 @@ class FlyteInvalidInputException(FlyteUserException): def __init__(self, request: typing.Any): self.request = request super().__init__() + + +class FlytePromiseAttributeResolveException(FlyteAssertion): + _ERROR_CODE = "USER:PromiseAttributeResolveError" diff --git a/flytekit/models/types.py b/flytekit/models/types.py index 5af60957690..1301d50d0db 100644 --- a/flytekit/models/types.py +++ b/flytekit/models/types.py @@ -1,5 +1,6 @@ import json as _json import typing +from typing import Dict from flyteidl.core import types_pb2 as _types_pb2 from google.protobuf import json_format as _json_format @@ -127,21 +128,34 @@ class TypeStructure(_common.FlyteIdlEntity): Models _types_pb2.TypeStructure """ - def __init__(self, tag: str): + def __init__(self, tag: str, dataclass_type: Dict[str, "LiteralType"] = None): self._tag = tag + self._dataclass_type = dataclass_type @property def tag(self) -> str: return self._tag + @property + def dataclass_type(self) -> Dict[str, "LiteralType"]: + return self._dataclass_type + def to_flyte_idl(self) -> _types_pb2.TypeStructure: return _types_pb2.TypeStructure( tag=self._tag, + dataclass_type={k: v.to_flyte_idl() for k, v in self._dataclass_type.items()} + if self._dataclass_type is not None + else None, ) @classmethod def from_flyte_idl(cls, proto: _types_pb2.TypeStructure): - return cls(tag=proto.tag) + return cls( + tag=proto.tag, + dataclass_type={k: LiteralType.from_flyte_idl(v) for k, v in proto.dataclass_type.items()} + if proto.dataclass_type is not None + else None, + ) class StructuredDatasetType(_common.FlyteIdlEntity): @@ -396,16 +410,18 @@ def from_flyte_idl(cls, proto): class OutputReference(_common.FlyteIdlEntity): - def __init__(self, node_id, var): + def __init__(self, node_id, var, attr_path: typing.List[typing.Union[str, int]] = None): """ A reference to an output produced by a node. The type can be retrieved -and validated- from the underlying interface of the node. :param Text node_id: Node id must exist at the graph layer. :param Text var: Variable name must refer to an output variable for the node. + :param List[Union[str, int]] attr_path: The attribute path the promise will be resolved with. """ self._node_id = node_id self._var = var + self._attr_path = attr_path if attr_path is not None else [] @property def node_id(self): @@ -423,6 +439,14 @@ def var(self): """ return self._var + @property + def attr_path(self) -> typing.List[typing.Union[str, int]]: + """ + The attribute path the promise will be resolved with. + :rtype: list[union[str, int]] + """ + return self._attr_path + @var.setter def var(self, var_name): self._var = var_name @@ -431,7 +455,17 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.types.OutputReference """ - return _types_pb2.OutputReference(node_id=self.node_id, var=self.var) + return _types_pb2.OutputReference( + node_id=self.node_id, + var=self.var, + attr_path=[ + _types_pb2.PromiseAttribute( + string_value=p if type(p) == str else None, + int_value=p if type(p) == int else None, + ) + for p in self._attr_path + ], + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -439,7 +473,11 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.types.OutputReference pb2_object: :rtype: OutputReference """ - return cls(node_id=pb2_object.node_id, var=pb2_object.var) + return cls( + node_id=pb2_object.node_id, + var=pb2_object.var, + attr_path=[p.string_value or p.int_value for p in pb2_object.attr_path], + ) class Error(_common.FlyteIdlEntity): diff --git a/setup.py b/setup.py index d9123af1c4a..ccbbb5a24c5 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.5.16", + "flyteidl>=1.10.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 6a487b464a3..b6decebd0d9 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -1,20 +1,24 @@ import typing from dataclasses import dataclass +from typing import Dict, List import pytest -from dataclasses_json import DataClassJsonMixin +from dataclasses_json import DataClassJsonMixin, dataclass_json from typing_extensions import Annotated from flytekit import LaunchPlan, task, workflow from flytekit.core import context_manager -from flytekit.core.context_manager import CompilationState +from flytekit.core.context_manager import CompilationState, FlyteContextManager from flytekit.core.promise import ( + Promise, VoidPromise, create_and_link_node, create_and_link_node_from_remote, + resolve_attr_path_in_promise, translate_inputs_to_literals, ) -from flytekit.exceptions.user import FlyteAssertion +from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException from flytekit.types.pickle.pickle import BatchSize @@ -159,3 +163,62 @@ def func(foo: Optional[int] = None): wf.add_entity(func, foo=None) wf() + + +def test_promise_with_attr_path(): + from dataclasses import dataclass + from typing import Dict, List + + from dataclasses_json import dataclass_json + + @dataclass_json + @dataclass + class Foo: + a: str + + @task + def t1() -> (List[str], Dict[str, str], Foo): + return ["a", "b"], {"a": "b"}, Foo(a="b") + + @task + def t2(a: str) -> str: + return a + + @workflow + def my_workflow() -> (str, str, str): + l, d, f = t1() + o1 = t2(a=l[0]) + o2 = t2(a=d["a"]) + o3 = t2(a=f.a) + return o1, o2, o3 + + # Run a local execution with promises having atrribute path + o1, o2, o3 = my_workflow() + assert o1 == "a" + assert o2 == "b" + assert o3 == "b" + + +def test_resolve_attr_path_in_promise(): + @dataclass_json + @dataclass + class Foo: + b: str + + src = {"a": [Foo(b="foo")]} + + src_lit = TypeEngine.to_literal( + FlyteContextManager.current_context(), + src, + Dict[str, List[Foo]], + TypeEngine.to_literal_type(Dict[str, List[Foo]]), + ) + src_promise = Promise("val1", src_lit) + + # happy path + tgt_promise = resolve_attr_path_in_promise(src_promise["a"][0]["b"]) + assert "foo" == TypeEngine.to_python_value(FlyteContextManager.current_context(), tgt_promise.val, str) + + # exception + with pytest.raises(FlytePromiseAttributeResolveException): + tgt_promise = resolve_attr_path_in_promise(src_promise["c"]) diff --git a/tests/flytekit/unit/models/test_types.py b/tests/flytekit/unit/models/test_types.py index 82f0b9b519d..93242c11d91 100644 --- a/tests/flytekit/unit/models/test_types.py +++ b/tests/flytekit/unit/models/test_types.py @@ -108,3 +108,14 @@ def test_output_reference(): assert obj.var == "var1" obj2 = _types.OutputReference.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 + + +def test_output_reference_with_attr_path(): + obj = _types.OutputReference(node_id="node1", var="var1", attr_path=["a", 0]) + assert obj.node_id == "node1" + assert obj.var == "var1" + assert obj.attr_path[0] == "a" + assert obj.attr_path[1] == 0 + + obj2 = _types.OutputReference.from_flyte_idl(obj.to_flyte_idl()) + assert obj == obj2