diff --git a/memgpt/server/rest_api/app.py b/memgpt/server/rest_api/app.py index 973a362dc6..b4d08ef9e3 100644 --- a/memgpt/server/rest_api/app.py +++ b/memgpt/server/rest_api/app.py @@ -1,16 +1,12 @@ -import importlib.util import json import logging -import os import secrets from pathlib import Path from typing import Optional import typer import uvicorn -from fastapi import FastAPI, Request -from fastapi.responses import FileResponse -from starlette.middleware.base import BaseHTTPMiddleware +from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware from memgpt.server.constants import REST_DEFAULT_PORT @@ -35,6 +31,7 @@ from memgpt.server.rest_api.routers.v1.users import ( router as users_router, # TODO: decide on admin ) +from memgpt.server.rest_api.static_files import mount_static_files from memgpt.server.server import SyncServer from memgpt.settings import settings @@ -58,29 +55,6 @@ OPENAI_API_PREFIX = "/openai" -class SmartStaticFilesMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - # List of API prefixes that should bypass static handling - api_prefixes = [API_PREFIX, OPENAI_API_PREFIX, ADMIN_PREFIX] - path = request.url.path - - # Check if the request path starts with any API prefix - if any(path.startswith(prefix) for prefix in api_prefixes): - # If it's an API call, process normally - print(f"API request detected: {path}") - response = await call_next(request) - else: - print(f"Static request detected: {path}") - # Try to serve static files, catch any errors like 404, etc. - static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files") - filepath = os.path.join(static_files_path, path.lstrip("/")) - if os.path.isfile(filepath): - return FileResponse(filepath) - else: - response = await call_next(request) - return response - - def create_application() -> "FastAPI": """the application start routine""" @@ -98,14 +72,12 @@ def create_application() -> "FastAPI": allow_methods=["*"], allow_headers=["*"], ) - # Usage in your create_application function: - app.add_middleware(SmartStaticFilesMiddleware) for route in v1_routes: app.include_router(route, prefix=API_PREFIX) # this gives undocumented routes for "latest" and bare api calls. # we should always tie this to the newest version of the api. - app.include_router(route, prefix="", include_in_schema=False) + # app.include_router(route, prefix="", include_in_schema=False) app.include_router(route, prefix="/latest", include_in_schema=False) # NOTE: ethan these are the extra routes @@ -123,7 +95,7 @@ def create_application() -> "FastAPI": app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX) # / static files - # mount_static_files(app) + mount_static_files(app) @app.on_event("startup") def on_startup(): diff --git a/memgpt/server/rest_api/static_files.py b/memgpt/server/rest_api/static_files.py index d9b0b39ac1..519bc5a36a 100644 --- a/memgpt/server/rest_api/static_files.py +++ b/memgpt/server/rest_api/static_files.py @@ -2,6 +2,7 @@ import os from fastapi import FastAPI, HTTPException +from fastapi.responses import FileResponse from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.staticfiles import StaticFiles @@ -20,15 +21,43 @@ async def get_response(self, path: str, scope): def mount_static_files(app: FastAPI): static_files_path = os.path.join(os.path.dirname(importlib.util.find_spec("memgpt").origin), "server", "static_files") if os.path.exists(static_files_path): - app.mount( - "/", - # "/app", - SPAStaticFiles( - directory=static_files_path, - html=True, - ), - name="spa-static-files", - ) + app.mount("/assets", StaticFiles(directory=os.path.join(static_files_path, "assets")), name="assets") + + @app.get("/memgpt_logo_transparent.png", include_in_schema=False) + async def serve_spa(): + return FileResponse(os.path.join(static_files_path, "memgpt_logo_transparent.png")) + + @app.get("/", include_in_schema=False) + async def serve_spa(): + return FileResponse(os.path.join(static_files_path, "index.html")) + + @app.get("/agents", include_in_schema=False) + async def serve_spa(): + return FileResponse(os.path.join(static_files_path, "index.html")) + + @app.get("/data-sources", include_in_schema=False) + async def serve_spa(): + return FileResponse(os.path.join(static_files_path, "index.html")) + + @app.get("/tools", include_in_schema=False) + async def serve_spa(): + return FileResponse(os.path.join(static_files_path, "index.html")) + + @app.get("/agent-templates", include_in_schema=False) + async def serve_spa(): + return FileResponse(os.path.join(static_files_path, "index.html")) + + @app.get("/human-templates", include_in_schema=False) + async def serve_spa(): + return FileResponse(os.path.join(static_files_path, "index.html")) + + @app.get("/settings/profile", include_in_schema=False) + async def serve_spa(): + return FileResponse(os.path.join(static_files_path, "index.html")) + + @app.get("/agents/{agent-id}/chat", include_in_schema=False) + async def serve_spa(): + return FileResponse(os.path.join(static_files_path, "index.html")) # def mount_static_files(app: FastAPI):