Skip to content

Commit

Permalink
feat: Resume workflow execution
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Aug 5, 2024
1 parent 41e5df8 commit 1dc4ed3
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 49 deletions.
14 changes: 14 additions & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from temporalio import activity

from ...autogen.openapi_model import (
CreateTransitionRequest,
EvaluateStep,
# ErrorWorkflowStep,
IfElseWorkflowStep,
InputChatMLMessage,
PromptStep,
ToolCallStep,
UpdateExecutionRequest,
YieldStep,
)
from ...clients.worker.types import ChatML
Expand All @@ -23,6 +25,9 @@
from ...models.execution.create_execution_transition import (
create_execution_transition as create_execution_transition_query,
)
from ...models.execution.update_execution import (
update_execution as update_execution_query,
)
from ...routers.sessions.protocol import Settings
from ...routers.sessions.session import llm_generate

Expand Down Expand Up @@ -142,6 +147,15 @@ async def transition_step(
**transition_data,
)

update_execution_query(
developer_id=context.developer_id,
task_id=context.task.id,
execution_id=context.execution.id,
data=UpdateExecutionRequest(
status="awaiting_input",
),
)

# Raise if it's a waiting step
if transition_info.type == "awaiting_input":
activity.raise_complete_async()
4 changes: 2 additions & 2 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from uuid import UUID

from temporalio.client import Client, TLSConfig
from temporalio.client import Client, TLSConfig, WorkflowHandle

from agents_api.env import (
temporal_client_cert,
Expand Down Expand Up @@ -75,7 +75,7 @@ async def run_task_execution_workflow(
):
client = await get_client()

await client.execute_workflow(
return await client.start_workflow(
"TaskExecutionWorkflow",
args=[execution_input, start, previous_inputs],
task_queue="memory-task-queue",
Expand Down
34 changes: 23 additions & 11 deletions agents-api/agents_api/routers/jobs/routers.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,43 @@
from typing import Literal

from fastapi import APIRouter
from pydantic import UUID4
from temporalio.client import WorkflowExecutionStatus

from agents_api.autogen.openapi_model import JobStatus, State
from agents_api.autogen.openapi_model import JobStatus
from agents_api.clients.temporal import get_client

router = APIRouter()


def map_job_status(status: WorkflowExecutionStatus) -> State:
def map_job_status(
status: WorkflowExecutionStatus,
) -> Literal[
"pending",
"in_progress",
"retrying",
"succeeded",
"aborted",
"failed",
"unknown",
]:
match status:
case WorkflowExecutionStatus.RUNNING:
return State.in_progress
return "in_progress"
case WorkflowExecutionStatus.COMPLETED:
return State.succeeded
return "succeeded"
case WorkflowExecutionStatus.FAILED:
return State.failed
return "failed"
case WorkflowExecutionStatus.CANCELED:
return State.aborted
return "aborted"
case WorkflowExecutionStatus.TERMINATED:
return State.aborted
return "aborted"
case WorkflowExecutionStatus.CONTINUED_AS_NEW:
return State.in_progress
return "in_progress"
case WorkflowExecutionStatus.TIMED_OUT:
return State.failed
return "failed"
case _:
return State.unknown
return "unknown"


@router.get("/jobs/{job_id}", tags=["jobs"])
Expand All @@ -39,7 +51,7 @@ async def get_job_status(job_id: UUID4) -> JobStatus:

return JobStatus(
name=handle.id,
reason=f"Execution status: {state.name}",
reason=f"Execution status: {state}",
created_at=job_description.start_time,
updated_at=job_description.execution_time,
id=job_id,
Expand Down
61 changes: 33 additions & 28 deletions agents-api/agents_api/routers/tasks/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@
from pycozo.client import QueryException
from pydantic import UUID4, BaseModel
from starlette.status import HTTP_201_CREATED
from temporalio.client import WorkflowHandle

from agents_api.autogen.openapi_model import (
CreateExecutionRequest,
CreateTaskRequest,
Execution,
ResourceCreatedResponse,
ResumeExecutionRequest,
StopExecutionRequest,
# ResourceUpdatedResponse,
Task,
Transition,
UpdateExecutionRequest,
)
from agents_api.clients.cozo import client as cozo_client
from agents_api.clients.temporal import run_task_execution_workflow
from agents_api.clients.temporal import get_client, run_task_execution_workflow
from agents_api.common.protocol.tasks import ExecutionInput
from agents_api.dependencies.developer_id import get_developer_id
from agents_api.models.execution.create_execution import (
Expand All @@ -30,6 +33,12 @@
from agents_api.models.execution.get_execution import (
get_execution as get_execution_query,
)
from agents_api.models.execution.get_paused_execution_token import (
get_paused_execution_token,
)
from agents_api.models.execution.get_temporal_workflow_data import (
get_temporal_workflow_data,
)

# from agents_api.models.execution.get_execution_transition import (
# get_execution_transition as get_execution_transition_query,
Expand All @@ -43,6 +52,7 @@
from agents_api.models.execution.list_executions import (
list_executions as list_task_executions_query,
)
from agents_api.models.execution.prepare_execution_input import prepare_execution_input
from agents_api.models.execution.update_execution import (
update_execution as update_execution_query,
)
Expand Down Expand Up @@ -207,22 +217,14 @@ async def create_task_execution(
raise

execution_id = uuid4()
execution = create_execution_query(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
data=data,
)

execution_input = ExecutionInput.fetch(
execution_input = prepare_execution_input(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
client=cozo_client,
)

try:
await run_task_execution_workflow(
handle = await run_task_execution_workflow(
execution_input=execution_input,
job_id=uuid4(),
)
Expand All @@ -241,6 +243,14 @@ async def create_task_execution(
detail="Task creation failed",
)

execution = create_execution_query(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
data=data,
workflow_hande=handle,
)

return ResourceCreatedResponse(
id=execution["execution_id"][0], created_at=execution["created_at"][0]
)
Expand Down Expand Up @@ -305,26 +315,21 @@ async def patch_execution(
@router.put("/tasks/{task_id}/executions/{execution_id}", tags=["tasks"])
async def put_execution(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
task_id: UUID4,
execution_id: UUID4,
data: UpdateExecutionRequest,
data: ResumeExecutionRequest | StopExecutionRequest,
) -> Execution:
try:
res = [
row.to_dict()
for _, row in update_execution_query(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
data=data,
).iterrows()
][0]
return Execution(**res)
except (IndexError, KeyError):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Execution not found",
temporal_client = await get_client()
if isinstance(data, StopExecutionRequest):
handle = temporal_client.get_workflow_handle_for(
*get_temporal_workflow_data(execution_id=execution_id)
)
await handle.cancel()
else:
token_data = get_paused_execution_token(
developer_id=x_developer_id, execution_id=execution_id
)
handle = temporal_client.get_async_activity_handle(token_data["task_token"])
await handle.complete("finished")


@router.get("/tasks/{task_id}/executions", tags=["tasks"])
Expand Down
20 changes: 12 additions & 8 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from datetime import timedelta

from temporalio import workflow
from temporalio import activity, workflow

with workflow.unsafe.imports_passed_through():
from ..activities.task_steps import (
Expand All @@ -13,17 +13,20 @@
tool_call_step,
transition_step,
)
from ..common.protocol.tasks import (
EvaluateStep,
ExecutionInput,
from ..autogen.openapi_model import (
# ErrorWorkflowStep,
EvaluateStep,
IfElseWorkflowStep,
PromptStep,
StepContext,
ToolCallStep,
TransitionInfo,
WaitForInputStep,
YieldStep,
)
from ..common.protocol.tasks import (
ExecutionInput,
StepContext,
TransitionInfo,
)


@workflow.defn
Expand All @@ -36,8 +39,7 @@ async def run(
previous_inputs: list[dict] = [],
) -> None:
wf_name, step_idx = start
spec = execution_input.task.spec
workflow_map = {wf.name: wf.steps for wf in spec.workflows}
workflow_map = {wf.name: wf.steps for wf in execution_input.task.workflows}
current_workflow = workflow_map[wf_name]
previous_inputs = previous_inputs or [execution_input.arguments]
step = current_workflow[step_idx]
Expand Down Expand Up @@ -108,6 +110,8 @@ async def run(
previous_inputs,
],
)
case WaitForInputStep():
should_wait = True

is_last = step_idx + 1 == len(current_workflow)
# Transition type
Expand Down

0 comments on commit 1dc4ed3

Please sign in to comment.