Skip to content

Commit

Permalink
feat: Add pagination for list tools (#1907)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Zhou <[email protected]>
  • Loading branch information
mattzh72 and Matt Zhou authored Oct 18, 2024
1 parent 180bbfe commit 61db758
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 21 deletions.
16 changes: 10 additions & 6 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def update_tool(
) -> Tool:
raise NotImplementedError

def list_tools(self) -> List[Tool]:
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
raise NotImplementedError

def get_tool(self, id: str) -> Tool:
Expand Down Expand Up @@ -1382,14 +1382,19 @@ def update_tool(
# raise ValueError(f"Failed to create tool: {response.text}")
# return ToolModel(**response.json())

def list_tools(self) -> List[Tool]:
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
"""
List available tools for the user.
Returns:
tools (List[Tool]): List of tools
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", headers=self.headers)
params = {}
if cursor:
params["cursor"] = str(cursor)
if limit:
params["limit"] = limit
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools", params=params, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list tools: {response.text}")
return [Tool(**tool) for tool in response.json()]
Expand Down Expand Up @@ -2281,15 +2286,14 @@ def update_tool(
ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id
)

def list_tools(self):
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]:
"""
List available tools for the user.
Returns:
tools (List[Tool]): List of tools
"""
tools = self.server.list_tools(user_id=self.user_id)
return tools
return self.server.list_tools(cursor=cursor, limit=limit, user_id=self.user_id)

def get_tool(self, id: str) -> Optional[Tool]:
"""
Expand Down
19 changes: 14 additions & 5 deletions letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
Integer,
String,
TypeDecorator,
asc,
desc,
or_,
)
from sqlalchemy.sql import func

Expand Down Expand Up @@ -707,12 +709,19 @@ def delete_organization(self, org_id: str):
session.commit()

@enforce_types
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]:
with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
# Query for public tools or user-specific tools
query = session.query(ToolModel).filter(or_(ToolModel.user_id == None, ToolModel.user_id == user_id))

# Apply cursor if provided (assuming cursor is an ID)
if cursor:
query = query.filter(ToolModel.id > cursor)

# Order by ID and apply limit
results = query.order_by(asc(ToolModel.id)).limit(limit).all()

# Convert to records
res = [r.to_record() for r in results]
return res

Expand Down
1 change: 1 addition & 0 deletions letta/server/rest_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def create_application() -> "FastAPI":
title="Letta",
summary="Create LLM agents with long-term memory and custom tools 📚🦙",
version="1.0.0", # TODO wire this up to the version in the package
debug=True,
)

if "--ade" in sys.argv:
Expand Down
15 changes: 9 additions & 6 deletions letta/server/rest_api/routers/v1/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,21 @@ def get_tool_id(

@router.get("/", response_model=List[Tool], operation_id="list_tools")
def list_all_tools(
cursor: Optional[str] = None,
limit: Optional[int] = 50,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a list of all tools available to agents created by a user
"""
actor = server.get_user_or_default(user_id=user_id)
actor.id

# TODO: add back when user-specific
return server.list_tools(user_id=actor.id)
# return server.ms.list_tools(user_id=None)
try:
actor = server.get_user_or_default(user_id=user_id)
return server.list_tools(cursor=cursor, limit=limit, user_id=actor.id)
except Exception as e:
# Log or print the full exception here for debugging
print(f"Error occurred: {e}")
raise HTTPException(status_code=500, detail=str(e))


@router.post("/", response_model=Tool, operation_id="create_tool")
Expand Down
4 changes: 2 additions & 2 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,9 +1981,9 @@ def delete_tool(self, tool_id: str):
"""Delete a tool"""
self.ms.delete_tool(tool_id)

def list_tools(self, user_id: str) -> List[Tool]:
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[Tool]:
"""List tools available to user_id"""
tools = self.ms.list_tools(user_id)
tools = self.ms.list_tools(cursor=cursor, limit=limit, user_id=user_id)
return tools

def add_default_tools(self, module_name="base", user_id: Optional[str] = None):
Expand Down
22 changes: 22 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,28 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
# print("CONFIG", config_response)


def test_list_tools_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
tools = client.list_tools()
visited_ids = {t.id: False for t in tools}

cursor = None
# Choose 3 for uneven buckets (only 7 default tools)
num_tools = 3
# Construct a complete pagination test to see if we can return all the tools eventually
for _ in range(0, len(tools), num_tools):
curr_tools = client.list_tools(cursor, num_tools)
assert len(curr_tools) <= num_tools

for curr_tool in curr_tools:
assert curr_tool.id in visited_ids
visited_ids[curr_tool.id] = True

cursor = curr_tools[-1].id

# Assert that everything has been visited
assert all(visited_ids.values())


def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
# clear sources
for source in client.list_sources():
Expand Down
3 changes: 1 addition & 2 deletions tests/test_new_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def agent(client):


def test_agent(client: Union[LocalClient, RESTClient]):
tools = client.list_tools()

# create agent
agent_state_test = client.create_agent(
name="test_agent2",
Expand All @@ -51,6 +49,7 @@ def test_agent(client: Union[LocalClient, RESTClient]):
assert agent_state_test.id in [a.id for a in agents]

# get agent
tools = client.list_tools()
print("TOOLS", [t.name for t in tools])
agent_state = client.get_agent(agent_state_test.id)
assert agent_state.name == "test_agent2"
Expand Down

0 comments on commit 61db758

Please sign in to comment.