Skip to content

Commit

Permalink
feat: Add test for evaluate step
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 17, 2024
1 parent 729a0aa commit ae112cb
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 44 deletions.
13 changes: 6 additions & 7 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
Expand All @@ -12,14 +9,16 @@
from ...env import testing


@beartype
async def evaluate_step(
context: StepContext[EvaluateStep],
) -> StepOutcome[dict[str, Any]]:
exprs = context.definition.arguments
) -> StepOutcome:
assert isinstance(context.current_step, EvaluateStep)

exprs = context.current_step.evaluate
output = simple_eval_dict(exprs, values=context.model_dump())

return StepOutcome(output=output)
result = StepOutcome(output=output)
return result


# Note: This is here just for clarity. We could have just imported evaluate_step directly
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict:
# context_data: dict = context.model_dump()

# next_workflow = (
# context.definition.then
# if simple_eval(context.definition.if_, names=context_data)
# else context.definition.else_
# context.current_step.then
# if simple_eval(context.current_step.if_, names=context_data)
# else context.current_step.else_
# )

# return {"goto_workflow": next_workflow}
8 changes: 4 additions & 4 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:

# Render template messages
prompt = (
[InputChatMLMessage(content=context.definition.prompt)]
if isinstance(context.definition.prompt, str)
else context.definition.prompt
[InputChatMLMessage(content=context.current_step.prompt)]
if isinstance(context.current_step.prompt, str)
else context.current_step.prompt
)

template_messages: list[InputChatMLMessage] = prompt
Expand All @@ -47,7 +47,7 @@ async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:
for m in messages
]

settings: dict = context.definition.settings.model_dump()
settings: dict = context.current_step.settings.model_dump()
# Get settings and run llm
response = await litellm.acompletion(
messages=messages,
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/activities/task_steps/tool_call_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
@beartype
async def tool_call_step(context: StepContext) -> dict:
raise NotImplementedError()
# assert isinstance(context.definition, ToolCallStep)
# assert isinstance(context.current_step, ToolCallStep)

# context.definition.tool_id
# context.definition.arguments
# context.current_step.tool_id
# context.current_step.arguments
# # get tool by id
# # call tool

Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any

from beartype import beartype
from temporalio import activity

Expand All @@ -12,7 +10,7 @@


@beartype
async def yield_step(context: StepContext[YieldStep]) -> StepOutcome[dict[str, Any]]:
async def yield_step(context: StepContext[YieldStep]) -> StepOutcome:
all_workflows = context.execution_input.task.workflows
workflow = context.current_step.workflow

Expand Down
23 changes: 10 additions & 13 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# ruff: noqa: F401, F403, F405
from typing import Annotated, Generic, Literal, Self, Type, TypeVar
from typing import Annotated, Any, Generic, Literal, Self, Type, TypeVar
from uuid import UUID

from litellm.utils import _select_tokenizer as select_tokenizer
from litellm.utils import token_counter
from pydantic import AwareDatetime, Field
from pydantic_partial import create_partial_model

from ..common.utils.datetime import utcnow
from .Agents import *
Expand Down Expand Up @@ -97,17 +96,15 @@ class ListResponse(BaseModel, Generic[DataT]):
# Create models
# -------------

CreateTransitionRequest = create_partial_model(
Transition,
#
# The following fields are optional
"id",
"execution_id",
"created_at",
"updated_at",
"metadata",
)
CreateTransitionRequest.model_rebuild()

class CreateTransitionRequest(Transition):
# The following fields are optional in this

id: UUID | None = None
execution_id: UUID | None = None
created_at: AwareDatetime | None = None
updated_at: AwareDatetime | None = None
metadata: dict[str, Any] | None = None


class CreateEntryRequest(BaseEntry):
Expand Down
11 changes: 4 additions & 7 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def current_workflow(self) -> Workflow:
@computed_field
@property
def current_step(self) -> WorkflowStepType:
step = self.current_workflow[self.cursor.step]
step = self.current_workflow.steps[self.cursor.step]
return step

@computed_field
@property
def is_last_step(self) -> bool:
return (self.cursor.step + 1) == len(self.current_workflow)
return (self.cursor.step + 1) == len(self.current_workflow.steps)

def model_dump(self, *args, **kwargs) -> dict[str, Any]:
dump = super().model_dump(*args, **kwargs)
Expand All @@ -108,11 +108,8 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]:
return dump


OutcomeType = TypeVar("OutcomeType", bound=BaseModel)


class StepOutcome(BaseModel, Generic[OutcomeType]):
output: OutcomeType | None
class StepOutcome(BaseModel):
output: Any
transition_to: tuple[TransitionType, TransitionTarget] | None = None


Expand Down
7 changes: 4 additions & 3 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
)
from ..common.protocol.tasks import (
ExecutionInput,
# OutcomeType,
StepContext,
StepOutcome,
# Workflow,
)
from ..env import testing


STEP_TO_ACTIVITY = {
Expand All @@ -59,7 +59,7 @@ async def run(
context = StepContext(
execution_input=execution_input,
inputs=previous_inputs,
current=start,
cursor=start,
)

step_type = type(context.current_step)
Expand All @@ -69,7 +69,8 @@ async def run(
outcome = await workflow.execute_activity(
activity,
context,
schedule_to_close_timeout=timedelta(seconds=600),
# TODO: This should be a configurable timeout everywhere based on the task
schedule_to_close_timeout=timedelta(seconds=3 if testing else 600),
)

# 2. Then, based on the outcome and step type, decide what to do next
Expand Down
2 changes: 1 addition & 1 deletion agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_task(
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [],
"main": [{"evaluate": {"hello": '"world"'}}],
}
),
client=client,
Expand Down
3 changes: 3 additions & 0 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ async def _(client=cozo_client, developer_id=test_developer_id, task=test_task):
assert execution.task_id == task.id
assert execution.input == data.input
mock_run_task_execution_workflow.assert_called_once()

result = await handle.result()
assert result["hello"] == "world"

0 comments on commit ae112cb

Please sign in to comment.