Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch user_id in header #1843

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions letta/server/rest_api/routers/openai/assistants/threads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional

from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query

Expand Down Expand Up @@ -43,7 +43,7 @@
def create_thread(
request: CreateThreadRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
Expand All @@ -68,7 +68,7 @@ def create_thread(
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
Expand Down Expand Up @@ -102,7 +102,7 @@ def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
agent_id = thread_id
Expand Down Expand Up @@ -146,7 +146,7 @@ def list_messages(
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id)
after_uuid = after if before else None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from fastapi import APIRouter, Body, Depends, Header, HTTPException

Expand Down Expand Up @@ -30,7 +30,7 @@
async def create_chat_completion(
completion_request: ChatCompletionRequest = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Send a message to a Letta agent via a /chat/completions completion_request
The bearer token will be used to identify the user.
Expand Down
22 changes: 11 additions & 11 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
def list_agents(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all agents associated with a given user.
Expand All @@ -55,7 +55,7 @@ def list_agents(
def create_agent(
agent: CreateAgent = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new agent with the specified configuration.
Expand All @@ -76,7 +76,7 @@ def update_agent(
agent_id: str,
update_agent: UpdateAgentState = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Update an exsiting agent"""
actor = server.get_user_or_default(user_id=user_id)
Expand All @@ -89,7 +89,7 @@ def update_agent(
def get_agent_state(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get the state of the agent.
Expand All @@ -107,7 +107,7 @@ def get_agent_state(
def delete_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete an agent.
Expand Down Expand Up @@ -159,7 +159,7 @@ def update_agent_memory(
agent_id: str,
request: Dict = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the core memory of a specific agent.
Expand Down Expand Up @@ -202,7 +202,7 @@ def get_agent_archival_memory(
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
Expand All @@ -227,7 +227,7 @@ def insert_agent_archival_memory(
agent_id: str,
request: CreateArchivalMemory = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Insert a memory into an agent's archival memory store.
Expand All @@ -245,7 +245,7 @@ def delete_agent_archival_memory(
memory_id: str,
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a memory from an agent's archival memory store.
Expand Down Expand Up @@ -276,7 +276,7 @@ def get_agent_messages(
DEFAULT_MESSAGE_TOOL_KWARG,
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve message history for an agent.
Expand Down Expand Up @@ -315,7 +315,7 @@ async def send_message(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the agent's response.
Expand Down
4 changes: 2 additions & 2 deletions letta/server/rest_api/routers/v1/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def list_blocks(
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)

Expand All @@ -33,7 +33,7 @@ def list_blocks(
def create_block(
create_block: CreateBlock = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)

Expand Down
4 changes: 2 additions & 2 deletions letta/server/rest_api/routers/v1/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def list_jobs(
server: "SyncServer" = Depends(get_letta_server),
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all jobs.
Expand All @@ -34,7 +34,7 @@ def list_jobs(
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
def list_active_jobs(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all active jobs.
Expand Down
24 changes: 12 additions & 12 deletions letta/server/rest_api/routers/v1/sources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import tempfile
from typing import List
from typing import List, Optional

from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile

Expand All @@ -21,7 +21,7 @@
def get_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get all sources
Expand All @@ -35,7 +35,7 @@ def get_source(
def get_source_id_by_name(
source_name: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a source by name
Expand All @@ -49,7 +49,7 @@ def get_source_id_by_name(
@router.get("/", response_model=List[Source], operation_id="list_sources")
def list_sources(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all data sources created by a user.
Expand All @@ -63,7 +63,7 @@ def list_sources(
def create_source(
source: SourceCreate,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new data source.
Expand All @@ -78,7 +78,7 @@ def update_source(
source_id: str,
source: SourceUpdate,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the name or documentation of an existing data source.
Expand All @@ -94,7 +94,7 @@ def update_source(
def delete_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a data source.
Expand All @@ -109,7 +109,7 @@ def attach_source_to_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Attach a data source to an existing agent.
Expand All @@ -127,7 +127,7 @@ def detach_source_from_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
) -> None:
"""
Detach a data source from an existing agent.
Expand All @@ -143,7 +143,7 @@ def upload_file_to_source(
source_id: str,
background_tasks: BackgroundTasks,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Upload a file to a data source.
Expand Down Expand Up @@ -176,7 +176,7 @@ def upload_file_to_source(
def list_passages(
source_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all passages associated with a data source.
Expand All @@ -190,7 +190,7 @@ def list_passages(
def list_documents(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all documents associated with a data source.
Expand Down
12 changes: 6 additions & 6 deletions letta/server/rest_api/routers/v1/tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

from fastapi import APIRouter, Body, Depends, Header, HTTPException

Expand All @@ -13,7 +13,7 @@
def delete_tool(
tool_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a tool by name
Expand Down Expand Up @@ -43,7 +43,7 @@ def get_tool(
def get_tool_id(
tool_name: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a tool ID by name
Expand All @@ -60,7 +60,7 @@ def get_tool_id(
@router.get("/", response_model=List[Tool], operation_id="list_tools")
def list_all_tools(
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a list of all tools available to agents created by a user
Expand All @@ -78,7 +78,7 @@ def create_tool(
tool: ToolCreate = Body(...),
update: bool = False,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new tool
Expand All @@ -98,7 +98,7 @@ def update_tool(
tool_id: str,
request: ToolUpdate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update an existing tool
Expand Down
5 changes: 4 additions & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,10 @@ def get_user_or_default(self, user_id: Optional[str]) -> User:
if user_id is None:
return self.get_default_user()
else:
return self.get_user(user_id=user_id)
try:
return self.get_user(user_id=user_id)
except ValueError:
raise HTTPException(status_code=404, detail=f"User with id {user_id} not found")

def list_llm_models(self) -> List[LLMConfig]:
"""List available models"""
Expand Down
Loading