From 82c8af251ea6e94c73cc6cf6ff410f10fc2dd6d9 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 14:02:18 -0400 Subject: [PATCH 1/6] fix(agents-api): Minor fix to tests Signed-off-by: Diwank Tomer --- .../agents_api/models/entry/create_entries.py | 3 +- .../agents_api/routers/tasks/create_task.py | 24 ++++------------ agents-api/tests/test_task_routes.py | 28 +++++++++++-------- 3 files changed, 23 insertions(+), 32 deletions(-) diff --git a/agents-api/agents_api/models/entry/create_entries.py b/agents-api/agents_api/models/entry/create_entries.py index 68d644266..d2aa87f86 100644 --- a/agents-api/agents_api/models/entry/create_entries.py +++ b/agents-api/agents_api/models/entry/create_entries.py @@ -78,8 +78,7 @@ def create_entries( verify_developer_owns_resource_query( developer_id, "sessions", session_id=session_id ), - mark_session_as_updated - and mark_session_updated_query(developer_id, session_id), + mark_session_updated_query(developer_id, session_id) if mark_session_as_updated else "", create_query, ] diff --git a/agents-api/agents_api/routers/tasks/create_task.py b/agents-api/agents_api/routers/tasks/create_task.py index 3c76663e4..5f8c86e1f 100644 --- a/agents-api/agents_api/routers/tasks/create_task.py +++ b/agents-api/agents_api/routers/tasks/create_task.py @@ -1,7 +1,5 @@ from typing import Annotated -from uuid import uuid4 -import pandas as pd from fastapi import Depends from pydantic import UUID4 from starlette.status import HTTP_201_CREATED @@ -9,6 +7,7 @@ from agents_api.autogen.openapi_model import ( CreateTaskRequest, ResourceCreatedResponse, + Task, ) from agents_api.dependencies.developer_id import get_developer_id from agents_api.models.task.create_task import create_task as create_task_query @@ -18,29 +17,18 @@ @router.post("/agents/{agent_id}/tasks", status_code=HTTP_201_CREATED, tags=["tasks"]) async def create_task( - request: CreateTaskRequest, + data: CreateTaskRequest, agent_id: UUID4, x_developer_id: Annotated[UUID4, Depends(get_developer_id)], ) -> ResourceCreatedResponse: - task_id = uuid4() - # TODO: Do thorough validation of the task spec - workflows = [ - {"name": "main", "steps": [w.model_dump() for w in request.main]}, - ] + [{"name": name, "steps": steps} for name, steps in request.model_extra.items()] - - resp: pd.DataFrame = create_task_query( - agent_id=agent_id, - task_id=task_id, + task: Task = create_task_query( developer_id=x_developer_id, - name=request.name, - description=request.description, - input_schema=request.input_schema or {}, - tools_available=request.tools or [], - workflows=workflows, + agent_id=agent_id, + data=data, ) return ResourceCreatedResponse( - id=resp["task_id"][0], created_at=resp["created_at"][0] + id=task.id, created_at=task.created_at, jobs=[] ) diff --git a/agents-api/tests/test_task_routes.py b/agents-api/tests/test_task_routes.py index 8f2eef7c6..b790240c4 100644 --- a/agents-api/tests/test_task_routes.py +++ b/agents-api/tests/test_task_routes.py @@ -10,12 +10,14 @@ def _(client=client, agent=test_agent): data = dict( name="test user", - main={ - "kind_": "evaluate", - "evaluate": { - "additionalProp1": "value1", - }, - }, + main=[ + { + "kind_": "evaluate", + "evaluate": { + "additionalProp1": "value1", + }, + } + ], ) response = client.request( @@ -31,12 +33,14 @@ def _(client=client, agent=test_agent): def _(make_request=make_request, agent=test_agent): data = dict( name="test user", - main={ - "kind_": "evaluate", - "evaluate": { - "additionalProp1": "value1", - }, - }, + main=[ + { + "kind_": "evaluate", + "evaluate": { + "additionalProp1": "value1", + }, + } + ], ) response = make_request( From b1fc2a4c0cc1bd8c3ed7fb5a0a5a7260c37d673c Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 15:59:33 -0400 Subject: [PATCH 2/6] fix(agents-api): Fix tests and fixtures Signed-off-by: Diwank Tomer --- .../agents_api/models/entry/create_entries.py | 4 +- .../models/session/patch_session.py | 2 +- .../models/session/update_session.py | 2 +- agents-api/agents_api/models/utils.py | 45 ++++++++++++++++--- .../agents_api/routers/tasks/create_task.py | 4 +- agents-api/agents_api/worker/worker.py | 3 ++ agents-api/tests/fixtures.py | 16 +++---- agents-api/tests/test_activities.py | 2 - agents-api/tests/test_entry_queries.py | 38 ++++++++++++++++ 9 files changed, 92 insertions(+), 24 deletions(-) diff --git a/agents-api/agents_api/models/entry/create_entries.py b/agents-api/agents_api/models/entry/create_entries.py index d2aa87f86..01193d395 100644 --- a/agents-api/agents_api/models/entry/create_entries.py +++ b/agents-api/agents_api/models/entry/create_entries.py @@ -78,7 +78,9 @@ def create_entries( verify_developer_owns_resource_query( developer_id, "sessions", session_id=session_id ), - mark_session_updated_query(developer_id, session_id) if mark_session_as_updated else "", + mark_session_updated_query(developer_id, session_id) + if mark_session_as_updated + else "", create_query, ] diff --git a/agents-api/agents_api/models/session/patch_session.py b/agents-api/agents_api/models/session/patch_session.py index 5aae245cc..f16a53f44 100644 --- a/agents-api/agents_api/models/session/patch_session.py +++ b/agents-api/agents_api/models/session/patch_session.py @@ -92,7 +92,7 @@ def patch_session( *sessions{{ {rest_fields}, metadata: md, @ "NOW" }}, - updated_at = [floor(now()), true], + updated_at = 'ASSERT', metadata = concat(md, $metadata), :put sessions {{ diff --git a/agents-api/agents_api/models/session/update_session.py b/agents-api/agents_api/models/session/update_session.py index b498650e8..5f7789f29 100644 --- a/agents-api/agents_api/models/session/update_session.py +++ b/agents-api/agents_api/models/session/update_session.py @@ -81,7 +81,7 @@ def update_session( *sessions{{ {rest_fields}, @ "NOW" }}, - updated_at = [floor(now()), true] + updated_at = 'ASSERT' :put sessions {{ {all_fields}, updated_at diff --git a/agents-api/agents_api/models/utils.py b/agents-api/agents_api/models/utils.py index 2939b2208..411349ad3 100644 --- a/agents-api/agents_api/models/utils.py +++ b/agents-api/agents_api/models/utils.py @@ -71,16 +71,42 @@ def mark_session_updated_query(developer_id: UUID | str, session_id: UUID | str) to_uuid("{str(session_id)}"), ]] - ?[developer_id, session_id, updated_at] := + ?[ + developer_id, + session_id, + situation, + summary, + created_at, + metadata, + render_templates, + token_budget, + context_overflow, + updated_at, + ] := input[developer_id, session_id], *sessions {{ session_id, + situation, + summary, + created_at, + metadata, + render_templates, + token_budget, + context_overflow, + @ 'NOW' }}, updated_at = [floor(now()), true] - :update sessions {{ + :put sessions {{ developer_id, session_id, + situation, + summary, + created_at, + metadata, + render_templates, + token_budget, + context_overflow, updated_at, }} """ @@ -148,8 +174,7 @@ def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]): and then run the query using the client, returning a DataFrame. """ - if debug: - from pprint import pprint + from pprint import pprint @wraps(func) def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame: @@ -162,17 +187,23 @@ def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame: query = "}\n\n{\n".join(queries) query = f"{{ {query} }}" + debug and print(query) debug and pprint( dict( - query=query, variables=variables, ) ) + # Run the query from ..clients.cozo import get_cozo_client - client = client or get_cozo_client() - result = client.run(query, variables) + try: + client = client or get_cozo_client() + result = client.run(query, variables) + + except Exception as e: + debug and print(repr(getattr(e, "__cause__", None) or e)) + raise # Need to fix the UUIDs in the result result = result.map(fix_uuid_if_present) diff --git a/agents-api/agents_api/routers/tasks/create_task.py b/agents-api/agents_api/routers/tasks/create_task.py index 5f8c86e1f..0e63210d2 100644 --- a/agents-api/agents_api/routers/tasks/create_task.py +++ b/agents-api/agents_api/routers/tasks/create_task.py @@ -29,6 +29,4 @@ async def create_task( data=data, ) - return ResourceCreatedResponse( - id=task.id, created_at=task.created_at, jobs=[] - ) + return ResourceCreatedResponse(id=task.id, created_at=task.created_at, jobs=[]) diff --git a/agents-api/agents_api/worker/worker.py b/agents-api/agents_api/worker/worker.py index 3b77d3fa8..ac045c5ab 100644 --- a/agents-api/agents_api/worker/worker.py +++ b/agents-api/agents_api/worker/worker.py @@ -1,3 +1,5 @@ +from datetime import timedelta + from temporalio.client import Client from temporalio.worker import Worker @@ -46,6 +48,7 @@ async def create_worker(client: Client | None = None): # Initialize the worker with the specified task queue, workflows, and activities worker = Worker( client, + graceful_shutdown_timeout=timedelta(seconds=30), task_queue=temporal_task_queue, workflows=[ SummarizationWorkflow, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index d8acec61c..4fb53ffd7 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -59,21 +59,19 @@ def activity_environment(): @fixture(scope="global") async def workflow_environment(): wf_env = await WorkflowEnvironment.start_local() - return wf_env + yield wf_env + await wf_env.shutdown() @fixture(scope="global") -async def temporal_client(wf_env=workflow_environment): - return wf_env.client - - -@fixture(scope="global") -async def temporal_worker(temporal_client=temporal_client): - worker = await create_worker(client=temporal_client) +async def temporal_worker(wf_env=workflow_environment): + worker = await create_worker(client=wf_env.client) + # FIXME: This does not stop the worker properly + c = worker.shutdown() async with worker as running_worker: yield running_worker - await running_worker.shutdown() + await c @fixture(scope="global") diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index a5e35760b..29adb8110 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -5,7 +5,6 @@ from .fixtures import ( cozo_client, patch_embed_acompletion, - temporal_client, temporal_worker, test_developer_id, test_doc, @@ -46,7 +45,6 @@ async def _( async def _( workflow_environment=workflow_environment, worker=temporal_worker, - client=temporal_client, ): async with workflow_environment as wf_env: assert wf_env is not None diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 6161ad94c..ce8b6a00d 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -5,6 +5,8 @@ # Tests for entry queries +import time + from ward import test from agents_api.autogen.openapi_model import CreateEntryRequest @@ -12,6 +14,7 @@ from agents_api.models.entry.delete_entries import delete_entries from agents_api.models.entry.get_history import get_history from agents_api.models.entry.list_entries import list_entries +from agents_api.models.session.get_session import get_session from tests.fixtures import cozo_client, test_developer_id, test_session MODEL = "gpt-4o" @@ -35,10 +38,45 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): developer_id=developer_id, session_id=session.id, data=[test_entry], + mark_session_as_updated=False, client=client, ) +@test("model: create entry, update session") +def _(client=cozo_client, developer_id=test_developer_id, session=test_session): + """ + Tests the addition of a new entry to the database. + Verifies that the entry can be successfully added using the create_entries function. + """ + + test_entry = CreateEntryRequest( + session_id=session.id, + role="user", + source="internal", + content="test entry content", + ) + + # FIXME: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep + time.sleep(1) + + create_entries( + developer_id=developer_id, + session_id=session.id, + data=[test_entry], + mark_session_as_updated=True, + client=client, + ) + + updated_session = get_session( + developer_id=developer_id, + session_id=session.id, + client=client, + ) + + assert updated_session.updated_at > session.updated_at + + @test("model: get entries") def _(client=cozo_client, developer_id=test_developer_id, session=test_session): """ From 3ab775839ab1f7d678b293a2b5efb16aaeac8607 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 18:54:42 -0400 Subject: [PATCH 3/6] refactor(agents-api): Move gather_messages to its own model Signed-off-by: Diwank Tomer --- agents-api/agents_api/models/chat/__init__.py | 22 +++++ .../agents_api/models/chat/gather_messages.py | 82 +++++++++++++++++ .../{session => chat}/get_cached_response.py | 0 .../{session => chat}/prepare_chat_context.py | 2 +- .../{session => chat}/set_cached_response.py | 0 .../agents_api/models/entry/create_entries.py | 6 +- .../agents_api/models/entry/get_history.py | 2 + .../agents_api/models/entry/list_entries.py | 2 + .../agents_api/models/session/__init__.py | 3 - agents-api/agents_api/models/utils.py | 4 + .../agents_api/routers/sessions/chat.py | 88 ++++--------------- agents-api/poetry.lock | 12 +-- 12 files changed, 137 insertions(+), 86 deletions(-) create mode 100644 agents-api/agents_api/models/chat/__init__.py create mode 100644 agents-api/agents_api/models/chat/gather_messages.py rename agents-api/agents_api/models/{session => chat}/get_cached_response.py (100%) rename agents-api/agents_api/models/{session => chat}/prepare_chat_context.py (98%) rename agents-api/agents_api/models/{session => chat}/set_cached_response.py (100%) diff --git a/agents-api/agents_api/models/chat/__init__.py b/agents-api/agents_api/models/chat/__init__.py new file mode 100644 index 000000000..428b72572 --- /dev/null +++ b/agents-api/agents_api/models/chat/__init__.py @@ -0,0 +1,22 @@ +""" +Module: agents_api/models/docs + +This module is responsible for managing document-related operations within the application, particularly for agents and possibly other entities. It serves as a core component of the document management system, enabling features such as document creation, listing, deletion, and embedding of snippets for enhanced search and retrieval capabilities. + +Main functionalities include: +- Creating new documents and associating them with agents or users. +- Listing documents based on various criteria, including ownership and metadata filters. +- Deleting documents by their unique identifiers. +- Embedding document snippets for retrieval purposes. + +The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users. + +This documentation aims to provide clear, concise, and sufficient context for new developers or contributors to understand the module's role without needing to dive deep into the code immediately. +""" + +# ruff: noqa: F401, F403, F405 + +from .gather_messages import gather_messages +from .get_cached_response import get_cached_response +from .prepare_chat_context import prepare_chat_context +from .set_cached_response import set_cached_response diff --git a/agents-api/agents_api/models/chat/gather_messages.py b/agents-api/agents_api/models/chat/gather_messages.py new file mode 100644 index 000000000..2a3c0eca1 --- /dev/null +++ b/agents-api/agents_api/models/chat/gather_messages.py @@ -0,0 +1,82 @@ +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException +from pycozo.client import QueryException +from pydantic import ValidationError + +from agents_api.autogen.Chat import ChatInput + +from ...autogen.openapi_model import DocReference, History +from ...clients import embed +from ...common.protocol.developers import Developer +from ...common.protocol.sessions import ChatContext +from ..docs.search_docs_hybrid import search_docs_hybrid +from ..entry.get_history import get_history +from ..utils import ( + partialclass, + rewrap_exceptions, +) + + +@rewrap_exceptions( + { + QueryException: partialclass(HTTPException, status_code=400), + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + } +) +@beartype +async def gather_messages( + *, + developer: Developer, + session_id: UUID, + chat_context: ChatContext, + chat_input: ChatInput, +): + new_raw_messages = [msg.model_dump() for msg in chat_input.messages] + recall = chat_input.recall + + assert len(new_raw_messages) > 0 + + # Get the session history + history: History = get_history( + developer_id=developer.id, + session_id=session_id, + allowed_sources=["api_request", "api_response", "tool_response", "summarizer"], + ) + + # Keep leaf nodes only + relations = history.relations + past_messages = [ + entry.model_dump() + for entry in history.entries + if entry.id not in {r.head for r in relations} + ] + + if not recall: + return past_messages, [] + + # Search matching docs + [query_embedding, *_] = await embed.embed( + inputs=[ + f"{msg.get('name') or msg['role']}: {msg['content']}" + for msg in new_raw_messages + ], + join_inputs=True, + ) + query_text = new_raw_messages[-1]["content"] + + # List all the applicable owners to search docs from + active_agent_id = chat_context.get_active_agent().id + user_ids = [user.id for user in chat_context.users] + owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)] + + doc_references: list[DocReference] = search_docs_hybrid( + developer_id=developer.id, + owners=owners, + query=query_text, + query_embedding=query_embedding, + ) + + return past_messages, doc_references diff --git a/agents-api/agents_api/models/session/get_cached_response.py b/agents-api/agents_api/models/chat/get_cached_response.py similarity index 100% rename from agents-api/agents_api/models/session/get_cached_response.py rename to agents-api/agents_api/models/chat/get_cached_response.py diff --git a/agents-api/agents_api/models/session/prepare_chat_context.py b/agents-api/agents_api/models/chat/prepare_chat_context.py similarity index 98% rename from agents-api/agents_api/models/session/prepare_chat_context.py rename to agents-api/agents_api/models/chat/prepare_chat_context.py index 83e6c6f8b..0e076bc20 100644 --- a/agents-api/agents_api/models/session/prepare_chat_context.py +++ b/agents-api/agents_api/models/chat/prepare_chat_context.py @@ -7,6 +7,7 @@ from ...autogen.openapi_model import make_session from ...common.protocol.sessions import ChatContext +from ..session.prepare_session_data import prepare_session_data from ..utils import ( cozo_query, fix_uuid_if_present, @@ -16,7 +17,6 @@ verify_developer_owns_resource_query, wrap_in_class, ) -from .prepare_session_data import prepare_session_data @rewrap_exceptions( diff --git a/agents-api/agents_api/models/session/set_cached_response.py b/agents-api/agents_api/models/chat/set_cached_response.py similarity index 100% rename from agents-api/agents_api/models/session/set_cached_response.py rename to agents-api/agents_api/models/chat/set_cached_response.py diff --git a/agents-api/agents_api/models/entry/create_entries.py b/agents-api/agents_api/models/entry/create_entries.py index 01193d395..31c8b4d01 100644 --- a/agents-api/agents_api/models/entry/create_entries.py +++ b/agents-api/agents_api/models/entry/create_entries.py @@ -1,4 +1,3 @@ -import json from uuid import UUID, uuid4 from beartype import beartype @@ -34,6 +33,7 @@ "id": UUID(d.pop("entry_id")), **d, }, + _kind="inserted", ) @cozo_query @beartype @@ -55,10 +55,6 @@ def create_entries( item["entry_id"] = item.pop("id", None) or str(uuid4()) item["created_at"] = (item.get("created_at") or utcnow()).timestamp() - if not item.get("token_count"): - item["token_count"] = len(json.dumps(item)) // 3.5 - item["tokenizer"] = "character_count" - cols, rows = cozo_process_mutate_data(data_dicts) # Construct a datalog query to insert the processed entries into the 'cozodb' database. diff --git a/agents-api/agents_api/models/entry/get_history.py b/agents-api/agents_api/models/entry/get_history.py index 49eb7b929..fed658ea5 100644 --- a/agents-api/agents_api/models/entry/get_history.py +++ b/agents-api/agents_api/models/entry/get_history.py @@ -62,6 +62,7 @@ def get_history( content, source, token_count, + tokenizer, created_at, timestamp, }, @@ -75,6 +76,7 @@ def get_history( "content": content, "source": source, "token_count": token_count, + "tokenizer": tokenizer, "created_at": created_at, "timestamp": timestamp } diff --git a/agents-api/agents_api/models/entry/list_entries.py b/agents-api/agents_api/models/entry/list_entries.py index 0c47d9a74..da2341c4c 100644 --- a/agents-api/agents_api/models/entry/list_entries.py +++ b/agents-api/agents_api/models/entry/list_entries.py @@ -65,6 +65,7 @@ def list_entries( content, source, token_count, + tokenizer, created_at, timestamp, ] := *entries {{ @@ -75,6 +76,7 @@ def list_entries( content, source, token_count, + tokenizer, created_at, timestamp, }}, diff --git a/agents-api/agents_api/models/session/__init__.py b/agents-api/agents_api/models/session/__init__.py index b4092611f..bc5f7fbb4 100644 --- a/agents-api/agents_api/models/session/__init__.py +++ b/agents-api/agents_api/models/session/__init__.py @@ -14,11 +14,8 @@ from .create_or_update_session import create_or_update_session from .create_session import create_session from .delete_session import delete_session -from .get_cached_response import get_cached_response from .get_session import get_session from .list_sessions import list_sessions from .patch_session import patch_session -from .prepare_chat_context import prepare_chat_context from .prepare_session_data import prepare_session_data -from .set_cached_response import set_cached_response from .update_session import update_session diff --git a/agents-api/agents_api/models/utils.py b/agents-api/agents_api/models/utils.py index 411349ad3..ee2ba3fdd 100644 --- a/agents-api/agents_api/models/utils.py +++ b/agents-api/agents_api/models/utils.py @@ -232,6 +232,7 @@ def wrap_in_class( cls: Type[ModelT] | Callable[..., ModelT], one: bool = False, transform: Callable[[dict], dict] | None = None, + _kind: str | None = None, ): def decorator(func: Callable[P, pd.DataFrame]): @wraps(func) @@ -239,6 +240,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> ModelT | list[ModelT]: df = func(*args, **kwargs) # Convert df to list of dicts + if _kind: + df = df[df["_kind"] == _kind] + data = df.to_dict(orient="records") nonlocal transform diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index e6103c15e..8d0355de2 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -9,76 +9,20 @@ ChatResponse, ChunkChatResponse, CreateEntryRequest, - DocReference, - History, MessageChatResponse, ) -from ...clients import embed, litellm +from ...clients import litellm from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext from ...common.utils.datetime import utcnow from ...common.utils.template import render_template from ...dependencies.developer_id import get_developer_data -from ...models.docs.search_docs_hybrid import search_docs_hybrid +from ...models.chat.gather_messages import gather_messages +from ...models.chat.prepare_chat_context import prepare_chat_context from ...models.entry.create_entries import create_entries -from ...models.entry.get_history import get_history -from ...models.session.prepare_chat_context import prepare_chat_context from .router import router -async def get_messages( - *, - developer: Developer, - session_id: UUID, - new_raw_messages: list[dict], - chat_context: ChatContext, - recall: bool, -): - assert len(new_raw_messages) > 0 - - # Get the session history - history: History = get_history( - developer_id=developer.id, - session_id=session_id, - allowed_sources=["api_request", "api_response", "tool_response", "summarizer"], - ) - - # Keep leaf nodes only - relations = history.relations - past_messages = [ - entry.model_dump() - for entry in history.entries - if entry.id not in {r.head for r in relations} - ] - - if not recall: - return past_messages, [] - - # Search matching docs - [query_embedding, *_] = await embed.embed( - inputs=[ - f"{msg.get('name') or msg['role']}: {msg['content']}" - for msg in new_raw_messages - ], - join_inputs=True, - ) - query_text = new_raw_messages[-1]["content"] - - # List all the applicable owners to search docs from - active_agent_id = chat_context.get_active_agent().id - user_ids = [user.id for user in chat_context.users] - owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)] - - doc_references: list[DocReference] = search_docs_hybrid( - developer_id=developer.id, - owners=owners, - query=query_text, - query_embedding=query_embedding, - ) - - return past_messages, doc_references - - @router.post( "/sessions/{session_id}/chat", status_code=HTTP_201_CREATED, @@ -87,7 +31,7 @@ async def get_messages( async def chat( developer: Annotated[Developer, Depends(get_developer_data)], session_id: UUID, - input: ChatInput, + chat_input: ChatInput, background_tasks: BackgroundTasks, ) -> ChatResponse: # First get the chat context @@ -97,18 +41,17 @@ async def chat( ) # Merge the settings and prepare environment - chat_context.merge_settings(input) + chat_context.merge_settings(chat_input) settings: dict = chat_context.settings.model_dump() env: dict = chat_context.get_chat_environment() - new_raw_messages = [msg.model_dump() for msg in input.messages] + new_raw_messages = [msg.model_dump() for msg in chat_input.messages] # Render the messages - past_messages, doc_references = await get_messages( + past_messages, doc_references = await gather_messages( developer=developer, session_id=session_id, - new_raw_messages=new_raw_messages, chat_context=chat_context, - recall=input.recall, + chat_input=chat_input, ) env["docs"] = doc_references @@ -118,7 +61,7 @@ async def chat( # Get the tools tools = settings.get("tools") or chat_context.get_active_tools() - # Truncate the messages if necessary + # TODO: Truncate the messages if necessary if chat_context.session.context_overflow == "truncate": # messages = messages[-settings["max_tokens"] :] raise NotImplementedError("Truncation is not yet implemented") @@ -133,11 +76,12 @@ async def chat( ) # Save the input and the response to the session history - if input.save: - # TODO: Count the number of tokens before saving it to the session - + if chat_input.save: new_entries = [ - CreateEntryRequest(**msg, source="api_request") for msg in new_messages + CreateEntryRequest.from_model_input( + model=settings["model"], **msg, source="api_request" + ) + for msg in new_messages ] background_tasks.add_task( @@ -156,7 +100,9 @@ async def chat( raise NotImplementedError("Adaptive context is not yet implemented") # Return the response - chat_response_class = ChunkChatResponse if input.stream else MessageChatResponse + chat_response_class = ( + ChunkChatResponse if chat_input.stream else MessageChatResponse + ) chat_response: ChatResponse = chat_response_class( id=uuid4(), created_at=utcnow(), diff --git a/agents-api/poetry.lock b/agents-api/poetry.lock index b8c0a42b7..6a29c9f4a 100644 --- a/agents-api/poetry.lock +++ b/agents-api/poetry.lock @@ -2149,13 +2149,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" [[package]] name = "langchain-core" -version = "0.2.30" +version = "0.2.31" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_core-0.2.30-py3-none-any.whl", hash = "sha256:ea7eccb9566dd51b2b74bd292c4239d843a77cdba8ffae2b5edf7000d70d4194"}, - {file = "langchain_core-0.2.30.tar.gz", hash = "sha256:552ec586698140062cd299a83bad7e308f925b496d306b62529579c6fb122f7a"}, + {file = "langchain_core-0.2.31-py3-none-any.whl", hash = "sha256:b4daf5ddc23c0c3d8c5fd1a6c118f95fb5d0f96067b43f2c5935e1cd572e4374"}, + {file = "langchain_core-0.2.31.tar.gz", hash = "sha256:afb2089d4c10842d2477dc5cfa9ae9feb415c1421c6ef9aa608fea879ee41769"}, ] [package.dependencies] @@ -2255,13 +2255,13 @@ dev = ["Sphinx (>=5.1.1)", "black (==23.12.1)", "build (>=0.10.0)", "coverage (> [[package]] name = "litellm" -version = "1.43.9" +version = "1.43.12" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.43.9-py3-none-any.whl", hash = "sha256:54253281139e61f130b7e1a613a11f7a5ee896c2ee8536b0ca9a5ffbfce4c5f0"}, - {file = "litellm-1.43.9.tar.gz", hash = "sha256:c397a14c9b851f007f09c99e5a28606f7f122fdb4ae954931220f60e9edc6918"}, + {file = "litellm-1.43.12-py3-none-any.whl", hash = "sha256:f2c5f498a079df6eb8448ac41704367a389ea679a22e195c79b7963ede5cc462"}, + {file = "litellm-1.43.12.tar.gz", hash = "sha256:719eca58904942465dfd827e9d8f317112996ef481db71f9562f5263a553c74a"}, ] [package.dependencies] From 313a3120df3d782724b1b998973b14332c87bba4 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 18:55:08 -0400 Subject: [PATCH 4/6] feat(agents-api): Add token count support using litellm Signed-off-by: Diwank Tomer --- .../agents_api/activities/summarization.py | 10 ---- .../agents_api/activities/truncation.py | 7 +-- agents-api/agents_api/autogen/Entries.py | 4 +- .../agents_api/autogen/openapi_model.py | 59 ++++++++++++++++++- .../agents_api/common/protocol/entries.py | 56 ------------------ .../agents_api/routers/tasks/__init__.py | 3 +- .../routers/tasks/get_execution_details.py | 2 - agents-api/tests/test_chat_routes.py | 44 ++++++++++---- agents-api/tests/test_entry_queries.py | 32 +++++----- agents-api/tests/test_session_queries.py | 22 ------- agents-api/tests/test_task_routes.py | 1 + .../julep/api/types/entries_base_entry.py | 4 +- sdks/python/poetry.lock | 16 ++--- sdks/ts/src/api/models/Entries_BaseEntry.ts | 4 +- sdks/ts/src/api/schemas/$Entries_BaseEntry.ts | 2 + typespec/entries/models.tsp | 4 +- 16 files changed, 127 insertions(+), 143 deletions(-) delete mode 100644 agents-api/agents_api/common/protocol/entries.py diff --git a/agents-api/agents_api/activities/summarization.py b/agents-api/agents_api/activities/summarization.py index 581dcdb00..8a45927ee 100644 --- a/agents-api/agents_api/activities/summarization.py +++ b/agents-api/agents_api/activities/summarization.py @@ -1,23 +1,13 @@ #!/usr/bin/env python3 -import asyncio -from textwrap import dedent -from typing import Callable -from uuid import UUID import pandas as pd from temporalio import activity -# from agents_api.common.protocol.entries import Entry # from agents_api.models.entry.entries_summarization import ( # entries_summarization_query, # get_toplevel_entries_query, # ) -from agents_api.rec_sum.entities import get_entities -from agents_api.rec_sum.summarize import summarize_messages -from agents_api.rec_sum.trim import trim_messages - -from ..env import summarization_model_name # TODO: remove stubs diff --git a/agents-api/agents_api/activities/truncation.py b/agents-api/agents_api/activities/truncation.py index 353e4b570..7f381ac0f 100644 --- a/agents-api/agents_api/activities/truncation.py +++ b/agents-api/agents_api/activities/truncation.py @@ -2,9 +2,7 @@ from temporalio import activity -# from agents_api.autogen.openapi_model import Role -from agents_api.common.protocol.entries import Entry -from agents_api.models.entry.delete_entries import delete_entries +from agents_api.autogen.openapi_model import Entry # from agents_api.models.entry.entries_summarization import get_toplevel_entries_query @@ -13,8 +11,7 @@ def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list if not len(messages): return messages - result: list[UUID] = [] - token_cnt, offset = 0, 0 + _token_cnt, _offset = 0, 0 # if messages[0].role == Role.system: # token_cnt, offset = messages[0].token_count, 1 diff --git a/agents-api/agents_api/autogen/Entries.py b/agents-api/agents_api/autogen/Entries.py index b9921daa4..00070bd71 100644 --- a/agents-api/agents_api/autogen/Entries.py +++ b/agents-api/agents_api/autogen/Entries.py @@ -45,8 +45,8 @@ class BaseEntry(BaseModel): source: Literal[ "api_request", "api_response", "tool_response", "internal", "summarizer", "meta" ] - tokenizer: str | None = None - token_count: int | None = None + tokenizer: str + token_count: int timestamp: Annotated[float, Field(ge=0.0)] """ This is the time that this event refers to. diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index 16a57ac31..24640c46d 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -1,7 +1,9 @@ # ruff: noqa: F401, F403, F405 -from typing import Annotated, Generic, TypeVar +from typing import Annotated, Generic, Self, Type, TypeVar from uuid import UUID +from litellm.utils import _select_tokenizer as select_tokenizer +from litellm.utils import token_counter from pydantic import AwareDatetime, Field from pydantic_partial import create_partial_model @@ -34,7 +36,34 @@ "metadata", ) -ChatMLRole = BaseEntry.model_fields["role"].annotation +ChatMLRole = Literal[ + "user", + "assistant", + "system", + "function", + "function_response", + "function_call", + "auto", +] + +ChatMLContent = ( + list[ChatMLTextContentPart | ChatMLImageContentPart] + | Tool + | ChosenToolCall + | str + | ToolResponse + | list[ + list[ChatMLTextContentPart | ChatMLImageContentPart] + | Tool + | ChosenToolCall + | str + | ToolResponse + ] +) + +ChatMLSource = Literal[ + "api_request", "api_response", "tool_response", "internal", "summarizer", "meta" +] class CreateEntryRequest(BaseEntry): @@ -42,6 +71,32 @@ class CreateEntryRequest(BaseEntry): float, Field(ge=0.0, default_factory=lambda: utcnow().timestamp()) ] + @classmethod + def from_model_input( + cls: Type[Self], + model: str, + *, + role: ChatMLRole, + content: ChatMLContent, + name: str | None = None, + source: ChatMLSource, + **kwargs: dict, + ) -> Self: + tokenizer: dict = select_tokenizer(model=model) + token_count = token_counter( + model=model, messages=[{"role": role, "content": content, "name": name}] + ) + + return cls( + role=role, + content=content, + name=name, + source=source, + tokenizer=tokenizer["type"], + token_count=token_count, + **kwargs, + ) + def make_session( *, diff --git a/agents-api/agents_api/common/protocol/entries.py b/agents-api/agents_api/common/protocol/entries.py deleted file mode 100644 index 18d63f583..000000000 --- a/agents-api/agents_api/common/protocol/entries.py +++ /dev/null @@ -1,56 +0,0 @@ -import json -from typing import Literal -from uuid import UUID - -from pydantic import Field, computed_field - -from ...autogen.openapi_model import ( - ChatMLImageContentPart, - ChatMLTextContentPart, -) -from ...autogen.openapi_model import ( - Entry as BaseEntry, -) - -EntrySource = Literal["api_request", "api_response", "internal", "summarizer"] -Tokenizer = Literal["character_count"] - - -LOW_IMAGE_TOKEN_COUNT = 85 -HIGH_IMAGE_TOKEN_COUNT = 85 + 4 * 170 - - -class Entry(BaseEntry): - """Represents an entry in the system, encapsulating all necessary details such as ID, session ID, source, role, and content among others.""" - - session_id: UUID - token_count: int - tokenizer: str = Field(default="character_count") - - # TODO: Replace this with a proper implementation. - - # @computed_field - # @property - # def token_count(self) -> int: - # """Calculates the token count based on the content's character count. The tokenizer 'character_count' divides the length of the content by 3.5 to estimate the token count. Raises NotImplementedError for unknown tokenizers.""" - # if self.tokenizer == "character_count": - # content_length = 0 - # if isinstance(self.content, str): - # content_length = len(self.content) - # elif isinstance(self.content, dict): - # content_length = len(json.dumps(self.content)) - # elif isinstance(self.content, list): - # for part in self.content: - # if isinstance(part, ChatMLTextContentPart): - # content_length += len(part.text) - # elif isinstance(part, ChatMLImageContentPart): - # content_length += ( - # LOW_IMAGE_TOKEN_COUNT - # if part.image_url.detail == "low" - # else HIGH_IMAGE_TOKEN_COUNT - # ) - - # # Divide the content length by 3.5 to estimate token count based on character count. - # return int(content_length // 3.5) - - # raise NotImplementedError(f"Unknown tokenizer: {self.tokenizer}") diff --git a/agents-api/agents_api/routers/tasks/__init__.py b/agents-api/agents_api/routers/tasks/__init__.py index 8d67171c0..66621b34c 100644 --- a/agents-api/agents_api/routers/tasks/__init__.py +++ b/agents-api/agents_api/routers/tasks/__init__.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401, F403, F405 from .create_task import create_task from .create_task_execution import create_task_execution from .get_execution_details import get_execution_details @@ -5,5 +6,5 @@ from .list_task_executions import list_task_executions from .list_tasks import list_tasks from .patch_execution import patch_execution -from .router import router # noqa: F401 +from .router import router from .update_execution import update_execution diff --git a/agents-api/agents_api/routers/tasks/get_execution_details.py b/agents-api/agents_api/routers/tasks/get_execution_details.py index 1377d7b22..6a9a01caa 100644 --- a/agents-api/agents_api/routers/tasks/get_execution_details.py +++ b/agents-api/agents_api/routers/tasks/get_execution_details.py @@ -1,5 +1,3 @@ -from uuid import uuid4 - from fastapi import HTTPException, status from pydantic import UUID4 diff --git a/agents-api/tests/test_chat_routes.py b/agents-api/tests/test_chat_routes.py index ccf91c89e..1e8065d9d 100644 --- a/agents-api/tests/test_chat_routes.py +++ b/agents-api/tests/test_chat_routes.py @@ -2,11 +2,12 @@ from ward import test -from agents_api.autogen.Sessions import CreateSessionRequest +from agents_api.autogen.openapi_model import ChatInput, CreateSessionRequest from agents_api.clients import embed, litellm +from agents_api.common.protocol.sessions import ChatContext +from agents_api.models.chat.gather_messages import gather_messages +from agents_api.models.chat.prepare_chat_context import prepare_chat_context from agents_api.models.session.create_session import create_session -from agents_api.models.session.prepare_chat_context import prepare_chat_context -from agents_api.routers.sessions.chat import get_messages from tests.fixtures import ( cozo_client, make_request, @@ -28,7 +29,7 @@ async def _( assert (await embed.embed())[0][0] == 1.0 -@test("chat: check that non-recall get_messages works") +@test("chat: check that non-recall gather_messages works") async def _( developer=test_developer, client=cozo_client, @@ -49,14 +50,13 @@ async def _( session_id = session.id - new_raw_messages = [{"role": "user", "content": "hello"}] + messages = [{"role": "user", "content": "hello"}] - past_messages, doc_references = await get_messages( + past_messages, doc_references = await gather_messages( developer=developer, session_id=session_id, - new_raw_messages=new_raw_messages, chat_context=chat_context, - recall=False, + chat_input=ChatInput(messages=messages, recall=False), ) assert isinstance(past_messages, list) @@ -68,7 +68,7 @@ async def _( embed.assert_not_called() -@test("chat: check that get_messages works") +@test("chat: check that gather_messages works") async def _( developer=test_developer, client=cozo_client, @@ -89,14 +89,13 @@ async def _( session_id = session.id - new_raw_messages = [{"role": "user", "content": "hello"}] + messages = [{"role": "user", "content": "hello"}] - past_messages, doc_references = await get_messages( + past_messages, doc_references = await gather_messages( developer=developer, session_id=session_id, - new_raw_messages=new_raw_messages, chat_context=chat_context, - recall=True, + chat_input=ChatInput(messages=messages, recall=True), ) assert isinstance(past_messages, list) @@ -136,3 +135,22 @@ async def _( # Check that both mocks were called at least once embed.assert_called() acompletion.assert_called() + + +@test("model: prepare chat context") +def _( + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, + session=test_session, + tool=test_tool, + user=test_user, +): + context = prepare_chat_context( + developer_id=developer_id, + session_id=session.id, + client=client, + ) + + assert isinstance(context, ChatContext) + assert len(context.toolsets) > 0 diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index ce8b6a00d..c6b7150b6 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -27,8 +27,8 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): Verifies that the entry can be successfully added using the create_entries function. """ - test_entry = CreateEntryRequest( - session_id=session.id, + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, role="user", source="internal", content="test entry content", @@ -50,8 +50,8 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): Verifies that the entry can be successfully added using the create_entries function. """ - test_entry = CreateEntryRequest( - session_id=session.id, + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, role="user", source="internal", content="test entry content", @@ -84,15 +84,15 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): Verifies that entries matching specific criteria can be successfully retrieved. """ - test_entry = CreateEntryRequest( - session_id=session.id, + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, role="user", source="api_request", content="test entry content", ) - internal_entry = CreateEntryRequest( - session_id=session.id, + internal_entry = CreateEntryRequest.from_model_input( + model=MODEL, role="user", content="test entry content", source="internal", @@ -122,15 +122,15 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): Verifies that entries matching specific criteria can be successfully retrieved. """ - test_entry = CreateEntryRequest( - session_id=session.id, + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, role="user", source="api_request", content="test entry content", ) - internal_entry = CreateEntryRequest( - session_id=session.id, + internal_entry = CreateEntryRequest.from_model_input( + model=MODEL, role="user", content="test entry content", source="internal", @@ -161,15 +161,15 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): Verifies that entries can be successfully deleted using the delete_entries function. """ - test_entry = CreateEntryRequest( - session_id=session.id, + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, role="user", source="api_request", content="test entry content", ) - internal_entry = CreateEntryRequest( - session_id=session.id, + internal_entry = CreateEntryRequest.from_model_input( + model=MODEL, role="user", content="internal entry content", source="internal", diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 94b5bbbe4..7eae8485f 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -5,19 +5,16 @@ from agents_api.autogen.openapi_model import CreateOrUpdateSessionRequest, Session from agents_api.autogen.Sessions import CreateSessionRequest -from agents_api.common.protocol.sessions import ChatContext from agents_api.models.session.create_or_update_session import create_or_update_session from agents_api.models.session.create_session import create_session from agents_api.models.session.delete_session import delete_session from agents_api.models.session.get_session import get_session from agents_api.models.session.list_sessions import list_sessions -from agents_api.models.session.prepare_chat_context import prepare_chat_context from tests.fixtures import ( cozo_client, test_agent, test_developer_id, test_session, - test_tool, test_user, ) @@ -146,22 +143,3 @@ def _( assert result is not None assert isinstance(result, Session) assert result.id == session_id - - -@test("model: prepare chat context") -def _( - client=cozo_client, - developer_id=test_developer_id, - agent=test_agent, - session=test_session, - tool=test_tool, - user=test_user, -): - context = prepare_chat_context( - developer_id=developer_id, - session_id=session.id, - client=client, - ) - - assert isinstance(context, ChatContext) - assert len(context.toolsets) > 0 diff --git a/agents-api/tests/test_task_routes.py b/agents-api/tests/test_task_routes.py index b790240c4..80aae9e7f 100644 --- a/agents-api/tests/test_task_routes.py +++ b/agents-api/tests/test_task_routes.py @@ -1,4 +1,5 @@ # Tests for task routes + from uuid import uuid4 from ward import test diff --git a/sdks/python/julep/api/types/entries_base_entry.py b/sdks/python/julep/api/types/entries_base_entry.py index cd9a8158f..a12c99c1c 100644 --- a/sdks/python/julep/api/types/entries_base_entry.py +++ b/sdks/python/julep/api/types/entries_base_entry.py @@ -15,8 +15,8 @@ class EntriesBaseEntry(pydantic_v1.BaseModel): name: typing.Optional[str] = None content: EntriesBaseEntryContent source: EntriesBaseEntrySource - tokenizer: typing.Optional[str] = None - token_count: typing.Optional[int] = None + tokenizer: str + token_count: int timestamp: float = pydantic_v1.Field() """ This is the time that this event refers to. diff --git a/sdks/python/poetry.lock b/sdks/python/poetry.lock index 060ffa16f..fe5a1519d 100644 --- a/sdks/python/poetry.lock +++ b/sdks/python/poetry.lock @@ -848,21 +848,21 @@ test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "p [[package]] name = "importlib-resources" -version = "6.4.0" +version = "6.4.1" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"}, - {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"}, + {file = "importlib_resources-6.4.1-py3-none-any.whl", hash = "sha256:8fbee7ba7376ca7c47ce8d31b96b93d8787349845f01ebdbee5ef90409035234"}, + {file = "importlib_resources-6.4.1.tar.gz", hash = "sha256:5ede8acf5d752abda46fb6922a4a6ab782b6d904dfd362bf2d8b857eee1759d9"}, ] [package.dependencies] zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] [[package]] name = "ipykernel" @@ -2309,13 +2309,13 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pyright" -version = "1.1.375" +version = "1.1.376" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.375-py3-none-any.whl", hash = "sha256:4c5e27eddeaee8b41cc3120736a1dda6ae120edf8523bb2446b6073a52f286e3"}, - {file = "pyright-1.1.375.tar.gz", hash = "sha256:7765557b0d6782b2fadabff455da2014476404c9e9214f49977a4e49dec19a0f"}, + {file = "pyright-1.1.376-py3-none-any.whl", hash = "sha256:0f2473b12c15c46b3207f0eec224c3cea2bdc07cd45dd4a037687cbbca0fbeff"}, + {file = "pyright-1.1.376.tar.gz", hash = "sha256:bffd63b197cd0810395bb3245c06b01f95a85ddf6bfa0e5644ed69c841e954dd"}, ] [package.dependencies] diff --git a/sdks/ts/src/api/models/Entries_BaseEntry.ts b/sdks/ts/src/api/models/Entries_BaseEntry.ts index e77df13c3..d397d851e 100644 --- a/sdks/ts/src/api/models/Entries_BaseEntry.ts +++ b/sdks/ts/src/api/models/Entries_BaseEntry.ts @@ -17,8 +17,8 @@ export type Entries_BaseEntry = { | "internal" | "summarizer" | "meta"; - tokenizer?: string; - token_count?: number; + tokenizer: string; + token_count: number; /** * This is the time that this event refers to. */ diff --git a/sdks/ts/src/api/schemas/$Entries_BaseEntry.ts b/sdks/ts/src/api/schemas/$Entries_BaseEntry.ts index 6aad5206a..bcdb7122e 100644 --- a/sdks/ts/src/api/schemas/$Entries_BaseEntry.ts +++ b/sdks/ts/src/api/schemas/$Entries_BaseEntry.ts @@ -37,9 +37,11 @@ export const $Entries_BaseEntry = { }, tokenizer: { type: "string", + isRequired: true, }, token_count: { type: "number", + isRequired: true, format: "uint16", }, timestamp: { diff --git a/typespec/entries/models.tsp b/typespec/entries/models.tsp index f9050b4f4..fba6803df 100644 --- a/typespec/entries/models.tsp +++ b/typespec/entries/models.tsp @@ -91,8 +91,8 @@ model BaseEntry { content: EntryContent | EntryContent[]; source: entrySource; - tokenizer?: string; - token_count?: uint16; + tokenizer: string; + token_count: uint16; /** This is the time that this event refers to. */ @minValue(0) From 152c29a9089120abf8985a8ce22d9745d704c0b6 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 19:49:49 -0400 Subject: [PATCH 5/6] feat(agents-api): Minor modifier to models Signed-off-by: Diwank Tomer --- agents-api/agents_api/models/agent/create_agent.py | 5 ++++- agents-api/agents_api/models/agent/delete_agent.py | 1 + agents-api/agents_api/models/agent/patch_agent.py | 1 + agents-api/agents_api/models/agent/update_agent.py | 1 + agents-api/agents_api/models/docs/delete_doc.py | 1 + agents-api/agents_api/models/docs/embed_snippets.py | 1 + agents-api/agents_api/models/entry/create_entries.py | 2 +- agents-api/agents_api/models/entry/delete_entries.py | 1 + agents-api/agents_api/models/execution/create_execution.py | 1 + .../models/execution/create_execution_transition.py | 5 ++++- agents-api/agents_api/models/execution/update_execution.py | 1 + agents-api/agents_api/models/session/create_session.py | 1 + agents-api/agents_api/models/session/delete_session.py | 1 + agents-api/agents_api/models/session/patch_session.py | 1 + agents-api/agents_api/models/session/update_session.py | 1 + agents-api/agents_api/models/task/create_task.py | 2 +- agents-api/agents_api/models/task/delete_task.py | 1 + agents-api/agents_api/models/task/patch_task.py | 1 + agents-api/agents_api/models/tools/create_tools.py | 1 + agents-api/agents_api/models/tools/delete_tool.py | 1 + agents-api/agents_api/models/tools/patch_tool.py | 1 + agents-api/agents_api/models/tools/update_tool.py | 1 + agents-api/agents_api/models/user/create_user.py | 7 ++++++- agents-api/agents_api/models/user/delete_user.py | 1 + agents-api/agents_api/models/user/patch_user.py | 1 + agents-api/agents_api/models/user/update_user.py | 1 + 26 files changed, 37 insertions(+), 5 deletions(-) diff --git a/agents-api/agents_api/models/agent/create_agent.py b/agents-api/agents_api/models/agent/create_agent.py index 6b649afbb..a4b408c78 100644 --- a/agents-api/agents_api/models/agent/create_agent.py +++ b/agents-api/agents_api/models/agent/create_agent.py @@ -34,7 +34,10 @@ } ) @wrap_in_class( - Agent, one=True, transform=lambda d: {"id": UUID(d.pop("agent_id")), **d} + Agent, + one=True, + transform=lambda d: {"id": UUID(d.pop("agent_id")), **d}, + _kind="inserted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/agent/delete_agent.py b/agents-api/agents_api/models/agent/delete_agent.py index 27e3ecc1c..e1efd7333 100644 --- a/agents-api/agents_api/models/agent/delete_agent.py +++ b/agents-api/agents_api/models/agent/delete_agent.py @@ -41,6 +41,7 @@ "deleted_at": utcnow(), "jobs": [], }, + _kind="deleted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/agent/patch_agent.py b/agents-api/agents_api/models/agent/patch_agent.py index 3a374f91a..87f5db046 100644 --- a/agents-api/agents_api/models/agent/patch_agent.py +++ b/agents-api/agents_api/models/agent/patch_agent.py @@ -29,6 +29,7 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["agent_id"], "jobs": [], **d}, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/agent/update_agent.py b/agents-api/agents_api/models/agent/update_agent.py index 9cd8f04d9..95116e1a2 100644 --- a/agents-api/agents_api/models/agent/update_agent.py +++ b/agents-api/agents_api/models/agent/update_agent.py @@ -28,6 +28,7 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["agent_id"], "jobs": [], **d}, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/docs/delete_doc.py b/agents-api/agents_api/models/docs/delete_doc.py index 9dcaf0f33..e7a02f3d9 100644 --- a/agents-api/agents_api/models/docs/delete_doc.py +++ b/agents-api/agents_api/models/docs/delete_doc.py @@ -32,6 +32,7 @@ "deleted_at": utcnow(), "jobs": [], }, + _kind="deleted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/docs/embed_snippets.py b/agents-api/agents_api/models/docs/embed_snippets.py index 89750192a..64bf8e130 100644 --- a/agents-api/agents_api/models/docs/embed_snippets.py +++ b/agents-api/agents_api/models/docs/embed_snippets.py @@ -30,6 +30,7 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["doc_id"], "updated_at": utcnow(), "jobs": []}, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/entry/create_entries.py b/agents-api/agents_api/models/entry/create_entries.py index 31c8b4d01..e71a81c7a 100644 --- a/agents-api/agents_api/models/entry/create_entries.py +++ b/agents-api/agents_api/models/entry/create_entries.py @@ -90,7 +90,7 @@ def create_entries( TypeError: partialclass(HTTPException, status_code=400), } ) -@wrap_in_class(Relation) +@wrap_in_class(Relation, _kind="inserted") @cozo_query @beartype def add_entry_relations( diff --git a/agents-api/agents_api/models/entry/delete_entries.py b/agents-api/agents_api/models/entry/delete_entries.py index f64bfbf73..5bf34c721 100644 --- a/agents-api/agents_api/models/entry/delete_entries.py +++ b/agents-api/agents_api/models/entry/delete_entries.py @@ -34,6 +34,7 @@ "deleted_at": utcnow(), "jobs": [], }, + _kind="deleted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/execution/create_execution.py b/agents-api/agents_api/models/execution/create_execution.py index 4e85a1db7..155d18b16 100644 --- a/agents-api/agents_api/models/execution/create_execution.py +++ b/agents-api/agents_api/models/execution/create_execution.py @@ -31,6 +31,7 @@ Execution, one=True, transform=lambda d: {"id": d["execution_id"], **d}, + _kind="inserted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/execution/create_execution_transition.py b/agents-api/agents_api/models/execution/create_execution_transition.py index df538be79..701ffd3cc 100644 --- a/agents-api/agents_api/models/execution/create_execution_transition.py +++ b/agents-api/agents_api/models/execution/create_execution_transition.py @@ -53,7 +53,10 @@ } ) @wrap_in_class( - Transition, transform=lambda d: {"id": d["transition_id"], **d}, one=True + Transition, + transform=lambda d: {"id": d["transition_id"], **d}, + one=True, + _kind="inserted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/execution/update_execution.py b/agents-api/agents_api/models/execution/update_execution.py index 1424c4ef2..4386b9502 100644 --- a/agents-api/agents_api/models/execution/update_execution.py +++ b/agents-api/agents_api/models/execution/update_execution.py @@ -31,6 +31,7 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["execution_id"], "jobs": [], **d}, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/session/create_session.py b/agents-api/agents_api/models/session/create_session.py index aab7fddac..7bb0576c9 100644 --- a/agents-api/agents_api/models/session/create_session.py +++ b/agents-api/agents_api/models/session/create_session.py @@ -36,6 +36,7 @@ "updated_at": (d.pop("updated_at")[0]), **d, }, + _kind="inserted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/session/delete_session.py b/agents-api/agents_api/models/session/delete_session.py index e6b0037ca..71f153fb1 100644 --- a/agents-api/agents_api/models/session/delete_session.py +++ b/agents-api/agents_api/models/session/delete_session.py @@ -34,6 +34,7 @@ "deleted_at": utcnow(), "jobs": [], }, + _kind="deleted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/session/patch_session.py b/agents-api/agents_api/models/session/patch_session.py index f16a53f44..131d82bec 100644 --- a/agents-api/agents_api/models/session/patch_session.py +++ b/agents-api/agents_api/models/session/patch_session.py @@ -46,6 +46,7 @@ "jobs": [], **d, }, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/session/update_session.py b/agents-api/agents_api/models/session/update_session.py index 5f7789f29..c4296634a 100644 --- a/agents-api/agents_api/models/session/update_session.py +++ b/agents-api/agents_api/models/session/update_session.py @@ -44,6 +44,7 @@ "jobs": [], **d, }, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/task/create_task.py b/agents-api/agents_api/models/task/create_task.py index 17f991e7b..61f0bca72 100644 --- a/agents-api/agents_api/models/task/create_task.py +++ b/agents-api/agents_api/models/task/create_task.py @@ -32,7 +32,7 @@ TypeError: partialclass(HTTPException, status_code=400), } ) -@wrap_in_class(spec_to_task, one=True) +@wrap_in_class(spec_to_task, one=True, _kind="inserted") @cozo_query @beartype def create_task( diff --git a/agents-api/agents_api/models/task/delete_task.py b/agents-api/agents_api/models/task/delete_task.py index 28c3defb3..1eb780057 100644 --- a/agents-api/agents_api/models/task/delete_task.py +++ b/agents-api/agents_api/models/task/delete_task.py @@ -33,6 +33,7 @@ "deleted_at": utcnow(), **d, }, + _kind="deleted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/task/patch_task.py b/agents-api/agents_api/models/task/patch_task.py index 93d57f32f..4be73c025 100644 --- a/agents-api/agents_api/models/task/patch_task.py +++ b/agents-api/agents_api/models/task/patch_task.py @@ -39,6 +39,7 @@ "updated_at": d["updated_at_ms"][0] / 1000, **d, }, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/tools/create_tools.py b/agents-api/agents_api/models/tools/create_tools.py index 597268863..0c034d4d2 100644 --- a/agents-api/agents_api/models/tools/create_tools.py +++ b/agents-api/agents_api/models/tools/create_tools.py @@ -32,6 +32,7 @@ d["type"]: d.pop("spec"), **d, }, + _kind="inserted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/tools/delete_tool.py b/agents-api/agents_api/models/tools/delete_tool.py index ad6a9d4f5..e6d00498a 100644 --- a/agents-api/agents_api/models/tools/delete_tool.py +++ b/agents-api/agents_api/models/tools/delete_tool.py @@ -28,6 +28,7 @@ ResourceDeletedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "deleted_at": utcnow(), "jobs": [], **d}, + _kind="deleted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/tools/patch_tool.py b/agents-api/agents_api/models/tools/patch_tool.py index acdbcf0b4..68bea3ee5 100644 --- a/agents-api/agents_api/models/tools/patch_tool.py +++ b/agents-api/agents_api/models/tools/patch_tool.py @@ -28,6 +28,7 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/tools/update_tool.py b/agents-api/agents_api/models/tools/update_tool.py index 1376aba37..b8a395a94 100644 --- a/agents-api/agents_api/models/tools/update_tool.py +++ b/agents-api/agents_api/models/tools/update_tool.py @@ -31,6 +31,7 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/user/create_user.py b/agents-api/agents_api/models/user/create_user.py index f9c5af19e..4115b3326 100644 --- a/agents-api/agents_api/models/user/create_user.py +++ b/agents-api/agents_api/models/user/create_user.py @@ -32,7 +32,12 @@ TypeError: partialclass(HTTPException, status_code=400), } ) -@wrap_in_class(User, one=True, transform=lambda d: {"id": UUID(d.pop("user_id")), **d}) +@wrap_in_class( + User, + one=True, + transform=lambda d: {"id": UUID(d.pop("user_id")), **d}, + _kind="inserted", +) @cozo_query @beartype def create_user( diff --git a/agents-api/agents_api/models/user/delete_user.py b/agents-api/agents_api/models/user/delete_user.py index 04f96e630..b9ebc2db7 100644 --- a/agents-api/agents_api/models/user/delete_user.py +++ b/agents-api/agents_api/models/user/delete_user.py @@ -36,6 +36,7 @@ "deleted_at": utcnow(), "jobs": [], }, + _kind="deleted", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/user/patch_user.py b/agents-api/agents_api/models/user/patch_user.py index fdad01fe7..89fe4db33 100644 --- a/agents-api/agents_api/models/user/patch_user.py +++ b/agents-api/agents_api/models/user/patch_user.py @@ -31,6 +31,7 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["user_id"], "jobs": [], **d}, + _kind="replaced", ) @cozo_query @beartype diff --git a/agents-api/agents_api/models/user/update_user.py b/agents-api/agents_api/models/user/update_user.py index 7132e13a0..88f24e6e0 100644 --- a/agents-api/agents_api/models/user/update_user.py +++ b/agents-api/agents_api/models/user/update_user.py @@ -28,6 +28,7 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["user_id"], "jobs": [], **d}, + _kind="replaced", ) @cozo_query @beartype From e97b6fb093af41fe2e78dab433594a77220d9eed Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 20:56:49 -0400 Subject: [PATCH 6/6] fix(agents-api): Minor fixes Signed-off-by: Diwank Tomer --- agents-api/agents_api/activities/__init__.py | 10 +--- agents-api/agents_api/clients/temporal.py | 46 ------------------- .../agents_api/routers/docs/create_doc.py | 46 ++++++++++++++++++- agents-api/poetry.lock | 14 +++--- agents-api/tests/fixtures.py | 12 +++++ agents-api/tests/test_activities.py | 24 +++++----- agents-api/tests/test_docs_routes.py | 8 ++-- sdks/python/poetry.lock | 6 +-- 8 files changed, 84 insertions(+), 82 deletions(-) diff --git a/agents-api/agents_api/activities/__init__.py b/agents-api/agents_api/activities/__init__.py index 49722a7d5..c641ab7b9 100644 --- a/agents-api/agents_api/activities/__init__.py +++ b/agents-api/agents_api/activities/__init__.py @@ -1,13 +1,5 @@ """ -The `activities` module within the agents-api package is designed to facilitate various activities related to agent interactions. This includes handling memory management, generating insights from dialogues, summarizing relationships, and more. Each file within the module offers specific functionality: - -- `co_density.py`: Conducts cognitive density analysis to generate concise, entity-dense summaries. -- `dialog_insights.py`: Extracts insights from dialogues, identifying details that participants might find interesting. -- `mem_mgmt.py`: Manages memory by updating and incorporating new personality information from dialogues. -- `mem_rating.py`: Rates memories based on their poignancy and importance. -- `relationship_summary.py`: Summarizes the relationship between individuals based on provided statements. -- `salient_questions.py`: Identifies salient questions from a set of statements. -- `summarization.py`: Summarizes dialogues and updates memory based on the conversation context. +The `activities` module within the agents-api package is designed to facilitate various activities related to agent interactions. This includes handling memory management, generating insights from dialogues, summarizing relationships, and more. This module plays a crucial role in enhancing the capabilities of agents by providing them with the tools to understand and process information more effectively. """ diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index 72a5056c8..29ceedded 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -34,52 +34,6 @@ async def get_client( ) -async def run_summarization_task( - session_id: UUID, job_id: UUID, client: Client | None = None -): - client = client or (await get_client()) - - await client.execute_workflow( - "SummarizationWorkflow", - args=[str(session_id)], - task_queue="memory-task-queue", - id=str(job_id), - ) - - -async def run_embed_docs_task( - doc_id: UUID, - title: str, - content: list[str], - job_id: UUID, - client: Client | None = None, -): - client = client or (await get_client()) - - await client.execute_workflow( - "EmbedDocsWorkflow", - args=[str(doc_id), title, content], - task_queue="memory-task-queue", - id=str(job_id), - ) - - -async def run_truncation_task( - token_count_threshold: int, - session_id: UUID, - job_id: UUID, - client: Client | None = None, -): - client = client or (await get_client()) - - await client.execute_workflow( - "TruncationWorkflow", - args=[str(session_id), token_count_threshold], - task_queue="memory-task-queue", - id=str(job_id), - ) - - async def run_task_execution_workflow( execution_input: ExecutionInput, job_id: UUID, diff --git a/agents-api/agents_api/routers/docs/create_doc.py b/agents-api/agents_api/routers/docs/create_doc.py index 0b43fe8eb..645f82964 100644 --- a/agents-api/agents_api/routers/docs/create_doc.py +++ b/agents-api/agents_api/routers/docs/create_doc.py @@ -1,15 +1,35 @@ from typing import Annotated +from uuid import UUID, uuid4 from fastapi import Depends from pydantic import UUID4 from starlette.status import HTTP_201_CREATED +from temporalio.client import Client as TemporalClient from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse +from ...clients import temporal from ...dependencies.developer_id import get_developer_id from ...models.docs.create_doc import create_doc as create_doc_query from .router import router +async def run_embed_docs_task( + doc_id: UUID, + title: str, + content: list[str], + job_id: UUID, + client: TemporalClient | None = None, +): + client = client or (await temporal.get_client()) + + await client.execute_workflow( + "EmbedDocsWorkflow", + args=[str(doc_id), title, content], + task_queue="memory-task-queue", + id=str(job_id), + ) + + @router.post("/users/{user_id}/docs", status_code=HTTP_201_CREATED, tags=["docs"]) async def create_user_doc( user_id: UUID4, @@ -23,7 +43,18 @@ async def create_user_doc( data=data, ) - return ResourceCreatedResponse(id=doc.id, created_at=doc.created_at, jobs=[]) + embed_job_id = uuid4() + + await run_embed_docs_task( + doc_id=doc.id, + title=doc.title, + content=doc.content, + job_id=embed_job_id, + ) + + return ResourceCreatedResponse( + id=doc.id, created_at=doc.created_at, jobs=[embed_job_id] + ) @router.post("/agents/{agent_id}/docs", status_code=HTTP_201_CREATED, tags=["docs"]) @@ -39,4 +70,15 @@ async def create_agent_doc( data=data, ) - return ResourceCreatedResponse(id=doc.id, created_at=doc.created_at, jobs=[]) + embed_job_id = uuid4() + + await run_embed_docs_task( + doc_id=doc.id, + title=doc.title, + content=doc.content, + job_id=embed_job_id, + ) + + return ResourceCreatedResponse( + id=doc.id, created_at=doc.created_at, jobs=[embed_job_id] + ) diff --git a/agents-api/poetry.lock b/agents-api/poetry.lock index 6a29c9f4a..e4edb74b7 100644 --- a/agents-api/poetry.lock +++ b/agents-api/poetry.lock @@ -2103,18 +2103,18 @@ files = [ [[package]] name = "langchain" -version = "0.2.13" +version = "0.2.14" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain-0.2.13-py3-none-any.whl", hash = "sha256:80f21e48cdada424dd2af9bbf42234fe095744cf181b31eeb63d1da7479e2783"}, - {file = "langchain-0.2.13.tar.gz", hash = "sha256:947e96ac3153a46aa6a0d8207e5f8b6794084c397f60a01bbf4bba78e6838fee"}, + {file = "langchain-0.2.14-py3-none-any.whl", hash = "sha256:eed76194ee7d9c081037a3df7868d4de90e0410b51fc1ca933a8379e464bf40c"}, + {file = "langchain-0.2.14.tar.gz", hash = "sha256:dc2aa5a58882054fb5d043c39ab8332ebd055f88f17839da68e1c7fd0a4fefe2"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" -langchain-core = ">=0.2.30,<0.3.0" +langchain-core = ">=0.2.32,<0.3.0" langchain-text-splitters = ">=0.2.0,<0.3.0" langsmith = ">=0.1.17,<0.2.0" numpy = {version = ">=1,<2", markers = "python_version < \"3.12\""} @@ -2149,13 +2149,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" [[package]] name = "langchain-core" -version = "0.2.31" +version = "0.2.32" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_core-0.2.31-py3-none-any.whl", hash = "sha256:b4daf5ddc23c0c3d8c5fd1a6c118f95fb5d0f96067b43f2c5935e1cd572e4374"}, - {file = "langchain_core-0.2.31.tar.gz", hash = "sha256:afb2089d4c10842d2477dc5cfa9ae9feb415c1421c6ef9aa608fea879ee41769"}, + {file = "langchain_core-0.2.32-py3-none-any.whl", hash = "sha256:1f5584cf0034909e35ea17010a847d4079417e0ddcb5a9eb3fbb2bd55f3268c0"}, + {file = "langchain_core-0.2.32.tar.gz", hash = "sha256:d82cdc350bbbe74261330d87056b7d9f1fb567828e9e03f708d23a48b941819e"}, ] [package.dependencies] diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 4fb53ffd7..437c95083 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -74,6 +74,18 @@ async def temporal_worker(wf_env=workflow_environment): await c +@fixture(scope="test") +def patch_temporal_get_client( + wf_env=workflow_environment, + temporal_worker=temporal_worker, +): + mock_client = wf_env.client + + with patch("agents_api.clients.temporal.get_client") as get_client: + get_client.return_value = mock_client + yield get_client + + @fixture(scope="global") def test_developer_id(cozo_client=cozo_client): developer_id = uuid4() diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index 29adb8110..ef9d4151b 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -16,7 +16,18 @@ # from agents_api.common.protocol.entries import Entry -@test("activity: embed_docs") +@test("activity: check that workflow environment and worker are started correctly") +async def _( + workflow_environment=workflow_environment, + worker=temporal_worker, +): + async with workflow_environment as wf_env: + assert wf_env is not None + assert worker is not None + assert worker.is_running + + +@test("activity: call direct embed_docs") async def _( cozo_client=cozo_client, developer_id=test_developer_id, @@ -41,17 +52,6 @@ async def _( embed.assert_called_once() -@test("activity: check that workflow environment and worker are started correctly") -async def _( - workflow_environment=workflow_environment, - worker=temporal_worker, -): - async with workflow_environment as wf_env: - assert wf_env is not None - assert worker is not None - assert worker.is_running - - # @test("get extra entries, do not strip system message") # def _(): # session_ids = [uuid.uuid4()] * 3 diff --git a/agents-api/tests/test_docs_routes.py b/agents-api/tests/test_docs_routes.py index 05f095f49..67e38a50e 100644 --- a/agents-api/tests/test_docs_routes.py +++ b/agents-api/tests/test_docs_routes.py @@ -25,6 +25,9 @@ def _(make_request=make_request, user=test_user): assert response.status_code == 201 + result = response.json() + assert len(result["jobs"]) > 0 + @test("route: create agent doc") def _(make_request=make_request, agent=test_agent): @@ -41,9 +44,8 @@ def _(make_request=make_request, agent=test_agent): assert response.status_code == 201 - # FIXME: Should create a job to process the document - # result = response.json() - # assert len(result["jobs"]) > 0 + result = response.json() + assert len(result["jobs"]) > 0 @test("route: delete doc") diff --git a/sdks/python/poetry.lock b/sdks/python/poetry.lock index fe5a1519d..bfef1a91f 100644 --- a/sdks/python/poetry.lock +++ b/sdks/python/poetry.lock @@ -848,13 +848,13 @@ test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "p [[package]] name = "importlib-resources" -version = "6.4.1" +version = "6.4.2" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.4.1-py3-none-any.whl", hash = "sha256:8fbee7ba7376ca7c47ce8d31b96b93d8787349845f01ebdbee5ef90409035234"}, - {file = "importlib_resources-6.4.1.tar.gz", hash = "sha256:5ede8acf5d752abda46fb6922a4a6ab782b6d904dfd362bf2d8b857eee1759d9"}, + {file = "importlib_resources-6.4.2-py3-none-any.whl", hash = "sha256:8bba8c54a8a3afaa1419910845fa26ebd706dc716dd208d9b158b4b6966f5c5c"}, + {file = "importlib_resources-6.4.2.tar.gz", hash = "sha256:6cbfbefc449cc6e2095dd184691b7a12a04f40bc75dd4c55d31c34f174cdf57a"}, ] [package.dependencies]