Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Minor fix to tests #457

Merged
merged 7 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
#!/usr/bin/env python3

import asyncio
from textwrap import dedent
from typing import Callable
from uuid import UUID

import pandas as pd
from temporalio import activity

# from agents_api.common.protocol.entries import Entry
# from agents_api.models.entry.entries_summarization import (
# entries_summarization_query,
# get_toplevel_entries_query,
# )
from agents_api.rec_sum.entities import get_entities
from agents_api.rec_sum.summarize import summarize_messages
from agents_api.rec_sum.trim import trim_messages

from ..env import summarization_model_name


# TODO: remove stubs
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
7 changes: 2 additions & 5 deletions agents-api/agents_api/activities/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from temporalio import activity

# from agents_api.autogen.openapi_model import Role
from agents_api.common.protocol.entries import Entry
from agents_api.models.entry.delete_entries import delete_entries
from agents_api.autogen.openapi_model import Entry

# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query

Expand All @@ -13,8 +11,7 @@ def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list
if not len(messages):
return messages

result: list[UUID] = []
token_cnt, offset = 0, 0
_token_cnt, _offset = 0, 0
# if messages[0].role == Role.system:
# token_cnt, offset = messages[0].token_count, 1

Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/autogen/Entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class BaseEntry(BaseModel):
source: Literal[
"api_request", "api_response", "tool_response", "internal", "summarizer", "meta"
]
tokenizer: str | None = None
token_count: int | None = None
tokenizer: str
token_count: int
timestamp: Annotated[float, Field(ge=0.0)]
"""
This is the time that this event refers to.
Expand Down
59 changes: 57 additions & 2 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# ruff: noqa: F401, F403, F405
from typing import Annotated, Generic, TypeVar
from typing import Annotated, Generic, Self, Type, TypeVar
from uuid import UUID

from litellm.utils import _select_tokenizer as select_tokenizer
from litellm.utils import token_counter
from pydantic import AwareDatetime, Field
from pydantic_partial import create_partial_model

Expand Down Expand Up @@ -34,14 +36,67 @@
"metadata",
)

ChatMLRole = BaseEntry.model_fields["role"].annotation
ChatMLRole = Literal[
"user",
"assistant",
"system",
"function",
"function_response",
"function_call",
"auto",
]

ChatMLContent = (
list[ChatMLTextContentPart | ChatMLImageContentPart]
| Tool
| ChosenToolCall
| str
| ToolResponse
| list[
list[ChatMLTextContentPart | ChatMLImageContentPart]
| Tool
| ChosenToolCall
| str
| ToolResponse
]
)

ChatMLSource = Literal[
"api_request", "api_response", "tool_response", "internal", "summarizer", "meta"
]


class CreateEntryRequest(BaseEntry):
timestamp: Annotated[
float, Field(ge=0.0, default_factory=lambda: utcnow().timestamp())
]

@classmethod
def from_model_input(
cls: Type[Self],
model: str,
*,
role: ChatMLRole,
content: ChatMLContent,
name: str | None = None,
source: ChatMLSource,
**kwargs: dict,
) -> Self:
tokenizer: dict = select_tokenizer(model=model)
token_count = token_counter(
model=model, messages=[{"role": role, "content": content, "name": name}]
)

return cls(
role=role,
content=content,
name=name,
source=source,
tokenizer=tokenizer["type"],
token_count=token_count,
**kwargs,
)


def make_session(
*,
Expand Down
56 changes: 0 additions & 56 deletions agents-api/agents_api/common/protocol/entries.py

This file was deleted.

22 changes: 22 additions & 0 deletions agents-api/agents_api/models/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Module: agents_api/models/docs

This module is responsible for managing document-related operations within the application, particularly for agents and possibly other entities. It serves as a core component of the document management system, enabling features such as document creation, listing, deletion, and embedding of snippets for enhanced search and retrieval capabilities.

Main functionalities include:
- Creating new documents and associating them with agents or users.
- Listing documents based on various criteria, including ownership and metadata filters.
- Deleting documents by their unique identifiers.
- Embedding document snippets for retrieval purposes.

The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users.

This documentation aims to provide clear, concise, and sufficient context for new developers or contributors to understand the module's role without needing to dive deep into the code immediately.
"""

# ruff: noqa: F401, F403, F405

from .gather_messages import gather_messages
from .get_cached_response import get_cached_response
from .prepare_chat_context import prepare_chat_context
from .set_cached_response import set_cached_response
82 changes: 82 additions & 0 deletions agents-api/agents_api/models/chat/gather_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from agents_api.autogen.Chat import ChatInput

from ...autogen.openapi_model import DocReference, History
from ...clients import embed
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ..docs.search_docs_hybrid import search_docs_hybrid
from ..entry.get_history import get_history
from ..utils import (
partialclass,
rewrap_exceptions,
)


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@beartype
async def gather_messages(
*,
developer: Developer,
session_id: UUID,
chat_context: ChatContext,
chat_input: ChatInput,
):
new_raw_messages = [msg.model_dump() for msg in chat_input.messages]
recall = chat_input.recall

assert len(new_raw_messages) > 0

# Get the session history
history: History = get_history(
developer_id=developer.id,
session_id=session_id,
allowed_sources=["api_request", "api_response", "tool_response", "summarizer"],
)

# Keep leaf nodes only
relations = history.relations
past_messages = [
entry.model_dump()
for entry in history.entries
if entry.id not in {r.head for r in relations}
]

if not recall:
return past_messages, []

# Search matching docs
[query_embedding, *_] = await embed.embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
],
join_inputs=True,
)
query_text = new_raw_messages[-1]["content"]

# List all the applicable owners to search docs from
active_agent_id = chat_context.get_active_agent().id
user_ids = [user.id for user in chat_context.users]
owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)]

doc_references: list[DocReference] = search_docs_hybrid(
developer_id=developer.id,
owners=owners,
query=query_text,
query_embedding=query_embedding,
)

return past_messages, doc_references
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ...autogen.openapi_model import make_session
from ...common.protocol.sessions import ChatContext
from ..session.prepare_session_data import prepare_session_data
from ..utils import (
cozo_query,
fix_uuid_if_present,
Expand All @@ -16,7 +17,6 @@
verify_developer_owns_resource_query,
wrap_in_class,
)
from .prepare_session_data import prepare_session_data


@rewrap_exceptions(
Expand Down
11 changes: 4 additions & 7 deletions agents-api/agents_api/models/entry/create_entries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from uuid import UUID, uuid4

from beartype import beartype
Expand Down Expand Up @@ -34,6 +33,7 @@
"id": UUID(d.pop("entry_id")),
**d,
},
_kind="inserted",
)
@cozo_query
@beartype
Expand All @@ -55,10 +55,6 @@ def create_entries(
item["entry_id"] = item.pop("id", None) or str(uuid4())
item["created_at"] = (item.get("created_at") or utcnow()).timestamp()

if not item.get("token_count"):
item["token_count"] = len(json.dumps(item)) // 3.5
item["tokenizer"] = "character_count"

cols, rows = cozo_process_mutate_data(data_dicts)

# Construct a datalog query to insert the processed entries into the 'cozodb' database.
Expand All @@ -78,8 +74,9 @@ def create_entries(
verify_developer_owns_resource_query(
developer_id, "sessions", session_id=session_id
),
mark_session_as_updated
and mark_session_updated_query(developer_id, session_id),
mark_session_updated_query(developer_id, session_id)
if mark_session_as_updated
else "",
create_query,
]

Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/entry/get_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_history(
content,
source,
token_count,
tokenizer,
created_at,
timestamp,
},
Expand All @@ -75,6 +76,7 @@ def get_history(
"content": content,
"source": source,
"token_count": token_count,
"tokenizer": tokenizer,
"created_at": created_at,
"timestamp": timestamp
}
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/entry/list_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def list_entries(
content,
source,
token_count,
tokenizer,
created_at,
timestamp,
] := *entries {{
Expand All @@ -75,6 +76,7 @@ def list_entries(
content,
source,
token_count,
tokenizer,
created_at,
timestamp,
}},
Expand Down
3 changes: 0 additions & 3 deletions agents-api/agents_api/models/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@
from .create_or_update_session import create_or_update_session
from .create_session import create_session
from .delete_session import delete_session
from .get_cached_response import get_cached_response
from .get_session import get_session
from .list_sessions import list_sessions
from .patch_session import patch_session
from .prepare_chat_context import prepare_chat_context
from .prepare_session_data import prepare_session_data
from .set_cached_response import set_cached_response
from .update_session import update_session
Loading
Loading