Skip to content

Commit

Permalink
Merge pull request #485 from julep-ai/f/transition-stream
Browse files Browse the repository at this point in the history
  • Loading branch information
creatorrr authored Sep 3, 2024
2 parents e73ae61 + eaeafb9 commit 19e02e2
Show file tree
Hide file tree
Showing 31 changed files with 567 additions and 231 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .prompt_step import prompt_step
from .raise_complete_async import raise_complete_async
from .return_step import return_step
from .set_value_step import set_value_step
from .switch_step import switch_step
from .tool_call_step import tool_call_step
from .transition_step import transition_step
Expand Down
30 changes: 0 additions & 30 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,13 @@
from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import (
Content,
ContentModel,
)
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.utils.template import render_template


def _content_to_dict(
content: str | list[str] | list[Content | ContentModel], role: str
) -> str | list[dict]:
if isinstance(content, str):
return content

result = []
for s in content:
if isinstance(s, str):
result.append({"content": {"type": "text", "text": s, "role": role}})
elif isinstance(s, Content):
result.append({"content": {"type": s.type, "text": s.text, "role": role}})
elif isinstance(s, ContentModel):
result.append(
{
"content": {
"type": s.type,
"image_url": {"url": s.image_url.url},
"role": role,
}
}
)

return result


@activity.defn
@beartype
async def prompt_step(context: StepContext) -> StepOutcome:
Expand Down
37 changes: 37 additions & 0 deletions agents-api/agents_api/activities/task_steps/set_value_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


@beartype
async def set_value_step(
context: StepContext,
additional_values: dict[str, Any] = {},
override_expr: dict[str, str] | None = None,
) -> StepOutcome:
try:
expr = override_expr if override_expr is not None else context.current_step.set

values = context.model_dump() | additional_values
output = simple_eval_dict(expr, values)
result = StepOutcome(output=output)

return result

except BaseException as e:
activity.logger.error(f"Error in set_value_step: {e}")
return StepOutcome(error=str(e) or repr(e))


# Note: This is here just for clarity. We could have just imported set_value_step directly
# They do the same thing, so we dont need to mock the set_value_step function
mock_set_value_step = set_value_step

set_value_step = activity.defn(name="set_value_step")(
set_value_step if not testing else mock_set_value_step
)
16 changes: 1 addition & 15 deletions agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,20 +544,6 @@ class SearchStep(BaseModel):
"""


class SetKey(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
key: str
"""
The key to set
"""
value: str
"""
The value to set
"""


class SetStep(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand All @@ -566,7 +552,7 @@ class SetStep(BaseModel):
"""
The kind of step
"""
set: SetKey
set: dict[str, str]
"""
The value to set
"""
Expand Down
15 changes: 15 additions & 0 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,19 @@ async def run_task_execution_workflow(
task_queue=temporal_task_queue,
id=str(job_id),
run_timeout=timedelta(days=31),
# TODO: Should add search_attributes for queryability
)


async def get_workflow_handle(
*,
handle_id: str,
client: Client | None = None,
):
client = client or (await get_client())

handle = client.get_workflow_handle(
handle_id,
)

return handle
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
simple_jinja_regex = re.compile(r"{{|{%.+}}|%}", re.DOTALL)


# FIXME: This does not work for some reason
# TODO: This does not work for some reason
def is_simple_jinja(template_string: str) -> bool:
return simple_jinja_regex.search(template_string) is None

Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/models/execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from .get_execution_transition import get_execution_transition
from .list_execution_transitions import list_execution_transitions
from .list_executions import list_executions
from .lookup_temporal_data import lookup_temporal_data
from .prepare_execution_input import prepare_execution_input
from .update_execution import update_execution
15 changes: 6 additions & 9 deletions agents-api/agents_api/models/execution/create_temporal_lookup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TypeVar
from uuid import UUID, uuid4
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
Expand All @@ -21,6 +21,7 @@

@rewrap_exceptions(
{
AssertionError: partialclass(HTTPException, status_code=404),
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
Expand All @@ -31,14 +32,10 @@
def create_temporal_lookup(
*,
developer_id: UUID,
task_id: UUID,
execution_id: UUID | None = None,
execution_id: UUID,
workflow_handle: WorkflowHandle,
) -> tuple[list[str], dict]:
execution_id = execution_id or uuid4()

developer_id = str(developer_id)
task_id = str(task_id)
execution_id = str(execution_id)

temporal_columns, temporal_values = cozo_process_mutate_data(
Expand All @@ -63,9 +60,9 @@ def create_temporal_lookup(
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"tasks",
task_id=task_id,
parents=[("agents", "agent_id")],
"executions",
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
temporal_executions_lookup_query,
]
Expand Down
64 changes: 64 additions & 0 deletions agents-api/agents_api/models/execution/lookup_temporal_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any, TypeVar
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(dict, one=True)
@cozo_query
@beartype
def lookup_temporal_data(
*,
developer_id: UUID,
execution_id: UUID,
) -> tuple[list[str], dict]:
developer_id = str(developer_id)
execution_id = str(execution_id)

temporal_query = """
?[id] :=
execution_id = to_uuid($execution_id),
*temporal_executions_lookup {
id, execution_id, run_id, first_execution_run_id, result_run_id
}
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"executions",
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
temporal_query,
]

return (
queries,
{
"execution_id": str(execution_id),
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ async def create_task_execution(
create_temporal_lookup,
#
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution.id,
workflow_handle=handle,
)
Expand Down
Loading

0 comments on commit 19e02e2

Please sign in to comment.