From fc3d4e16848f0dfa66c962e57f3410a3e82aa3b5 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 18 Oct 2024 14:25:00 -0700 Subject: [PATCH] feat: Add endpoint to get full Tool objects belonging to an agent (#1906) Co-authored-by: Matt Zhou --- letta/client/client.py | 31 ++++++++++++++++++++++ letta/server/rest_api/routers/v1/agents.py | 20 +++++++++----- letta/server/server.py | 15 +++++++++-- tests/test_new_client.py | 21 +++++++++------ 4 files changed, 71 insertions(+), 16 deletions(-) diff --git a/letta/client/client.py b/letta/client/client.py index 77468f0d36..bea579a3fc 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -96,6 +96,9 @@ def update_agent( ): raise NotImplementedError + def get_tools_from_agent(self, agent_id: str): + raise NotImplementedError + def add_tool_to_agent(self, agent_id: str, tool_id: str): raise NotImplementedError @@ -480,6 +483,21 @@ def update_agent( raise ValueError(f"Failed to update agent: {response.text}") return AgentState(**response.json()) + def get_tools_from_agent(self, agent_id: str) -> List[Tool]: + """ + Get tools to an existing agent + + Args: + agent_id (str): ID of the agent + + Returns: + List[Tool]: A List of Tool objs + """ + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/tools", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to get tools from agents: {response.text}") + return [Tool(**tool) for tool in response.json()] + def add_tool_to_agent(self, agent_id: str, tool_id: str): """ Add tool to an existing agent @@ -1692,6 +1710,19 @@ def update_agent( ) return agent_state + def get_tools_from_agent(self, agent_id: str) -> List[Tool]: + """ + Get tools from an existing agent. + + Args: + agent_id (str): ID of the agent + + Returns: + List[Tool]: A list of Tool objs + """ + self.interface.clear() + return self.server.get_tools_from_agent(agent_id=agent_id, user_id=self.user_id) + def add_tool_to_agent(self, agent_id: str, tool_id: str): """ Add tool to an existing agent diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index f3f8a79d96..b928509f1d 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -27,6 +27,7 @@ from letta.schemas.message import Message, MessageCreate, UpdateMessage from letta.schemas.passage import Passage from letta.schemas.source import Source +from letta.schemas.tool import Tool from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.utils import get_letta_server, sse_async_generator from letta.server.server import SyncServer @@ -100,6 +101,17 @@ def update_agent( return server.update_agent(update_agent, user_id=actor.id) +@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent") +def get_tools_from_agent( + agent_id: str, + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """Get tools from an existing agent""" + actor = server.get_user_or_default(user_id=user_id) + return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id) + + @router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent") def add_tool_to_agent( agent_id: str, @@ -107,10 +119,8 @@ def add_tool_to_agent( server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - """Add tools to an exsiting agent""" + """Add tools to an existing agent""" actor = server.get_user_or_default(user_id=user_id) - - update_agent.id = agent_id return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) @@ -121,10 +131,8 @@ def remove_tool_from_agent( server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - """Add tools to an exsiting agent""" + """Add tools to an existing agent""" actor = server.get_user_or_default(user_id=user_id) - - update_agent.id = agent_id return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) diff --git a/letta/server/server.py b/letta/server/server.py index 8eca1be92b..ff136348d9 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -977,13 +977,24 @@ def update_agent( # TODO: probably reload the agent somehow? return letta_agent.agent_state + def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]: + """Get tools from an existing agent""" + if self.ms.get_user(user_id=user_id) is None: + raise ValueError(f"User user_id={user_id} does not exist") + if self.ms.get_agent(agent_id=agent_id) is None: + raise ValueError(f"Agent agent_id={agent_id} does not exist") + + # Get the agent object (loaded in memory) + letta_agent = self._get_or_load_agent(agent_id=agent_id) + return letta_agent.tools + def add_tool_to_agent( self, agent_id: str, tool_id: str, user_id: str, ): - """Update the agents core memory block, return the new state""" + """Add tools from an existing agent""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id) is None: @@ -1022,7 +1033,7 @@ def remove_tool_from_agent( tool_id: str, user_id: str, ): - """Update the agents core memory block, return the new state""" + """Remove tools from an existing agent""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id) is None: diff --git a/tests/test_new_client.py b/tests/test_new_client.py index 76e414df8f..3ddfc1eac9 100644 --- a/tests/test_new_client.py +++ b/tests/test_new_client.py @@ -152,19 +152,24 @@ def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent): agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=scrape_website_tool.id) # confirm that both tools are in the agent state - curr_tools = agent_state.tools - assert len(curr_tools) == curr_num_tools + 2 - assert github_tool.name in curr_tools - assert scrape_website_tool.name in curr_tools + # we could access it like agent_state.tools, but will use the client function instead + # this is obviously redundant as it requires retrieving the agent again + # but allows us to test the `get_tools_from_agent` pathway as well + curr_tools = client.get_tools_from_agent(agent_state.id) + curr_tool_names = [t.name for t in curr_tools] + assert len(curr_tool_names) == curr_num_tools + 2 + assert github_tool.name in curr_tool_names + assert scrape_website_tool.name in curr_tool_names # remove only the github tool agent_state = client.remove_tool_from_agent(agent_id=agent_state.id, tool_id=github_tool.id) # confirm that only one tool left - curr_tools = agent_state.tools - assert len(curr_tools) == curr_num_tools + 1 - assert github_tool.name not in curr_tools - assert scrape_website_tool.name in curr_tools + curr_tools = client.get_tools_from_agent(agent_state.id) + curr_tool_names = [t.name for t in curr_tools] + assert len(curr_tool_names) == curr_num_tools + 1 + assert github_tool.name not in curr_tool_names + assert scrape_website_tool.name in curr_tool_names def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):