Skip to content

Commit

Permalink
BranchNode local execution(#4080)
Browse files Browse the repository at this point in the history
Signed-off-by: Da-Yi Wu <[email protected]>

Resolve dynamic cases by compile one more time

Signed-off-by: Da-Yi Wu <[email protected]>

Modify Gate local execution

Signed-off-by: Da-Yi Wu <[email protected]>
  • Loading branch information
ericwudayi committed Oct 26, 2023
1 parent 1b5692a commit c3579f2
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 10 deletions.
23 changes: 21 additions & 2 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,28 @@


class BranchNode(object):
def __init__(self, name: str, ifelse_block: _core_wf.IfElseBlock):
def __init__(self, name: str, ifelse_block: _core_wf.IfElseBlock, cs: ConditionalSection):
self._name = name
self._ifelse_block = ifelse_block
self._cs = cs

@property
def name(self):
return self._name

## may output node or None
def __call__(self, **kwargs):
## conditions section have cases
## condition section run case in sequence
## each item is Case
## Case is either
self._cs.eval_by_kwargs(**kwargs)
for c in self._cs.cases:
if c.expr is None:
return c.output_node
if c.expr.eval():
return c.output_node


class ConditionalSection:
"""
Expand Down Expand Up @@ -161,6 +175,11 @@ def __repr__(self):
def __str__(self):
return self.__repr__()

def eval_by_kwargs(self, **kwargs):
for c in self._cases:
if c.expr is not None:
c.expr.eval_by_kwargs(**kwargs)


class LocalExecutedConditionalSection(ConditionalSection):
def __init__(self, name: str):
Expand Down Expand Up @@ -474,7 +493,7 @@ def to_ifelse_block(node_id: str, cs: ConditionalSection) -> Tuple[_core_wf.IfEl

def to_branch_node(name: str, cs: ConditionalSection) -> Tuple[BranchNode, typing.List[Promise]]:
blocks, promises = to_ifelse_block(name, cs)
return BranchNode(name=name, ifelse_block=blocks), promises
return BranchNode(name=name, ifelse_block=blocks, cs=cs), promises


def conditional(name: str) -> ConditionalSection:
Expand Down
16 changes: 13 additions & 3 deletions flytekit/core/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,16 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
return p

# Assume this is an approval operation since that's the only remaining option.
msg = click.style("[Approval Gate] ", fg="yellow") + click.style(
f"@{self.name} Approve {typing.cast(Promise, self._upstream_item).val.value}?", fg="cyan"
)
# msg = click.style("[Approval Gate] ", fg="yellow") + click.style(
# f"@{self.name} Approve {typing.cast(Promise, self._upstream_item).val.value}?", fg="cyan"
# )

assert len(kwargs) == 1

value = kwargs[list(kwargs.keys())[0]]
if isinstance(value, Promise):
value = value.eval()
msg = click.style("[Approval Gate] ", fg="yellow") + click.style(f"@{self.name} Approve {value}?", fg="cyan")
proceed = click.confirm(msg, default=True)
if proceed:
# We need to return a promise here, and a promise is what should've been passed in by the call in approve()
Expand All @@ -127,6 +134,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
def local_execution_mode(self):
return ExecutionState.Mode.LOCAL_TASK_EXECUTION

def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore


def wait_for_input(name: str, timeout: datetime.timedelta, expected_type: typing.Type):
"""Create a Gate object that waits for user input of the specified type.
Expand Down
37 changes: 37 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,24 @@ def eval(self) -> bool:

return _comparators[self.op](lhs, rhs)

def eval_by_kwargs(self, **kwargs):
if isinstance(self.lhs, Promise):
key = f"{self.lhs.ref.node_id}.{self.lhs._var}"
if key in kwargs:
self._lhs._val = kwargs[key]._val
self._lhs._promise_ready = True
## elif is Conjunction or ComparisonExpression recursive run
elif isinstance(self.lhs, ConjunctionExpression) or isinstance(self.lhs, ComparisonExpression):
self.lhs.eval_by_kwargs(**kwargs)

if isinstance(self.rhs, Promise):
key = f"{self.rhs.ref.node_id}.{self.rhs._var}"
if key in kwargs:
self._rhs._val = kwargs[key]._val
self._rhs._promise_ready = True
elif isinstance(self.rhs, ConjunctionExpression) or isinstance(self.rhs, ComparisonExpression):
self.rhs.eval_by_kwargs(**kwargs)

def __and__(self, other):
return ConjunctionExpression(lhs=self, op=ConjunctionOps.AND, rhs=other)

Expand Down Expand Up @@ -256,6 +274,24 @@ def eval(self) -> bool:

return l_eval or r_eval

def eval_by_kwargs(self, **kwargs):
if isinstance(self.lhs, Promise):
key = f"{self.lhs.ref.node_id}.{self.lhs._var}"
if key in kwargs:
self._lhs._val = kwargs[key]._val
self._lhs._promise_ready = True
## elif is Conjunction or ComparisonExpression recursive run
elif isinstance(self.lhs, ConjunctionExpression) or isinstance(self.lhs, ComparisonExpression):
self.lhs.eval_by_kwargs(**kwargs)

if isinstance(self.rhs, Promise):
key = f"{self.rhs.ref.node_id}.{self.rhs._var}"
if key in kwargs:
self._rhs._val = kwargs[key]._val
self._rhs._promise_ready = True
elif isinstance(self.rhs, ConjunctionExpression) or isinstance(self.rhs, ComparisonExpression):
self.rhs.eval_by_kwargs(**kwargs)

def __and__(self, other: Union[ComparisonExpression, "ConjunctionExpression"]):
return ConjunctionExpression(lhs=self, op=ConjunctionOps.AND, rhs=other)

Expand Down Expand Up @@ -1043,6 +1079,7 @@ def flyte_entity_call_handler(
return None
return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)
else:
## If LocallyExecutable is workflow, then we would get into here,
mode = cast(LocallyExecutable, entity).local_execution_mode()
with FlyteContextManager.with_context(
ctx.with_execution_state(ctx.new_execution_state().with_params(mode=mode))
Expand Down
122 changes: 119 additions & 3 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import inspect
import typing
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from functools import update_wrapper
Expand All @@ -11,7 +12,7 @@
from flytekit.core import constants as _common_constants
from flytekit.core.base_task import PythonTask
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.condition import ConditionalSection
from flytekit.core.condition import BranchNode, ConditionalSection
from flytekit.core.context_manager import (
CompilationState,
ExecutionState,
Expand Down Expand Up @@ -272,6 +273,8 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis
input_kwargs = self.python_interface.default_inputs_as_kwargs
input_kwargs.update(kwargs)
self.compile()
## Cannot just runinto the flyte_entity_call, we should pass the graph into it

try:
return flyte_entity_call_handler(self, *args, **input_kwargs)
except Exception as exc:
Expand All @@ -294,7 +297,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
native_types=self.python_interface.inputs,
)
kwargs_literals = {k: Promise(var=k, val=v) for k, v in literal_map.items()}
## This is dummy, because it is already compiled
self.compile()
## Execute one more time on workflow local_execute
function_outputs = self.execute(**kwargs_literals)

if inspect.iscoroutine(function_outputs):
Expand Down Expand Up @@ -671,6 +676,10 @@ def compile(self, **kwargs):
# Construct the default input promise bindings, but then override with the provided inputs, if any
input_kwargs = construct_input_promises([k for k in self.interface.inputs.keys()])
input_kwargs.update(kwargs)

# This function would run over all the func in workflow
# This is the first time enter __call__ where compilation state = 0
# All the thing this func would do is create Task node and link them.
workflow_outputs = exception_scopes.user_entry_point(self._workflow_function)(**input_kwargs)
all_nodes.extend(comp_ctx.compilation_state.nodes)

Expand All @@ -683,7 +692,6 @@ def compile(self, **kwargs):
if isinstance(n.flyte_entity, PythonAutoContainerTask) and n.flyte_entity.task_resolver == self:
logger.debug(f"WF {self.name} saving task {n.flyte_entity.name}")
self.add(n.flyte_entity)

# Iterate through the workflow outputs
bindings = []
output_names = list(self.interface.outputs.keys())
Expand Down Expand Up @@ -743,7 +751,115 @@ def execute(self, **kwargs):
call execute from dispatch_execute which is in local_execute, workflows should also call an execute inside
local_execute. This makes mocking cleaner.
"""
return exception_scopes.user_entry_point(self._workflow_function)(**kwargs)
# try:
return self.execute_with_graph(**kwargs)
# except:
# print (self._nodes[0])
#return exception_scopes.user_entry_point(self._workflow_function)(**kwargs)

def execute_with_graph(self, **kwargs):
"""
This function is try to execute the workflows and tasks with pre-defined user sequence. Tasks are executed
according to the sequence of the graph by topology sort.
"""

## Set graph according to node
"""
Check if the graph is valid, and do the topological sort
"""
# 0 = unvisited, 1 = in progress, 2 = visited
visited = defaultdict(int)
sorted_node = []
# Create a map that holds the outputs of each node.
intermediate_node_outputs: Dict[Node, Dict[str, Promise]] = {GLOBAL_START_NODE: {}}

## This line only used for dynamic workflow

self.compile(**kwargs)
# Start things off with the outputs of the global input node, i.e. the inputs to the workflow.
# local_execute should've already ensured that all the values in kwargs are Promise objects
for k, v in kwargs.items():
intermediate_node_outputs[GLOBAL_START_NODE][k] = v

def toplogical_sort_node(node):
if visited[node] == 0:
visited[node] = 1
# print (node, node.upstream_nodes)
for n in node.upstream_nodes:
toplogical_sort_node(n)
visited[node] = 2
sorted_node.append(node)
elif visited[node] == 1:
raise Exception("Cycle detected")

def execute_node(node):
if node not in intermediate_node_outputs.keys():
intermediate_node_outputs[node] = {}

# Retrieve the entity from the node, and call it by looking up the promises the node's bindings require,
# and then fill them in using the node output tracker map we have.
entity = node.flyte_entity
entity_kwargs = get_promise_map(node.bindings, intermediate_node_outputs)

# When entity is conditional
if entity is None:
return

if isinstance(entity, BranchNode):
sub_node = entity(**entity_kwargs)
results = execute_node(sub_node)
intermediate_node_outputs[node].update(intermediate_node_outputs[sub_node])
return
else:
# Handle the calling and outputs of each node's entity
results = entity(**entity_kwargs)
expected_output_names = list(entity.python_interface.outputs.keys())

if isinstance(results, VoidPromise) or results is None:
return # pragma: no cover # Move along, nothing to assign

# Because we should've already returned in the above check, we just raise an Exception here.
if len(entity.python_interface.outputs) == 0:
raise FlyteValueException(results, "Interface output should've been VoidPromise or None.")

# if there's only one output,
if len(expected_output_names) == 1:
if entity.python_interface.output_tuple_name and isinstance(results, tuple):
intermediate_node_outputs[node][expected_output_names[0]] = results[0]
else:
intermediate_node_outputs[node][expected_output_names[0]] = results

else:
if len(results) != len(expected_output_names):
raise FlyteValueException(results, f"Different lengths {results} {expected_output_names}")
for idx, r in enumerate(results):
intermediate_node_outputs[node][expected_output_names[idx]] = r

for node in self._nodes:
if visited[node] == 0:
toplogical_sort_node(node)
# print(sorted_node)
# Next iterate through the nodes in order.
for n in sorted_node:
execute_node(n)

# The values that we return below from the output have to be pulled by fulfilling all of the
# workflow's output bindings.
# The return style here has to match what 1) what the workflow would've returned had it been declared
# functionally, and 2) what a user would return in mock function. That is, if it's a tuple, then it
# should be a tuple here, if it's a one element named tuple, then we do a one-element non-named tuple,
# if it's a single element then we return a single element
if len(self._output_bindings) == 0:
return None

if len(self.output_bindings) == 1:
# Again use presence of output_tuple_name to understand that we're dealing with a one-element
# named tuple
if self.python_interface.output_tuple_name:
return (get_promise(self.output_bindings[0].binding, intermediate_node_outputs),)
# Just a normal single element
return get_promise(self.output_bindings[0].binding, intermediate_node_outputs)
return tuple([get_promise(b.binding, intermediate_node_outputs) for b in self.output_bindings])


@overload
Expand Down
5 changes: 3 additions & 2 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ def dynamic_wf() -> str:
def wf() -> str:
return dynamic_wf()

assert wf() == "hello"

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(settings)
) as ctx:
Expand All @@ -200,6 +198,9 @@ def wf() -> str:
assert dynamic_job_spec.nodes[0].task_node.overrides.resources.requests[0].value == "3"
assert dynamic_job_spec.nodes[0].task_node.overrides.resources.requests[1].value == "5Gi"

assert wf() == "hello"



def test_dynamic_return_dict():
@dynamic
Expand Down

0 comments on commit c3579f2

Please sign in to comment.