Skip to content

Commit

Permalink
[wip] feat(agents-api,sdks): multimodal support (#343)
Browse files Browse the repository at this point in the history
* feat: Multimodal support

Signed-off-by: Diwank Tomer <[email protected]>

* feat: Generate types and sdks

* Update openapi.yaml

* fix: Fix image-part and text-part objects

* fix: Fix model queries to work with new chatml schema

Signed-off-by: Diwank Tomer <[email protected]>

* fix: Minor fix, do not use datetime.utcnow

Signed-off-by: Diwank Tomer <[email protected]>

* lint

Signed-off-by: Diwank Tomer <[email protected]>

* fix(python-sdk): Fix pytype error

Signed-off-by: Diwank Tomer <[email protected]>

* refactor: Lint sdks/python (CI)

* feat(agents-api): Add support for multimodal stuff in sessions.chat

Signed-off-by: Diwank Tomer <[email protected]>

* fix(agents-api): Fix openapi spec

Signed-off-by: Diwank Tomer <[email protected]>

* feat(agents-api): Generate types and sdks

Signed-off-by: Diwank Singh <[email protected]>

* fix: Small fix for pytype

Signed-off-by: Diwank Tomer <[email protected]>

* fix: Make message content transformations depending on its type

* fix: Annotate developer_id as a string

* fix(agents-api,sdks): Lint and fix small bug

Signed-off-by: Diwank Tomer <[email protected]>

* feat(sdks): Add tests for multimodal chat in python sdk

Signed-off-by: Diwank Tomer <[email protected]>

* fix: Set remember and recall to True by default

* fix: Get content length

---------

Signed-off-by: Diwank Tomer <[email protected]>
Signed-off-by: Diwank Singh <[email protected]>
Co-authored-by: Diwank Tomer <[email protected]>
Co-authored-by: creatorrr <[email protected]>
Co-authored-by: Dmitry Paramonov <[email protected]>
  • Loading branch information
4 people authored May 29, 2024
1 parent d734550 commit 7cd34fa
Show file tree
Hide file tree
Showing 82 changed files with 1,478 additions and 352 deletions.
141 changes: 94 additions & 47 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-05-16T13:06:33+00:00
# timestamp: 2024-05-28T03:07:50+00:00

from __future__ import annotations

Expand Down Expand Up @@ -231,29 +231,6 @@ class Role(str, Enum):
function = "function"


class ChatMLMessage(BaseModel):
role: Role
"""
ChatML role (system|assistant|user|function_call|function)
"""
content: str
"""
ChatML content
"""
name: str | None = None
"""
ChatML name
"""
created_at: AwareDatetime
"""
Message created at (RFC-3339 format)
"""
id: UUID
"""
Message ID
"""


class Role1(str, Enum):
"""
ChatML role (system|assistant|user|function_call|function|auto)
Expand All @@ -267,25 +244,6 @@ class Role1(str, Enum):
auto = "auto"


class InputChatMLMessage(BaseModel):
role: Role1
"""
ChatML role (system|assistant|user|function_call|function|auto)
"""
content: str
"""
ChatML content
"""
name: str | None = None
"""
ChatML name
"""
continue_: Annotated[bool | None, Field(False, alias="continue")]
"""
Whether to continue this message or return a new one
"""


class Function(BaseModel):
name: str
"""
Expand Down Expand Up @@ -358,10 +316,6 @@ class FinishReason(str, Enum):
function_call = "function_call"


class Response(BaseModel):
items: ChatMLMessage | None = None


class Memory(BaseModel):
agent_id: UUID
"""
Expand Down Expand Up @@ -821,6 +775,53 @@ class DocIds(BaseModel):
user_doc_ids: List[str]


class ChatMLTextContentPart(BaseModel):
type: Literal["text"] = "text"
"""
Fixed to 'text'
"""
text: str
"""
Text content part
"""


class Detail(str, Enum):
"""
image detail to feed into the model can be low | high | auto
"""

low = "low"
high = "high"
auto = "auto"


class ImageUrl(BaseModel):
"""
Image content part, can be a URL or a base64-encoded image
"""

url: str
"""
URL or base64 data url (e.g. `data:image/jpeg;base64,<the base64 encoded image>`)
"""
detail: Detail | None = "auto" # pytype: disable=annotation-type-mismatch
"""
image detail to feed into the model can be low | high | auto
"""


class ChatMLImageContentPart(BaseModel):
type: Literal["image_url"] = "image_url"
"""
Fixed to 'image_url'
"""
image_url: ImageUrl
"""
Image content part, can be a URL or a base64-encoded image
"""


class Agent(BaseModel):
name: str
"""
Expand Down Expand Up @@ -953,6 +954,48 @@ class UpdateAgentRequest(BaseModel):
"""


class ChatMLMessage(BaseModel):
role: Role
"""
ChatML role (system|assistant|user|function_call|function)
"""
content: str | List[ChatMLTextContentPart | ChatMLImageContentPart]
"""
ChatML content
"""
name: str | None = None
"""
ChatML name
"""
created_at: AwareDatetime
"""
Message created at (RFC-3339 format)
"""
id: UUID
"""
Message ID
"""


class InputChatMLMessage(BaseModel):
role: Role1
"""
ChatML role (system|assistant|user|function_call|function|auto)
"""
content: str | List[ChatMLTextContentPart | ChatMLImageContentPart]
"""
ChatML content
"""
name: str | None = None
"""
ChatML name
"""
continue_: Annotated[bool | None, Field(False, alias="continue")]
"""
Whether to continue this message or return a new one
"""


class ChatInputData(BaseModel):
messages: Annotated[List[InputChatMLMessage], Field(min_length=1)]
"""
Expand All @@ -968,6 +1011,10 @@ class ChatInputData(BaseModel):
"""


class Response(BaseModel):
items: ChatMLMessage | None = None


class ChatResponse(BaseModel):
"""
Represents a chat completion response returned by model, based on the provided input.
Expand Down
6 changes: 5 additions & 1 deletion agents-api/agents_api/clients/worker/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Callable, Literal, Optional, Protocol
from uuid import UUID
from pydantic import BaseModel
from agents_api.autogen.openapi_model import (
ChatMLTextContentPart,
ChatMLImageContentPart,
)


class PromptModule(Protocol):
Expand All @@ -12,7 +16,7 @@ class PromptModule(Protocol):

class ChatML(BaseModel):
role: Literal["system", "user", "assistant", "function_call"]
content: str
content: str | dict | list[ChatMLTextContentPart] | list[ChatMLImageContentPart]

name: Optional[str] = None
entry_id: Optional[UUID] = None
Expand Down
31 changes: 21 additions & 10 deletions agents-api/agents_api/common/protocol/entries.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from datetime import datetime
import json
from typing import Literal
from uuid import UUID, uuid4

from pydantic import BaseModel, Field, computed_field
from agents_api.autogen.openapi_model import Role
from agents_api.autogen.openapi_model import (
Role,
ChatMLImageContentPart,
ChatMLTextContentPart,
)
from agents_api.common.utils.datetime import utcnow

EntrySource = Literal["api_request", "api_response", "internal", "summarizer"]
Tokenizer = Literal["character_count"]
Expand All @@ -20,25 +24,32 @@ class Entry(BaseModel):
source: EntrySource = Field(default="api_request")
role: Role
name: str | None = None
content: str
content: str | list[ChatMLTextContentPart] | list[ChatMLImageContentPart] | dict
tokenizer: str = Field(default="character_count")
created_at: float = Field(
default_factory=lambda: datetime.utcnow().timestamp()
default_factory=lambda: utcnow().timestamp()
) # Uses a default factory to set the creation timestamp
timestamp: float = Field(
default_factory=lambda: datetime.utcnow().timestamp()
default_factory=lambda: utcnow().timestamp()
) # Uses a default factory to set the current timestamp

@computed_field
@property
def token_count(self) -> int:
"""Calculates the token count based on the content's character count. The tokenizer 'character_count' divides the length of the content by 3.5 to estimate the token count. Raises NotImplementedError for unknown tokenizers."""
if self.tokenizer == "character_count":
content_length = (
len(self.content)
if isinstance(self.content, str)
else len(json.dumps(self.content))
)
content_length = 0
if isinstance(self.content, str):
content_length = len(self.content)
elif isinstance(self.content, dict):
content_length = len(json.dumps(self.content))
elif isinstance(self.content, list):
text = ""
for part in self.content:
# TODO: how to calc token count for images?
if isinstance(part, ChatMLTextContentPart):
text += part.text

# Divide the content length by 3.5 to estimate token count based on character count.
return int(content_length // 3.5)

Expand Down
17 changes: 17 additions & 0 deletions agents-api/agents_api/common/utils/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from functools import wraps


def pdb_on_exception(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception:
import pdb
import traceback

traceback.print_exc()
pdb.set_trace()
raise

return wrapper
46 changes: 45 additions & 1 deletion agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


# Funcs
async def render_template(
async def render_template_string(
template_string: str, variables: dict, check: bool = False
) -> str:
# Parse template
Expand All @@ -36,3 +36,47 @@ async def render_template(
# Render
rendered = await template.render_async(**variables)
return rendered


async def render_template_parts(
template_strings: list[dict], variables: dict, check: bool = False
) -> list[dict]:
# Parse template
templates = [
(jinja_env.from_string(msg["text"]) if msg["type"] == "text" else None)
for msg in template_strings
]

# If check is required, get required vars from template and validate variables
if check:
for template in templates:
if template is None:
continue

schema = to_json_schema(infer(template))
validate(instance=variables, schema=schema)

# Render
rendered = [
(
{"type": "text", "text": await template.render_async(**variables)}
if template is not None
else msg
)
for template, msg in zip(templates, template_strings)
]

return rendered


async def render_template(
template_string: str | list[dict], variables: dict, check: bool = False
) -> str | list[dict]:
if isinstance(template_string, str):
return await render_template_string(template_string, variables, check)

elif isinstance(template_string, list):
return await render_template_parts(template_string, variables, check)

else:
raise ValueError("template_string should be str or list[dict]")
4 changes: 3 additions & 1 deletion agents-api/agents_api/dependencies/developer_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from .exceptions import InvalidHeaderFormat


async def get_developer_id(x_developer_id: Annotated[str | None, Header()] = None):
async def get_developer_id(
x_developer_id: Annotated[uuid.UUID | None, Header()] = None
):
if skip_check_developer_headers:
return x_developer_id or uuid.UUID("00000000-0000-0000-0000-000000000000")

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/agent/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def create_agent_query(
pd.DataFrame: A DataFrame containing the results of the query execution.
"""

preset = default_settings["preset"]
preset = default_settings.get("preset")
default_settings["preset"] = getattr(preset, "value", preset)

settings_cols, settings_vals = cozo_process_mutate_data(
Expand Down
10 changes: 8 additions & 2 deletions agents-api/agents_api/models/docs/create_docs.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from typing import Literal
from uuid import UUID

from beartype import beartype

from ...common.utils.cozo import cozo_process_mutate_data
from ..utils import cozo_query
from ...common.utils.datetime import utcnow


@cozo_query
@beartype
def create_docs_query(
owner_type: Literal["user", "agent"],
owner_id: UUID,
id: UUID,
title: str,
content: list[str],
content: list[str] | str,
metadata: dict = {},
) -> tuple[str, dict]:
):
"""
Constructs and executes a datalog query to create a new document and its associated snippets in the 'cozodb' database.
Expand All @@ -30,6 +32,10 @@ def create_docs_query(
Returns:
pd.DataFrame: A DataFrame containing the results of the query execution.
"""

if isinstance(content, str):
content = [content]

created_at: float = utcnow().timestamp()
snippet_cols, snippet_rows = "", []

Expand Down
Loading

0 comments on commit 7cd34fa

Please sign in to comment.