Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix the static file mounting handler breaking the API #1743

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions memgpt/server/rest_api/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import importlib.util
import json
import logging
import os
import secrets
from pathlib import Path
from typing import Optional

import typer
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware

from memgpt.server.constants import REST_DEFAULT_PORT
Expand All @@ -31,7 +35,6 @@
from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.server import SyncServer
from memgpt.settings import settings

Expand All @@ -55,6 +58,29 @@
OPENAI_API_PREFIX = "/openai"


class SmartStaticFilesMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# List of API prefixes that should bypass static handling
api_prefixes = [API_PREFIX, OPENAI_API_PREFIX, ADMIN_PREFIX]
path = request.url.path

# Check if the request path starts with any API prefix
if any(path.startswith(prefix) for prefix in api_prefixes):
# If it's an API call, process normally
print(f"API request detected: {path}")
response = await call_next(request)
else:
print(f"Static request detected: {path}")
# Try to serve static files, catch any errors like 404, etc.
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
filepath = os.path.join(static_files_path, path.lstrip("/"))
if os.path.isfile(filepath):
return FileResponse(filepath)
else:
response = await call_next(request)
return response


def create_application() -> "FastAPI":
"""the application start routine"""

Expand All @@ -72,6 +98,8 @@ def create_application() -> "FastAPI":
allow_methods=["*"],
allow_headers=["*"],
)
# Usage in your create_application function:
app.add_middleware(SmartStaticFilesMiddleware)

for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
Expand All @@ -95,7 +123,7 @@ def create_application() -> "FastAPI":
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)

# / static files
mount_static_files(app)
# mount_static_files(app)

@app.on_event("startup")
def on_startup():
Expand Down
18 changes: 16 additions & 2 deletions memgpt/server/rest_api/static_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,25 @@ def mount_static_files(app: FastAPI):
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
if os.path.exists(static_files_path):
app.mount(
# "/",
"/app",
"/",
# "/app",
SPAStaticFiles(
directory=static_files_path,
html=True,
),
name="spa-static-files",
)


# def mount_static_files(app: FastAPI):
# static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
# if os.path.exists(static_files_path):

# @app.get("/{full_path:path}")
# async def serve_spa(full_path: str):
# if full_path.startswith("v1"):
# raise HTTPException(status_code=404, detail="Not found")
# file_path = os.path.join(static_files_path, full_path)
# if os.path.isfile(file_path):
# return FileResponse(file_path)
# return FileResponse(os.path.join(static_files_path, "index.html"))
Loading