Skip to content

Commit

Permalink
breakpoint to rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
norton120 committed Jul 9, 2024
1 parent 37a04a2 commit 7f4711e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
6 changes: 5 additions & 1 deletion memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ class MetadataStore:
"""
db_session: "Session" = None

def __init__(self):
def __init__(self, db_session: Optional["Session"] = None):
"""
Args:
db_session: the database session to use.
"""
self.db_session = get_db_session()

def create_api_key(self,
Expand Down
8 changes: 6 additions & 2 deletions memgpt/server/rest_api/auth/index.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from uuid import UUID
from typing import Annotated, Generator

from fastapi import APIRouter
from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field

from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.orm.utilities import get_db_session
from memgpt.server.server import SyncServer

router = APIRouter()
Expand All @@ -20,7 +22,9 @@ class AuthRequest(BaseModel):
def setup_auth_router(server: SyncServer, interface: QueuingInterface, password: str) -> APIRouter:

@router.post("/auth", tags=["auth"], response_model=AuthResponse)
def authenticate_user(request: AuthRequest) -> AuthResponse:
def authenticate_user(request: AuthRequest,
db_session: Annotated["Generator", Depends(get_db_session)],
) -> AuthResponse:
"""
Authenticates the user and sends response with User related data.
Expand Down
5 changes: 3 additions & 2 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def __init__(
max_chaining_steps: bool = None,
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
# test hooks
config: Optional["MemGPTConfig"] = None
config: Optional["MemGPTConfig"] = None,
db_session: Optional["Session"] = None,
):
"""Server process holds in-memory agents that are being run"""

Expand Down Expand Up @@ -179,7 +180,7 @@ def __init__(
assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config)

# Initialize the metadata store
self.ms = MetadataStore(self.config)
self.ms = MetadataStore(self.config, db_session=db_session)

# pre-fill database (users, presets, humans, personas)
# TODO: figure out how to handle default users (server is technically multi-user)
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,5 @@ def db_session(request) -> "Session":
@pytest.fixture
def test_app(db_session):
"""a per-test-function db scoped version of the rest api app"""

app.dependency_overrides[sessionmaker] = lambda: db_session
return app

0 comments on commit 7f4711e

Please sign in to comment.