Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Minor fix to tests #457

Merged
merged 7 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions agents-api/agents_api/models/entry/create_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def create_entries(
verify_developer_owns_resource_query(
developer_id, "sessions", session_id=session_id
),
mark_session_as_updated
and mark_session_updated_query(developer_id, session_id),
mark_session_updated_query(developer_id, session_id)
if mark_session_as_updated
else "",
create_query,
]

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def patch_session(
*sessions{{
{rest_fields}, metadata: md, @ "NOW"
}},
updated_at = [floor(now()), true],
updated_at = 'ASSERT',
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
metadata = concat(md, $metadata),

:put sessions {{
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def update_session(
*sessions{{
{rest_fields}, @ "NOW"
}},
updated_at = [floor(now()), true]
updated_at = 'ASSERT'

:put sessions {{
{all_fields}, updated_at
Expand Down
45 changes: 38 additions & 7 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,42 @@ def mark_session_updated_query(developer_id: UUID | str, session_id: UUID | str)
to_uuid("{str(session_id)}"),
]]

?[developer_id, session_id, updated_at] :=
?[
developer_id,
session_id,
situation,
summary,
created_at,
metadata,
render_templates,
token_budget,
context_overflow,
updated_at,
] :=
input[developer_id, session_id],
*sessions {{
session_id,
situation,
summary,
created_at,
metadata,
render_templates,
token_budget,
context_overflow,
@ 'NOW'
}},
updated_at = [floor(now()), true]

:update sessions {{
:put sessions {{
developer_id,
session_id,
situation,
summary,
created_at,
metadata,
render_templates,
token_budget,
context_overflow,
updated_at,
}}
"""
Expand Down Expand Up @@ -148,8 +174,7 @@ def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
and then run the query using the client, returning a DataFrame.
"""

if debug:
from pprint import pprint
from pprint import pprint

@wraps(func)
def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame:
Expand All @@ -162,17 +187,23 @@ def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame:
query = "}\n\n{\n".join(queries)
query = f"{{ {query} }}"

debug and print(query)
debug and pprint(
dict(
query=query,
variables=variables,
)
)

# Run the query
from ..clients.cozo import get_cozo_client

client = client or get_cozo_client()
result = client.run(query, variables)
try:
client = client or get_cozo_client()
result = client.run(query, variables)

except Exception as e:
debug and print(repr(getattr(e, "__cause__", None) or e))
raise

# Need to fix the UUIDs in the result
result = result.map(fix_uuid_if_present)
Expand Down
26 changes: 6 additions & 20 deletions agents-api/agents_api/routers/tasks/create_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Annotated
from uuid import uuid4

import pandas as pd
from fastapi import Depends
from pydantic import UUID4
from starlette.status import HTTP_201_CREATED

from agents_api.autogen.openapi_model import (
CreateTaskRequest,
ResourceCreatedResponse,
Task,
)
from agents_api.dependencies.developer_id import get_developer_id
from agents_api.models.task.create_task import create_task as create_task_query
Expand All @@ -18,29 +17,16 @@

@router.post("/agents/{agent_id}/tasks", status_code=HTTP_201_CREATED, tags=["tasks"])
async def create_task(
request: CreateTaskRequest,
data: CreateTaskRequest,
agent_id: UUID4,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
task_id = uuid4()

# TODO: Do thorough validation of the task spec

workflows = [
{"name": "main", "steps": [w.model_dump() for w in request.main]},
] + [{"name": name, "steps": steps} for name, steps in request.model_extra.items()]

resp: pd.DataFrame = create_task_query(
agent_id=agent_id,
task_id=task_id,
task: Task = create_task_query(
developer_id=x_developer_id,
name=request.name,
description=request.description,
input_schema=request.input_schema or {},
tools_available=request.tools or [],
workflows=workflows,
agent_id=agent_id,
data=data,
)

return ResourceCreatedResponse(
id=resp["task_id"][0], created_at=resp["created_at"][0]
)
return ResourceCreatedResponse(id=task.id, created_at=task.created_at, jobs=[])
3 changes: 3 additions & 0 deletions agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

from temporalio.client import Client
from temporalio.worker import Worker

Expand Down Expand Up @@ -46,6 +48,7 @@ async def create_worker(client: Client | None = None):
# Initialize the worker with the specified task queue, workflows, and activities
worker = Worker(
client,
graceful_shutdown_timeout=timedelta(seconds=30),
task_queue=temporal_task_queue,
workflows=[
SummarizationWorkflow,
Expand Down
16 changes: 7 additions & 9 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,19 @@ def activity_environment():
@fixture(scope="global")
async def workflow_environment():
wf_env = await WorkflowEnvironment.start_local()
return wf_env
yield wf_env
await wf_env.shutdown()


@fixture(scope="global")
async def temporal_client(wf_env=workflow_environment):
return wf_env.client


@fixture(scope="global")
async def temporal_worker(temporal_client=temporal_client):
worker = await create_worker(client=temporal_client)
async def temporal_worker(wf_env=workflow_environment):
worker = await create_worker(client=wf_env.client)

# FIXME: This does not stop the worker properly
c = worker.shutdown()
async with worker as running_worker:
yield running_worker
await running_worker.shutdown()
await c


@fixture(scope="global")
Expand Down
2 changes: 0 additions & 2 deletions agents-api/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .fixtures import (
cozo_client,
patch_embed_acompletion,
temporal_client,
temporal_worker,
test_developer_id,
test_doc,
Expand Down Expand Up @@ -46,7 +45,6 @@ async def _(
async def _(
workflow_environment=workflow_environment,
worker=temporal_worker,
client=temporal_client,
):
async with workflow_environment as wf_env:
assert wf_env is not None
Expand Down
38 changes: 38 additions & 0 deletions agents-api/tests/test_entry_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

# Tests for entry queries

import time

from ward import test

from agents_api.autogen.openapi_model import CreateEntryRequest
from agents_api.models.entry.create_entries import create_entries
from agents_api.models.entry.delete_entries import delete_entries
from agents_api.models.entry.get_history import get_history
from agents_api.models.entry.list_entries import list_entries
from agents_api.models.session.get_session import get_session
from tests.fixtures import cozo_client, test_developer_id, test_session

MODEL = "gpt-4o"
Expand All @@ -35,10 +38,45 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
developer_id=developer_id,
session_id=session.id,
data=[test_entry],
mark_session_as_updated=False,
client=client,
)


@test("model: create entry, update session")
def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
"""
Tests the addition of a new entry to the database.
Verifies that the entry can be successfully added using the create_entries function.
"""

test_entry = CreateEntryRequest(
session_id=session.id,
role="user",
source="internal",
content="test entry content",
)

# FIXME: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep
time.sleep(1)

create_entries(
developer_id=developer_id,
session_id=session.id,
data=[test_entry],
mark_session_as_updated=True,
client=client,
)

updated_session = get_session(
developer_id=developer_id,
session_id=session.id,
client=client,
)

assert updated_session.updated_at > session.updated_at


@test("model: get entries")
def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
"""
Expand Down
28 changes: 16 additions & 12 deletions agents-api/tests/test_task_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
def _(client=client, agent=test_agent):
data = dict(
name="test user",
main={
"kind_": "evaluate",
"evaluate": {
"additionalProp1": "value1",
},
},
main=[
{
"kind_": "evaluate",
"evaluate": {
"additionalProp1": "value1",
},
}
],
)

response = client.request(
Expand All @@ -31,12 +33,14 @@ def _(client=client, agent=test_agent):
def _(make_request=make_request, agent=test_agent):
data = dict(
name="test user",
main={
"kind_": "evaluate",
"evaluate": {
"additionalProp1": "value1",
},
},
main=[
{
"kind_": "evaluate",
"evaluate": {
"additionalProp1": "value1",
},
}
],
)

response = make_request(
Expand Down
Loading