diff --git a/letta/server/rest_api/routers/openai/assistants/threads.py b/letta/server/rest_api/routers/openai/assistants/threads.py index 1e9be774ae..43d7235faf 100644 --- a/letta/server/rest_api/routers/openai/assistants/threads.py +++ b/letta/server/rest_api/routers/openai/assistants/threads.py @@ -1,5 +1,5 @@ import uuid -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query @@ -43,7 +43,7 @@ def create_thread( request: CreateThreadRequest = Body(...), server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): # TODO: use requests.description and requests.metadata fields # TODO: handle requests.file_ids and requests.tools @@ -68,7 +68,7 @@ def create_thread( def retrieve_thread( thread_id: str = Path(..., description="The unique identifier of the thread."), server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): actor = server.get_user_or_default(user_id=user_id) agent = server.get_agent(user_id=actor.id, agent_id=thread_id) @@ -102,7 +102,7 @@ def create_message( thread_id: str = Path(..., description="The unique identifier of the thread."), request: CreateMessageRequest = Body(...), server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): actor = server.get_user_or_default(user_id=user_id) agent_id = thread_id @@ -146,7 +146,7 @@ def list_messages( after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."), server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): actor = server.get_user_or_default(user_id) after_uuid = after if before else None diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 489abe01c6..3c1afc3938 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException @@ -30,7 +30,7 @@ async def create_chat_completion( completion_request: ChatCompletionRequest = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Send a message to a Letta agent via a /chat/completions completion_request The bearer token will be used to identify the user. diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 74012db75f..00e1cce98a 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -40,7 +40,7 @@ @router.get("/", response_model=List[AgentState], operation_id="list_agents") def list_agents( server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all agents associated with a given user. @@ -55,7 +55,7 @@ def list_agents( def create_agent( agent: CreateAgent = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Create a new agent with the specified configuration. @@ -76,7 +76,7 @@ def update_agent( agent_id: str, update_agent: UpdateAgentState = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Update an exsiting agent""" actor = server.get_user_or_default(user_id=user_id) @@ -89,7 +89,7 @@ def update_agent( def get_agent_state( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get the state of the agent. @@ -107,7 +107,7 @@ def get_agent_state( def delete_agent( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete an agent. @@ -159,7 +159,7 @@ def update_agent_memory( agent_id: str, request: Dict = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Update the core memory of a specific agent. @@ -202,7 +202,7 @@ def get_agent_archival_memory( after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."), before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."), limit: Optional[int] = Query(None, description="How many results to include in the response."), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the memories in an agent's archival memory store (paginated query). @@ -227,7 +227,7 @@ def insert_agent_archival_memory( agent_id: str, request: CreateArchivalMemory = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Insert a memory into an agent's archival memory store. @@ -245,7 +245,7 @@ def delete_agent_archival_memory( memory_id: str, # memory_id: str = Query(..., description="Unique ID of the memory to be deleted."), server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete a memory from an agent's archival memory store. @@ -276,7 +276,7 @@ def get_agent_messages( DEFAULT_MESSAGE_TOOL_KWARG, description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.", ), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve message history for an agent. @@ -315,7 +315,7 @@ async def send_message( agent_id: str, server: SyncServer = Depends(get_letta_server), request: LettaRequest = Body(...), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Process a user message and return the agent's response. diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index df9c86cc85..74dc76dad5 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -19,7 +19,7 @@ def list_blocks( templates_only: bool = Query(True, description="Whether to include only templates"), name: Optional[str] = Query(None, description="Name of the block"), server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): actor = server.get_user_or_default(user_id=user_id) @@ -33,7 +33,7 @@ def list_blocks( def create_block( create_block: CreateBlock = Body(...), server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): actor = server.get_user_or_default(user_id=user_id) diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index 113ea81d66..bd581a98ef 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -13,7 +13,7 @@ def list_jobs( server: "SyncServer" = Depends(get_letta_server), source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all jobs. @@ -34,7 +34,7 @@ def list_jobs( @router.get("/active", response_model=List[Job], operation_id="list_active_jobs") def list_active_jobs( server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all active jobs. diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 2feae2df9b..a59abd31bb 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import List +from typing import List, Optional from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile @@ -21,7 +21,7 @@ def get_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get all sources @@ -35,7 +35,7 @@ def get_source( def get_source_id_by_name( source_name: str, server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get a source by name @@ -49,7 +49,7 @@ def get_source_id_by_name( @router.get("/", response_model=List[Source], operation_id="list_sources") def list_sources( server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all data sources created by a user. @@ -63,7 +63,7 @@ def list_sources( def create_source( source: SourceCreate, server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Create a new data source. @@ -78,7 +78,7 @@ def update_source( source_id: str, source: SourceUpdate, server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Update the name or documentation of an existing data source. @@ -94,7 +94,7 @@ def update_source( def delete_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete a data source. @@ -109,7 +109,7 @@ def attach_source_to_agent( source_id: str, agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."), server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Attach a data source to an existing agent. @@ -127,7 +127,7 @@ def detach_source_from_agent( source_id: str, agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."), server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ) -> None: """ Detach a data source from an existing agent. @@ -143,7 +143,7 @@ def upload_file_to_source( source_id: str, background_tasks: BackgroundTasks, server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Upload a file to a data source. @@ -176,7 +176,7 @@ def upload_file_to_source( def list_passages( source_id: str, server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all passages associated with a data source. @@ -190,7 +190,7 @@ def list_passages( def list_documents( source_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all documents associated with a data source. diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 9b4d58a561..404fabfdd8 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException @@ -13,7 +13,7 @@ def delete_tool( tool_id: str, server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete a tool by name @@ -43,7 +43,7 @@ def get_tool( def get_tool_id( tool_name: str, server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get a tool ID by name @@ -60,7 +60,7 @@ def get_tool_id( @router.get("/", response_model=List[Tool], operation_id="list_tools") def list_all_tools( server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get a list of all tools available to agents created by a user @@ -78,7 +78,7 @@ def create_tool( tool: ToolCreate = Body(...), update: bool = False, server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Create a new tool @@ -98,7 +98,7 @@ def update_tool( tool_id: str, request: ToolUpdate = Body(...), server: SyncServer = Depends(get_letta_server), - user_id: str = Header(None), # Extract user_id from header, default to None if not present + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Update an existing tool diff --git a/letta/server/server.py b/letta/server/server.py index 9c7d721ed5..fae55f30b1 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1921,7 +1921,10 @@ def get_user_or_default(self, user_id: Optional[str]) -> User: if user_id is None: return self.get_default_user() else: - return self.get_user(user_id=user_id) + try: + return self.get_user(user_id=user_id) + except ValueError: + raise HTTPException(status_code=404, detail=f"User with id {user_id} not found") def list_llm_models(self) -> List[LLMConfig]: """List available models"""