Skip to content

Commit

Permalink
feat: Add endpoint to get full Tool objects belonging to an agent (#1906
Browse files Browse the repository at this point in the history
)

Co-authored-by: Matt Zhou <[email protected]>
  • Loading branch information
mattzh72 and Matt Zhou authored Oct 18, 2024
1 parent c51af44 commit fc3d4e1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 16 deletions.
31 changes: 31 additions & 0 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def update_agent(
):
raise NotImplementedError

def get_tools_from_agent(self, agent_id: str):
raise NotImplementedError

def add_tool_to_agent(self, agent_id: str, tool_id: str):
raise NotImplementedError

Expand Down Expand Up @@ -480,6 +483,21 @@ def update_agent(
raise ValueError(f"Failed to update agent: {response.text}")
return AgentState(**response.json())

def get_tools_from_agent(self, agent_id: str) -> List[Tool]:
"""
Get tools to an existing agent
Args:
agent_id (str): ID of the agent
Returns:
List[Tool]: A List of Tool objs
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/tools", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get tools from agents: {response.text}")
return [Tool(**tool) for tool in response.json()]

def add_tool_to_agent(self, agent_id: str, tool_id: str):
"""
Add tool to an existing agent
Expand Down Expand Up @@ -1692,6 +1710,19 @@ def update_agent(
)
return agent_state

def get_tools_from_agent(self, agent_id: str) -> List[Tool]:
"""
Get tools from an existing agent.
Args:
agent_id (str): ID of the agent
Returns:
List[Tool]: A list of Tool objs
"""
self.interface.clear()
return self.server.get_tools_from_agent(agent_id=agent_id, user_id=self.user_id)

def add_tool_to_agent(self, agent_id: str, tool_id: str):
"""
Add tool to an existing agent
Expand Down
20 changes: 14 additions & 6 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from letta.schemas.message import Message, MessageCreate, UpdateMessage
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.utils import get_letta_server, sse_async_generator
from letta.server.server import SyncServer
Expand Down Expand Up @@ -100,17 +101,26 @@ def update_agent(
return server.update_agent(update_agent, user_id=actor.id)


@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent")
def get_tools_from_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Get tools from an existing agent"""
actor = server.get_user_or_default(user_id=user_id)
return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id)


@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent")
def add_tool_to_agent(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Add tools to an exsiting agent"""
"""Add tools to an existing agent"""
actor = server.get_user_or_default(user_id=user_id)

update_agent.id = agent_id
return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)


Expand All @@ -121,10 +131,8 @@ def remove_tool_from_agent(
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Add tools to an exsiting agent"""
"""Add tools to an existing agent"""
actor = server.get_user_or_default(user_id=user_id)

update_agent.id = agent_id
return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)


Expand Down
15 changes: 13 additions & 2 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,13 +977,24 @@ def update_agent(
# TODO: probably reload the agent somehow?
return letta_agent.agent_state

def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
"""Get tools from an existing agent"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")

# Get the agent object (loaded in memory)
letta_agent = self._get_or_load_agent(agent_id=agent_id)
return letta_agent.tools

def add_tool_to_agent(
self,
agent_id: str,
tool_id: str,
user_id: str,
):
"""Update the agents core memory block, return the new state"""
"""Add tools from an existing agent"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id) is None:
Expand Down Expand Up @@ -1022,7 +1033,7 @@ def remove_tool_from_agent(
tool_id: str,
user_id: str,
):
"""Update the agents core memory block, return the new state"""
"""Remove tools from an existing agent"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id) is None:
Expand Down
21 changes: 13 additions & 8 deletions tests/test_new_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,24 @@ def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent):
agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=scrape_website_tool.id)

# confirm that both tools are in the agent state
curr_tools = agent_state.tools
assert len(curr_tools) == curr_num_tools + 2
assert github_tool.name in curr_tools
assert scrape_website_tool.name in curr_tools
# we could access it like agent_state.tools, but will use the client function instead
# this is obviously redundant as it requires retrieving the agent again
# but allows us to test the `get_tools_from_agent` pathway as well
curr_tools = client.get_tools_from_agent(agent_state.id)
curr_tool_names = [t.name for t in curr_tools]
assert len(curr_tool_names) == curr_num_tools + 2
assert github_tool.name in curr_tool_names
assert scrape_website_tool.name in curr_tool_names

# remove only the github tool
agent_state = client.remove_tool_from_agent(agent_id=agent_state.id, tool_id=github_tool.id)

# confirm that only one tool left
curr_tools = agent_state.tools
assert len(curr_tools) == curr_num_tools + 1
assert github_tool.name not in curr_tools
assert scrape_website_tool.name in curr_tools
curr_tools = client.get_tools_from_agent(agent_state.id)
curr_tool_names = [t.name for t in curr_tools]
assert len(curr_tool_names) == curr_num_tools + 1
assert github_tool.name not in curr_tool_names
assert scrape_website_tool.name in curr_tool_names


def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
Expand Down

0 comments on commit fc3d4e1

Please sign in to comment.