Skip to content

Commit

Permalink
feat: Add per-agent locking to send message (#2109)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Nov 26, 2024
1 parent 91982f2 commit c2ee91c
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 26 deletions.
7 changes: 6 additions & 1 deletion letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ToolRule,
)
from letta.schemas.user import User
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.settings import settings
from letta.utils import enforce_types, get_utc_time, printd

Expand Down Expand Up @@ -383,7 +384,11 @@ def update_agent(self, agent: AgentState):
session.commit()

@enforce_types
def delete_agent(self, agent_id: str):
def delete_agent(self, agent_id: str, per_agent_lock_manager: PerAgentLockManager):
# TODO: Remove this once Agent is on the ORM
# TODO: To prevent unbounded growth
per_agent_lock_manager.clear_lock(agent_id)

with self.session_maker() as session:

# delete agents
Expand Down
28 changes: 15 additions & 13 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,19 +475,21 @@ async def send_message(
"""
actor = server.get_user_or_default(user_id=user_id)

result = await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
messages=request.messages,
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
return_message_object=request.return_message_object,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_function_name=request.assistant_message_function_name,
assistant_message_function_kwarg=request.assistant_message_function_kwarg,
)
agent_lock = server.per_agent_lock_manager.get_lock(agent_id)
async with agent_lock:
result = await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
messages=request.messages,
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
return_message_object=request.return_message_object,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_function_name=request.assistant_message_function_name,
assistant_message_function_kwarg=request.assistant_message_function_kwarg,
)
return result


Expand Down
12 changes: 10 additions & 2 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import traceback
import warnings
from abc import abstractmethod
from asyncio import Lock
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -79,6 +80,7 @@
from letta.services.block_manager import BlockManager
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.services.organization_manager import OrganizationManager
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
Expand Down Expand Up @@ -231,6 +233,9 @@ def __init__(

self.credentials = LettaCredentials.load()

# Locks
self.send_message_lock = Lock()

# Initialize the metadata store
config = LettaConfig.load()
if settings.letta_pg_uri_no_default:
Expand All @@ -252,6 +257,9 @@ def __init__(
self.blocks_agents_manager = BlocksAgentsManager()
self.sandbox_config_manager = SandboxConfigManager(tool_settings)

# Managers that interface with parallelism
self.per_agent_lock_manager = PerAgentLockManager()

# Make default user and org
if init_with_default_org_and_user:
self.default_org = self.organization_manager.create_default_organization()
Expand Down Expand Up @@ -925,7 +933,7 @@ def create_agent(
logger.exception(e)
try:
if agent:
self.ms.delete_agent(agent_id=agent.agent_state.id)
self.ms.delete_agent(agent_id=agent.agent_state.id, per_agent_lock_manager=self.per_agent_lock_manager)
except Exception as delete_e:
logger.exception(f"Failed to delete_agent:\n{delete_e}")
raise e
Expand Down Expand Up @@ -1522,7 +1530,7 @@ def delete_agent(self, user_id: str, agent_id: str):

# Next, attempt to delete it from the actual database
try:
self.ms.delete_agent(agent_id=agent_id)
self.ms.delete_agent(agent_id=agent_id, per_agent_lock_manager=self.per_agent_lock_manager)
except Exception as e:
logger.exception(f"Failed to delete agent {agent_id} via ID with:\n{str(e)}")
raise ValueError(f"Failed to delete agent {agent_id} in database")
Expand Down
18 changes: 18 additions & 0 deletions letta/services/per_agent_lock_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import asyncio
from collections import defaultdict


class PerAgentLockManager:
"""Manages per-agent locks."""

def __init__(self):
self.locks = defaultdict(asyncio.Lock)

def get_lock(self, agent_id: str) -> asyncio.Lock:
"""Retrieve the lock for a specific agent_id."""
return self.locks[agent_id]

def clear_lock(self, agent_id: str):
"""Optionally remove a lock if no longer needed (to prevent unbounded growth)."""
if agent_id in self.locks:
del self.locks[agent_id]
43 changes: 43 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import threading
import time
Expand Down Expand Up @@ -295,3 +296,45 @@ def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent

finally:
client.delete_agent(agent.id)


def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()

send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"

messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response) > 0, "Retrieving messages failed"


@pytest.mark.asyncio
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
"""
Test that sending two messages in parallel does not error.
"""
if not isinstance(client, RESTClient):
pytest.skip("This test only runs when the server is enabled")

# Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
async def send_message_task(message: str):
response = await asyncio.to_thread(client.send_message, agent.id, message, role="user")
assert response, f"Sending message '{message}' failed"
return response

# Prepare two tasks with different messages
messages = ["Test message 1", "Test message 2"]
tasks = [send_message_task(message) for message in messages]

# Run the tasks concurrently
responses = await asyncio.gather(*tasks, return_exceptions=True)

# Check for exceptions and validate responses
for i, response in enumerate(responses):
if isinstance(response, Exception):
pytest.fail(f"Task {i} failed with exception: {response}")
else:
assert response, f"Task {i} returned an invalid response: {response}"

# Ensure both tasks completed
assert len(responses) == len(messages), "Not all messages were processed"
10 changes: 0 additions & 10 deletions tests/test_client_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,6 @@ def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}"


def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()

send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"

messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response) > 0, "Retrieving messages failed"


def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: AgentState):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming")
Expand Down

0 comments on commit c2ee91c

Please sign in to comment.