diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index bc7b4df865..c549dc21aa 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -18,19 +18,51 @@ ) from flytekit.models.core import condition as _core_cond from flytekit.models.core import workflow as _core_wf +from flytekit.models.core.workflow import IfElseBlock from flytekit.models.literals import Binding, BindingData, Literal, RetryStrategy from flytekit.models.types import Error class BranchNode(object): - def __init__(self, name: str, ifelse_block: _core_wf.IfElseBlock): + def __init__(self, name: str, ifelse_block: _core_wf.IfElseBlock, cs: typing.Optional[ConditionalSection] = None): self._name = name self._ifelse_block = ifelse_block + self._cs = cs @property def name(self): return self._name + # Output Node or None + def __call__(self, **kwargs): + for c in self._cs.cases: + _update_promise(c.expr, kwargs) + if c.expr is None: + if c.output_node is None: + raise ValueError(c.err) + return c.output_node + if c.expr.eval(): + return c.output_node + + +def _update_promise( + operand: Union[Literal, Promise, ConjunctionExpression, ComparisonExpression], promises: typing.Dict[str, Promise] +): + if isinstance(operand, Literal): + return Promise(var="placeholder", val=operand) + elif isinstance(operand, ConjunctionExpression) or isinstance(operand, ComparisonExpression): + lhs = _update_promise(operand.lhs, promises) + rhs = _update_promise(operand.rhs, promises) + if isinstance(operand._lhs, Promise) and lhs is not None: + operand._lhs._val = lhs.val + operand._lhs._promise_ready = True + if isinstance(operand._rhs, Promise) and rhs is not None: + operand._rhs._val = rhs.val + operand._rhs._promise_ready = True + + elif isinstance(operand, Promise): + return promises[create_branch_node_promise_var(operand.ref.node_id, operand.var)] + class ConditionalSection: """ @@ -111,7 +143,7 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP return self._compute_outputs(n) return self._condition - def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case: + def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression, bool]) -> Case: return self._condition._if(expr) def compute_output_vars(self) -> typing.Optional[typing.List[str]]: @@ -245,7 +277,7 @@ class Case(object): def __init__( self, cs: ConditionalSection, - expr: Optional[Union[ComparisonExpression, ConjunctionExpression]], + expr: Optional[Union[ComparisonExpression, ConjunctionExpression, bool]], stmt: str = "elif", ): self._cs = cs @@ -329,7 +361,7 @@ class Condition(object): def __init__(self, cs: ConditionalSection): self._cs = cs - def _if(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case: + def _if(self, expr: Union[ComparisonExpression, ConjunctionExpression, bool]) -> Case: if expr is None: raise AssertionError(f"Required an expression received None for condition:{self._cs.name}.if_(...)") return self._cs.start_branch(Case(cs=self._cs, expr=expr, stmt="if_")) @@ -438,13 +470,13 @@ def transform_to_boolexpr( def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]: + if c.output_promise is None: + raise AssertionError("Illegal Condition block, with no output promise") expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr)) - if c.output_promise is not None: - n = c.output_node - return _core_wf.IfBlock(condition=expr, then_node=n), promises + return _core_wf.IfBlock(condition=expr, then_node=c.output_node), promises -def to_ifelse_block(node_id: str, cs: ConditionalSection) -> Tuple[_core_wf.IfElseBlock, typing.List[Binding]]: +def to_ifelse_block(node_id: str, cs: ConditionalSection) -> tuple[IfElseBlock, list[Promise]]: if len(cs.cases) == 0: raise AssertionError("Illegal Condition block, with no if-else cases") if len(cs.cases) < 2: @@ -474,7 +506,7 @@ def to_ifelse_block(node_id: str, cs: ConditionalSection) -> Tuple[_core_wf.IfEl def to_branch_node(name: str, cs: ConditionalSection) -> Tuple[BranchNode, typing.List[Promise]]: blocks, promises = to_ifelse_block(name, cs) - return BranchNode(name=name, ifelse_block=blocks), promises + return BranchNode(name=name, ifelse_block=blocks, cs=cs), promises def conditional(name: str) -> ConditionalSection: diff --git a/flytekit/core/gate.py b/flytekit/core/gate.py index 7685165743..df56179cf3 100644 --- a/flytekit/core/gate.py +++ b/flytekit/core/gate.py @@ -113,7 +113,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr return p # Assume this is an approval operation since that's the only remaining option. - v = typing.cast(Promise, self._upstream_item).val.value + upstream_item = kwargs[list(kwargs.keys())[0]] + v = typing.cast(Promise, upstream_item).val.value if isinstance(v, Scalar): v = scalar_to_string(v) msg = click.style("[Approval Gate] ", fg="yellow") + click.style( @@ -132,6 +133,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr def local_execution_mode(self): return ExecutionState.Mode.LOCAL_TASK_EXECUTION + def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore + def wait_for_input(name: str, timeout: datetime.timedelta, expected_type: typing.Type): """Create a Gate object that waits for user input of the specified type. diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index ffb9279766..c3d650b88e 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -361,7 +361,7 @@ def t1() -> (int, str): ... # TODO: Currently, NodeOutput we're creating is the slimmer core package Node class, but since only the # id is used, it's okay for now. Let's clean all this up though. - def __init__(self, var: str, val: Union[NodeOutput, _literals_models.Literal]): + def __init__(self, var: str, val: Optional[Union[NodeOutput, _literals_models.Literal]]): self._var = var self._promise_ready = True self._val = val @@ -738,6 +738,7 @@ def binding_data_from_python_std( # This is the scalar case - e.g. my_task(in1=5) scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar + return _literals_models.BindingData(scalar=scalar) @@ -1190,6 +1191,7 @@ def flyte_entity_call_handler( return None return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs) else: + # If LocallyExecutable is workflow, then we would get into here, mode = cast(LocallyExecutable, entity).local_execution_mode() with FlyteContextManager.with_context( ctx.with_execution_state(ctx.new_execution_state().with_params(mode=mode)) diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index e1e80a4227..b786267b2a 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -281,8 +281,25 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: # The rest of this function mimics the local_execute of the workflow. We can't use the workflow # local_execute directly though since that converts inputs into Promises. logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") - self._create_and_cache_dynamic_workflow() - function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs) + if not ctx.compilation_state: + cs = ctx.new_compilation_state(prefix="d") + else: + cs = ctx.compilation_state.with_params(prefix="d") + + updated_ctx = ctx.with_compilation_state(cs) + if self.execution_mode == self.ExecutionBehavior.DYNAMIC: + es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.DYNAMIC_TASK_EXECUTION) + updated_ctx = updated_ctx.with_execution_state(es) + + with FlyteContextManager.with_context(updated_ctx): + self._create_and_cache_dynamic_workflow() + cast(PythonFunctionWorkflow, self._wf).compile(**kwargs) + + # Not sure what Behavior this would happen. + if self._wf is None: + raise ValueError("Dynamic workflow was not created during compilation") + + function_outputs = self._wf.execute(**kwargs) if isinstance(function_outputs, VoidPromise) or function_outputs is None: return VoidPromise(self.name) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 8019edf869..56e375c4e3 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -3,6 +3,7 @@ import asyncio import inspect import typing +from collections import defaultdict from dataclasses import dataclass from enum import Enum from functools import update_wrapper @@ -11,7 +12,7 @@ from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.class_based_resolver import ClassStorageTaskResolver -from flytekit.core.condition import ConditionalSection, conditional +from flytekit.core.condition import ConditionalSection, conditional, BranchNode from flytekit.core.context_manager import ( CompilationState, ExecutionState, @@ -37,6 +38,7 @@ extract_obj_name, flyte_entity_call_handler, translate_inputs_to_literals, + resolve_attr_path_in_promise, ) from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference @@ -134,7 +136,9 @@ def get_promise(binding_data: _literal_models.BindingData, outputs_cache: Dict[N ) # b.var is the name of the input to the task # binding_data.promise.var is the name of the upstream node's output we want - return outputs_cache[binding_data.promise.node][binding_data.promise.var] + o = outputs_cache[binding_data.promise.node][binding_data.promise.var] + o._attr_path = binding_data.promise.attr_path + return resolve_attr_path_in_promise(o) elif binding_data.scalar is not None: return Promise(var="placeholder", val=_literal_models.Literal(scalar=binding_data.scalar)) elif binding_data.collection is not None: @@ -194,6 +198,9 @@ def __init__( self._nodes: List[Node] = [] self._output_bindings: List[_literal_models.Binding] = [] self._docs = docs + # Create a map that holds the outputs of each node. + self._intermediate_node_outputs: Dict[Node, Dict[str, Promise]] = {GLOBAL_START_NODE: {}} + self._return_nodes: List[Node] = [] if self._python_interface.docstring: if self.docs is None: @@ -250,6 +257,10 @@ def nodes(self) -> List[Node]: self.compile() return self._nodes + @property + def intermediate_node_outputs(self) -> Dict[Node, Dict[str, Promise]]: + return self._intermediate_node_outputs + def __repr__(self): return ( f"WorkflowBase - {self._name} && " @@ -272,6 +283,7 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis input_kwargs = self.python_interface.default_inputs_as_kwargs input_kwargs.update(kwargs) self.compile() + try: return flyte_entity_call_handler(self, *args, **input_kwargs) except Exception as exc: @@ -281,6 +293,72 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis def execute(self, **kwargs): raise Exception("Should not be called") + def execute_node(self, node): + """ + Execute Node and store the output in the intermediate_node_outputs. + This function handle the recursive call of the branch node. + """ + + if node not in self.intermediate_node_outputs.keys(): + self.intermediate_node_outputs[node] = {} + + # Retrieve the entity from the node, and call it by looking up the promises the node's bindings require, + # and then fill them in using the node output tracker map we have. + entity = node.flyte_entity + if entity is None: + # start node doesn't have an entity + return + entity_kwargs = get_promise_map(node.bindings, self.intermediate_node_outputs) + + if isinstance(entity, BranchNode): + sub_node = entity(**entity_kwargs) + self.execute_node(sub_node) + self.intermediate_node_outputs[node].update(self.intermediate_node_outputs[sub_node]) + return + else: + # Handle the calling and outputs of each node's entity + results = entity(**entity_kwargs) + expected_output_names = list(entity.python_interface.outputs.keys()) + + if isinstance(results, VoidPromise) or results is None: + return # pragma: no cover # Move along, nothing to assign + + # Because we should've already returned in the above check, we just raise an Exception here. + if len(entity.python_interface.outputs) == 0: + raise FlyteValueException(results, "Interface output should've been VoidPromise or None.") + + # if there's only one output, + if len(expected_output_names) == 1: + if entity.python_interface.output_tuple_name and isinstance(results, tuple): + self.intermediate_node_outputs[node][expected_output_names[0]] = results[0] + else: + self.intermediate_node_outputs[node][expected_output_names[0]] = results + + else: + if len(results) != len(expected_output_names): + raise FlyteValueException(results, f"Different lengths {results} {expected_output_names}") + for idx, r in enumerate(results): + self.intermediate_node_outputs[node][expected_output_names[idx]] = r + + def create_promise(self): + if len(self.python_interface.outputs) == 0: + return VoidPromise(self.name) + + # The values that we return below from the output have to be pulled by fulfilling all the + # workflow's output bindings. + # The return style here has to match what 1) what the workflow would've returned had it been declared + # functionally, and 2) what a user would return in mock function. That is, if it's a tuple, then it + # should be a tuple here, if it's a one element named tuple, then we do a one-element non-named tuple, + # if it's a single element then we return a single element + if len(self.output_bindings) == 1: + # Again use presence of output_tuple_name to understand that we're dealing with a one-element + # named tuple + if self.python_interface.output_tuple_name: + return (get_promise(self.output_bindings[0].binding, self.intermediate_node_outputs),) + # Just a normal single element + return get_promise(self.output_bindings[0].binding, self.intermediate_node_outputs) + return tuple([get_promise(b.binding, self.intermediate_node_outputs) for b in self.output_bindings]) + def compile(self, **kwargs): pass @@ -435,7 +513,7 @@ def execute(self, **kwargs): """ Called by local_execute. This function is how local execution for imperative workflows runs. Because when an entity is added using the add_entity function, all inputs to that entity should've been already declared, we - can just iterate through the nodes in order and we shouldn't run into any dependency issues. That is, we force + can just iterate through the nodes in order, and we shouldn't run into any dependency issues. That is, we force the user to declare entities already in a topological sort. To keep track of outputs, we create a map to start things off, filled in only with the workflow inputs (if any). As things are run, their outputs are stored in this map. @@ -444,67 +522,20 @@ def execute(self, **kwargs): if not self.ready(): raise FlyteValidationException(f"Workflow not ready, wf is currently {self}") - # Create a map that holds the outputs of each node. - intermediate_node_outputs: Dict[Node, Dict[str, Promise]] = {GLOBAL_START_NODE: {}} - # Start things off with the outputs of the global input node, i.e. the inputs to the workflow. # local_execute should've already ensured that all the values in kwargs are Promise objects for k, v in kwargs.items(): - intermediate_node_outputs[GLOBAL_START_NODE][k] = v + self.intermediate_node_outputs[GLOBAL_START_NODE][k] = v # Next iterate through the nodes in order. for node in self.compilation_state.nodes: - if node not in intermediate_node_outputs.keys(): - intermediate_node_outputs[node] = {} - - # Retrieve the entity from the node, and call it by looking up the promises the node's bindings require, - # and then fill them in using the node output tracker map we have. - entity = node.flyte_entity - entity_kwargs = get_promise_map(node.bindings, intermediate_node_outputs) + if node not in self.intermediate_node_outputs.keys(): + self.intermediate_node_outputs[node] = {} + self.execute_node(node) - # Handle the calling and outputs of each node's entity - results = entity(**entity_kwargs) - expected_output_names = list(entity.python_interface.outputs.keys()) - - if isinstance(results, VoidPromise) or results is None: - continue # pragma: no cover # Move along, nothing to assign - - # Because we should've already returned in the above check, we just raise an Exception here. - if len(entity.python_interface.outputs) == 0: - raise FlyteValueException(results, "Interface output should've been VoidPromise or None.") - - # if there's only one output, - if len(expected_output_names) == 1: - if entity.python_interface.output_tuple_name and isinstance(results, tuple): - intermediate_node_outputs[node][expected_output_names[0]] = results[0] - else: - intermediate_node_outputs[node][expected_output_names[0]] = results - - else: - if len(results) != len(expected_output_names): - raise FlyteValueException(results, f"Different lengths {results} {expected_output_names}") - for idx, r in enumerate(results): - intermediate_node_outputs[node][expected_output_names[idx]] = r - - # The rest of this function looks like the above but now we're doing it for the workflow as a whole rather + # The rest of this function looks like the above, but now we're doing it for the workflow as a whole rather # than just one node at a time. - if len(self.python_interface.outputs) == 0: - return VoidPromise(self.name) - - # The values that we return below from the output have to be pulled by fulfilling all of the - # workflow's output bindings. - # The return style here has to match what 1) what the workflow would've returned had it been declared - # functionally, and 2) what a user would return in mock function. That is, if it's a tuple, then it - # should be a tuple here, if it's a one element named tuple, then we do a one-element non-named tuple, - # if it's a single element then we return a single element - if len(self.output_bindings) == 1: - # Again use presence of output_tuple_name to understand that we're dealing with a one-element - # named tuple - if self.python_interface.output_tuple_name: - return (get_promise(self.output_bindings[0].binding, intermediate_node_outputs),) - # Just a normal single element - return get_promise(self.output_bindings[0].binding, intermediate_node_outputs) - return tuple([get_promise(b.binding, intermediate_node_outputs) for b in self.output_bindings]) + return self.create_promise() def create_conditional(self, name: str) -> ConditionalSection: ctx = FlyteContext.current_context() @@ -678,7 +709,13 @@ def compile(self, **kwargs): # Construct the default input promise bindings, but then override with the provided inputs, if any input_kwargs = construct_input_promises([k for k in self.interface.inputs.keys()]) input_kwargs.update(kwargs) + + # This function would run over all the func in workflow + # This is the first time enter __call__ where compilation state = 0 + # All the thing this func would do is create Task node and link them. workflow_outputs = exception_scopes.user_entry_point(self._workflow_function)(**input_kwargs) + if isinstance(workflow_outputs, VoidPromise): + self._return_nodes = [workflow_outputs.ref.node] all_nodes.extend(comp_ctx.compilation_state.nodes) # This little loop was added as part of the task resolver change. The task resolver interface itself is @@ -697,6 +734,10 @@ def compile(self, **kwargs): # The reason the length 1 case is separate is because the one output might be a list. We don't want to # iterate through the list here, instead we should let the binding creation unwrap it and make a binding # collection/map out of it. + + if len(output_names) == 0 and not (isinstance(workflow_outputs, VoidPromise) or workflow_outputs is None): + raise FlyteValidationException("Workflow return value is not None, but no output type is specified.") + if len(output_names) == 1: if isinstance(workflow_outputs, tuple): if len(workflow_outputs) != 1: @@ -709,13 +750,20 @@ def compile(self, **kwargs): ) workflow_outputs = workflow_outputs[0] t = self.python_interface.outputs[output_names[0]] - b, _ = binding_from_python_std( + b, nodes = binding_from_python_std( ctx, output_names[0], self.interface.outputs[output_names[0]].type, workflow_outputs, t, ) + # Handling the scalar case for nodes + if len(nodes) == 0 and b is not None: + self._return_nodes = [ + Node(id="placeholder", metadata=None, bindings=b, upstream_nodes=None, flyte_entity=None) + ] + else: + self._return_nodes = nodes bindings.append(b) elif len(output_names) > 1: if not isinstance(workflow_outputs, tuple): @@ -726,7 +774,7 @@ def compile(self, **kwargs): if isinstance(workflow_outputs[i], ConditionalSection): raise AssertionError("A Conditional block (if-else) should always end with an `else_()` clause") t = self.python_interface.outputs[out] - b, _ = binding_from_python_std( + b, nodes = binding_from_python_std( ctx, out, self.interface.outputs[out].type, @@ -734,6 +782,9 @@ def compile(self, **kwargs): t, ) bindings.append(b) + if len(nodes) == 0 and b is not None: + nodes = [Node(id="placeholder", metadata=None, bindings=b, upstream_nodes=None, flyte_entity=None)] + self._return_nodes.extend(nodes) # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain self._nodes = all_nodes @@ -750,7 +801,53 @@ def execute(self, **kwargs): call execute from dispatch_execute which is in local_execute, workflows should also call an execute inside local_execute. This makes mocking cleaner. """ - return exception_scopes.user_entry_point(self._workflow_function)(**kwargs) + # Start things off with the outputs of the global input node, i.e. the inputs to the workflow. + # local_execute should've already ensured that all the values in kwargs are Promise objects + + for k, v in kwargs.items(): + self.intermediate_node_outputs[GLOBAL_START_NODE][k] = v + + sorted_nodes = self.get_sorted_nodes() + + # Next iterate through the nodes in order. + for node in sorted_nodes: + self.execute_node(node) + + if len(self.output_bindings) == 0 and len(self._return_nodes) > 0: + # If size of return_nodes is 1, we check if the node has output + if len(self._return_nodes) == 1 and len(self.intermediate_node_outputs[self._return_nodes[0]]) > 0: + raise FlyteValidationException("Workflow return value is not None, but no output type is specified.") + if len(self._return_nodes) > 1: + raise FlyteValidationException("Workflow return multiple values, but no output types are specified.") + if len(self.output_bindings) > 0 and len(self._return_nodes) == 0: + raise FlyteValidationException("Workflow return value is None, but output type is specified.") + return self.create_promise() + + def get_sorted_nodes(self): + """ + This function is to do topological sort on the graph + """ + # 0 = unvisited, 1 = in progress, 2 = visited + visited: typing.Dict[Node, int] = defaultdict(int) + sorted_node = [] + + def topological_sort_node(node: Node): + if visited[node] == 1: + raise FlyteValidationException( + "Cycle detected or one node is called multiple times in local sequential chaining, please check out your chain dependecy." + ) + if visited[node] == 2: + return + visited[node] = 1 + for n in node.upstream_nodes: + topological_sort_node(n) + visited[node] = 2 + sorted_node.append(node) + + for node in self._nodes: + if visited[node] == 0: + topological_sort_node(node) + return sorted_node @overload diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index b3bf0c5eab..66c10ca5fa 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -9,6 +9,7 @@ from flytekit import task, workflow from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.condition import conditional +from flytekit.exceptions.user import FlyteValidationException from flytekit.models.core.workflow import Node from flytekit.tools.translator import get_serializable @@ -312,20 +313,25 @@ def branching(x: int) -> int: assert branching(x=3) == 5 -def test_no_output_condition(): +def test_output_condition(): @task def t(): ... @workflow - def wf1(): + def wf1() -> int: t() + return 3 @workflow def branching(x: int): return conditional("test").if_(x == 2).then(t()).else_().then(wf1()) assert branching(x=2) is None + with pytest.raises( + FlyteValidationException, match="Workflow return value is not None, but no output type is specified" + ): + assert branching(x=3) is None def test_subworkflow_condition_named_tuple(): diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index b6decebd0d..6adb7ed9c1 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -192,7 +192,7 @@ def my_workflow() -> (str, str, str): o3 = t2(a=f.a) return o1, o2, o3 - # Run a local execution with promises having atrribute path + # Run a local execution with promises having attribute path o1, o2, o3 = my_workflow() assert o1 == "a" assert o2 == "b" diff --git a/tests/flytekit/unit/core/test_type_conversion_errors.py b/tests/flytekit/unit/core/test_type_conversion_errors.py index fbdd2c8640..ced6168664 100644 --- a/tests/flytekit/unit/core/test_type_conversion_errors.py +++ b/tests/flytekit/unit/core/test_type_conversion_errors.py @@ -83,7 +83,6 @@ def test_workflow_with_task_error(correct_input): TypeError, match=( r"Encountered error while executing workflow '{}':\n" - r" Error encountered while executing 'wf_with_task_error':\n" r" Failed to convert outputs of task '.+' at position 0:\n" r" Expected value of type \ but got .+ of type .+" ).format(wf_with_task_error.name), diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 77051813d1..08e324d6fb 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -1277,10 +1277,17 @@ def test_wf_explicitly_returning_empty_task(): def t1(): ... + @task + def t2() -> int: + return 3 + @workflow def my_subwf(): - return t1() # This forces the wf local_execute to handle VoidPromises + a = t1() + t2() + return a # This forces the wf local_execute to handle VoidPromises + my_subwf() assert my_subwf() is None @@ -1622,7 +1629,7 @@ def wf2(a: typing.Union[int, str]) -> typing.Union[int, str]: with pytest.raises( TypeError, match=re.escape( - "Error encountered while executing 'wf2':\n" + f"Encountered error while executing workflow '{prefix}tests.flytekit.unit.core.test_type_hints.wf2':\n" f" Failed to convert inputs of task '{prefix}tests.flytekit.unit.core.test_type_hints.t2':\n" ' Cannot convert from to typing.Union[float, dict] (using tag str)' diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 974fad08ad..cc79fe6f4a 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -14,7 +14,7 @@ from flytekit.core.condition import conditional from flytekit.core.task import task from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow -from flytekit.exceptions.user import FlyteValidationException, FlyteValueException +from flytekit.exceptions.user import FlyteValidationException from flytekit.tools.translator import get_serializable from flytekit.types.schema import FlyteSchema @@ -233,7 +233,7 @@ def no_outputs_wf(): return t1(a=3) # Should raise an exception because the workflow returns something when it shouldn't - with pytest.raises(FlyteValueException): + with pytest.raises(FlyteValidationException): no_outputs_wf() # Should raise an exception because it doesn't return something when it should diff --git a/tests/flytekit/unit/core/test_workflows_local_chain.py b/tests/flytekit/unit/core/test_workflows_local_chain.py new file mode 100644 index 0000000000..14f84d8b87 --- /dev/null +++ b/tests/flytekit/unit/core/test_workflows_local_chain.py @@ -0,0 +1,412 @@ +import io +import sys +import typing +from collections import OrderedDict +from datetime import timedelta +from io import StringIO + +import pytest +from mock import patch + +import flytekit +from flytekit import dynamic, map_task +from flytekit.configuration import Image, ImageConfig +from flytekit.core.condition import conditional +from flytekit.core.gate import wait_for_input +from flytekit.core.node_creation import create_node +from flytekit.core.task import TaskMetadata, task +from flytekit.core.workflow import workflow +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + +""" +Test for simple tasks/subworkflow chaining +""" + + +def test_task_local_chain(): + @task + def task_a(): + print("a") + + @task + def task_b(): + print("b") + + @task + def task_c(): + print("c") + + @workflow() + def my_wf(): + a = task_a() + b = task_b() + c = task_c() + + c >> b >> a + + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + my_wf() + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == "c\nb\na\n" + + +def test_subworkflow_local_chain(): + @task + def task_a(): + print("a") + + @task + def task_b(): + print("b") + + @task + def task_c(): + print("c") + + @workflow + def sub_wf(): + t2 = task_b() + t3 = task_c() + + t3 >> t2 + + @workflow() + def my_wf(): + sf = sub_wf() + t3 = task_a() + t2 = task_b() + t1 = task_c() + + t1 >> t2 >> t3 >> sf + + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + my_wf() + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == "c\nb\na\nc\nb\n" + + +def test_task_pass_int_local_chain(): + @task + def task_five() -> int: + print("5") + return 5 + + @task + def task_six() -> int: + print("6") + return 6 + + @task + def task_add(a: int, b: int) -> int: + print(a + b) + return a + b + + @workflow() + def my_wf() -> int: + five = task_five() + six = task_six() + add = task_add(a=five, b=six) + six >> five >> add + return add + + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + result = my_wf() + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == "6\n5\n11\n" + assert result == 11 + + +def test_wf_nested_comp(): + @task + def t1(a: int) -> int: + print("t1", a) + a = a + 5 + return a + + @workflow() + def outer() -> typing.Tuple[int, int]: + # You should not do this. This is just here for testing. + @workflow + def wf2() -> int: + print("wf2") + return t1(a=5) + + ft = t1(a=3) + fwf = wf2() + fwf >> ft + return ft, fwf + + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + _ = outer() + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == "wf2\nt1 5\nt1 3\n" + + assert (8, 10) == outer() + entity_mapping = OrderedDict() + + model_wf = get_serializable(entity_mapping, serialization_settings, outer) + + assert len(model_wf.template.interface.outputs) == 2 + assert len(model_wf.template.nodes) == 2 + assert model_wf.template.nodes[1].workflow_node is not None + + sub_wf = model_wf.sub_workflows[0] + assert len(sub_wf.nodes) == 1 + assert sub_wf.nodes[0].id == "n0" + assert sub_wf.nodes[0].task_node.reference_id.name == "tests.flytekit.unit.core.test_workflows_local_chain.t1" + + +""" +Test for failing chaining +""" + + +def test_cycle_fail(): + @task + def five() -> int: + print("five") + return 5 + + @task + def t1(a: int) -> int: + print("t1", a) + a = a + 5 + return a + + @task + def t2(a: int) -> int: + print("t2", a) + a = a + 5 + return a + + @workflow + def wf(): + a = five() + b = t1(a=a) + c = t2(a=b) + c >> b >> a # wrong sequence will cause cycle + + @workflow + def wf_multiple_call(): + a = five() + b = t1(a=a) + c = t2(a=b) + + a >> a >> b >> c + + # c >> b >> a is an invalid execute sequence + with pytest.raises(Exception): + wf() + # a is called multiple times. + with pytest.raises(Exception): + wf_multiple_call() + + +""" +Test for conditional chaining +""" + + +def test_condition_local_chain(): + @task + def square(n: float) -> float: + print("square") + return n * n + + @task + def double(n: float) -> float: + print("double") + return 2 * n + + @workflow() + def multiplier_3(my_input: float): + a = ( + conditional("fractions") + .if_((my_input >= 0) & (my_input < 1.0)) + .then(double(n=my_input)) + .else_() + .then(square(n=my_input)) + ) + b = ( + conditional("fractions2") + .if_((my_input >= 0) & (my_input < 1.0)) + .then(square(n=my_input)) + .else_() + .then(double(n=my_input)) + ) + + b >> a + + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + multiplier_3(my_input=0.5) # call square first, then call double when input = 0.5 + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == "square\ndouble\n" + + +""" +Test for dynamic chaining +""" + + +def test_wf_dynamic_local_chain(): + @task + def t1(a: int) -> int: + print("t1") + a = a + 2 + return a + + @dynamic + def use_result(a: int) -> int: + print("call use_result") + if a > 6: + return 5 + else: + return 0 + + @dynamic + def use_result2(a: int) -> int: + print("call use_result2") + if a > 6: + return 0 + else: + return 5 + + @task + def t2(a: int) -> int: + print("t2") + return a + 3 + + @workflow + def wf(): + a1 = t1(a=7) + a2 = t2(a=9) + b1 = use_result(a=a1) + b2 = use_result2(a=a1) + a1 >> b2 >> b1 >> a2 + + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + wf() + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == "t1\ncall use_result2\ncall use_result\nt2\n" + + +def test_create_node_dynamic_local_local_chain(): + @task + def task1(s: str) -> str: + print("task1") + return s + + @task + def task2(s: str) -> str: + print("task2") + return s + + @dynamic + def dynamic_wf() -> str: + node_1 = create_node(task1, s="hello") + node_2 = create_node(task2, s="world") + node_2 >> node_1 + + return node_1.o0 + + @workflow + def wf() -> str: + return dynamic_wf() + + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + assert wf() == "hello" + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == "task2\ntask1\n" + + +""" +Test for gate chaining. +""" + + +def test_dyn_signal_no_approve_local_chain(): + @task + def t1(a: int) -> int: + return a + 5 + + @task + def t2(a: int) -> int: + return a + 6 + + @dynamic + def dyn(a: int) -> typing.Tuple[int, int]: + x = t1(a=a) + s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool) + s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int) + z = t1(a=3) + y = t2(a=s2) + z >> x >> s2 >> s1 + + return y, z + + @workflow + def wf_dyn(a: int) -> typing.Tuple[int, int]: + y, z = dyn(a=a) + return y, z + + with patch("sys.stdin", StringIO("3\ny\n")) as stdin, patch("sys.stdout", new_callable=StringIO): + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + wf_dyn(a=5) + sys.stdout = sys.__stdout__ + gate_message = "[Input Gate] Waiting for input @my-signal-name-2 of type : " + gate_message += "[Input Gate] Waiting for input @my-signal-name of type : " + assert capturedOutput.getvalue() == gate_message + assert stdin.read() == "" # all input consumed + + +""" +Test map task chaining +""" + + +def test_map_task_chaining(): + @task + def complex_task(a: int) -> str: + print("t1") + b = a + 2 + return str(b) + + @task + def complex_task2(a: int) -> str: + print("t2") + b = a + 5 + return str(b) + + maptask = map_task(complex_task, metadata=TaskMetadata(retries=1)) + maptask2 = map_task(complex_task2, metadata=TaskMetadata(retries=1)) + + @workflow + def w1(a: typing.List[int]) -> typing.List[str]: + t1 = maptask(a=a) + t2 = maptask2(a=a) + t2 >> t1 + return t1 + + capturedOutput = io.StringIO() + sys.stdout = capturedOutput + res = w1(a=[1, 2, 3]) + sys.stdout = sys.__stdout__ + assert capturedOutput.getvalue() == "t2\nt2\nt2\nt1\nt1\nt1\n" + assert res == ["3", "4", "5"]