diff --git a/agents-api/agents_api/models/entry/create_entries.py b/agents-api/agents_api/models/entry/create_entries.py index d2aa87f86..01193d395 100644 --- a/agents-api/agents_api/models/entry/create_entries.py +++ b/agents-api/agents_api/models/entry/create_entries.py @@ -78,7 +78,9 @@ def create_entries( verify_developer_owns_resource_query( developer_id, "sessions", session_id=session_id ), - mark_session_updated_query(developer_id, session_id) if mark_session_as_updated else "", + mark_session_updated_query(developer_id, session_id) + if mark_session_as_updated + else "", create_query, ] diff --git a/agents-api/agents_api/models/session/patch_session.py b/agents-api/agents_api/models/session/patch_session.py index 5aae245cc..f16a53f44 100644 --- a/agents-api/agents_api/models/session/patch_session.py +++ b/agents-api/agents_api/models/session/patch_session.py @@ -92,7 +92,7 @@ def patch_session( *sessions{{ {rest_fields}, metadata: md, @ "NOW" }}, - updated_at = [floor(now()), true], + updated_at = 'ASSERT', metadata = concat(md, $metadata), :put sessions {{ diff --git a/agents-api/agents_api/models/session/update_session.py b/agents-api/agents_api/models/session/update_session.py index b498650e8..5f7789f29 100644 --- a/agents-api/agents_api/models/session/update_session.py +++ b/agents-api/agents_api/models/session/update_session.py @@ -81,7 +81,7 @@ def update_session( *sessions{{ {rest_fields}, @ "NOW" }}, - updated_at = [floor(now()), true] + updated_at = 'ASSERT' :put sessions {{ {all_fields}, updated_at diff --git a/agents-api/agents_api/models/utils.py b/agents-api/agents_api/models/utils.py index 2939b2208..411349ad3 100644 --- a/agents-api/agents_api/models/utils.py +++ b/agents-api/agents_api/models/utils.py @@ -71,16 +71,42 @@ def mark_session_updated_query(developer_id: UUID | str, session_id: UUID | str) to_uuid("{str(session_id)}"), ]] - ?[developer_id, session_id, updated_at] := + ?[ + developer_id, + session_id, + situation, + summary, + created_at, + metadata, + render_templates, + token_budget, + context_overflow, + updated_at, + ] := input[developer_id, session_id], *sessions {{ session_id, + situation, + summary, + created_at, + metadata, + render_templates, + token_budget, + context_overflow, + @ 'NOW' }}, updated_at = [floor(now()), true] - :update sessions {{ + :put sessions {{ developer_id, session_id, + situation, + summary, + created_at, + metadata, + render_templates, + token_budget, + context_overflow, updated_at, }} """ @@ -148,8 +174,7 @@ def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]): and then run the query using the client, returning a DataFrame. """ - if debug: - from pprint import pprint + from pprint import pprint @wraps(func) def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame: @@ -162,17 +187,23 @@ def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame: query = "}\n\n{\n".join(queries) query = f"{{ {query} }}" + debug and print(query) debug and pprint( dict( - query=query, variables=variables, ) ) + # Run the query from ..clients.cozo import get_cozo_client - client = client or get_cozo_client() - result = client.run(query, variables) + try: + client = client or get_cozo_client() + result = client.run(query, variables) + + except Exception as e: + debug and print(repr(getattr(e, "__cause__", None) or e)) + raise # Need to fix the UUIDs in the result result = result.map(fix_uuid_if_present) diff --git a/agents-api/agents_api/routers/tasks/create_task.py b/agents-api/agents_api/routers/tasks/create_task.py index 5f8c86e1f..0e63210d2 100644 --- a/agents-api/agents_api/routers/tasks/create_task.py +++ b/agents-api/agents_api/routers/tasks/create_task.py @@ -29,6 +29,4 @@ async def create_task( data=data, ) - return ResourceCreatedResponse( - id=task.id, created_at=task.created_at, jobs=[] - ) + return ResourceCreatedResponse(id=task.id, created_at=task.created_at, jobs=[]) diff --git a/agents-api/agents_api/worker/worker.py b/agents-api/agents_api/worker/worker.py index 3b77d3fa8..ac045c5ab 100644 --- a/agents-api/agents_api/worker/worker.py +++ b/agents-api/agents_api/worker/worker.py @@ -1,3 +1,5 @@ +from datetime import timedelta + from temporalio.client import Client from temporalio.worker import Worker @@ -46,6 +48,7 @@ async def create_worker(client: Client | None = None): # Initialize the worker with the specified task queue, workflows, and activities worker = Worker( client, + graceful_shutdown_timeout=timedelta(seconds=30), task_queue=temporal_task_queue, workflows=[ SummarizationWorkflow, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index d8acec61c..4fb53ffd7 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -59,21 +59,19 @@ def activity_environment(): @fixture(scope="global") async def workflow_environment(): wf_env = await WorkflowEnvironment.start_local() - return wf_env + yield wf_env + await wf_env.shutdown() @fixture(scope="global") -async def temporal_client(wf_env=workflow_environment): - return wf_env.client - - -@fixture(scope="global") -async def temporal_worker(temporal_client=temporal_client): - worker = await create_worker(client=temporal_client) +async def temporal_worker(wf_env=workflow_environment): + worker = await create_worker(client=wf_env.client) + # FIXME: This does not stop the worker properly + c = worker.shutdown() async with worker as running_worker: yield running_worker - await running_worker.shutdown() + await c @fixture(scope="global") diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index a5e35760b..29adb8110 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -5,7 +5,6 @@ from .fixtures import ( cozo_client, patch_embed_acompletion, - temporal_client, temporal_worker, test_developer_id, test_doc, @@ -46,7 +45,6 @@ async def _( async def _( workflow_environment=workflow_environment, worker=temporal_worker, - client=temporal_client, ): async with workflow_environment as wf_env: assert wf_env is not None diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 6161ad94c..ce8b6a00d 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -5,6 +5,8 @@ # Tests for entry queries +import time + from ward import test from agents_api.autogen.openapi_model import CreateEntryRequest @@ -12,6 +14,7 @@ from agents_api.models.entry.delete_entries import delete_entries from agents_api.models.entry.get_history import get_history from agents_api.models.entry.list_entries import list_entries +from agents_api.models.session.get_session import get_session from tests.fixtures import cozo_client, test_developer_id, test_session MODEL = "gpt-4o" @@ -35,10 +38,45 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): developer_id=developer_id, session_id=session.id, data=[test_entry], + mark_session_as_updated=False, client=client, ) +@test("model: create entry, update session") +def _(client=cozo_client, developer_id=test_developer_id, session=test_session): + """ + Tests the addition of a new entry to the database. + Verifies that the entry can be successfully added using the create_entries function. + """ + + test_entry = CreateEntryRequest( + session_id=session.id, + role="user", + source="internal", + content="test entry content", + ) + + # FIXME: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep + time.sleep(1) + + create_entries( + developer_id=developer_id, + session_id=session.id, + data=[test_entry], + mark_session_as_updated=True, + client=client, + ) + + updated_session = get_session( + developer_id=developer_id, + session_id=session.id, + client=client, + ) + + assert updated_session.updated_at > session.updated_at + + @test("model: get entries") def _(client=cozo_client, developer_id=test_developer_id, session=test_session): """