Skip to content

Commit

Permalink
fix: fix the static file mounting handler breaking the API (#1743)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Sep 10, 2024
1 parent 1622689 commit 865a7c9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
34 changes: 31 additions & 3 deletions memgpt/server/rest_api/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import importlib.util
import json
import logging
import os
import secrets
from pathlib import Path
from typing import Optional

import typer
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware

from memgpt.server.constants import REST_DEFAULT_PORT
Expand All @@ -31,7 +35,6 @@
from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.server import SyncServer
from memgpt.settings import settings

Expand All @@ -55,6 +58,29 @@
OPENAI_API_PREFIX = "/openai"


class SmartStaticFilesMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# List of API prefixes that should bypass static handling
api_prefixes = [API_PREFIX, OPENAI_API_PREFIX, ADMIN_PREFIX]
path = request.url.path

# Check if the request path starts with any API prefix
if any(path.startswith(prefix) for prefix in api_prefixes):
# If it's an API call, process normally
print(f"API request detected: {path}")
response = await call_next(request)
else:
print(f"Static request detected: {path}")
# Try to serve static files, catch any errors like 404, etc.
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
filepath = os.path.join(static_files_path, path.lstrip("/"))
if os.path.isfile(filepath):
return FileResponse(filepath)
else:
response = await call_next(request)
return response


def create_application() -> "FastAPI":
"""the application start routine"""

Expand All @@ -72,6 +98,8 @@ def create_application() -> "FastAPI":
allow_methods=["*"],
allow_headers=["*"],
)
# Usage in your create_application function:
app.add_middleware(SmartStaticFilesMiddleware)

for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
Expand All @@ -95,7 +123,7 @@ def create_application() -> "FastAPI":
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)

# / static files
mount_static_files(app)
# mount_static_files(app)

@app.on_event("startup")
def on_startup():
Expand Down
18 changes: 16 additions & 2 deletions memgpt/server/rest_api/static_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,25 @@ def mount_static_files(app: FastAPI):
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
if os.path.exists(static_files_path):
app.mount(
# "/",
"/app",
"/",
# "/app",
SPAStaticFiles(
directory=static_files_path,
html=True,
),
name="spa-static-files",
)


# def mount_static_files(app: FastAPI):
# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
# if os.path.exists(static_files_path):

# @app.get("/{full_path:path}")
# async def serve_spa(full_path: str):
# if full_path.startswith("v1"):
# raise HTTPException(status_code=404, detail="Not found")
# file_path = os.path.join(static_files_path, full_path)
# if os.path.isfile(file_path):
# return FileResponse(file_path)
# return FileResponse(os.path.join(static_files_path, "index.html"))

0 comments on commit 865a7c9

Please sign in to comment.