Skip to content

Commit

Permalink
feat(agents-api): Add more tests for the task execution
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 17, 2024
1 parent ae112cb commit 43c03fe
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def prepare_execution_input(
# TODO: Enable these later
user = null,
session = null,
arguments = {{}},
arguments = execution->"input"
"""

queries = [
Expand Down
47 changes: 2 additions & 45 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,48 +162,5 @@ async def run(

# Otherwise, recurse to the next step
workflow.continue_as_new(
execution_input, next_target, previous_inputs + [final_output]
)

##################

# should_wait, is_error = False, False
# # Run the step
# match step:
# case PromptStep():
# outputs = await workflow.execute_activity(
# prompt_step,
# context,
# schedule_to_close_timeout=timedelta(seconds=600),
# )
#
# # TODO: ChatCompletion does not have tool_calls
# # if outputs.tool_calls is not None:
# # should_wait = True

# case ToolCallStep():
# outputs = await workflow.execute_activity(
# tool_call_step,
# context,
# schedule_to_close_timeout=timedelta(seconds=600),
# )

# case IfElseWorkflowStep():
# outputs = await workflow.execute_activity(
# if_else_step,
# context,
# schedule_to_close_timeout=timedelta(seconds=600),
# )
# workflow_step = YieldStep(**outputs["goto_workflow"])
#
# outputs = await workflow.execute_child_workflow(
# TaskExecutionWorkflow.run,
# args=[
# execution_input,
# (workflow_step.workflow, 0),
# previous_inputs,
# ],
# )

# case WaitForInputStep():
# should_wait = True
args=[execution_input, next_target, previous_inputs + [final_output]]
)
110 changes: 107 additions & 3 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,79 @@
# Tests for task queries

from agents_api.models.task.create_task import create_task
from ward import test

from agents_api.autogen.openapi_model import CreateExecutionRequest
from agents_api.autogen.openapi_model import CreateExecutionRequest, CreateTaskRequest
from agents_api.routers.tasks.create_task_execution import start_execution

from .fixtures import cozo_client, test_developer_id, test_task
from .fixtures import cozo_client, test_agent, test_developer_id, test_task
from .utils import patch_testing_temporal


@test("workflow: create task execution")
async def _(client=cozo_client, developer_id=test_developer_id, task=test_task):
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [{"evaluate": {"hello": '"world"'}}],
}
),
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input
mock_run_task_execution_workflow.assert_called_once()

result = await handle.result()
assert result["hello"] == "world"


@test("workflow: create task execution")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [
{"evaluate": {"hello": '"nope"'}},
{"evaluate": {"hello": '"world"'}},
],
}
),
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
Expand All @@ -28,3 +89,46 @@ async def _(client=cozo_client, developer_id=test_developer_id, task=test_task):

result = await handle.result()
assert result["hello"] == "world"


@test("workflow: create task execution")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [
# Testing that we can access the input
{"evaluate": {"hello": '_["test"]'}},
],
}
),
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input
mock_run_task_execution_workflow.assert_called_once()

result = await handle.result()
assert result["hello"] == data.input["test"]

0 comments on commit 43c03fe

Please sign in to comment.