Skip to content

Commit

Permalink
refactor(agents-api): Minor refactors
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 18, 2024
1 parent 6cd98ae commit 49ddd62
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 35 deletions.
4 changes: 3 additions & 1 deletion agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from beartype import beartype
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
Expand All @@ -8,6 +9,7 @@
from ...env import testing


@beartype
async def evaluate_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
Expand All @@ -21,7 +23,7 @@ async def evaluate_step(context: StepContext) -> StepOutcome:
return result

except Exception as e:
logging.error(f"Error in log_step: {e}")
logging.error(f"Error in evaluate_step: {e}")
return StepOutcome(output=None)


Expand Down
39 changes: 26 additions & 13 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import (
IfElseWorkflowStep,
)
from ...autogen.openapi_model import IfElseWorkflowStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


@activity.defn
@beartype
async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict:
raise NotImplementedError()
# context_data: dict = context.model_dump()
async def if_else_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, IfElseWorkflowStep)

expr: str = context.current_step.if_
output = simple_eval(expr, names=context.model_dump())

result = StepOutcome(output=output)
return result

except Exception as e:
logging.error(f"Error in if_else_step: {e}")
return StepOutcome(output=None)


# next_workflow = (
# context.current_step.then
# if simple_eval(context.current_step.if_, names=context_data)
# else context.current_step.else_
# )
# Note: This is here just for clarity. We could have just imported if_else_step directly
# They do the same thing, so we dont need to mock the if_else_step function
mock_if_else_step = if_else_step

# return {"goto_workflow": next_workflow}
if_else_step = activity.defn(name="if_else_step")(if_else_step if not testing else mock_if_else_step)
6 changes: 3 additions & 3 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

Expand All @@ -11,9 +12,8 @@
from ...env import testing


async def log_step(
context: StepContext,
) -> StepOutcome:
@beartype
async def log_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
Expand Down
28 changes: 10 additions & 18 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
with workflow.unsafe.imports_passed_through():
from ..activities.task_steps import (
evaluate_step,
# if_else_step,
if_else_step,
log_step,
# prompt_step,
raise_complete_async,
Expand All @@ -24,7 +24,7 @@
CreateTransitionRequest,
ErrorWorkflowStep,
EvaluateStep,
# IfElseWorkflowStep,
IfElseWorkflowStep,
LogStep,
# PromptStep,
ReturnStep,
Expand Down Expand Up @@ -56,7 +56,7 @@
# NOTE: local activities are directly called in the workflow executor
# They MUST NOT FAIL, otherwise they will crash the workflow
EvaluateStep: evaluate_step,
# IfElseWorkflowStep: if_else_step,
IfElseWorkflowStep: if_else_step,
YieldStep: yield_step,
LogStep: log_step,
ReturnStep: return_step,
Expand Down Expand Up @@ -95,16 +95,17 @@ async def run(
outcome = await execute_activity(
activity,
context,
#
# TODO: This should be a configurable timeout everywhere based on the task
schedule_to_close_timeout=timedelta(seconds=3 if testing else 600),
)

# 2a. Then, based on the outcome and step type, decide what to do next
# 2a. Set globals
# (By default, exit if last otherwise transition 'step' to the next step)
final_output = None
transition_type: TransitionType
next_target: TransitionTarget | None
metadata: dict = {"step_type": step_type.__name__}
metadata: dict = {"__meta__": {"step_type": step_type.__name__}}

if context.is_last_step:
transition_type = "finish"
Expand All @@ -131,19 +132,16 @@ async def transition(**kwargs):
schedule_to_close_timeout=timedelta(seconds=600),
)

# 3. Orchestrate the step
# 3. Then, based on the outcome and step type, decide what to do next
match context.current_step, outcome:
case LogStep(), StepOutcome(output=output):
if output is None:
raise ApplicationError("log step threw an error")
case step, StepOutcome(output=None):
raise ApplicationError(f"{step.__class__.__name__} step threw an error")

case LogStep(), StepOutcome(output=output):
await transition(output=dict(logged=output))
final_output = context.current_input

case ReturnStep(), StepOutcome(output=output):
if output is None:
raise ApplicationError("return step threw an error")

final_output = output
transition_type = "finish"
await transition()
Expand All @@ -166,9 +164,6 @@ async def transition(**kwargs):
await transition()

case EvaluateStep(), StepOutcome(output=output):
if output is None:
raise ApplicationError("evaluate step threw an error")

final_output = output
await transition()

Expand All @@ -182,9 +177,6 @@ async def transition(**kwargs):
case YieldStep(), StepOutcome(
output=output, transition_to=(yield_transition_type, yield_next_target)
):
if output is None:
raise ApplicationError("yield step threw an error")

await transition(
output=output, type=yield_transition_type, next=yield_next_target
)
Expand Down

0 comments on commit 49ddd62

Please sign in to comment.