Skip to content

Commit

Permalink
feat: add new providers (bedrock, deepseek) + bugfixes (#2445)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Feb 20, 2025
2 parents 5502490 + 5d02f0d commit a678e6d
Show file tree
Hide file tree
Showing 73 changed files with 2,624 additions and 492 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ repos:
rev: v2.3.0
hooks:
- id: check-yaml
exclude: 'docs/.*|tests/data/.*|configs/.*'
exclude: 'docs/.*|tests/data/.*|configs/.*|helm/.*'
- id: end-of-file-fixer
exclude: 'docs/.*|tests/data/.*|letta/server/static_files/.*'
- id: trailing-whitespace
Expand Down
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ ENV LETTA_ENVIRONMENT=${LETTA_ENVIRONMENT} \
PATH="/app/.venv/bin:$PATH" \
POSTGRES_USER=letta \
POSTGRES_PASSWORD=letta \
POSTGRES_DB=letta
POSTGRES_DB=letta \
COMPOSIO_DISABLE_VERSION_CHECK=true

WORKDIR /app

Expand Down
66 changes: 66 additions & 0 deletions alembic/versions/a113caac453e_add_identities_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""add identities table
Revision ID: a113caac453e
Revises: 7980d239ea08
Create Date: 2025-02-14 09:58:18.227122
"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "a113caac453e"
down_revision: Union[str, None] = "7980d239ea08"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Create identities table
op.create_table(
"identities",
sa.Column("id", sa.String(), nullable=False),
sa.Column("identifier_key", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("identity_type", sa.String(), nullable=False),
sa.Column("project_id", sa.String(), nullable=True),
# From OrganizationMixin
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
# Foreign key to organizations
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
# Composite unique constraint
sa.UniqueConstraint(
"identifier_key",
"project_id",
"organization_id",
name="unique_identifier_pid_org_id",
),
sa.PrimaryKeyConstraint("id"),
)

# Add identity_id column to agents table
op.add_column("agents", sa.Column("identity_id", sa.String(), nullable=True))

# Add foreign key constraint
op.create_foreign_key("fk_agents_identity_id", "agents", "identities", ["identity_id"], ["id"], ondelete="CASCADE")


def downgrade() -> None:
# First remove the foreign key constraint and column from agents
op.drop_constraint("fk_agents_identity_id", "agents", type_="foreignkey")
op.drop_column("agents", "identity_id")

# Then drop the table
op.drop_table("identities")
27 changes: 27 additions & 0 deletions alembic/versions/a3047a624130_add_identifier_key_to_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""add identifier key to agents
Revision ID: a3047a624130
Revises: a113caac453e
Create Date: 2025-02-14 12:24:16.123456
"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "a3047a624130"
down_revision: Union[str, None] = "a113caac453e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column("agents", sa.Column("identifier_key", sa.String(), nullable=True))


def downgrade() -> None:
op.drop_column("agents", "identifier_key")
2 changes: 1 addition & 1 deletion letta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.27"
__version__ = "0.6.28"

# import clients
from letta.client.client import LocalClient, RESTClient, create_client
Expand Down
14 changes: 13 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from letta.settings import summarizer_settings
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
from letta.tracing import trace_method
from letta.utils import (
count_tokens,
get_friendly_error_msg,
Expand Down Expand Up @@ -309,6 +310,7 @@ def _handle_function_error_response(
# Return updated messages
return messages

@trace_method("Get AI Reply")
def _get_ai_reply(
self,
message_sequence: List[Message],
Expand Down Expand Up @@ -399,6 +401,7 @@ def _get_ai_reply(
log_telemetry(self.logger, "_handle_ai_response finish catch-all exception")
raise Exception("Retries exhausted and no valid response received.")

@trace_method("Handle AI Response")
def _handle_ai_response(
self,
response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function?
Expand Down Expand Up @@ -492,7 +495,10 @@ def _handle_ai_response(
try:
raw_function_args = function_call.arguments
function_args = parse_json(raw_function_args)
except Exception:
if not isinstance(function_args, dict):
raise ValueError(f"Function arguments are not a dictionary: {function_args} (raw={raw_function_args})")
except Exception as e:
print(e)
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}"
function_response = "None" # more like "never ran?"
messages = self._handle_function_error_response(
Expand Down Expand Up @@ -627,9 +633,15 @@ def _handle_ai_response(
elif self.tool_rules_solver.is_terminal_tool(function_name):
heartbeat_request = False

# if continue tool rule, then must request a heartbeat
# TODO: dont even include heartbeats in the args
if self.tool_rules_solver.is_continue_tool(function_name):
heartbeat_request = True

log_telemetry(self.logger, "_handle_ai_response finish")
return messages, heartbeat_request, function_failed

@trace_method("Agent Step")
def step(
self,
messages: Union[Message, List[Message]],
Expand Down
2 changes: 2 additions & 0 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2351,6 +2351,7 @@ def create_agent(
tool_rules: Optional[List[BaseToolRule]] = None,
include_base_tools: Optional[bool] = True,
include_multi_agent_tools: bool = False,
include_base_tool_rules: bool = True,
# metadata
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
Expand Down Expand Up @@ -2402,6 +2403,7 @@ def create_agent(
"tool_rules": tool_rules,
"include_base_tools": include_base_tools,
"include_multi_agent_tools": include_multi_agent_tools,
"include_base_tool_rules": include_base_tool_rules,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config if llm_config else self._default_llm_config,
Expand Down
2 changes: 2 additions & 0 deletions letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
# The max amount of tokens supported by the underlying model (eg 8k for gpt-4 and Mistral 7B)
LLM_MAX_TOKENS = {
"DEFAULT": 8192,
"deepseek-chat": 64000,
"deepseek-reasoner": 64000,
## OpenAI models: https://platform.openai.com/docs/models/overview
# "o1-preview
"chatgpt-4o-latest": 128000,
Expand Down
12 changes: 6 additions & 6 deletions letta/functions/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,12 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
# append the heartbeat
# TODO: don't hard-code
# TODO: if terminal, don't include this
if function.__name__ not in ["send_message"]:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
schema["parameters"]["required"].append("request_heartbeat")
# if function.__name__ not in ["send_message"]:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
schema["parameters"]["required"].append("request_heartbeat")

return schema

Expand Down
153 changes: 153 additions & 0 deletions letta/helpers/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import base64
from typing import Any, Dict, List, Optional, Union

import numpy as np
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy import Dialect

from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule

# --------------------------
# LLMConfig Serialization
# --------------------------


def serialize_llm_config(config: Union[Optional[LLMConfig], Dict]) -> Optional[Dict]:
"""Convert an LLMConfig object into a JSON-serializable dictionary."""
if config and isinstance(config, LLMConfig):
return config.model_dump()
return config


def deserialize_llm_config(data: Optional[Dict]) -> Optional[LLMConfig]:
"""Convert a dictionary back into an LLMConfig object."""
return LLMConfig(**data) if data else None


# --------------------------
# EmbeddingConfig Serialization
# --------------------------


def serialize_embedding_config(config: Union[Optional[EmbeddingConfig], Dict]) -> Optional[Dict]:
"""Convert an EmbeddingConfig object into a JSON-serializable dictionary."""
if config and isinstance(config, EmbeddingConfig):
return config.model_dump()
return config


def deserialize_embedding_config(data: Optional[Dict]) -> Optional[EmbeddingConfig]:
"""Convert a dictionary back into an EmbeddingConfig object."""
return EmbeddingConfig(**data) if data else None


# --------------------------
# ToolRule Serialization
# --------------------------


def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str, Any]]:
"""Convert a list of ToolRules into a JSON-serializable format."""

if not tool_rules:
return []

data = [{**rule.model_dump(), "type": rule.type.value} for rule in tool_rules] # Convert Enum to string for JSON compatibility

# Validate ToolRule structure
for rule_data in data:
if rule_data["type"] == ToolRuleType.constrain_child_tools.value and "children" not in rule_data:
raise ValueError(f"Invalid ToolRule serialization: 'children' field missing for rule {rule_data}")

return data


def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]]:
"""Convert a list of dictionaries back into ToolRule objects."""
if not data:
return []

return [deserialize_tool_rule(rule_data) for rule_data in data]


def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
rule_type = ToolRuleType(data.get("type"))

if rule_type == ToolRuleType.run_first:
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop:
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools:
return ChildToolRule(**data)
elif rule_type == ToolRuleType.conditional:
return ConditionalToolRule(**data)
elif rule_type == ToolRuleType.continue_loop:
return ContinueToolRule(**data)
raise ValueError(f"Unknown ToolRule type: {rule_type}")


# --------------------------
# ToolCall Serialization
# --------------------------


def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]]) -> List[Dict]:
"""Convert a list of OpenAI ToolCall objects into JSON-serializable format."""
if not tool_calls:
return []

serialized_calls = []
for call in tool_calls:
if isinstance(call, OpenAIToolCall):
serialized_calls.append(call.model_dump())
elif isinstance(call, dict):
serialized_calls.append(call) # Already a dictionary, leave it as-is
else:
raise TypeError(f"Unexpected tool call type: {type(call)}")

return serialized_calls


def deserialize_tool_calls(data: Optional[List[Dict]]) -> List[OpenAIToolCall]:
"""Convert a JSON list back into OpenAIToolCall objects."""
if not data:
return []

calls = []
for item in data:
func_data = item.pop("function", None)
tool_call_function = OpenAIFunction(**func_data) if func_data else None
calls.append(OpenAIToolCall(function=tool_call_function, **item))

return calls


# --------------------------
# Vector Serialization
# --------------------------


def serialize_vector(vector: Optional[Union[List[float], np.ndarray]]) -> Optional[bytes]:
"""Convert a NumPy array or list into a base64-encoded byte string."""
if vector is None:
return None
if isinstance(vector, list):
vector = np.array(vector, dtype=np.float32)

return base64.b64encode(vector.tobytes())


def deserialize_vector(data: Optional[bytes], dialect: Dialect) -> Optional[np.ndarray]:
"""Convert a base64-encoded byte string back into a NumPy array."""
if not data:
return None

if dialect.name == "sqlite":
data = base64.b64decode(data)

return np.frombuffer(data, dtype=np.float32)
Loading

0 comments on commit a678e6d

Please sign in to comment.