Skip to content

Commit

Permalink
fix(agents-api): Fix tests and fixtures
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 14, 2024
1 parent 82c8af2 commit b1fc2a4
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 24 deletions.
4 changes: 3 additions & 1 deletion agents-api/agents_api/models/entry/create_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def create_entries(
verify_developer_owns_resource_query(
developer_id, "sessions", session_id=session_id
),
mark_session_updated_query(developer_id, session_id) if mark_session_as_updated else "",
mark_session_updated_query(developer_id, session_id)
if mark_session_as_updated
else "",
create_query,
]

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def patch_session(
*sessions{{
{rest_fields}, metadata: md, @ "NOW"
}},
updated_at = [floor(now()), true],
updated_at = 'ASSERT',
metadata = concat(md, $metadata),
:put sessions {{
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def update_session(
*sessions{{
{rest_fields}, @ "NOW"
}},
updated_at = [floor(now()), true]
updated_at = 'ASSERT'
:put sessions {{
{all_fields}, updated_at
Expand Down
45 changes: 38 additions & 7 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,42 @@ def mark_session_updated_query(developer_id: UUID | str, session_id: UUID | str)
to_uuid("{str(session_id)}"),
]]
?[developer_id, session_id, updated_at] :=
?[
developer_id,
session_id,
situation,
summary,
created_at,
metadata,
render_templates,
token_budget,
context_overflow,
updated_at,
] :=
input[developer_id, session_id],
*sessions {{
session_id,
situation,
summary,
created_at,
metadata,
render_templates,
token_budget,
context_overflow,
@ 'NOW'
}},
updated_at = [floor(now()), true]
:update sessions {{
:put sessions {{
developer_id,
session_id,
situation,
summary,
created_at,
metadata,
render_templates,
token_budget,
context_overflow,
updated_at,
}}
"""
Expand Down Expand Up @@ -148,8 +174,7 @@ def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
and then run the query using the client, returning a DataFrame.
"""

if debug:
from pprint import pprint
from pprint import pprint

@wraps(func)
def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame:
Expand All @@ -162,17 +187,23 @@ def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame:
query = "}\n\n{\n".join(queries)
query = f"{{ {query} }}"

debug and print(query)
debug and pprint(
dict(
query=query,
variables=variables,
)
)

# Run the query
from ..clients.cozo import get_cozo_client

client = client or get_cozo_client()
result = client.run(query, variables)
try:
client = client or get_cozo_client()
result = client.run(query, variables)

except Exception as e:
debug and print(repr(getattr(e, "__cause__", None) or e))
raise

# Need to fix the UUIDs in the result
result = result.map(fix_uuid_if_present)
Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/routers/tasks/create_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@ async def create_task(
data=data,
)

return ResourceCreatedResponse(
id=task.id, created_at=task.created_at, jobs=[]
)
return ResourceCreatedResponse(id=task.id, created_at=task.created_at, jobs=[])
3 changes: 3 additions & 0 deletions agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

from temporalio.client import Client
from temporalio.worker import Worker

Expand Down Expand Up @@ -46,6 +48,7 @@ async def create_worker(client: Client | None = None):
# Initialize the worker with the specified task queue, workflows, and activities
worker = Worker(
client,
graceful_shutdown_timeout=timedelta(seconds=30),
task_queue=temporal_task_queue,
workflows=[
SummarizationWorkflow,
Expand Down
16 changes: 7 additions & 9 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,19 @@ def activity_environment():
@fixture(scope="global")
async def workflow_environment():
wf_env = await WorkflowEnvironment.start_local()
return wf_env
yield wf_env
await wf_env.shutdown()


@fixture(scope="global")
async def temporal_client(wf_env=workflow_environment):
return wf_env.client


@fixture(scope="global")
async def temporal_worker(temporal_client=temporal_client):
worker = await create_worker(client=temporal_client)
async def temporal_worker(wf_env=workflow_environment):
worker = await create_worker(client=wf_env.client)

# FIXME: This does not stop the worker properly
c = worker.shutdown()
async with worker as running_worker:
yield running_worker
await running_worker.shutdown()
await c


@fixture(scope="global")
Expand Down
2 changes: 0 additions & 2 deletions agents-api/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .fixtures import (
cozo_client,
patch_embed_acompletion,
temporal_client,
temporal_worker,
test_developer_id,
test_doc,
Expand Down Expand Up @@ -46,7 +45,6 @@ async def _(
async def _(
workflow_environment=workflow_environment,
worker=temporal_worker,
client=temporal_client,
):
async with workflow_environment as wf_env:
assert wf_env is not None
Expand Down
38 changes: 38 additions & 0 deletions agents-api/tests/test_entry_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

# Tests for entry queries

import time

from ward import test

from agents_api.autogen.openapi_model import CreateEntryRequest
from agents_api.models.entry.create_entries import create_entries
from agents_api.models.entry.delete_entries import delete_entries
from agents_api.models.entry.get_history import get_history
from agents_api.models.entry.list_entries import list_entries
from agents_api.models.session.get_session import get_session
from tests.fixtures import cozo_client, test_developer_id, test_session

MODEL = "gpt-4o"
Expand All @@ -35,10 +38,45 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
developer_id=developer_id,
session_id=session.id,
data=[test_entry],
mark_session_as_updated=False,
client=client,
)


@test("model: create entry, update session")
def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
"""
Tests the addition of a new entry to the database.
Verifies that the entry can be successfully added using the create_entries function.
"""

test_entry = CreateEntryRequest(
session_id=session.id,
role="user",
source="internal",
content="test entry content",
)

# FIXME: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep
time.sleep(1)

create_entries(
developer_id=developer_id,
session_id=session.id,
data=[test_entry],
mark_session_as_updated=True,
client=client,
)

updated_session = get_session(
developer_id=developer_id,
session_id=session.id,
client=client,
)

assert updated_session.updated_at > session.updated_at


@test("model: get entries")
def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
"""
Expand Down

0 comments on commit b1fc2a4

Please sign in to comment.