Skip to content

Commit

Permalink
fix: static files mounting bug (#1746)
Browse files Browse the repository at this point in the history
Co-authored-by: Shubham Naik <[email protected]>
  • Loading branch information
cpacker and Shubham Naik authored Sep 10, 2024
1 parent 47a6956 commit 0cc7447
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 41 deletions.
36 changes: 4 additions & 32 deletions memgpt/server/rest_api/app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import importlib.util
import json
import logging
import os
import secrets
from pathlib import Path
from typing import Optional

import typer
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware

from memgpt.server.constants import REST_DEFAULT_PORT
Expand All @@ -35,6 +31,7 @@
from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.server import SyncServer
from memgpt.settings import settings

Expand All @@ -58,29 +55,6 @@
OPENAI_API_PREFIX = "/openai"


class SmartStaticFilesMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# List of API prefixes that should bypass static handling
api_prefixes = [API_PREFIX, OPENAI_API_PREFIX, ADMIN_PREFIX]
path = request.url.path

# Check if the request path starts with any API prefix
if any(path.startswith(prefix) for prefix in api_prefixes):
# If it's an API call, process normally
print(f"API request detected: {path}")
response = await call_next(request)
else:
print(f"Static request detected: {path}")
# Try to serve static files, catch any errors like 404, etc.
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
filepath = os.path.join(static_files_path, path.lstrip("/"))
if os.path.isfile(filepath):
return FileResponse(filepath)
else:
response = await call_next(request)
return response


def create_application() -> "FastAPI":
"""the application start routine"""

Expand All @@ -98,14 +72,12 @@ def create_application() -> "FastAPI":
allow_methods=["*"],
allow_headers=["*"],
)
# Usage in your create_application function:
app.add_middleware(SmartStaticFilesMiddleware)

for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
# this gives undocumented routes for "latest" and bare api calls.
# we should always tie this to the newest version of the api.
app.include_router(route, prefix="", include_in_schema=False)
# app.include_router(route, prefix="", include_in_schema=False)
app.include_router(route, prefix="/latest", include_in_schema=False)

# NOTE: ethan these are the extra routes
Expand All @@ -123,7 +95,7 @@ def create_application() -> "FastAPI":
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)

# / static files
# mount_static_files(app)
mount_static_files(app)

@app.on_event("startup")
def on_startup():
Expand Down
47 changes: 38 additions & 9 deletions memgpt/server/rest_api/static_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.staticfiles import StaticFiles

Expand All @@ -20,15 +21,43 @@ async def get_response(self, path: str, scope):
def mount_static_files(app: FastAPI):
static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files")
if os.path.exists(static_files_path):
app.mount(
"/",
# "/app",
SPAStaticFiles(
directory=static_files_path,
html=True,
),
name="spa-static-files",
)
app.mount("/assets", StaticFiles(directory=os.path.join(static_files_path, "assets")), name="assets")

@app.get("/memgpt_logo_transparent.png", include_in_schema=False)
async def serve_spa():
return FileResponse(os.path.join(static_files_path, "memgpt_logo_transparent.png"))

@app.get("/", include_in_schema=False)
async def serve_spa():
return FileResponse(os.path.join(static_files_path, "index.html"))

@app.get("/agents", include_in_schema=False)
async def serve_spa():
return FileResponse(os.path.join(static_files_path, "index.html"))

@app.get("/data-sources", include_in_schema=False)
async def serve_spa():
return FileResponse(os.path.join(static_files_path, "index.html"))

@app.get("/tools", include_in_schema=False)
async def serve_spa():
return FileResponse(os.path.join(static_files_path, "index.html"))

@app.get("/agent-templates", include_in_schema=False)
async def serve_spa():
return FileResponse(os.path.join(static_files_path, "index.html"))

@app.get("/human-templates", include_in_schema=False)
async def serve_spa():
return FileResponse(os.path.join(static_files_path, "index.html"))

@app.get("/settings/profile", include_in_schema=False)
async def serve_spa():
return FileResponse(os.path.join(static_files_path, "index.html"))

@app.get("/agents/{agent-id}/chat", include_in_schema=False)
async def serve_spa():
return FileResponse(os.path.join(static_files_path, "index.html"))


# def mount_static_files(app: FastAPI):
Expand Down

0 comments on commit 0cc7447

Please sign in to comment.