Skip to content

Commit

Permalink
feat: add support for user_id in header (#1755)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Sep 13, 2024
1 parent 93f7409 commit 7e70082
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
20 changes: 17 additions & 3 deletions memgpt/server/rest_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import typer
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.cors import CORSMiddleware

from memgpt.server.constants import REST_DEFAULT_PORT
Expand Down Expand Up @@ -38,8 +39,6 @@
# TODO(ethan)
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
interface: StreamingServerInterface = StreamingServerInterface
# global server
# server: SyncServer = None
server = SyncServer(default_interface_factory=lambda: interface())

# TODO(ethan): eventuall remove
Expand Down Expand Up @@ -77,6 +76,21 @@ def create_application() -> "FastAPI":
allow_headers=["*"],
)

@app.middleware("http")
async def set_current_user_middleware(request: Request, call_next):
user_id = request.headers.get("user_id")
if user_id:
try:
server.set_current_user(user_id)
except ValueError as e:
# Return an HTTP 401 Unauthorized response
# raise HTTPException(status_code=401, detail=str(e))
return JSONResponse(status_code=401, content={"detail": str(e)})
else:
server.set_current_user(None)
response = await call_next(request)
return response

for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
# this gives undocumented routes for "latest" and bare api calls.
Expand Down
23 changes: 23 additions & 0 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,13 +1771,36 @@ def retry_agent_message(self, agent_id: str) -> List[Message]:
memgpt_agent = self._get_or_load_agent(agent_id=agent_id)
return memgpt_agent.retry_message()

def set_current_user(self, user_id: Optional[str]):
"""Very hacky way to set the current user for the server, to be replaced once server becomes stateless
NOTE: clearly not thread-safe, only exists to provide basic user_id support for REST API for now
"""

# Make sure the user_id actually exists
if user_id is not None:
user_obj = self.get_user(user_id)
if not user_obj:
raise ValueError(f"User with id {user_id} not found")

self._current_user = user_id

# TODO(ethan) wire back to real method in future ORM PR
def get_current_user(self) -> User:
"""Returns the currently authed user.
Since server is the core gateway this needs to pass through server as the
first touchpoint.
"""

# Check if _current_user is set and if it's non-null:
if hasattr(self, "_current_user") and self._current_user is not None:
current_user = self.get_user(self._current_user)
if not current_user:
warnings.warn(f"Provided user '{self._current_user}' not found, using default user")
else:
return current_user

# NOTE: same code as local client to get the default user
config = MemGPTConfig.load()
user_id = config.anon_clientid
Expand Down

0 comments on commit 7e70082

Please sign in to comment.