Skip to content

Commit

Permalink
Merge branch 'main' into update-static-files
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker committed Sep 10, 2024
2 parents d0ce25d + cdc55f5 commit d51beaf
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 305 deletions.
64 changes: 5 additions & 59 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
import logging
import os
import subprocess
import sys
from enum import Enum
from pathlib import Path
from typing import Annotated, Optional

import questionary
Expand All @@ -24,7 +22,6 @@
from memgpt.schemas.enums import OptionState
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memory import ChatMemory, Memory
from memgpt.server.constants import WS_DEFAULT_PORT
from memgpt.server.server import logger as server_logger

# from memgpt.interface import CLIInterface as interface # for printing to terminal
Expand Down Expand Up @@ -304,9 +301,6 @@ def server(
type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest",
port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None,
host: Annotated[Optional[str], typer.Option(help="Host to run the server on (default to localhost)")] = None,
use_ssl: Annotated[bool, typer.Option(help="Run the server using HTTPS?")] = False,
ssl_cert: Annotated[Optional[str], typer.Option(help="Path to SSL certificate (if use_ssl is True)")] = None,
ssl_key: Annotated[Optional[str], typer.Option(help="Path to SSL key file (if use_ssl is True)")] = None,
debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False,
):
"""Launch a MemGPT server process"""
Expand All @@ -317,71 +311,23 @@ def server(
if MemGPTConfig.exists():
config = MemGPTConfig.load()
MetadataStore(config)
client = create_client() # triggers user creation
_ = create_client() # triggers user creation
else:
typer.secho(f"No configuration exists. Run memgpt configure before starting the server.", fg=typer.colors.RED)
sys.exit(1)

try:
from memgpt.server.rest_api.server import start_server

start_server(
port=port,
host=host,
use_ssl=use_ssl,
ssl_cert=ssl_cert,
ssl_key=ssl_key,
debug=debug,
)
from memgpt.server.rest_api.app import start_server

start_server(port=port, host=host, debug=debug)

except KeyboardInterrupt:
# Handle CTRL-C
typer.secho("Terminating the server...")
sys.exit(0)

elif type == ServerChoice.ws_api:
if debug:
from memgpt.server.server import logger as server_logger

# Set the logging level
server_logger.setLevel(logging.DEBUG)
# Create a StreamHandler
stream_handler = logging.StreamHandler()
# Set the formatter (optional)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)
# Add the handler to the logger
server_logger.addHandler(stream_handler)

if port is None:
port = WS_DEFAULT_PORT

# Change to the desired directory
script_path = Path(__file__).resolve()
script_dir = script_path.parent

server_directory = os.path.join(script_dir.parent, "server", "ws_api")
command = f"python server.py {port}"

# Run the command
typer.secho(f"Running WS (websockets) server: {command} (inside {server_directory})")

process = None
try:
# Start the subprocess in a new session
process = subprocess.Popen(command, shell=True, start_new_session=True, cwd=server_directory)
process.wait()
except KeyboardInterrupt:
# Handle CTRL-C
if process is not None:
typer.secho("Terminating the server...")
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
typer.secho("Server terminated with kill()")
sys.exit(0)
raise NotImplementedError("WS suppport deprecated")


def run(
Expand Down
169 changes: 169 additions & 0 deletions memgpt/server/rest_api/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import json
import logging
import secrets
from pathlib import Path
from typing import Optional

import typer
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware

from memgpt.server.constants import REST_DEFAULT_PORT

# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from memgpt.server.rest_api.auth.index import (
setup_auth_router, # TODO: probably remove right?
)
from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.rest_api.routers.openai.assistants.assistants import (
router as openai_assistants_router,
)
from memgpt.server.rest_api.routers.openai.assistants.threads import (
router as openai_threads_router,
)
from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions import (
router as openai_chat_completions_router,
)

# from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.server import SyncServer
from memgpt.settings import settings

# TODO(ethan)
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
interface: StreamingServerInterface = StreamingServerInterface
server: SyncServer = SyncServer(default_interface_factory=lambda: interface())

# TODO(ethan): eventuall remove
if password := settings.server_pass:
# if the pass was specified in the environment, use it
print(f"Using existing admin server password from environment.")
else:
# Autogenerate a password for this session and dump it to stdout
password = secrets.token_urlsafe(16)
typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN)


ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
OPENAI_API_PREFIX = "/openai"


def create_application() -> "FastAPI":
"""the application start routine"""

app = FastAPI(
swagger_ui_parameters={"docExpansion": "none"},
# openapi_tags=TAGS_METADATA,
title="MemGPT",
summary="Create LLM agents with long-term memory and custom tools 📚🦙",
version="1.0.0", # TODO wire this up to the version in the package
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
# this gives undocumented routes for "latest" and bare api calls.
# we should always tie this to the newest version of the api.
app.include_router(route, prefix="", include_in_schema=False)
app.include_router(route, prefix="/latest", include_in_schema=False)

# NOTE: ethan these are the extra routes
# TODO(ethan) remove

# admin/users
app.include_router(users_router, prefix=ADMIN_PREFIX)

# openai
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)

# /api/auth endpoints
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)

# / static files
mount_static_files(app)

@app.on_event("startup")
def on_startup():
# load the default tools
# from memgpt.orm.tool import Tool

# Tool.load_default_tools(get_db_session())

# Update the OpenAPI schema
if not app.openapi_schema:
app.openapi_schema = app.openapi()

openai_docs, memgpt_docs = [app.openapi_schema.copy() for _ in range(2)]

openai_docs["paths"] = {k: v for k, v in openai_docs["paths"].items() if k.startswith("/openai")}
openai_docs["info"]["title"] = "OpenAI Assistants API"
memgpt_docs["paths"] = {k: v for k, v in memgpt_docs["paths"].items() if not k.startswith("/openai")}
memgpt_docs["info"]["title"] = "MemGPT API"

# Split the API docs into MemGPT API, and OpenAI Assistants compatible API
for name, docs in [
(
"openai",
openai_docs,
),
(
"memgpt",
memgpt_docs,
),
]:
if settings.cors_origins:
docs["servers"] = [{"url": host} for host in settings.cors_origins]
Path(f"openapi_{name}.json").write_text(json.dumps(docs, indent=2))

@app.on_event("shutdown")
def on_shutdown():
global server
server.save_agents()
server = None

return app


app = create_application()


def start_server(
port: Optional[int] = None,
host: Optional[str] = None,
debug: bool = False,
):
"""Convenience method to start the server from within Python"""
if debug:
from memgpt.server.server import logger as server_logger

# Set the logging level
server_logger.setLevel(logging.DEBUG)
# Create a StreamHandler
stream_handler = logging.StreamHandler()
# Set the formatter (optional)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)
# Add the handler to the logger
server_logger.addHandler(stream_handler)

print(f"Running: uvicorn server:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
uvicorn.run(
app,
host=host or "localhost",
port=port or REST_DEFAULT_PORT,
)
Loading

0 comments on commit d51beaf

Please sign in to comment.