From 23b6082638edbd4ef5825f9a7e4ce4ac5d22115d Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 14 Jun 2024 18:55:58 -0400 Subject: [PATCH 01/45] logs red to green --- memgpt/settings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/memgpt/settings.py b/memgpt/settings.py index 745bca0c64..142cc3baea 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -1,5 +1,7 @@ from pathlib import Path from typing import Optional +from pydantic import Field +from pathlib import Path from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -42,6 +44,5 @@ def memgpt_pg_uri_no_default(self) -> str: else: return None - # singleton settings = Settings() From ac0572d4fe9d5743765b5f9df67f2464c58dc8ac Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 14 Jun 2024 20:31:32 -0400 Subject: [PATCH 02/45] logs reflect debug status --- memgpt/cli/cli.py | 7 +++---- memgpt/config.py | 1 - memgpt/log.py | 22 ++++++++++------------ memgpt/server/rest_api/auth/index.py | 1 + memgpt/server/rest_api/server.py | 2 +- 5 files changed, 15 insertions(+), 18 deletions(-) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index c7367306e9..dff366e9db 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -6,7 +6,7 @@ import uuid from enum import Enum from pathlib import Path -from typing import Annotated, Optional +from typing import Annotated, Optional, Tuple import questionary import requests @@ -35,7 +35,6 @@ logger = get_logger(__name__) - def migrate( debug: Annotated[bool, typer.Option(help="Print extra tracebacks for failed migrations")] = False, ): @@ -58,7 +57,7 @@ def str_to_quickstart_choice(choice_str: str) -> QuickstartChoice: raise ValueError(f"{choice_str} is not a valid QuickstartChoice. Valid options are: {valid_options}") -def set_config_with_dict(new_config: dict) -> (MemGPTConfig, bool): +def set_config_with_dict(new_config: dict) -> Tuple[MemGPTConfig, bool]: """_summary_ Args: @@ -347,7 +346,7 @@ def server( elif type == ServerChoice.ws_api: if debug: - from memgpt.server.server import logger as server_logger + from memgpt.server.server import get_logger as server_logger # Set the logging level server_logger.setLevel(logging.DEBUG) diff --git a/memgpt/config.py b/memgpt/config.py index cfda1a70ee..0dfb77f86a 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -13,7 +13,6 @@ logger = get_logger(__name__) - # helper functions for writing to configs def get_field(config, section, field): if section not in config: diff --git a/memgpt/log.py b/memgpt/log.py index ed2a8bab48..ec6c5bb452 100644 --- a/memgpt/log.py +++ b/memgpt/log.py @@ -1,3 +1,4 @@ +from typing import Optional import logging from logging.config import dictConfig from pathlib import Path @@ -8,7 +9,6 @@ selected_log_level = logging.DEBUG if settings.debug else logging.INFO - def _setup_logfile() -> "Path": """ensure the logger filepath is in place @@ -19,16 +19,17 @@ def _setup_logfile() -> "Path": logfile.touch(exist_ok=True) return logfile - # TODO: production logging should be much less invasive DEVELOPMENT_LOGGING = { "version": 1, "disable_existing_loggers": True, "formatters": { - "standard": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}, + "standard": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, "no_datetime": { "format": "%(name)s - %(levelname)s - %(message)s", - }, + } }, "handlers": { "console": { @@ -49,10 +50,7 @@ def _setup_logfile() -> "Path": "loggers": { "MemGPT": { "level": logging.DEBUG if settings.debug else logging.INFO, - "handlers": [ - "console", - "file", - ], + "handlers": ["console","file",], "propagate": False, }, "uvicorn": { @@ -63,14 +61,14 @@ def _setup_logfile() -> "Path": }, } - -def get_logger(name: Optional[str] = None) -> "logging.Logger": +def get_logger(name:Optional[str]=None) -> "logging.Logger": """returns the project logger, scoped to a child name if provided Args: name: will define a child logger """ - dictConfig(DEVELOPMENT_LOGGING) - parent_logger = logging.getLogger("MemGPT") + logging.config.dictConfig(DEVELOPMENT_LOGGING) + parent_logger = logging.getLogger("MemGPT") if name: return parent_logger.getChild(name) return parent_logger + diff --git a/memgpt/server/rest_api/auth/index.py b/memgpt/server/rest_api/auth/index.py index 0c3727a695..94005d1bb7 100644 --- a/memgpt/server/rest_api/auth/index.py +++ b/memgpt/server/rest_api/auth/index.py @@ -6,6 +6,7 @@ from memgpt.log import get_logger from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer +from memgpt.log import get_logger logger = get_logger(__name__) router = APIRouter() diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index a34be5f464..81ccc25efd 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -189,7 +189,7 @@ def start_server( ): print("DEBUG", debug) if debug: - from memgpt.server.server import logger as server_logger + from memgpt.server.server import get_logger as server_logger # Set the logging level server_logger.setLevel(logging.DEBUG) From dc334211b292a2ffdbf24e02b5cde7ac366ac706 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 14 Jun 2024 21:30:15 -0400 Subject: [PATCH 03/45] import submodule --- memgpt/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/memgpt/log.py b/memgpt/log.py index ec6c5bb452..0434e3fe04 100644 --- a/memgpt/log.py +++ b/memgpt/log.py @@ -66,7 +66,7 @@ def get_logger(name:Optional[str]=None) -> "logging.Logger": Args: name: will define a child logger """ - logging.config.dictConfig(DEVELOPMENT_LOGGING) + dictConfig(DEVELOPMENT_LOGGING) parent_logger = logging.getLogger("MemGPT") if name: return parent_logger.getChild(name) From 2488380b4272f510153218babc77b32a6380ccce Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 14 Jun 2024 21:49:52 -0400 Subject: [PATCH 04/45] using memgpt logger not global logger --- memgpt/cli/cli.py | 2 +- memgpt/server/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index dff366e9db..ef305892ed 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -346,7 +346,7 @@ def server( elif type == ServerChoice.ws_api: if debug: - from memgpt.server.server import get_logger as server_logger + from memgpt.server.server import logger as server_logger # Set the logging level server_logger.setLevel(logging.DEBUG) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index cc5baf5048..e038a3c4ac 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -10,6 +10,7 @@ from fastapi import HTTPException import memgpt.constants as constants +from memgpt.log import get_logger import memgpt.presets.presets as presets import memgpt.server.utils as server_utils import memgpt.system as system @@ -55,7 +56,6 @@ logger = get_logger(__name__) - class Server(object): """Abstract server class that supports multi-agent multi-user""" From 39fe6f21b75ea6e939dea91e63c07d36136cec87 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 14 Jun 2024 23:25:42 -0400 Subject: [PATCH 05/45] found the bug duplication --- memgpt/server/rest_api/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index 81ccc25efd..a34be5f464 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -189,7 +189,7 @@ def start_server( ): print("DEBUG", debug) if debug: - from memgpt.server.server import get_logger as server_logger + from memgpt.server.server import logger as server_logger # Set the logging level server_logger.setLevel(logging.DEBUG) From c74109b02d56deac669666a80d961642a7d37a5e Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 14 Jun 2024 23:50:54 -0400 Subject: [PATCH 06/45] black --- memgpt/cli/cli.py | 1 + memgpt/config.py | 1 + memgpt/log.py | 19 +++++++++++-------- memgpt/server/server.py | 1 + memgpt/settings.py | 1 + tests/test_client.py | 2 +- 6 files changed, 16 insertions(+), 9 deletions(-) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index ef305892ed..421115a40a 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -35,6 +35,7 @@ logger = get_logger(__name__) + def migrate( debug: Annotated[bool, typer.Option(help="Print extra tracebacks for failed migrations")] = False, ): diff --git a/memgpt/config.py b/memgpt/config.py index 0dfb77f86a..cfda1a70ee 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -13,6 +13,7 @@ logger = get_logger(__name__) + # helper functions for writing to configs def get_field(config, section, field): if section not in config: diff --git a/memgpt/log.py b/memgpt/log.py index 0434e3fe04..499eab9084 100644 --- a/memgpt/log.py +++ b/memgpt/log.py @@ -9,6 +9,7 @@ selected_log_level = logging.DEBUG if settings.debug else logging.INFO + def _setup_logfile() -> "Path": """ensure the logger filepath is in place @@ -19,17 +20,16 @@ def _setup_logfile() -> "Path": logfile.touch(exist_ok=True) return logfile + # TODO: production logging should be much less invasive DEVELOPMENT_LOGGING = { "version": 1, "disable_existing_loggers": True, "formatters": { - "standard": { - "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - }, + "standard": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}, "no_datetime": { "format": "%(name)s - %(levelname)s - %(message)s", - } + }, }, "handlers": { "console": { @@ -50,7 +50,10 @@ def _setup_logfile() -> "Path": "loggers": { "MemGPT": { "level": logging.DEBUG if settings.debug else logging.INFO, - "handlers": ["console","file",], + "handlers": [ + "console", + "file", + ], "propagate": False, }, "uvicorn": { @@ -61,14 +64,14 @@ def _setup_logfile() -> "Path": }, } -def get_logger(name:Optional[str]=None) -> "logging.Logger": + +def get_logger(name: Optional[str] = None) -> "logging.Logger": """returns the project logger, scoped to a child name if provided Args: name: will define a child logger """ dictConfig(DEVELOPMENT_LOGGING) - parent_logger = logging.getLogger("MemGPT") + parent_logger = logging.getLogger("MemGPT") if name: return parent_logger.getChild(name) return parent_logger - diff --git a/memgpt/server/server.py b/memgpt/server/server.py index e038a3c4ac..c6fe81f804 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -56,6 +56,7 @@ logger = get_logger(__name__) + class Server(object): """Abstract server class that supports multi-agent multi-user""" diff --git a/memgpt/settings.py b/memgpt/settings.py index 142cc3baea..ee4ebad031 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -44,5 +44,6 @@ def memgpt_pg_uri_no_default(self) -> str: else: return None + # singleton settings = Settings() diff --git a/tests/test_client.py b/tests/test_client.py index 9e855e6242..a757c682ff 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -78,7 +78,7 @@ def client(request): if server_url is None: # run server in thread # NOTE: must set MEMGPT_SERVER_PASS enviornment variable - server_url = "http://localhost:8283" + server_url = "http://localhost:8083" print("Starting server thread") thread = threading.Thread(target=run_server, daemon=True) thread.start() From ca93678961dc79800efa330cfac9c46258e9e3eb Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 14 Jun 2024 23:54:45 -0400 Subject: [PATCH 07/45] isort --- memgpt/log.py | 1 - memgpt/server/rest_api/auth/index.py | 3 --- memgpt/server/server.py | 1 - memgpt/settings.py | 3 --- 4 files changed, 8 deletions(-) diff --git a/memgpt/log.py b/memgpt/log.py index 499eab9084..ed2a8bab48 100644 --- a/memgpt/log.py +++ b/memgpt/log.py @@ -1,4 +1,3 @@ -from typing import Optional import logging from logging.config import dictConfig from pathlib import Path diff --git a/memgpt/server/rest_api/auth/index.py b/memgpt/server/rest_api/auth/index.py index 94005d1bb7..6be07d888e 100644 --- a/memgpt/server/rest_api/auth/index.py +++ b/memgpt/server/rest_api/auth/index.py @@ -3,12 +3,9 @@ from fastapi import APIRouter from pydantic import BaseModel, Field -from memgpt.log import get_logger from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer -from memgpt.log import get_logger -logger = get_logger(__name__) router = APIRouter() diff --git a/memgpt/server/server.py b/memgpt/server/server.py index c6fe81f804..cc5baf5048 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -10,7 +10,6 @@ from fastapi import HTTPException import memgpt.constants as constants -from memgpt.log import get_logger import memgpt.presets.presets as presets import memgpt.server.utils as server_utils import memgpt.system as system diff --git a/memgpt/settings.py b/memgpt/settings.py index ee4ebad031..b70107bfad 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -1,8 +1,5 @@ from pathlib import Path from typing import Optional -from pydantic import Field -from pathlib import Path - from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict From 20bd4cc9cacfc7e077c2a9550e76f2d7a02620dd Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 14 Jun 2024 21:23:28 -0400 Subject: [PATCH 08/45] placeholder while thinking --- memgpt/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/memgpt/config.py b/memgpt/config.py index cfda1a70ee..2e5282a376 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -50,6 +50,9 @@ class MemGPTConfig: # embedding parameters default_embedding_config: EmbeddingConfig = None +# NONE OF THIS IS CONFIG ↓↓↓↓↓ +# @norton120 these are + # database configs: archival archival_storage_type: str = "chroma" # local, db archival_storage_path: str = os.path.join(MEMGPT_DIR, "chroma") From 57da67a8f3d07832ed86230e116ba668066c293d Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Sun, 16 Jun 2024 14:25:31 -0400 Subject: [PATCH 09/45] removing dead code to make it easier to refactor --- memgpt/config.py | 2 +- memgpt/server/server.py | 103 +--------------------------------------- 2 files changed, 2 insertions(+), 103 deletions(-) diff --git a/memgpt/config.py b/memgpt/config.py index 2e5282a376..ddb2211daa 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -51,7 +51,7 @@ class MemGPTConfig: default_embedding_config: EmbeddingConfig = None # NONE OF THIS IS CONFIG ↓↓↓↓↓ -# @norton120 these are +# @norton120 these are the metdadatastore # database configs: archival archival_storage_type: str = "chroma" # local, db diff --git a/memgpt/server/server.py b/memgpt/server/server.py index cc5baf5048..cc1f7538ab 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -95,7 +95,6 @@ def create_agent( user_id: uuid.UUID, agent_config: Union[dict, AgentState], interface: Union[AgentInterface, None], - # persistence_manager: Union[PersistenceManager, None], ) -> str: """Create a new agent using a config""" raise NotImplementedError @@ -119,49 +118,7 @@ def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> raise NotImplementedError -class LockingServer(Server): - """Basic support for concurrency protections (all requests that modify an agent lock the agent until the operation is complete)""" - - # Locks for each agent - _agent_locks = {} - - @staticmethod - def agent_lock_decorator(func: Callable) -> Callable: - @wraps(func) - def wrapper(self, user_id: uuid.UUID, agent_id: uuid.UUID, *args, **kwargs): - # logger.info("Locking check") - - # Initialize the lock for the agent_id if it doesn't exist - if agent_id not in self._agent_locks: - # logger.info(f"Creating lock for agent_id = {agent_id}") - self._agent_locks[agent_id] = Lock() - - # Check if the agent is currently locked - if not self._agent_locks[agent_id].acquire(blocking=False): - # logger.info(f"agent_id = {agent_id} is busy") - raise HTTPException(status_code=423, detail=f"Agent '{agent_id}' is currently busy.") - - try: - # Execute the function - # logger.info(f"running function on agent_id = {agent_id}") - return func(self, user_id, agent_id, *args, **kwargs) - finally: - # Release the lock - # logger.info(f"releasing lock on agent_id = {agent_id}") - self._agent_locks[agent_id].release() - - return wrapper - - # @agent_lock_decorator - def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: - raise NotImplementedError - - # @agent_lock_decorator - def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: - raise NotImplementedError - - -class SyncServer(LockingServer): +class SyncServer(Server): """Simple single-threaded / blocking server process""" def __init__( @@ -175,26 +132,6 @@ def __init__( ): """Server process holds in-memory agents that are being run""" - # Server supports several auth modes: - # "none": - # no authentication, trust the incoming requests to have access to the user_id being modified - # "jwt_local": - # clients send bearer JWT tokens, which decode to user_ids - # JWT tokens are generated by the server process (using pyJWT) and stored in a database table - # "jwt_external": - # clients still send bearer JWT tokens, but token generation and validation is handled by an external service - # ie the server process will call 'external.decode(token)' to get the user_id - # if auth_mode == "none": - # self.auth_mode = auth_mode - # raise NotImplementedError # TODO - # elif auth_mode == "jwt_local": - # self.auth_mode = auth_mode - # elif auth_mode == "jwt_external": - # self.auth_mode = auth_mode - # raise NotImplementedError # TODO - # else: - # raise ValueError(auth_mode) - # List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts self.active_agents = [] @@ -210,36 +147,15 @@ def __init__( # self.default_interface = default_interface # self.default_interface = default_interface_cls() - # The default persistence manager that will get assigned to agents ON CREATION - # self.default_persistence_manager_cls = default_persistence_manager_cls - # Initialize the connection to the DB self.config = MemGPTConfig.load() logger.info(f"loading configuration from '{self.config.config_path}'") assert self.config.persona is not None, "Persona must be set in the config" assert self.config.human is not None, "Human must be set in the config" - # Update storage URI to match passed in settings - # (NOTE: no longer needed since envs being used, I think) - # for memory_type in ("archival", "recall", "metadata"): - # if settings.memgpt_pg_uri: - # # override with env - # setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri) - # self.config.save() - # TODO figure out how to handle credentials for the server self.credentials = MemGPTCredentials.load() - # Ensure valid database configuration - # TODO: add back once tests are matched - # assert ( - # self.config.metadata_storage_type == "postgres" - # ), f"Invalid metadata_storage_type for server: {self.config.metadata_storage_type}" - # assert ( - # self.config.archival_storage_type == "postgres" - # ), f"Invalid archival_storage_type for server: {self.config.archival_storage_type}" - # assert self.config.recall_storage_type == "postgres", f"Invalid recall_storage_type for server: {self.config.recall_storage_type}" - # Generate default LLM/Embedding configs for the server # TODO: we may also want to do the same thing with default persona/human/etc. self.server_llm_config = LLMConfig( @@ -248,11 +164,6 @@ def __init__( model_endpoint=self.config.default_llm_config.model_endpoint, model_wrapper=self.config.default_llm_config.model_wrapper, context_window=self.config.default_llm_config.context_window, - # openai_key=self.credentials.openai_key, - # azure_key=self.credentials.azure_key, - # azure_endpoint=self.credentials.azure_endpoint, - # azure_version=self.credentials.azure_version, - # azure_deployment=self.credentials.azure_deployment, ) self.server_embedding_config = EmbeddingConfig( embedding_endpoint_type=self.config.default_embedding_config.embedding_endpoint_type, @@ -285,7 +196,6 @@ def save_agents(self): """Saves all the agents that are in the in-memory object store""" for agent_d in self.active_agents: try: - # agent_d["agent"].save() save_agent(agent_d["agent"], self.ms) logger.info(f"Saved agent {agent_d['agent_id']}") except Exception as e: @@ -303,7 +213,6 @@ def _add_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, agent_obj: Agent) # Make sure the agent doesn't already exist if self._get_agent(user_id=user_id, agent_id=agent_id) is not None: # Can be triggered on concucrent request, so don't throw a full error - # raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is already loaded") logger.exception(f"Agent (user={user_id}, agent={agent_id}) is already loaded") return # Add Agent instance to the in-memory list @@ -330,7 +239,6 @@ def _load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, interface: Union[ if not agent_state: logger.exception(f"agent_id {agent_id} does not exist") raise ValueError(f"agent_id {agent_id} does not exist") - # print(f"server._load_agent :: load got agent state {agent_id}, messages = {agent_state.state['messages']}") # Instantiate an agent object using the state retrieved logger.info(f"Creating an agent object") @@ -540,7 +448,6 @@ def _command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Uni input_message = system.get_token_limit_warning() self._step(user_id=user_id, agent_id=agent_id, input_message=input_message) - # @LockingServer.agent_lock_decorator def user_message( self, user_id: uuid.UUID, @@ -599,7 +506,6 @@ def user_message( usage = self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message, timestamp=timestamp) return usage - # @LockingServer.agent_lock_decorator def system_message( self, user_id: uuid.UUID, @@ -760,8 +666,6 @@ def create_agent( # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False, ) - # FIXME: this is a hacky way to get the system prompts injected into agent into the DB - # self.ms.update_agent(agent.agent_state) except Exception as e: logger.exception(e) try: @@ -1438,11 +1342,6 @@ def list_all_sources(self, user_id: uuid.UUID) -> List[SourceModel]: # count number of passages passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) num_passages = passage_conn.size({"data_source": source.name}) - - # TODO: add when documents table implemented - ## count number of documents - # document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) - # num_documents = document_conn.size({"data_source": source.name}) num_documents = 0 agent_ids = self.ms.list_attached_agents(source_id=source.id) From 56368a04b51e9a82419b426771822533869e5a9c Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Sun, 16 Jun 2024 18:07:56 -0400 Subject: [PATCH 10/45] starting in on abstracting the metadatastore adapters --- memgpt/metadata.py | 11 ++++++++++- memgpt/server/server.py | 5 +++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 25137a32ab..09ff8740d6 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -4,7 +4,7 @@ import secrets import traceback import uuid -from typing import List, Optional +from typing import List, Optional, Type from sqlalchemy import ( BIGINT, @@ -44,6 +44,15 @@ from memgpt.settings import settings from memgpt.utils import enforce_types, get_utc_time, printd + +def get_metadata_store() -> Type["MetadataStore"]: + """This uses app settings to select and configure a MetadataStore + Returns: + A metadataStore Adapter (ie Posgtres, SQLiteChroma etc) + """ + # GH 1437 - cut in the lookup and config here + raise NotImplementedError + Base = declarative_base() diff --git a/memgpt/server/server.py b/memgpt/server/server.py index cc1f7538ab..fd584d29c4 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -147,6 +147,11 @@ def __init__( # self.default_interface = default_interface # self.default_interface = default_interface_cls() + # GH-1437 start overload refactoring for configs; + # metastore based on configured adapter here + self.metadatastore = metadatastore or get_metadatastore_adapter() + + # Initialize the connection to the DB self.config = MemGPTConfig.load() logger.info(f"loading configuration from '{self.config.config_path}'") From 92f136e4c14e47ad5ea49fdac513c3e888d23014 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 12 Jun 2024 17:29:45 -0400 Subject: [PATCH 11/45] most of the initial config override in test_server is working. Next up: - isolate the test_server failing tests - move the settings mock into a conftest fixture - add a test hook for SyncServer so you can do the same thing there. - propigate. --- compose.yaml | 1 + configs/server_config.yaml | 11 +++++++---- development.compose.yml | 1 + memgpt/config.py | 31 ++++++++----------------------- memgpt/server/server.py | 8 +++++--- memgpt/settings.py | 26 +++++++++++++++++--------- tests/test_client.py | 2 +- tests/test_load_archival.py | 4 ++-- tests/test_server.py | 14 +++++--------- tests/test_storage.py | 4 ++-- 10 files changed, 49 insertions(+), 53 deletions(-) diff --git a/compose.yaml b/compose.yaml index dca20fff64..329a997e5b 100644 --- a/compose.yaml +++ b/compose.yaml @@ -6,6 +6,7 @@ services: aliases: - pgvector_db - memgpt-db + - memgpt environment: - POSTGRES_USER=${MEMGPT_PG_USER} - POSTGRES_PASSWORD=${MEMGPT_PG_PASSWORD} diff --git a/configs/server_config.yaml b/configs/server_config.yaml index a2b27ae8bc..dd8fdf7258 100644 --- a/configs/server_config.yaml +++ b/configs/server_config.yaml @@ -7,7 +7,6 @@ human = basic model = gpt-4 model_endpoint = https://api.openai.com/v1 model_endpoint_type = openai -model_wrapper = null context_window = 8192 [embedding] @@ -20,17 +19,21 @@ embedding_chunk_size = 300 [archival_storage] type = postgres path = /root/.memgpt/chroma -uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt +uri = postgresql+pg8000://memgpt:memgpt@memgpt-db:5432/memgpt [recall_storage] type = postgres path = /root/.memgpt -uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt +uri = postgresql+pg8000://memgpt:memgpt@memgpt-db:5432/memgpt [metadata_storage] type = postgres path = /root/.memgpt -uri = postgresql+pg8000://memgpt:memgpt@pgvector_db:5432/memgpt +uri = postgresql+pg8000://memgpt:memgpt@memgpt-db:5432/memgpt + +[version] +memgpt_version = 0.3.14 [client] anon_clientid = 00000000-0000-0000-0000-000000000000 + diff --git a/development.compose.yml b/development.compose.yml index 6932537978..44713dfe4e 100644 --- a/development.compose.yml +++ b/development.compose.yml @@ -12,6 +12,7 @@ services: - memgpt_db env_file: - .env + # no value syntax to not set the env at all if it is not set in .env environment: - MEMGPT_SERVER_PASS=test_server_token - WATCHFILES_FORCE_POLLING=true diff --git a/memgpt/config.py b/memgpt/config.py index ddb2211daa..abb0b11905 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -6,6 +6,7 @@ from dataclasses import dataclass import memgpt +from memgpt.settings import settings import memgpt.utils as utils from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET, MEMGPT_DIR from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig @@ -34,7 +35,7 @@ def set_field(config, section, field, value): @dataclass class MemGPTConfig: - config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config") + config_path: str = str(settings.config_path.absolute()) anon_clientid: str = str(uuid.UUID(int=0)) # preset @@ -107,18 +108,11 @@ def load(cls) -> "MemGPTConfig": config = configparser.ConfigParser() - # allow overriding with env variables - if os.getenv("MEMGPT_CONFIG_PATH"): - config_path = os.getenv("MEMGPT_CONFIG_PATH") - else: - config_path = MemGPTConfig.config_path - # insure all configuration directories exist cls.create_config_dir() - printd(f"Loading config from {config_path}") - if os.path.exists(config_path): - # read existing config - config.read(config_path) + printd(f"Loading config from {settings.config_path}") + if settings.config_path.exists(): + config.read(str(settings.config_path.absolute())) # Handle extraction of nested LLMConfig and EmbeddingConfig llm_config_dict = { @@ -173,7 +167,7 @@ def load(cls) -> "MemGPTConfig": "metadata_storage_uri": get_field(config, "metadata_storage", "uri"), # Misc "anon_clientid": get_field(config, "client", "anon_clientid"), - "config_path": config_path, + "config_path": settings.config_path, "memgpt_version": get_field(config, "version", "memgpt_version"), } # Don't include null values @@ -183,8 +177,7 @@ def load(cls) -> "MemGPTConfig": # create new config anon_clientid = MemGPTConfig.generate_uuid() - config = cls(anon_clientid=anon_clientid, config_path=config_path) - + config = cls(anon_clientid=anon_clientid, config_path=settings.config_path) config.create_config_dir() # create dirs return config @@ -193,7 +186,6 @@ def save(self): import memgpt config = configparser.ConfigParser() - # CLI defaults set_field(config, "defaults", "preset", self.preset) set_field(config, "defaults", "persona", self.persona) @@ -281,14 +273,7 @@ def save(self): @staticmethod def exists(): - # allow overriding with env variables - if os.getenv("MEMGPT_CONFIG_PATH"): - config_path = os.getenv("MEMGPT_CONFIG_PATH") - else: - config_path = MemGPTConfig.config_path - - assert not os.path.isdir(config_path), f"Config path {config_path} cannot be set to a directory." - return os.path.exists(config_path) + return settings.config_path.exists() and not settings.config_path.is_dir() @staticmethod def create_config_dir(): diff --git a/memgpt/server/server.py b/memgpt/server/server.py index fd584d29c4..5aec9fd772 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1,3 +1,4 @@ +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union import json import uuid import warnings @@ -5,7 +6,6 @@ from datetime import datetime from functools import wraps from threading import Lock -from typing import Callable, List, Optional, Tuple, Union from fastapi import HTTPException @@ -153,8 +153,10 @@ def __init__( # Initialize the connection to the DB - self.config = MemGPTConfig.load() - logger.info(f"loading configuration from '{self.config.config_path}'") + self.config = config or MemGPTConfig.load() + msg = "server :: loading configuration as passed" if config else \ + f"server :: loading configuration from '{self.config.config_path}'" + print(msg) assert self.config.persona is not None, "Persona must be set in the config" assert self.config.human is not None, "Human must be set in the config" diff --git a/memgpt/settings.py b/memgpt/settings.py index b70107bfad..a72c25b3d5 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -10,21 +10,23 @@ class Settings(BaseSettings): memgpt_dir: Optional[Path] = Field(Path.home() / ".memgpt", env="MEMGPT_DIR") debug: Optional[bool] = False server_pass: Optional[str] = None - pg_db: Optional[str] = None - pg_user: Optional[str] = None - pg_password: Optional[str] = None - pg_host: Optional[str] = None - pg_port: Optional[int] = None - pg_uri: Optional[str] = None # option to specifiy full uri + pg_db: Optional[str] = "memgpt" + pg_user: Optional[str] = "memgpt" + pg_password: Optional[str] = "memgpt" + pg_host: Optional[str] = "localhost" + pg_port: Optional[int] = 5432 cors_origins: Optional[list] = ["http://memgpt.localhost", "http://localhost:8283", "http://localhost:8083"] + _pg_uri: Optional[str] = None # calculated to specify full uri + # configurations + config_path: Optional[Path] = Path("~/.memgpt/config").expanduser() # agent configuration defaults default_preset: Optional[str] = "memgpt_chat" @property - def memgpt_pg_uri(self) -> str: - if self.pg_uri: - return self.pg_uri + def pg_uri(self) -> str: + if self._pg_uri: + return self._pg_uri elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port: return f"postgresql+pg8000://{self.pg_user}:{self.pg_password}@{self.pg_host}:{self.pg_port}/{self.pg_db}" else: @@ -41,6 +43,12 @@ def memgpt_pg_uri_no_default(self) -> str: else: return None + @pg_uri.setter + def pg_uri(self, value: str): + self._pg_uri = value + + + # singleton settings = Settings() diff --git a/tests/test_client.py b/tests/test_client.py index a757c682ff..365d3940c5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -30,7 +30,7 @@ def _reset_config(): # Use os.getenv with a fallback to os.environ.get - db_url = settings.memgpt_pg_uri + db_url = settings.pg_uri if os.getenv("OPENAI_API_KEY"): create_config("openai") diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 08df688b7b..460605dac1 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -63,7 +63,7 @@ def test_load_directory( # setup config if metadata_storage_connector == "postgres": - TEST_MEMGPT_CONFIG.metadata_storage_uri = settings.memgpt_pg_uri + TEST_MEMGPT_CONFIG.metadata_storage_uri = settings.pg_uri TEST_MEMGPT_CONFIG.metadata_storage_type = "postgres" elif metadata_storage_connector == "sqlite": print("testing sqlite metadata") @@ -71,7 +71,7 @@ def test_load_directory( else: raise NotImplementedError(f"Storage type {metadata_storage_connector} not implemented") if passage_storage_connector == "postgres": - TEST_MEMGPT_CONFIG.archival_storage_uri = settings.memgpt_pg_uri + TEST_MEMGPT_CONFIG.archival_storage_uri = settings.pg_uri TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" elif passage_storage_connector == "chroma": print("testing chroma passage storage") diff --git a/tests/test_server.py b/tests/test_server.py index 9b79f4516d..99f99e0fff 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -16,17 +16,14 @@ from .utils import DummyDataConnector, create_config, wipe_config, wipe_memgpt_home - @pytest.fixture(scope="module") -def server(): +def server(tmp_path_factory): load_dotenv() - wipe_config() + #wipe_config() + settings.config_path = tmp_path_factory.mktemp("test") / "config" wipe_memgpt_home() - db_url = settings.memgpt_pg_uri - - # Use os.getenv with a fallback to os.environ.get - db_url = settings.memgpt_pg_uri + db_url = settings.pg_db # start of the conftest hook here if os.getenv("OPENAI_API_KEY"): create_config("openai") @@ -49,8 +46,7 @@ def server(): config.save() credentials.save() - - server = SyncServer() + server = SyncServer(config=config) return server diff --git a/tests/test_storage.py b/tests/test_storage.py index d6feb1eae3..c7c6856336 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -136,8 +136,8 @@ def test_storage( TEST_MEMGPT_CONFIG.default_llm_config = config.default_llm_config if storage_connector == "postgres": - TEST_MEMGPT_CONFIG.archival_storage_uri = settings.memgpt_pg_uri - TEST_MEMGPT_CONFIG.recall_storage_uri = settings.memgpt_pg_uri + TEST_MEMGPT_CONFIG.archival_storage_uri = settings.pg_uri + TEST_MEMGPT_CONFIG.recall_storage_uri = settings.pg_uri TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" TEST_MEMGPT_CONFIG.recall_storage_type = "postgres" if storage_connector == "lancedb": From 72b88c7f5a1705f392e05944870e2df47b9f2753 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 12 Jun 2024 17:59:00 -0400 Subject: [PATCH 12/45] abstracted fixture --- configs/server_config.yaml | 14 +++--- memgpt/settings.py | 2 + tests/conftest.py | 99 ++++++++++++++++++++++++++++++++++++++ tests/test_server.py | 4 +- 4 files changed, 109 insertions(+), 10 deletions(-) create mode 100644 tests/conftest.py diff --git a/configs/server_config.yaml b/configs/server_config.yaml index dd8fdf7258..0aa49dda99 100644 --- a/configs/server_config.yaml +++ b/configs/server_config.yaml @@ -4,16 +4,16 @@ persona = sam_pov human = basic [model] -model = gpt-4 -model_endpoint = https://api.openai.com/v1 -model_endpoint_type = openai +model = ehartford/dolphin-2.5-mixtral-8x7b +model_endpoint = https://api.memgpt.ai +model_endpoint_type = vllm context_window = 8192 [embedding] -embedding_endpoint_type = openai -embedding_endpoint = https://api.openai.com/v1 -embedding_model = text-embedding-ada-002 -embedding_dim = 1536 +embedding_endpoint_type = hugging-face +embedding_endpoint = https://embeddings.memgpt.ai +embedding_model = BAAI/bge-large-en-v1.5 +embedding_dim = 1024 embedding_chunk_size = 300 [archival_storage] diff --git a/memgpt/settings.py b/memgpt/settings.py index a72c25b3d5..a998f15cc8 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -22,6 +22,8 @@ class Settings(BaseSettings): # agent configuration defaults default_preset: Optional[str] = "memgpt_chat" + # TODO: extract to vendor plugin + openai_api_key: Optional[str] = None @property def pg_uri(self) -> str: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..1a0b8cb0ad --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,99 @@ +import pytest + +from memgpt.settings import settings +from tests.utils import wipe_memgpt_home +from memgpt.data_types import EmbeddingConfig, LLMConfig +from memgpt.credentials import MemGPTCredentials +from memgpt.server.server import SyncServer + +from tests.config import TestMGPTConfig + +@pytest.fixture(scope="module") +def server(tmp_path_factory): + settings.config_path = tmp_path_factory.mktemp("test") / "config" + wipe_memgpt_home() + + db_url = settings.pg_db # start of the conftest hook here + + if settings.openai_api_key: + config = TestMGPTConfig( + archival_storage_uri=db_url, + recall_storage_uri=db_url, + metadata_storage_uri=db_url, + archival_storage_type="postgres", + recall_storage_type="postgres", + metadata_storage_type="postgres", + # embeddings + default_embedding_config=EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_model="text-embedding-ada-002", + embedding_dim=1536, + ), + # llms + default_llm_config=LLMConfig( + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model="gpt-4", + ), + ) + credentials = MemGPTCredentials( + openai_key=os.getenv("OPENAI_API_KEY"), + ) + else: # hosted + config = TestMGPTConfig( + archival_storage_uri=db_url, + recall_storage_uri=db_url, + metadata_storage_uri=db_url, + archival_storage_type="postgres", + recall_storage_type="postgres", + metadata_storage_type="postgres", + # embeddings + default_embedding_config=EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_model="BAAI/bge-large-en-v1.5", + embedding_dim=1024, + ), + # llms + default_llm_config=LLMConfig( + model_endpoint_type="vllm", + model_endpoint="https://api.memgpt.ai", + model="ehartford/dolphin-2.5-mixtral-8x7b", + ), + ) + credentials = MemGPTCredentials() + + config.save() + credentials.save() + server = SyncServer(config=config) + return server + + +@pytest.fixture(scope="module") +def user_id(server): + # create user + user = server.create_user() + print(f"Created user\n{user.id}") + + # initialize with default presets + server.initialize_default_presets(user.id) + yield user.id + + # cleanup + server.delete_user(user.id) + + +@pytest.fixture(scope="module") +def agent_id(server, user_id): + # create agent + agent_state = server.create_agent( + user_id=user_id, + name="test_agent", + preset="memgpt_chat", + ) + print(f"Created agent\n{agent_state}") + yield agent_state.id + + # cleanup + server.delete_agent(user_id, agent_state.id) diff --git a/tests/test_server.py b/tests/test_server.py index 99f99e0fff..0c79fceb2d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,8 +1,5 @@ -import os import uuid - import pytest -from dotenv import load_dotenv import memgpt.utils as utils from memgpt.constants import BASE_TOOLS @@ -74,6 +71,7 @@ def agent_id(server, user_id): # cleanup server.delete_agent(user_id, agent_state.id) +from .utils import DummyDataConnector def test_error_on_nonexistent_agent(server, user_id, agent_id): try: From af7bb375d61704791b4be9cc35b3abaf16707ab6 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 12 Jun 2024 20:11:17 -0400 Subject: [PATCH 13/45] moving more to fixtures --- memgpt/client/client.py | 7 +-- tests/conftest.py | 86 ++++++++++++++---------------------- tests/test_base_functions.py | 16 +------ 3 files changed, 40 insertions(+), 69 deletions(-) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 2b2bc316c6..15b4cfb419 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -54,9 +54,9 @@ from memgpt.utils import get_human_text -def create_client(base_url: Optional[str] = None, token: Optional[str] = None): +def create_client(base_url: Optional[str] = None, token: Optional[str] = None, config: Optional[MemGPTConfig] = None): if base_url is None: - return LocalClient() + return LocalClient(config=config) else: return RESTClient(base_url, token) @@ -675,6 +675,7 @@ def __init__( auto_save: bool = False, user_id: Optional[str] = None, debug: bool = False, + config: "MemGPTConfig" = None, ): """ Initializes a new instance of Client class. @@ -686,7 +687,7 @@ def __init__( self.auto_save = auto_save # determine user_id (pulled from local config) - config = MemGPTConfig.load() + config = config or MemGPTConfig.load() if user_id: self.user_id = uuid.UUID(user_id) else: diff --git a/tests/conftest.py b/tests/conftest.py index 1a0b8cb0ad..379e4fdc36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,61 +9,43 @@ from tests.config import TestMGPTConfig @pytest.fixture(scope="module") -def server(tmp_path_factory): - settings.config_path = tmp_path_factory.mktemp("test") / "config" - wipe_memgpt_home() +def config(): + db_url = settings.pg_db + use_oai = settings.openai_api_key + config_args = {} + arg_pairs = ( + (db_url, ("archival_storage_uri", "recall_storage_uri", "metadata_storage_uri")), + ("postgres", ("archival_storage_type", "recall_storage_type", "metadata_storage_type")), + ) + for arg, keys in arg_pairs: + for key in keys: + config_args[key] = arg - db_url = settings.pg_db # start of the conftest hook here + default_embedding_config=EmbeddingConfig( + embedding_endpoint_type="openai" if use_oai else "hugging-face", + embedding_endpoint="https://api.openai.com/v1" if use_oai else "https://embeddings.memgpt.ai", + embedding_model="text-embedding-ada-002" if use_oai else "BAAI/bge-large-en-v1.5", + embedding_dim=1536 if use_oai else 1024, + ) + default_llm_config=LLMConfig( + model_endpoint_type="openai" if use_oai else "vllm", + model_endpoint="https://api.openai.com/v1" if use_oai else "https://api.memgpt.ai", + model="gpt-4" if use_oai else "ehartford/dolphin-2.5-mixtral-8x7b", + ) + return TestMGPTConfig( + default_embedding_config=default_embedding_config, + default_llm_config=default_llm_config, + **config_args,) - if settings.openai_api_key: - config = TestMGPTConfig( - archival_storage_uri=db_url, - recall_storage_uri=db_url, - metadata_storage_uri=db_url, - archival_storage_type="postgres", - recall_storage_type="postgres", - metadata_storage_type="postgres", - # embeddings - default_embedding_config=EmbeddingConfig( - embedding_endpoint_type="openai", - embedding_endpoint="https://api.openai.com/v1", - embedding_model="text-embedding-ada-002", - embedding_dim=1536, - ), - # llms - default_llm_config=LLMConfig( - model_endpoint_type="openai", - model_endpoint="https://api.openai.com/v1", - model="gpt-4", - ), - ) - credentials = MemGPTCredentials( - openai_key=os.getenv("OPENAI_API_KEY"), - ) - else: # hosted - config = TestMGPTConfig( - archival_storage_uri=db_url, - recall_storage_uri=db_url, - metadata_storage_uri=db_url, - archival_storage_type="postgres", - recall_storage_type="postgres", - metadata_storage_type="postgres", - # embeddings - default_embedding_config=EmbeddingConfig( - embedding_endpoint_type="hugging-face", - embedding_endpoint="https://embeddings.memgpt.ai", - embedding_model="BAAI/bge-large-en-v1.5", - embedding_dim=1024, - ), - # llms - default_llm_config=LLMConfig( - model_endpoint_type="vllm", - model_endpoint="https://api.memgpt.ai", - model="ehartford/dolphin-2.5-mixtral-8x7b", - ), - ) - credentials = MemGPTCredentials() +@pytest.fixture(scope="module") +def server(tmp_path_factory, config): + settings.config_path = tmp_path_factory.mktemp("test") / "config" + wipe_memgpt_home() + + if key := settings.openai_api_key: + creds_config = {"openai_key": key} + credentials = MemGPTCredentials(**creds_config) config.save() credentials.save() server = SyncServer(config=config) diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index a1d8143b67..b07aea7ef1 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -1,27 +1,15 @@ -import os - import pytest import memgpt.functions.function_sets.base as base_functions from memgpt import create_client - -from .utils import create_config, wipe_config - -# test_agent_id = "test_agent" client = None @pytest.fixture(scope="module") -def agent_obj(): +def agent_obj(config): """Create a test agent that we can call functions on""" - wipe_config() global client - if os.getenv("OPENAI_API_KEY"): - create_config("openai") - else: - create_config("memgpt_hosted") - - client = create_client() + client = create_client(config=config) agent_state = client.create_agent() From 6e6d3108c385aa234f115f5597f976c6e2caa2f8 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 12 Jun 2024 20:19:16 -0400 Subject: [PATCH 14/45] defaults --- memgpt/server/rest_api/agents/index.py | 4 ++-- memgpt/settings.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py index 77d574770f..a16c99cc11 100644 --- a/memgpt/server/rest_api/agents/index.py +++ b/memgpt/server/rest_api/agents/index.py @@ -70,8 +70,8 @@ def create_agent( human = request.config["human"] if "human" in request.config else None persona_name = request.config["persona_name"] if "persona_name" in request.config else None persona = request.config["persona"] if "persona" in request.config else None - request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset - tool_names = request.config["function_names"] if ("function_names" in request.config and request.config["function_names"]) else None + request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.preset + tool_names = request.config["function_names"] metadata = request.config["metadata"] if "metadata" in request.config else {} metadata["human"] = human_name metadata["persona"] = persona_name diff --git a/memgpt/settings.py b/memgpt/settings.py index a998f15cc8..0f742cf4e7 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -20,8 +20,11 @@ class Settings(BaseSettings): # configurations config_path: Optional[Path] = Path("~/.memgpt/config").expanduser() - # agent configuration defaults - default_preset: Optional[str] = "memgpt_chat" + # application default starter settings + persona: Optional[str] = "sam_pov" + human: Optional[str] = "basic" + preset: Optional[str] = "memgpt_chat" + # TODO: extract to vendor plugin openai_api_key: Optional[str] = None From 5c3d7fede8b80b10d77400caa2b316dbdc8dba8e Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 12 Jun 2024 20:53:22 -0400 Subject: [PATCH 15/45] Removed duplication of constants/configs/envars for default persona, human, and preset. Now all derived from settings (which is in turn derived from envars). Still need to square away with the config file hierarchy, so once we resolve the value there is only one definitive source of truth across the rest of the code. --- examples/memgpt_client.py | 8 +- memgpt/autogen/examples/agent_autoreply.py | 9 +- memgpt/autogen/examples/agent_docs.py | 9 +- memgpt/autogen/examples/agent_groupchat.py | 9 +- memgpt/autogen/memgpt_agent.py | 3 +- memgpt/config.py | 8 +- memgpt/constants.py | 3 - memgpt/data_types.py | 14 +- memgpt/models/pydantic_models.py | 11 +- memgpt/presets/presets.py | 93 ++++++++++++ .../rest_api/openai_assistants/assistants.py | 12 +- memgpt/server/rest_api/presets/index.py | 8 +- tests/test_agent_function_update.py | 124 ++++++++++++++++ tests/test_base_functions.py | 5 +- tests/test_client.py | 5 +- tests/test_metadata_store.py | 139 ++++++++++++++++++ 16 files changed, 406 insertions(+), 54 deletions(-) create mode 100644 tests/test_agent_function_update.py create mode 100644 tests/test_metadata_store.py diff --git a/examples/memgpt_client.py b/examples/memgpt_client.py index 065770605c..20d19ec56f 100644 --- a/examples/memgpt_client.py +++ b/examples/memgpt_client.py @@ -1,7 +1,7 @@ import json from memgpt import Admin, create_client -from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET +from memgpt.settings import settings from memgpt.utils import get_human_text, get_persona_text """ @@ -33,9 +33,9 @@ def main(): # Create an agent agent_info = client.create_agent( name="my_agent", - preset=DEFAULT_PRESET, - persona=get_persona_text(DEFAULT_PERSONA), - human=get_human_text(DEFAULT_HUMAN), + preset=settings.preset, + persona=get_persona_text(settings.persona), + human=get_human_text(settings.human), ) print(f"Created agent: {agent_info.name} with ID {str(agent_info.id)}") diff --git a/memgpt/autogen/examples/agent_autoreply.py b/memgpt/autogen/examples/agent_autoreply.py index 80304b6c70..67c9a50f70 100644 --- a/memgpt/autogen/examples/agent_autoreply.py +++ b/memgpt/autogen/examples/agent_autoreply.py @@ -13,8 +13,9 @@ import autogen +from memgpt.settings import settings from memgpt.autogen.memgpt_agent import create_memgpt_autogen_agent_from_config -from memgpt.constants import DEFAULT_PRESET, LLM_MAX_TOKENS +from memgpt.constants import LLM_MAX_TOKENS LLM_BACKEND = "openai" # LLM_BACKEND = "azure" @@ -40,7 +41,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": DEFAULT_PRESET, + "preset": settings.default_preset, "model_wrapper": None, # OpenAI specific "model_endpoint_type": "openai", @@ -79,7 +80,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": DEFAULT_PRESET, + "preset": settings.default_preset, "model_wrapper": None, # Azure specific "model_endpoint_type": "azure", @@ -108,7 +109,7 @@ # MemGPT-powered agents will also use local LLMs, but they need additional setup (also they use the Completions endpoint) config_list_memgpt = [ { - "preset": DEFAULT_PRESET, + "preset": settings.default_preset, "model": None, # only required for Ollama, see: https://memgpt.readme.io/docs/ollama "context_window": 8192, # the context window of your model (for Mistral 7B-based models, it's likely 8192) "model_wrapper": "chatml", # chatml is the default wrapper diff --git a/memgpt/autogen/examples/agent_docs.py b/memgpt/autogen/examples/agent_docs.py index c3c885a015..b62ffa4a18 100644 --- a/memgpt/autogen/examples/agent_docs.py +++ b/memgpt/autogen/examples/agent_docs.py @@ -15,8 +15,9 @@ import autogen +from memgpt.settings import settings from memgpt.autogen.memgpt_agent import create_memgpt_autogen_agent_from_config -from memgpt.constants import DEFAULT_PRESET, LLM_MAX_TOKENS +from memgpt.constants import LLM_MAX_TOKENS LLM_BACKEND = "openai" # LLM_BACKEND = "azure" @@ -42,7 +43,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": DEFAULT_PRESET, + "preset": settings.default_preset, "model_wrapper": None, # OpenAI specific "model_endpoint_type": "openai", @@ -81,7 +82,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": DEFAULT_PRESET, + "preset": settings.default_preset, "model_wrapper": None, # Azure specific "model_endpoint_type": "azure", @@ -110,7 +111,7 @@ # MemGPT-powered agents will also use local LLMs, but they need additional setup (also they use the Completions endpoint) config_list_memgpt = [ { - "preset": DEFAULT_PRESET, + "preset": settings.default_preset, "model": None, # only required for Ollama, see: https://memgpt.readme.io/docs/ollama "context_window": 8192, # the context window of your model (for Mistral 7B-based models, it's likely 8192) "model_wrapper": "chatml", # chatml is the default wrapper diff --git a/memgpt/autogen/examples/agent_groupchat.py b/memgpt/autogen/examples/agent_groupchat.py index e57ab7d9d8..ed36e37f35 100644 --- a/memgpt/autogen/examples/agent_groupchat.py +++ b/memgpt/autogen/examples/agent_groupchat.py @@ -13,8 +13,9 @@ import autogen +from memgpt.settings import settings from memgpt.autogen.memgpt_agent import create_memgpt_autogen_agent_from_config -from memgpt.constants import DEFAULT_PRESET, LLM_MAX_TOKENS +from memgpt.constants import LLM_MAX_TOKENS LLM_BACKEND = "openai" # LLM_BACKEND = "azure" @@ -40,7 +41,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": DEFAULT_PRESET, + "preset": settings.default_preset, "model_wrapper": None, # OpenAI specific "model_endpoint_type": "openai", @@ -79,7 +80,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": DEFAULT_PRESET, + "preset": settings.default_preset, "model_wrapper": None, # Azure specific "model_endpoint_type": "azure", @@ -108,7 +109,7 @@ # MemGPT-powered agents will also use local LLMs, but they need additional setup (also they use the Completions endpoint) config_list_memgpt = [ { - "preset": DEFAULT_PRESET, + "preset": settings.default_preset, "model": None, # only required for Ollama, see: https://memgpt.readme.io/docs/ollama "context_window": 8192, # the context window of your model (for Mistral 7B-based models, it's likely 8192) "model_wrapper": "chatml", # chatml is the default wrapper diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index 04d95dfbe9..910d3e625c 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -10,6 +10,7 @@ UserProxyAgent, ) +from memgpt.settings import settings import memgpt.constants as constants import memgpt.system as system import memgpt.utils as utils @@ -414,7 +415,7 @@ def create_memgpt_autogen_agent_from_config( interface_kwargs = {} # The "system message" in AutoGen becomes the persona in MemGPT - persona_desc = utils.get_persona_text(constants.DEFAULT_PERSONA) if system_message == "" else system_message + persona_desc = utils.get_persona_text(settings.persona) if system_message == "" else system_message # The user profile is based on the input mode if human_input_mode == "ALWAYS": user_desc = "" diff --git a/memgpt/config.py b/memgpt/config.py index abb0b11905..985e5a818b 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -8,7 +8,7 @@ import memgpt from memgpt.settings import settings import memgpt.utils as utils -from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET, MEMGPT_DIR +from memgpt.constants import MEMGPT_DIR from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig from memgpt.log import get_logger @@ -39,11 +39,11 @@ class MemGPTConfig: anon_clientid: str = str(uuid.UUID(int=0)) # preset - preset: str = DEFAULT_PRESET # TODO: rename to system prompt + preset: str = settings.preset # persona parameters - persona: str = DEFAULT_PERSONA - human: str = DEFAULT_HUMAN + persona: str = settings.persona + human: str = settings.human # model parameters default_llm_config: LLMConfig = None diff --git a/memgpt/constants.py b/memgpt/constants.py index 254dd2d889..4f20317e6b 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -17,9 +17,6 @@ DEFAULT_MEMGPT_MODEL = "gpt-4" -DEFAULT_PERSONA = "sam_pov" -DEFAULT_HUMAN = "basic" -DEFAULT_PRESET = "memgpt_chat" # Tools BASE_TOOLS = [ diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 80af95702a..ff31d54b85 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -8,11 +8,8 @@ import numpy as np from pydantic import BaseModel, Field +from memgpt.settings import settings from memgpt.constants import ( - DEFAULT_HUMAN, - DEFAULT_PERSONA, - DEFAULT_PRESET, - JSON_ENSURE_ASCII, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM, TOOL_CALL_ID_MAX_LEN, @@ -862,13 +859,10 @@ class Preset(BaseModel): user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.") description: Optional[str] = Field(None, description="The description of the preset.") created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.") - system: str = Field( - gpt_system.get_system_text(DEFAULT_PRESET), description="The system prompt of the preset." - ) # default system prompt is same as default preset name - # system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.") - persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.") + system: str = Field(..., description="The system prompt of the preset.") + persona: str = Field(default=get_persona_text(settings.persona), description="The persona of the preset.") persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.") - human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.") + human: str = Field(default=get_human_text(settings.human), description="The human of the preset.") human_name: Optional[str] = Field(None, description="The name of the human of the preset.") functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") # functions: List[str] = Field(..., description="The functions of the preset.") # TODO: convert to ID diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 24ea2af69b..c7d4ce97c0 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -9,7 +9,7 @@ from sqlalchemy_utils import ChoiceType from sqlmodel import Field, SQLModel -from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA +from memgpt.settings import settings from memgpt.utils import get_human_text, get_persona_text, get_utc_time @@ -46,10 +46,9 @@ class PresetModel(BaseModel): description: Optional[str] = Field(None, description="The description of the preset.") created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.") system: str = Field(..., description="The system prompt of the preset.") - system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.") - persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.") + persona: str = Field(default=get_persona_text(settings.persona), description="The persona of the preset.") persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.") - human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.") + human: str = Field(default=get_human_text(settings.human), description="The human of the preset.") human_name: Optional[str] = Field(None, description="The name of the human of the preset.") functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") @@ -117,14 +116,14 @@ class CoreMemory(BaseModel): class HumanModel(SQLModel, table=True): - text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.") + text: str = Field(default=get_human_text(settings.human), description="The human text.") name: str = Field(..., description="The name of the human.") id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True) user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.") class PersonaModel(SQLModel, table=True): - text: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona text.") + text: str = Field(default=get_persona_text(settings.persona), description="The persona text.") name: str = Field(..., description="The name of the persona.") id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True) user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.") diff --git a/memgpt/presets/presets.py b/memgpt/presets/presets.py index b17f527702..7b39d1ae25 100644 --- a/memgpt/presets/presets.py +++ b/memgpt/presets/presets.py @@ -83,6 +83,99 @@ def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore): ms.add_human(human) +# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): +def create_agent_from_preset( + agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True +): + """Initialize a new agent from a preset (combination of system + function)""" + raise DeprecationWarning("Function no longer supported - pass a Preset object to Agent.__init__ instead") + +def create_preset_from_file(filename: str, name: str, user_id: uuid.UUID, ms: MetadataStore) -> Preset: + preset_config = load_yaml_file(filename) + preset_system_prompt = preset_config["system_prompt"] + preset_function_set_names = preset_config["functions"] + functions_schema = generate_functions_json(preset_function_set_names) + + if ms.get_preset(user_id=user_id, name=name) is not None: + printd(f"Preset '{name}' already exists for user '{user_id}'") + return ms.get_preset(user_id=user_id, name=name) + + preset = Preset( + user_id=user_id, + name=name, + system=gpt_system.get_system_text(preset_system_prompt), + persona=get_persona_text(settings.persona), + human=get_human_text(settings.human), + persona_name=settings.persona, + human_name=settings.human, + functions_schema=functions_schema, + ) + ms.create_preset(preset) + return preset + + +def load_preset(preset_name: str, user_id: uuid.UUID): + preset_config = available_presets[preset_name] + preset_system_prompt = preset_config["system_prompt"] + preset_function_set_names = preset_config["functions"] + functions_schema = generate_functions_json(preset_function_set_names) + + preset = Preset( + user_id=user_id, + name=preset_name, + system=gpt_system.get_system_text(preset_system_prompt), + persona=get_persona_text(settings.persona), + persona_name=settings.persona, + human=get_human_text(settings.human), + human_name=settings.human, + functions_schema=functions_schema, + ) + return preset + + +def add_default_presets(user_id: uuid.UUID, ms: MetadataStore): + """Add the default presets to the metadata store""" + # make sure humans/personas added + add_default_humans_and_personas(user_id=user_id, ms=ms) + + # make sure base functions added + # TODO: pull from functions instead + add_default_tools(user_id=user_id, ms=ms) + + # add default presets + for preset_name in preset_options: + if ms.get_preset(user_id=user_id, name=preset_name) is not None: + printd(f"Preset '{preset_name}' already exists for user '{user_id}'") + continue + + preset = load_preset(preset_name, user_id) + ms.create_preset(preset) + + +def generate_functions_json(preset_functions: List[str]): + """ + Generate JSON schema for the functions based on what is locally available. + + TODO: store function definitions in the DB, instead of locally + """ + # Available functions is a mapping from: + # function_name -> { + # json_schema: schema + # python_function: function + # } + available_functions = load_all_function_sets() + # Filter down the function set based on what the preset requested + preset_function_set = {} + for f_name in preset_functions: + if f_name not in available_functions: + raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}") + preset_function_set[f_name] = available_functions[f_name] + assert len(preset_functions) == len(preset_function_set) + preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()] + printd(f"Available functions:\n", list(preset_function_set.keys())) + return preset_function_set_schemas + + # def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): def create_agent_from_preset( agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True diff --git a/memgpt/server/rest_api/openai_assistants/assistants.py b/memgpt/server/rest_api/openai_assistants/assistants.py index 16d11d2643..b4b0dbc54f 100644 --- a/memgpt/server/rest_api/openai_assistants/assistants.py +++ b/memgpt/server/rest_api/openai_assistants/assistants.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from memgpt.config import MemGPTConfig -from memgpt.constants import DEFAULT_PRESET +from memgpt.settings import settings from memgpt.data_types import Message from memgpt.models.openai import ( AssistantFile, @@ -148,7 +148,7 @@ def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterfac def create_assistant(request: CreateAssistantRequest = Body(...)): # TODO: create preset return OpenAIAssistant( - id=DEFAULT_PRESET, + id=settings.preset, name="default_preset", description=request.description, created_at=int(get_utc_time().timestamp()), @@ -303,7 +303,7 @@ def create_message( content=[Text(text=message.text)], role=message.role, thread_id=str(message.agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this + assistant_id=settings.preset, # TODO: update this # file_ids=message.file_ids, # metadata=message.metadata, ) @@ -343,7 +343,7 @@ def list_messages( content=[Text(text=message["text"])], role=message["role"], thread_id=str(message["agent_id"]), - assistant_id=DEFAULT_PRESET, # TODO: update this + assistant_id=settings.preset, # TODO: update this # file_ids=message.file_ids, # metadata=message.metadata, ) @@ -368,7 +368,7 @@ def retrieve_message( content=[Text(text=message.text)], role=message.role, thread_id=str(message.agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this + assistant_id=settings.preset, # TODO: update this # file_ids=message.file_ids, # metadata=message.metadata, ) @@ -407,7 +407,7 @@ def create_run( id=run_id, created_at=create_time, thread_id=str(agent_id), - assistant_id=DEFAULT_PRESET, # TODO: update this + assistant_id=settings.preset, # TODO: update this status="completed", # TODO: eventaully allow offline execution expires_at=create_time, model=agent.agent_state.llm_config.model, diff --git a/memgpt/server/rest_api/presets/index.py b/memgpt/server/rest_api/presets/index.py index 4702371f1e..b7d31e8f16 100644 --- a/memgpt/server/rest_api/presets/index.py +++ b/memgpt/server/rest_api/presets/index.py @@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel, Field -from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET +from memgpt.settings import settings from memgpt.data_types import Preset # TODO remove from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel from memgpt.prompts import gpt_system @@ -37,9 +37,9 @@ class CreatePresetsRequest(BaseModel): # user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.") description: Optional[str] = Field(None, description="The description of the preset.") # created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.") - system: Optional[str] = Field(None, description="The system prompt of the preset.") # TODO: make optional and allow defaults - persona: Optional[str] = Field(default=None, description="The persona of the preset.") - human: Optional[str] = Field(default=None, description="The human of the preset.") + system: str = Field(..., description="The system prompt of the preset.") + persona: str = Field(default=get_persona_text(settings.persona), description="The persona of the preset.") + human: str = Field(default=get_human_text(settings.human), description="The human of the preset.") functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") # TODO persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.") diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py new file mode 100644 index 0000000000..d22bac5197 --- /dev/null +++ b/tests/test_agent_function_update.py @@ -0,0 +1,124 @@ +import inspect +import json +import os +import uuid + +import pytest + +from memgpt.settings import settings +from memgpt import constants, create_client +from memgpt.functions.functions import USER_FUNCTIONS_DIR +from memgpt.models import chat_completion_response +from memgpt.utils import assistant_function_to_tool +from tests import TEST_MEMGPT_CONFIG +from tests.utils import create_config, wipe_config + + +def hello_world(self) -> str: + """Test function for agent to gain access to + + Returns: + str: A message for the world + """ + return "hello, world!" + + +@pytest.fixture(scope="module") +def agent(): + """Create a test agent that we can call functions on""" + wipe_config() + global client + if os.getenv("OPENAI_API_KEY"): + create_config("openai") + else: + create_config("memgpt_hosted") + + # create memgpt client + client = create_client() + + # ensure user exists + user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) + if not client.server.get_user(user_id=user_id): + client.server.create_user({"id": user_id}) + + agent_state = client.create_agent( + preset=settings.preset, + ) + + return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) + + +@pytest.fixture(scope="module") +def hello_world_function(): + with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w", encoding="utf-8") as f: + f.write(inspect.getsource(hello_world)) + + +@pytest.fixture(scope="module") +def ai_function_call(): + return chat_completion_response.Message( + **assistant_function_to_tool( + { + "role": "assistant", + "content": "I will now call hello world", + "function_call": { + "name": "hello_world", + "arguments": json.dumps({}), + }, + } + ) + ) + + return + + +def test_add_function_happy(agent, hello_world_function, ai_function_call): + agent.add_function("hello_world") + + assert "hello_world" in [f_schema["name"] for f_schema in agent.functions] + assert "hello_world" in agent.functions_python.keys() + + msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call) + content = json.loads(msgs[-1].to_openai_dict()["content"], strict=constants.JSON_LOADS_STRICT) + assert content["message"] == "hello, world!" + assert content["status"] == "OK" + assert not function_failed + + +def test_add_function_already_loaded(agent, hello_world_function): + agent.add_function("hello_world") + # no exception for duplicate loading + agent.add_function("hello_world") + + +def test_add_function_not_exist(agent): + # pytest assert exception + with pytest.raises(ValueError): + agent.add_function("non_existent") + + +def test_remove_function_happy(agent, hello_world_function): + agent.add_function("hello_world") + + # ensure function is loaded + assert "hello_world" in [f_schema["name"] for f_schema in agent.functions] + assert "hello_world" in agent.functions_python.keys() + + agent.remove_function("hello_world") + + assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions] + assert "hello_world" not in agent.functions_python.keys() + + +def test_remove_function_not_exist(agent): + # do not raise error + agent.remove_function("non_existent") + + +def test_remove_base_function_fails(agent): + with pytest.raises(ValueError): + agent.remove_function("send_message") + + +if __name__ == "__main__": + pytest.main(["-vv", os.path.abspath(__file__)]) diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index b07aea7ef1..a230b8c353 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -1,5 +1,6 @@ import pytest +from memgpt.settings import settings import memgpt.functions.function_sets.base as base_functions from memgpt import create_client client = None @@ -11,7 +12,9 @@ def agent_obj(config): global client client = create_client(config=config) - agent_state = client.create_agent() + agent_state = client.create_agent( + preset=settings.preset, + ) global agent_obj agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id) diff --git a/tests/test_client.py b/tests/test_client.py index 365d3940c5..f1af0bfdba 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,9 +6,8 @@ import pytest from dotenv import load_dotenv +from memgpt.settings import settings from memgpt import Admin, create_client -from memgpt.config import MemGPTConfig -from memgpt.constants import DEFAULT_PRESET from memgpt.credentials import MemGPTCredentials from memgpt.data_types import Preset # TODO move to PresetModel from memgpt.settings import settings @@ -16,7 +15,7 @@ test_agent_name = f"test_client_{str(uuid.uuid4())}" # test_preset_name = "test_preset" -test_preset_name = DEFAULT_PRESET +test_preset_name = settings.preset test_agent_state = None client = None diff --git a/tests/test_metadata_store.py b/tests/test_metadata_store.py new file mode 100644 index 0000000000..6dd05b78e3 --- /dev/null +++ b/tests/test_metadata_store.py @@ -0,0 +1,139 @@ +import pytest + +from memgpt.agent import Agent, save_agent +from memgpt.settings import settings +from memgpt.data_types import AgentState, LLMConfig, Source, User +from memgpt.metadata import MetadataStore +from memgpt.models.pydantic_models import HumanModel, PersonaModel +from memgpt.presets.presets import add_default_presets +from memgpt.settings import settings +from memgpt.utils import get_human_text, get_persona_text +from tests import TEST_MEMGPT_CONFIG + + +# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"]) +@pytest.mark.parametrize("storage_connector", ["sqlite"]) +def test_storage(storage_connector): + if storage_connector == "postgres": + TEST_MEMGPT_CONFIG.archival_storage_uri = settings.pg_uri + TEST_MEMGPT_CONFIG.recall_storage_uri = settings.pg_uri + TEST_MEMGPT_CONFIG.archival_storage_type = "postgres" + TEST_MEMGPT_CONFIG.recall_storage_type = "postgres" + if storage_connector == "sqlite": + TEST_MEMGPT_CONFIG.recall_storage_type = "local" + + ms = MetadataStore(TEST_MEMGPT_CONFIG) + + # users + user_1 = User() + user_2 = User() + ms.create_user(user_1) + ms.create_user(user_2) + + # test adding default humans/personas/presets + # add_default_humans_and_personas(user_id=user_1.id, ms=ms) + # add_default_humans_and_personas(user_id=user_2.id, ms=ms) + ms.add_human(human=HumanModel(name="test_human", text="This is a test human")) + ms.add_persona(persona=PersonaModel(name="test_persona", text="This is a test persona")) + add_default_presets(user_id=user_1.id, ms=ms) + add_default_presets(user_id=user_2.id, ms=ms) + assert len(ms.list_humans(user_id=user_1.id)) > 0, ms.list_humans(user_id=user_1.id) + assert len(ms.list_personas(user_id=user_1.id)) > 0, ms.list_personas(user_id=user_1.id) + + # generate data + agent_1 = AgentState( + user_id=user_1.id, + name="agent_1", + preset=settings.preset, + persona=settings.persona, + human=settings.human, + llm_config=TEST_MEMGPT_CONFIG.default_llm_config, + embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config, + ) + source_1 = Source(user_id=user_1.id, name="source_1") + + # test creation + ms.create_agent(agent_1) + ms.create_source(source_1) + + # test listing + len(ms.list_agents(user_id=user_1.id)) == 1 + len(ms.list_agents(user_id=user_2.id)) == 0 + len(ms.list_sources(user_id=user_1.id)) == 1 + len(ms.list_sources(user_id=user_2.id)) == 0 + + # test agent_state saving + agent_state = ms.get_agent(agent_1.id).state + assert agent_state == {}, agent_state # when created via create_agent, it should be empty + + from memgpt.presets.presets import add_default_presets + + add_default_presets(user_1.id, ms) + preset_obj = ms.get_preset(name=settings.preset, user_id=user_1.id) + from memgpt.interface import CLIInterface as interface # for printing to terminal + + # Overwrite fields in the preset if they were specified + preset_obj.human = get_human_text(settings.human) + preset_obj.persona = get_persona_text(settings.persona) + + # Create the agent + agent = Agent( + interface=interface(), + created_by=user_1.id, + name="agent_test_agent_state", + preset=preset_obj, + llm_config=config.default_llm_config, + embedding_config=config.default_embedding_config, + # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now + first_message_verify_mono=( + True if (config.default_llm_config.model is not None and "gpt-4" in config.default_llm_config.model) else False + ), + ) + agent_with_agent_state = agent.agent_state + save_agent(agent=agent, ms=ms) + + agent_state = ms.get_agent(agent_with_agent_state.id).state + assert agent_state is not None, agent_state # when created via create_agent_from_preset, it should be non-empty + + # test: updating + + # test: update JSON-stored LLMConfig class + print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config) + llm_config = ms.get_agent(agent_1.id).llm_config + assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}" + assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}" + llm_config.model = "gpt3.5-turbo" + agent_1.llm_config = llm_config + ms.update_agent(agent_1) + assert ms.get_agent(agent_1.id).llm_config.model == "gpt3.5-turbo", f"Updated LLMConfig to {ms.get_agent(agent_1.id).llm_config.model}" + + # test attaching sources + len(ms.list_attached_sources(agent_id=agent_1.id)) == 0 + ms.attach_source(user_1.id, agent_1.id, source_1.id) + len(ms.list_attached_sources(agent_id=agent_1.id)) == 1 + + # test: detaching sources + ms.detach_source(agent_1.id, source_1.id) + len(ms.list_attached_sources(agent_id=agent_1.id)) == 0 + + # test getting + ms.get_user(user_1.id) + ms.get_agent(agent_1.id) + ms.get_source(source_1.id) + + # test api keys + api_key = ms.create_api_key(user_id=user_1.id) + print("api_key=", api_key.token, api_key.user_id) + api_key_result = ms.get_api_key(api_key=api_key.token) + assert api_key.token == api_key_result.token, (api_key, api_key_result) + user_result = ms.get_user_from_api_key(api_key=api_key.token) + assert user_1.id == user_result.id, (user_1, user_result) + all_keys_for_user = ms.get_all_api_keys_for_user(user_id=user_1.id) + assert len(all_keys_for_user) > 0, all_keys_for_user + ms.delete_api_key(api_key=api_key.token) + + # test deletion + ms.delete_user(user_1.id) + ms.delete_user(user_2.id) + ms.delete_agent(agent_1.id) + ms.delete_source(source_1.id) From 094abd060eee2be54998eff1341bd1599a57d25d Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Tue, 18 Jun 2024 08:18:08 -0400 Subject: [PATCH 16/45] conflicting persist --- .../sqlite/test_prefixed_ids_sqlite_chroma_.db | Bin 0 -> 151552 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 .persist/sqlite/test_prefixed_ids_sqlite_chroma_.db diff --git a/.persist/sqlite/test_prefixed_ids_sqlite_chroma_.db b/.persist/sqlite/test_prefixed_ids_sqlite_chroma_.db new file mode 100644 index 0000000000000000000000000000000000000000..7accc88b0d55c2e1311864c75b2841771a53c00d GIT binary patch literal 151552 zcmeI*+i%;}9l&wYaV1~ksBN0EX_}T@XT+Y9&hj-*Tdayrw{CPvWT``d0W3|@4jWw@ zky_(y!_YWu+U{w}K^^-f~QNz1ga3jyx?qE{LM=x-1KV@TvPV>i&d}Puvgv;ez`vc6{#k z@l)YtU3|>V?aTb#EA#!#KY|b4Rv;Y#1Q0*~0R#|0009ILKmY**zD0p2pLRF)#PEpw zDrSB!xKADkAbJUYeS^G&z@$ zH00Jx%zPoZPaX&$fB*srAbT009IL zKmY**5I_I{1Q0*~fg>pJ(y_weL8A!#{r?ecSaul!1Q0*~0R#|0009ILKmdV53h?*; zhXkWI0tg_000IagfB*srAb@orfAbN)wXT6)U5YwVLUrc|XI8WhGZp2rBYtW6(f0i%hcRF zRdeJe_xma<1w~#`mUGp7MIKwMmdi@9qI#*Tm0aQKxVO3Xrg7L+EL)B5MEE7kx%{dU z*(EjHr?suHI^Nf+HO+R^@XKR)&DMt2u++ae!mr74 zsjRHLQS=hW)YvA*<+8G@xYb`&R^`sTVD+K#doK>8rHKjg{+8cxb+g`VZK;mw{x5A^ zp4j4Gq&Z`m@s^G3xMA9Q%i8oR7B+Mw-`U#S4T)%mdD}@`Bhe?`Je~0qwg)lN83$ev z9TX$44W^`P6Ww}&SQx0;Si;b+UVSVrojxsoyy}P8X3MlqCwNUo0xC#5Y#ez{E=Xir7h z9;RzsEEQKP<=jfKBDWjXt+wgYM>;vVT3oqSRown87S`uwuZ_QX4>^%eNk2N>E&79) zRi}dYfr->36wU^P)Ay`nX-SgAkFx%Op>3EA$A9UM92S03{HX9^e7{h&9k=^V2*0}N zXzqzGAu_ukF+0cQ&KWy$MAT|^Rc|)dt&P3kW!BeB!>}3~T}yB<9Ohv7M(6qQ{*?5Z z)Gdnr5L2`M+hccsBBrI$QSsi-{IKjitU5pUMS`?5Cn;9{Pn&wr1>4a0d|yhsGTN;r zIvv$)=h89v0^R+$2L!?U{r}1S6 zZ$0;%%RH@bbW!sDou~aPjqQ0qCcN<>TlFtO?9R`pBk|Fh<3&lZLSmZ)BUDL)Nzu9l zmp_(qUJlpTd)p9}iQFywiCsowy!~|dDpKA{3rAi?Po<=VlijWZ-fga_S@-3c?Zv0k z(#ezJ2X)`bO>&=7M8r+-AiPWdwjGB0*ur?^h|rtfzfu$%B=cG@ydzNkD|r71*xB{H zV{nPeME=-r*mMr1bT;<*$^9M@`gesNyyrK#pXk4wiZr^PX&=rlsCAx57B^8uW%{-w2l@LD4x1;V&+c%Ny?0Mt=Kp)IO{ozA2q1s}0tg_000IagfB*t<0p|bXA}k|-00IagfB*sr zAb#@4YsqMhGB)00IagfB*srAbLrF z+qY_#W2#!)Y5L2m(`?pkb+cvKremww$l~Z-QBF(e&WYPOM_a3zk<7@#ON(VCS5f3j z?s8s{BU$p;(BRJYRm+g8)s-c=SgOdyYCeBncJsW_i8)uxD}`M7x_m{sE{{Fj#CbVf zW!zhTxl~qG-Y9ydItq=;Wo20@E5$`+RSs(BZo@M6uCnv>krlkOp!m@E{0k}R%IN-l z+J@P9s241K@BKT^r=`(RaeKZ4FUSjkg3m7|&;>cZoc}SACsf!=8ifnDu%BC4CZgac z{+nvAolQyCCc3HF&2~#SJHjrOAm7bAmzE|b#OCSnF94+_ZLdK@Aoch}h~>3^wr%yz2yiCu#7;kG*02vS4G3#U_3?o2nw zP0hC5Z&0&4AD-L(#W!>1NM~o(*G$8( z8XNLktEHmXeeI^Gdb6=^JYY!-2~I( zAzv<)@=C54*+so&YHpsYIr5U*rB_x8ioB#O=c@UNJhoUZyM1a!^-@BX4j(_H;X||(=Umnei9=b{Weur8lz{2z}VCJR$ t`v10Vw!*_q&2}z5%kTexXO}c=7y$$jKmY**5I_I{1Q0*~f&UqSe*tP!&cOfx literal 0 HcmV?d00001 From 3e6e3bb86823b3f2e3dd7d305d8e54f00e930d29 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Tue, 18 Jun 2024 08:56:46 -0400 Subject: [PATCH 17/45] Started a working README for refactor, so if I get hit by a bus the next person doesn't need to spend a week getting up to speed. This helps clarify the goal in this PR: one config hierarchy assembled once, with one mega hook. --- memgpt/server/README.md | 14 ++++++++++++++ memgpt/server/server.py | 3 +++ 2 files changed, 17 insertions(+) create mode 100644 memgpt/server/README.md diff --git a/memgpt/server/README.md b/memgpt/server/README.md new file mode 100644 index 0000000000..fe3c2f5c82 --- /dev/null +++ b/memgpt/server/README.md @@ -0,0 +1,14 @@ +# SyncServer + +## Preamble +MemGPT is undergoing significant refactoring and migrating to a classic MVC pattern.The goal is to introduce software best practices to greatly increase reliability, reduce cycle times, and support rapid test driven development. + +**Note:** This README should represent current state as it evolves and serve as a cheat sheet for developers during the process. Please keep this up to date as you evolve the code! Just like MemGPT agents manage their own core memory to preserve state, use this README as the shared developer "core memory" so we don't waste cycles on cognative load. + +## Current State +[SyncServer](./server.py) behaves as a single monolith MVC Controller. On either side of the SyncServer (Controller) we have: +- Models are currently not managed via ORM. DB/memory syncing does **not** use sqlalchemy metadata - it is managed externally by a bespoke in-memory state manager called[PersistanceManger](../memgpt/persistence_manager.py). Models themselves are piped into one large interface class [MetadataStore](../metadata.py) that exposes methods for each CRUD action on all the Models. There is also quite a bit of Controller business logic stored in this MetadataStore object (and vice versa). +- Clients (views) are the CLI, python client and rest API. In some places they duplicate the Controller stand-up and configuration (TODO: should really expand on this once we dig into clients). +- Controller layer also owns the application configuration and startup (via a MemGPTConfig object that is parsed in the controller init). + +- Configuration is via stack of envars, files, and a [constants module](../constants.py). The config is re-assessed in all 3 layers (and a few places outside the MVC stack). Each assessment adds a variety of defaults and modifiers - so we've got shared mutable state spiderwebbed all over the codebase in the form of configs. \ No newline at end of file diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 5aec9fd772..427c71b35e 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1,3 +1,6 @@ +"""See memgpt/server/README.md for context on SyncServer""" + + from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union import json import uuid From 4fe967044aeeea46fcb7472c122a58d76490d3c8 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Tue, 18 Jun 2024 11:38:57 -0400 Subject: [PATCH 18/45] ORM abstraction testing pattern set up --- memgpt/constants.py | 3 - memgpt/orm/__init__.py | 0 memgpt/orm/base.py | 82 +++++++++++++++++++++++++++ memgpt/orm/errors.py | 4 ++ memgpt/orm/sqlalchemy_base.py | 103 ++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + tests/conftest.py | 32 +++++++++++ tests/orm/__init__.py | 0 tests/orm/test_bases.py | 23 ++++++++ 9 files changed, 245 insertions(+), 3 deletions(-) create mode 100644 memgpt/orm/__init__.py create mode 100644 memgpt/orm/base.py create mode 100644 memgpt/orm/errors.py create mode 100644 memgpt/orm/sqlalchemy_base.py create mode 100644 tests/orm/__init__.py create mode 100644 tests/orm/test_bases.py diff --git a/memgpt/constants.py b/memgpt/constants.py index 4f20317e6b..93152eab27 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -28,9 +28,6 @@ "archival_memory_search", ] -# LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level -LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET} - FIRST_MESSAGE_ATTEMPTS = 10 INITIAL_BOOT_MESSAGE = "Boot sequence complete. Persona activated." diff --git a/memgpt/orm/__init__.py b/memgpt/orm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/memgpt/orm/base.py b/memgpt/orm/base.py new file mode 100644 index 0000000000..6600cb9668 --- /dev/null +++ b/memgpt/orm/base.py @@ -0,0 +1,82 @@ +from typing import Optional +from datetime import datetime +from uuid import UUID +from sqlalchemy import Boolean, DateTime, func, text, UUID as SQLUUID +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, + declarative_mixin, + declared_attr, +) + +class AbsoluteBase(DeclarativeBase): + """For the few rare instances where we need a bare table + (like through m2m joins) extending AbsoluteBase ensures + all models inherit from the same DeclarativeBase. + """ + +class Base(DeclarativeBase): + """absolute base for sqlalchemy classes""" + + +@declarative_mixin +class CommonSqlalchemyMetaMixins(Base): + __abstract__ = True + + created_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + updated_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), server_default=func.now(), server_onupdate=func.now() + ) + is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE")) + + @declared_attr + def _created_by_id(cls): + return cls._user_by_id() + + @declared_attr + def _last_updated_by_id(cls): + return cls._user_by_id() + + @classmethod + def _user_by_id(cls): + """a flexible non-constrained record of a user. + This way users can get added, deleted etc without history freaking out + """ + return mapped_column(SQLUUID(), nullable=True) + + @property + def last_updated_by_id(self) -> Optional[str]: + return self._user_id_getter("last_updated") + + @last_updated_by_id.setter + def last_updated_by_id(self, value: str) -> None: + self._user_id_setter("last_updated", value) + + @property + def created_by_id(self) -> Optional[str]: + return self._user_id_getter("created") + + @created_by_id.setter + def created_by_id(self, value: str) -> None: + self._user_id_setter("created", value) + + def _user_id_getter(self, prop: str) -> Optional[str]: + """returns the user id for the specified property""" + full_prop = f"_{prop}_by_id" + prop_value = getattr(self, full_prop, None) + if not prop_value: + return + return f"user-{prop_value}" + + def _user_id_setter(self, prop: str, value: str) -> None: + """returns the user id for the specified property""" + full_prop = f"_{prop}_by_id" + if not value: + setattr(self, full_prop, None) + return + prefix, id_ = value.split("-",1) + assert prefix == "user", f"{prefix} is not a valid id prefix for a user id" + setattr(self, full_prop, UUID(id_)) diff --git a/memgpt/orm/errors.py b/memgpt/orm/errors.py new file mode 100644 index 0000000000..4d29a2a248 --- /dev/null +++ b/memgpt/orm/errors.py @@ -0,0 +1,4 @@ + + +class NoResultFound(Exception): + """A record or records cannot be found given the provided search params""" \ No newline at end of file diff --git a/memgpt/orm/sqlalchemy_base.py b/memgpt/orm/sqlalchemy_base.py new file mode 100644 index 0000000000..48a08d492c --- /dev/null +++ b/memgpt/orm/sqlalchemy_base.py @@ -0,0 +1,103 @@ +from uuid import uuid4, UUID, Type, Union, List, Literal +from typing import Optional, TYPE_CHECKING +from humps import depascalize +from sqlalchemy import select, UUID as SQLUUID +from sqlalchemy.orm import ( + Mapped, + mapped_column +) +from memgpt.log import get_logger +from memgpt.orm.base import CommonSqlalchemyMetaMixins +from memgpt.orm.errors import NoResultFound + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + from sqlalchemy import Select + from memgpt.orm.user import User + +logger = get_logger(__name__) + + +class SqlalchemyBase(CommonSqlalchemyMetaMixins): + __abstract__ = True + + __order_by_default__ = "created_at" + + _id: Mapped[UUID] = mapped_column(SQLUUID(), primary_key=True, default=uuid4) + + @property + def __prefix__(self) -> str: + return depascalize(self.__class__.__name__) + + @property + def id(self) -> Optional[str]: + if self._id: + return f"{self.__prefix__}-{self._id}" + + @id.setter + def id(self, value: str) -> None: + if not value: + return + prefix, id_ = value.split("-", 1) + assert ( + prefix == self.__prefix__ + ), f"{prefix} is not a valid id prefix for {self.__class__.__name__}" + self._id = UUID(id_) + @classmethod + def list(cls, db_session: "Session") -> list[Type["Base"]]: + with db_session as session: + return session.query(cls).all() + + @classmethod + def read( + cls, db_session: "Session", identifier: Union[str, UUID], **kwargs + ) -> Type["SqlalchemyBase"]: + del kwargs + identifier = cls.to_uid(identifier) + if found := db_session.get(cls, identifier): + return found + raise NoResultFound(f"{cls.__name__} with id {identifier} not found") + + def create(self, db_session: "Session") -> Type["SqlalchemyBase"]: + with db_session as session: + session.add(self) + session.commit() + session.refresh(self) + return self + + def delete(self, db_session: "Session") -> Type["SqlalchemyBase"]: + self.deleted = True + return self.update(db_session) + + def update(self, db_session: "Session") -> Type["SqlalchemyBase"]: + with db_session as session: + session.add(self) + session.commit() + session.refresh(self) + return self + + @classmethod + def apply_access_predicate( + cls, + query: "Select", + actor: "User", + access: List[Literal["read", "write", "admin"]], + ) -> "Select": + """applies a WHERE clause restricting results to the given actor and access level + Args: + query: The initial sqlalchemy select statement + actor: The user acting on the query. **Note**: this is called 'actor' to identify the + person or system acting. Users can act on users, making naming very sticky otherwise. + access: + what mode of access should the query restrict to? This will be used with granular permissions, + but because of how it will impact every query we want to be explicitly calling access ahead of time. + Returns: + the sqlalchemy select statement restricted to the given access. + """ + del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment + org_uid = getattr( + actor, "_organization_id", getattr(actor.organization, "_id", None) + ) + if not org_uid: + raise ValueError("object %s has no organization accessor", actor) + return query.where(cls._organization_id == org_uid) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index bbeb1f6a22..a3b797747b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ ollama = ["llama-index-embeddings-ollama"] [tool.poetry.group.dev.dependencies] black = "^24.4.2" +faker = "^25.8.0" [tool.black] line-length = 140 diff --git a/tests/conftest.py b/tests/conftest.py index 379e4fdc36..673eff559c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,12 @@ +from typing import TYPE_CHECKING import pytest +from sqlalchemy import text +from sqlalchemy.orm import sessionmaker + from memgpt.settings import settings +from memgpt.orm.utilities import _create_engine_for_adapter +from memgpt.orm.base import Base from tests.utils import wipe_memgpt_home from memgpt.data_types import EmbeddingConfig, LLMConfig from memgpt.credentials import MemGPTCredentials @@ -8,6 +14,9 @@ from tests.config import TestMGPTConfig +if TYPE_CHECKING: + from sqlalchemy import Session + @pytest.fixture(scope="module") def config(): db_url = settings.pg_db @@ -79,3 +88,26 @@ def agent_id(server, user_id): # cleanup server.delete_agent(user_id, agent_state.id) + + +# new ORM +@pytest.fixture(params=["sqlite_chroma","postgres",]) +def db_session(request) -> "Session": + """Creates a function-scoped orm session for the given test and adapter. + Note: both pg and sqlite/chroma will have results scoped to each test function - so 2x results + for each. These are cleared at the _beginning_ of each test run - so states are persisted for inspection + after the end of the test! + + """ + function_ = request.node.name + engine = _create_engine_for_adapter(adapter=request.param, database="memgpt_test") + with engine.begin() as connection: + for statement in ( + text(f"CREATE SCHEMA IF NOT EXISTS {function_}"), + text(f"SET search_path TO {function_},public"), + ): + connection.execute(statement) + Base.metadata.drop_all(bind=connection) + Base.metadata.create_all(bind=connection) + with sessionmaker(bind=engine)() as session: + yield session diff --git a/tests/orm/__init__.py b/tests/orm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/orm/test_bases.py b/tests/orm/test_bases.py new file mode 100644 index 0000000000..176296fb0a --- /dev/null +++ b/tests/orm/test_bases.py @@ -0,0 +1,23 @@ +from pytest import mark as m +from faker import Faker + +from memgpt.orm.user import User +from memgpt.orm.organization import Organization + +faker = Faker() + +@m.unit +class TestORMBases: + """eyeball unit tests of accessors, id logic etc""" + + def test_prefixed_ids(db_session): + + user = User( + email=faker.email, + organization=Organization.default() + ).create(db_session) + + assert user.id.startswith('user-') + assert str(user._id) in user.id + assert user.organization.id.startswith('organization-') + assert str(user.organization._id) in user.organization.id \ No newline at end of file From 73c636c8eb8fefed48c134f58e986d97aaad6f17 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Tue, 18 Jun 2024 19:31:17 -0400 Subject: [PATCH 19/45] almost have the conftest pattern set up --- docs/autogen.md | 8 +-- memgpt/autogen/examples/agent_autoreply.py | 6 +-- memgpt/autogen/examples/agent_docs.py | 6 +-- memgpt/autogen/examples/agent_groupchat.py | 6 +-- memgpt/client/client.py | 5 +- memgpt/metadata.py | 8 --- memgpt/orm/README.md | 16 ++++++ memgpt/orm/mixins.py | 61 ++++++++++++++++++++++ memgpt/orm/organization.py | 22 ++++++++ memgpt/orm/sqlalchemy_base.py | 4 +- memgpt/orm/user.py | 20 +++++++ memgpt/orm/utilities.py | 40 ++++++++++++++ memgpt/server/rest_api/presets/index.py | 6 +-- memgpt/server/server.py | 5 -- tests/conftest.py | 16 +++--- tests/orm/test_bases.py | 4 +- tests/test_concurrent_connections.py | 4 +- tests/test_openai_assistant_api.py | 2 +- tests/test_websocket_interface.py | 2 +- 19 files changed, 196 insertions(+), 45 deletions(-) create mode 100644 memgpt/orm/README.md create mode 100644 memgpt/orm/mixins.py create mode 100644 memgpt/orm/organization.py create mode 100644 memgpt/orm/user.py create mode 100644 memgpt/orm/utilities.py diff --git a/docs/autogen.md b/docs/autogen.md index 1a04a3d73e..5e6fede0ef 100644 --- a/docs/autogen.md +++ b/docs/autogen.md @@ -156,7 +156,7 @@ config_list = [ # MemGPT-powered agents will also use local LLMs, but they need additional setup (also they use the Completions endpoint) config_list_memgpt = [ { - "preset": DEFAULT_PRESET, + "preset": settings.preset, "model": None, # not required for web UI, only required for Ollama, see: https://memgpt.readme.io/docs/ollama "model_wrapper": "airoboros-l2-70b-2.1", # airoboros is the default wrapper and should work for most models "model_endpoint_type": "webui", @@ -183,7 +183,7 @@ config_list = [ # MemGPT-powered agents will also use local LLMs, but they need additional setup (also they use the Completions endpoint) config_list_memgpt = [ { - "preset": DEFAULT_PRESET, + "preset": settings.preset, "model": None, "model_wrapper": "airoboros-l2-70b-2.1", "model_endpoint_type": "lmstudio", @@ -209,7 +209,7 @@ config_list = [ # This config is for autogen agents that powered by MemGPT config_list_memgpt = [ { - "preset": DEFAULT_PRESET, + "preset": settings.preset, "model": "gpt-4", "context_window": 8192, # gpt-4 context window "model_wrapper": None, @@ -240,7 +240,7 @@ config_list = [ # This config is for autogen agents that powered by MemGPT config_list_memgpt = [ { - "preset": DEFAULT_PRESET, + "preset": settings.preset, "model": "gpt-4", # make sure you choose a model that you have access to deploy on your Azure account "model_wrapper": None, "context_window": 8192, # gpt-4 context window diff --git a/memgpt/autogen/examples/agent_autoreply.py b/memgpt/autogen/examples/agent_autoreply.py index 67c9a50f70..3b94631e60 100644 --- a/memgpt/autogen/examples/agent_autoreply.py +++ b/memgpt/autogen/examples/agent_autoreply.py @@ -41,7 +41,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": settings.default_preset, + "preset": settings.preset, "model_wrapper": None, # OpenAI specific "model_endpoint_type": "openai", @@ -80,7 +80,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": settings.default_preset, + "preset": settings.preset, "model_wrapper": None, # Azure specific "model_endpoint_type": "azure", @@ -109,7 +109,7 @@ # MemGPT-powered agents will also use local LLMs, but they need additional setup (also they use the Completions endpoint) config_list_memgpt = [ { - "preset": settings.default_preset, + "preset": settings.preset, "model": None, # only required for Ollama, see: https://memgpt.readme.io/docs/ollama "context_window": 8192, # the context window of your model (for Mistral 7B-based models, it's likely 8192) "model_wrapper": "chatml", # chatml is the default wrapper diff --git a/memgpt/autogen/examples/agent_docs.py b/memgpt/autogen/examples/agent_docs.py index b62ffa4a18..37c9e8389e 100644 --- a/memgpt/autogen/examples/agent_docs.py +++ b/memgpt/autogen/examples/agent_docs.py @@ -43,7 +43,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": settings.default_preset, + "preset": settings.preset, "model_wrapper": None, # OpenAI specific "model_endpoint_type": "openai", @@ -82,7 +82,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": settings.default_preset, + "preset": settings.preset, "model_wrapper": None, # Azure specific "model_endpoint_type": "azure", @@ -111,7 +111,7 @@ # MemGPT-powered agents will also use local LLMs, but they need additional setup (also they use the Completions endpoint) config_list_memgpt = [ { - "preset": settings.default_preset, + "preset": settings.preset, "model": None, # only required for Ollama, see: https://memgpt.readme.io/docs/ollama "context_window": 8192, # the context window of your model (for Mistral 7B-based models, it's likely 8192) "model_wrapper": "chatml", # chatml is the default wrapper diff --git a/memgpt/autogen/examples/agent_groupchat.py b/memgpt/autogen/examples/agent_groupchat.py index ed36e37f35..061373b608 100644 --- a/memgpt/autogen/examples/agent_groupchat.py +++ b/memgpt/autogen/examples/agent_groupchat.py @@ -41,7 +41,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": settings.default_preset, + "preset": settings.preset, "model_wrapper": None, # OpenAI specific "model_endpoint_type": "openai", @@ -80,7 +80,7 @@ { "model": model, "context_window": LLM_MAX_TOKENS[model], - "preset": settings.default_preset, + "preset": settings.preset, "model_wrapper": None, # Azure specific "model_endpoint_type": "azure", @@ -109,7 +109,7 @@ # MemGPT-powered agents will also use local LLMs, but they need additional setup (also they use the Completions endpoint) config_list_memgpt = [ { - "preset": settings.default_preset, + "preset": settings.preset, "model": None, # only required for Ollama, see: https://memgpt.readme.io/docs/ollama "context_window": 8192, # the context window of your model (for Mistral 7B-based models, it's likely 8192) "model_wrapper": "chatml", # chatml is the default wrapper diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 15b4cfb419..09b9a35925 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -6,7 +6,8 @@ import requests from memgpt.config import MemGPTConfig -from memgpt.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET +from memgpt.constants import BASE_TOOLS +from memgpt.settings import settings from memgpt.data_sources.connectors import DataConnector from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, Preset, Source from memgpt.functions.functions import parse_source_code @@ -399,7 +400,7 @@ def create_preset( schema.append(tool.json_schema) # include default tools - default_preset = self.get_preset(name=DEFAULT_PRESET) + default_preset = self.get_preset(name=settings.preset) if default_tools: # TODO # from memgpt.functions.functions import load_function_set diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 09ff8740d6..cfdc7ad762 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -45,14 +45,6 @@ from memgpt.utils import enforce_types, get_utc_time, printd -def get_metadata_store() -> Type["MetadataStore"]: - """This uses app settings to select and configure a MetadataStore - Returns: - A metadataStore Adapter (ie Posgtres, SQLiteChroma etc) - """ - # GH 1437 - cut in the lookup and config here - raise NotImplementedError - Base = declarative_base() diff --git a/memgpt/orm/README.md b/memgpt/orm/README.md new file mode 100644 index 0000000000..a8ba975364 --- /dev/null +++ b/memgpt/orm/README.md @@ -0,0 +1,16 @@ + +### ORM Basic Design Patterns + +Standard 3nf ORM pattern. +- first-class entities get their own module. These Modules with a single entity are named singularly ("user", "agent" etc) because they represent one model. +- Polymorphic models share a module, and these module names are pluralized ("Memories") because there are multiple. +- mixins, helpers, errors etc are collections so the module names are pluralized. +- Imports are always lazy whenever possible to guard against circular deps. + + +## Mixin magic +relationship mixins expect standard naming (ie `_organization_id` on the child side of 1:M to an organization). + +The relationship is declared explicitly in the model, as the lazy joining rules, back populates etc will be bespoke per class. + +If you need to reference the same entity more than once, you'll need to skip the mixin and do it by hand. \ No newline at end of file diff --git a/memgpt/orm/mixins.py b/memgpt/orm/mixins.py new file mode 100644 index 0000000000..eefc8a0619 --- /dev/null +++ b/memgpt/orm/mixins.py @@ -0,0 +1,61 @@ +from typing import Optional, Type +from uuid import UUID +from sqlalchemy import UUID as SQLUUID, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column + +from memgpt.orm.base import Base + + +class MalformedIdError(Exception): + pass + + +def _relation_getter(instance: "Base", prop: str) -> Optional[str]: + prefix = prop.replace("_", "") + formatted_prop = f"_{prop}_id" + try: + uuid_ = getattr(instance, formatted_prop) + return f"{prefix}-{uuid_}" + except AttributeError: + return None + + +def _relation_setter(instance: Type["Base"], prop: str, value: str) -> None: + formatted_prop = f"_{prop}_id" + prefix = prop.replace("_", "") + if not value: + setattr(instance, formatted_prop, None) + return + try: + found_prefix, id_ = value.split("-", 1) + except ValueError as e: + raise MalformedIdError(f"{value} is not a valid ID.") from e + assert ( + # TODO: should be able to get this from the Mapped typing, not sure how though + # prefix = getattr(?, "prefix") + found_prefix + == prefix + ), f"{found_prefix} is not a valid id prefix, expecting {prefix}" + try: + setattr(instance, formatted_prop, UUID(id_)) + except ValueError as e: + raise MalformedIdError("Hash segment of {value} is not a valid UUID") from e + + +class OrganizationMixin(Base): + """Mixin for models that belong to an organization.""" + + __abstract__ = True + + _organization_id: Mapped[UUID] = mapped_column( + SQLUUID(), ForeignKey("organization._id") + ) + + @property + def organization_id(self) -> str: + return _relation_getter(self, "organization") + + @organization_id.setter + def organization_id(self, value: str) -> None: + _relation_setter(self, "organization", value) + diff --git a/memgpt/orm/organization.py b/memgpt/orm/organization.py new file mode 100644 index 0000000000..c1e96d7c13 --- /dev/null +++ b/memgpt/orm/organization.py @@ -0,0 +1,22 @@ +from typing import Optional, TYPE_CHECKING +from pydantic import EmailStr +from sqlalchemy.orm import Mapped, relationship, mapped_column + +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +if TYPE_CHECKING: + from memgpt.orm.user import User + from sqlalchemy.orm.session import Session + +class Organization(SqlalchemyBase): + """The highest level of the object tree. All Entities belong to one and only one Organization.""" + __tablename__ = "organization" + name:Mapped[Optional[str]] = mapped_column(nullable=True, doc="The display name of the organization.") + + # relationships + users: Mapped["User"] = relationship("User", back_populates="organization") + + @classmethod + def default(cls, db_session:"Session") -> "Organization": + """Get the default org, or create it if it doesn't exist.""" + raise ValueError("Kaboom! this is mid-development") + org = db_session.query(cls).one().scalar() diff --git a/memgpt/orm/sqlalchemy_base.py b/memgpt/orm/sqlalchemy_base.py index 48a08d492c..b4629b7c3d 100644 --- a/memgpt/orm/sqlalchemy_base.py +++ b/memgpt/orm/sqlalchemy_base.py @@ -1,5 +1,5 @@ -from uuid import uuid4, UUID, Type, Union, List, Literal -from typing import Optional, TYPE_CHECKING +from uuid import uuid4, UUID +from typing import Optional, TYPE_CHECKING,Type, Union, List, Literal from humps import depascalize from sqlalchemy import select, UUID as SQLUUID from sqlalchemy.orm import ( diff --git a/memgpt/orm/user.py b/memgpt/orm/user.py new file mode 100644 index 0000000000..6d576e46db --- /dev/null +++ b/memgpt/orm/user.py @@ -0,0 +1,20 @@ +from typing import Optional, TYPE_CHECKING +from pydantic import EmailStr +from sqlalchemy import String +from sqlalchemy.orm import Mapped, relationship, mapped_column + +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.mixins import OrganizationMixin + +class User(SqlalchemyBase, OrganizationMixin): + """User ORM class""" + __tablename__ = "user" + + name:Mapped[Optional[str]] = mapped_column(nullable=True, doc="The display name of the user.") + email:Mapped[Optional[EmailStr]] = mapped_column(String, + nullable=True, + doc="The email address of the user. Uninforced at this time.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="users") + diff --git a/memgpt/orm/utilities.py b/memgpt/orm/utilities.py new file mode 100644 index 0000000000..70a2dd6bb0 --- /dev/null +++ b/memgpt/orm/utilities.py @@ -0,0 +1,40 @@ +from typing import Optional, TYPE_CHECKING, Generator +from urllib.parse import urlsplit, urlunsplit +from sqlalchemy import create_engine as sqlalchemy_create_engine +from sqlalchemy.orm import sessionmaker + + +from memgpt.settings import settings + +if TYPE_CHECKING: + from sqlalchemy.engine import Engine + + +def create_engine( + storage_type: Optional[str] = None, + database: Optional[str] = None, +) -> "Engine": + """creates an engine for the storage_type designated by settings + Args: + storage_type: test hook to inject storage_type, you should not be setting this + database: test hook to inject database, you should not be setting this + Returns: a sqlalchemy engine + """ + storage_type = storage_type or settings.storage_type + match storage_type: + case "postgres": + url_parts = list(urlsplit(settings.pg_uri)) + PATH_PARAM = 2 # avoid the magic number! + url_parts[PATH_PARAM] = f"/{database}" if database else url_parts.path + return sqlalchemy_create_engine(urlunsplit(url_parts)) + case "sqlite-chroma": + return sqlalchemy_create_engine(f"sqlite:///{database}") + case _: + raise ValueError(f"Unsupported storage_type: {storage_type}") + + +def get_db_session() -> "Generator": + """dependency primarily for FastAPI""" + bound_session = sessionmaker(bind=create_engine()) + with bound_session() as session: + yield session \ No newline at end of file diff --git a/memgpt/server/rest_api/presets/index.py b/memgpt/server/rest_api/presets/index.py index b7d31e8f16..3ef44081ea 100644 --- a/memgpt/server/rest_api/presets/index.py +++ b/memgpt/server/rest_api/presets/index.py @@ -106,7 +106,7 @@ async def create_preset( system = request.system # TODO: insert into system table else: - system_name = request.system_name if request.system_name else DEFAULT_PRESET + system_name = request.system_name if request.system_name else settings.preset system = request.system if request.system else gpt_system.get_system_text(system_name) if not request.human_name and request.human: @@ -115,7 +115,7 @@ async def create_preset( human = request.human server.ms.add_human(HumanModel(text=human, name=human_name, user_id=user_id)) else: - human_name = request.human_name if request.human_name else DEFAULT_HUMAN + human_name = request.human_name if request.human_name else settings.human human = request.human if request.human else get_human_text(human_name) if not request.persona_name and request.persona: @@ -124,7 +124,7 @@ async def create_preset( persona = request.persona server.ms.add_persona(PersonaModel(text=persona, name=persona_name, user_id=user_id)) else: - persona_name = request.persona_name if request.persona_name else DEFAULT_PERSONA + persona_name = request.persona_name if request.persona_name else settings.persona persona = request.persona if request.persona else get_persona_text(persona_name) # create preset diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 427c71b35e..3fcfeedb14 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -150,11 +150,6 @@ def __init__( # self.default_interface = default_interface # self.default_interface = default_interface_cls() - # GH-1437 start overload refactoring for configs; - # metastore based on configured adapter here - self.metadatastore = metadatastore or get_metadatastore_adapter() - - # Initialize the connection to the DB self.config = config or MemGPTConfig.load() msg = "server :: loading configuration as passed" if config else \ diff --git a/tests/conftest.py b/tests/conftest.py index 673eff559c..a3ccf7e86d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from memgpt.settings import settings -from memgpt.orm.utilities import _create_engine_for_adapter +from memgpt.orm.utilities import create_engine from memgpt.orm.base import Base from tests.utils import wipe_memgpt_home from memgpt.data_types import EmbeddingConfig, LLMConfig @@ -96,16 +96,20 @@ def db_session(request) -> "Session": """Creates a function-scoped orm session for the given test and adapter. Note: both pg and sqlite/chroma will have results scoped to each test function - so 2x results for each. These are cleared at the _beginning_ of each test run - so states are persisted for inspection - after the end of the test! + after the end of the test. """ function_ = request.node.name - engine = _create_engine_for_adapter(adapter=request.param, database="memgpt_test") - with engine.begin() as connection: - for statement in ( + engine = create_engine(storage_type=request.param, database="memgpt_test") + adapter_statements = { + "sqlite_chroma": (text(f"attach ':memory:' as {function_}"),), + "postgres": ( text(f"CREATE SCHEMA IF NOT EXISTS {function_}"), text(f"SET search_path TO {function_},public"), - ): + ), + } + with engine.begin() as connection: + for statement in adapter_statements[request.param]: connection.execute(statement) Base.metadata.drop_all(bind=connection) Base.metadata.create_all(bind=connection) diff --git a/tests/orm/test_bases.py b/tests/orm/test_bases.py index 176296fb0a..58f20c6d3e 100644 --- a/tests/orm/test_bases.py +++ b/tests/orm/test_bases.py @@ -10,11 +10,11 @@ class TestORMBases: """eyeball unit tests of accessors, id logic etc""" - def test_prefixed_ids(db_session): + def test_prefixed_ids(self, db_session): user = User( email=faker.email, - organization=Organization.default() + organization=Organization.default(db_session=db_session), ).create(db_session) assert user.id.startswith('user-') diff --git a/tests/test_concurrent_connections.py b/tests/test_concurrent_connections.py index 060acfc095..6fcb5b4eb4 100644 --- a/tests/test_concurrent_connections.py +++ b/tests/test_concurrent_connections.py @@ -8,7 +8,7 @@ from memgpt import Admin, create_client from memgpt.config import MemGPTConfig -from memgpt.constants import DEFAULT_PRESET +from memgpt.settings import settings from memgpt.credentials import MemGPTCredentials from memgpt.data_types import Preset # TODO move to PresetModel from memgpt.settings import settings @@ -16,7 +16,7 @@ test_agent_name = f"test_client_{str(uuid.uuid4())}" # test_preset_name = "test_preset" -test_preset_name = DEFAULT_PRESET +test_preset_name = settings.preset test_agent_state = None client = None diff --git a/tests/test_openai_assistant_api.py b/tests/test_openai_assistant_api.py index 54d8c8485f..586edf5fbe 100644 --- a/tests/test_openai_assistant_api.py +++ b/tests/test_openai_assistant_api.py @@ -17,7 +17,7 @@ # # # test: create agent # request_body = { -# "assistant_name": DEFAULT_PRESET, +# "assistant_name": settings.preset, # } # print(request_body) # response = client.post("/v1/threads", json=request_body) diff --git a/tests/test_websocket_interface.py b/tests/test_websocket_interface.py index 8b7ab7f0b0..f42de7f0f4 100644 --- a/tests/test_websocket_interface.py +++ b/tests/test_websocket_interface.py @@ -18,7 +18,7 @@ # # Create an agent and hook it up to the WebSocket interface # memgpt_agent = presets.create_agent_from_preset( -# presets.DEFAULT_PRESET, +# presets.settings.preset, # None, # no agent config to provide # "gpt-4-1106-preview", # personas.get_persona_text("sam_pov"), From 62da64662c1b80ff7fa00ef0a4dc7b022e8b2c74 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 19 Jun 2024 10:42:46 -0400 Subject: [PATCH 20/45] This is the basic 2 backend pattern. TODO: - mount the test sqlite/chroma somewhere that doesn't clutter up the repo --- init.sql | 4 ++++ memgpt/orm/organization.py | 7 +++++-- memgpt/orm/utilities.py | 2 +- tests/conftest.py | 5 +++-- tests/db_setup.sql | 2 ++ tests/orm/test_bases.py | 7 ++++--- 6 files changed, 19 insertions(+), 8 deletions(-) create mode 100644 tests/db_setup.sql diff --git a/init.sql b/init.sql index c1244ff087..7b87971a0c 100644 --- a/init.sql +++ b/init.sql @@ -34,3 +34,7 @@ ALTER DATABASE :"db_name" CREATE EXTENSION IF NOT EXISTS vector WITH SCHEMA :"db_name"; DROP SCHEMA IF EXISTS public CASCADE; + +CREATE DATABASE test_memgpt; +GRANT ALL PRIVILEGES ON DATABASE test_memgpt to "memgpt"; + diff --git a/memgpt/orm/organization.py b/memgpt/orm/organization.py index c1e96d7c13..d00360c352 100644 --- a/memgpt/orm/organization.py +++ b/memgpt/orm/organization.py @@ -1,5 +1,6 @@ from typing import Optional, TYPE_CHECKING from pydantic import EmailStr +from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Mapped, relationship, mapped_column from memgpt.orm.sqlalchemy_base import SqlalchemyBase @@ -18,5 +19,7 @@ class Organization(SqlalchemyBase): @classmethod def default(cls, db_session:"Session") -> "Organization": """Get the default org, or create it if it doesn't exist.""" - raise ValueError("Kaboom! this is mid-development") - org = db_session.query(cls).one().scalar() + try: + return db_session.query(cls).one().scalar() + except NoResultFound: + return cls(name="Default Organization").create(db_session) diff --git a/memgpt/orm/utilities.py b/memgpt/orm/utilities.py index 70a2dd6bb0..a0486de9f6 100644 --- a/memgpt/orm/utilities.py +++ b/memgpt/orm/utilities.py @@ -27,7 +27,7 @@ def create_engine( PATH_PARAM = 2 # avoid the magic number! url_parts[PATH_PARAM] = f"/{database}" if database else url_parts.path return sqlalchemy_create_engine(urlunsplit(url_parts)) - case "sqlite-chroma": + case "sqlite_chroma": return sqlalchemy_create_engine(f"sqlite:///{database}") case _: raise ValueError(f"Unsupported storage_type: {storage_type}") diff --git a/tests/conftest.py b/tests/conftest.py index a3ccf7e86d..15de7ba7e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,12 +99,13 @@ def db_session(request) -> "Session": after the end of the test. """ - function_ = request.node.name - engine = create_engine(storage_type=request.param, database="memgpt_test") + function_ = request.node.name.replace("[","_").replace("]","_") + engine = create_engine(storage_type=request.param, database="test_memgpt") adapter_statements = { "sqlite_chroma": (text(f"attach ':memory:' as {function_}"),), "postgres": ( text(f"CREATE SCHEMA IF NOT EXISTS {function_}"), + text(f"CREATE EXTENSION IF NOT EXISTS vector"), text(f"SET search_path TO {function_},public"), ), } diff --git a/tests/db_setup.sql b/tests/db_setup.sql new file mode 100644 index 0000000000..34f0a06b01 --- /dev/null +++ b/tests/db_setup.sql @@ -0,0 +1,2 @@ +CREATE DATABASE test_memgpt; +GRANT ALL PRIVILEGES ON DATABASE test_memgpt to "memgpt"; diff --git a/tests/orm/test_bases.py b/tests/orm/test_bases.py index 58f20c6d3e..084418dabf 100644 --- a/tests/orm/test_bases.py +++ b/tests/orm/test_bases.py @@ -13,11 +13,12 @@ class TestORMBases: def test_prefixed_ids(self, db_session): user = User( - email=faker.email, + email=faker.email(), organization=Organization.default(db_session=db_session), ).create(db_session) assert user.id.startswith('user-') assert str(user._id) in user.id - assert user.organization.id.startswith('organization-') - assert str(user.organization._id) in user.organization.id \ No newline at end of file + with db_session as session: + assert user.organization.id.startswith('organization-') + assert str(user.organization._id) in user.organization.id \ No newline at end of file From 2ea713b1b7a49e25eb8d526be86f0efb39f7041a Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 19 Jun 2024 16:55:48 -0400 Subject: [PATCH 21/45] conftest respects relationships --- tests/orm/test_bases.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/orm/test_bases.py b/tests/orm/test_bases.py index 084418dabf..86254c1651 100644 --- a/tests/orm/test_bases.py +++ b/tests/orm/test_bases.py @@ -20,5 +20,6 @@ def test_prefixed_ids(self, db_session): assert user.id.startswith('user-') assert str(user._id) in user.id with db_session as session: - assert user.organization.id.startswith('organization-') - assert str(user.organization._id) in user.organization.id \ No newline at end of file + session.add(user) + assert user.organization.id.startswith('organization-'), "Organization id is prefixed incorrectly" + assert str(user.organization._id) in user.organization.id, "Organization id is not using the correct uuid" \ No newline at end of file From d0d18d7a02246a762c4a479f86752e5abf31c47d Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 19 Jun 2024 17:38:56 -0400 Subject: [PATCH 22/45] sqlite now stores all the test databases in the .persist folder to keep things clean --- development.compose.yml | 1 + tests/conftest.py | 26 +++++++++++++++++--------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/development.compose.yml b/development.compose.yml index 44713dfe4e..324722f156 100644 --- a/development.compose.yml +++ b/development.compose.yml @@ -26,6 +26,7 @@ services: - ./tests/pytest.ini:/memgpt/pytest.ini - ./pyproject.toml:/pyproject.toml - ./tests:/tests + - ./.persist/sqlite:/sqlite ports: - "8083:8083" - "8283:8283" diff --git a/tests/conftest.py b/tests/conftest.py index 15de7ba7e5..a68b5edfe1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,17 +100,25 @@ def db_session(request) -> "Session": """ function_ = request.node.name.replace("[","_").replace("]","_") - engine = create_engine(storage_type=request.param, database="test_memgpt") - adapter_statements = { - "sqlite_chroma": (text(f"attach ':memory:' as {function_}"),), - "postgres": ( - text(f"CREATE SCHEMA IF NOT EXISTS {function_}"), - text(f"CREATE EXTENSION IF NOT EXISTS vector"), - text(f"SET search_path TO {function_},public"), - ), + adapter_test_configurations = { + "sqlite_chroma": { + "statements": (text(f"attach ':memory:' as {function_}"),), + "database": f"/sqlite/{function_}.db" + }, + "postgres": { + "statements":( + text(f"CREATE SCHEMA IF NOT EXISTS {function_}"), + text(f"CREATE EXTENSION IF NOT EXISTS vector"), + text(f"SET search_path TO {function_},public"), + ), + "database": "test_memgpt" + } } + adapter = adapter_test_configurations[request.param] + engine = create_engine(storage_type=request.param, database=adapter["database"]) + with engine.begin() as connection: - for statement in adapter_statements[request.param]: + for statement in adapter["statements"]: connection.execute(statement) Base.metadata.drop_all(bind=connection) Base.metadata.create_all(bind=connection) From 0187d8192074c2a972825c769474eb1c8c681ced Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 19 Jun 2024 17:41:46 -0400 Subject: [PATCH 23/45] updating readme --- CONTRIBUTING.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 31744f0e76..37783997da 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -145,3 +145,5 @@ If you prefer to keep your resources isolated by developing purely in containers docker compose -f compose.yaml -f development.compose.yml up ``` This will volume mount your local codebase and reload the server on file changes. + +MemGPT supports 2 alternate application backends, Postgres (with PGVector) and SQLite + Chromadb. Any time your unit or integration tests interact with the application data model (so almost always), the test suite will be run against _both_ database backends to ensure compatability. From 732cbbb542930f85131088aa755d1fbbcead6033 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 19 Jun 2024 19:13:31 -0400 Subject: [PATCH 24/45] more readme --- CONTRIBUTING.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 37783997da..e938e2d80f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -147,3 +147,9 @@ docker compose -f compose.yaml -f development.compose.yml up This will volume mount your local codebase and reload the server on file changes. MemGPT supports 2 alternate application backends, Postgres (with PGVector) and SQLite + Chromadb. Any time your unit or integration tests interact with the application data model (so almost always), the test suite will be run against _both_ database backends to ensure compatability. +After each run you can find test artifacts in the `.persist` folder. Connect to the `SQLite`, `Chroma` and `pgdata` directories to inspect individual test artifacts. + +- for Postgres, each individual test will have a schema named after that test with the full data model. +- For SQLite, each individual test will be a unique database file (named after the test). + + From 347c9f0d3fea0b6612448901c16218cef56581f9 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 21 Jun 2024 15:03:29 -0400 Subject: [PATCH 25/45] more models, bringing up lots of questions about the data model --- memgpt/orm/agent.py | 14 ++++++++++++++ memgpt/orm/mixins.py | 17 +++++++++++++++++ memgpt/orm/token.py | 15 +++++++++++++++ tests/orm/agent.py | 0 4 files changed, 46 insertions(+) create mode 100644 memgpt/orm/agent.py create mode 100644 memgpt/orm/token.py create mode 100644 tests/orm/agent.py diff --git a/memgpt/orm/agent.py b/memgpt/orm/agent.py new file mode 100644 index 0000000000..8423d1a686 --- /dev/null +++ b/memgpt/orm/agent.py @@ -0,0 +1,14 @@ +from typing import Optional +from sqlalchemy import String +from sqlachemy.orm import Mapped, mapped_column, relationship + +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.mixins import OrganizationMixin + +class Agent(SqlalchemyBase): + __tablename__ = 'agent' + + name:Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="a name to identify the token") + + # is this correct? do agents belong to a single user? + user: Mapped["User"] = relationship("User", back_populates="agents") \ No newline at end of file diff --git a/memgpt/orm/mixins.py b/memgpt/orm/mixins.py index eefc8a0619..64f6e939ab 100644 --- a/memgpt/orm/mixins.py +++ b/memgpt/orm/mixins.py @@ -59,3 +59,20 @@ def organization_id(self) -> str: def organization_id(self, value: str) -> None: _relation_setter(self, "organization", value) + +class UserMixin(Base): + """Mixin for models that belong to a user.""" + + __abstract__ = True + + _user_id: Mapped[UUID] = mapped_column( + SQLUUID(), ForeignKey("user._id") + ) + + @property + def user_id(self) -> str: + return _relation_getter(self, "user") + + @user_id.setter + def user_id(self, value: str) -> None: + _relation_setter(self, "user", value) \ No newline at end of file diff --git a/memgpt/orm/token.py b/memgpt/orm/token.py new file mode 100644 index 0000000000..d3715b6e5e --- /dev/null +++ b/memgpt/orm/token.py @@ -0,0 +1,15 @@ +from typing import Optional +from sqlalchemy import String +from sqlachemy.orm import Mapped, mapped_column, relationship + +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.mixins import UserMixin + + +class Token(SqlalchemyBase, UserMixin): + __tablename__ = 'token' + + hash:Mapped[str] = mapped_column(String, doc="the secured one-way hash of the token") + name:Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="a name to identify the token") + + user: Mapped["User"] = relationship("User", back_populates="tokens") \ No newline at end of file diff --git a/tests/orm/agent.py b/tests/orm/agent.py new file mode 100644 index 0000000000..e69de29bb2 From 7cae7cf58b9cd729cb4db0cdeb1be8a321cf8b06 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Mon, 24 Jun 2024 17:41:14 -0400 Subject: [PATCH 26/45] I'm trying to keep this as close to the current model as possible while stripping out extraneous elements. The memory thing needs to be abstracted in a later time, never clear if these are strings or templates or references to a related object --- memgpt/orm/agent.py | 21 ++++++++++++++++----- memgpt/orm/mixins.py | 19 ++++++++++++++++++- memgpt/orm/user.py | 6 +++++- memgpt/orm/user_agent.py | 15 +++++++++++++++ 4 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 memgpt/orm/user_agent.py diff --git a/memgpt/orm/agent.py b/memgpt/orm/agent.py index 8423d1a686..ec3f5e62bf 100644 --- a/memgpt/orm/agent.py +++ b/memgpt/orm/agent.py @@ -1,14 +1,25 @@ -from typing import Optional +from typing import Optional, List, TYPE_CHECKING from sqlalchemy import String from sqlachemy.orm import Mapped, mapped_column, relationship from memgpt.orm.sqlalchemy_base import SqlalchemyBase from memgpt.orm.mixins import OrganizationMixin +if TYPE_CHECKING: + from memgpt.orm.organization import Organization + from memgpt.orm.user import User -class Agent(SqlalchemyBase): +class Agent(SqlalchemyBase, OrganizationMixin): __tablename__ = 'agent' - name:Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="a name to identify the token") + name:Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="a human-readable identifier for an agent, non-unique.") + persona: Mapped[str] = mapped_column(doc="the persona text for the agent, current state.") + # todo: this doesn't allign with 1:M agents to users! + human: Mapped[str] = mapped_column(doc="the human text for the agent and the current user, current state.") + preset: Mapped[str] = mapped_column(doc="the preset text for the agent, current state.") - # is this correct? do agents belong to a single user? - user: Mapped["User"] = relationship("User", back_populates="agents") \ No newline at end of file + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="agents") + users: Mapped[List["User"]] = relationship("User", + back_populates="agents", + secondary="user_agent", + doc="the users associated with this agent.") diff --git a/memgpt/orm/mixins.py b/memgpt/orm/mixins.py index 64f6e939ab..d22c633e2a 100644 --- a/memgpt/orm/mixins.py +++ b/memgpt/orm/mixins.py @@ -75,4 +75,21 @@ def user_id(self) -> str: @user_id.setter def user_id(self, value: str) -> None: - _relation_setter(self, "user", value) \ No newline at end of file + _relation_setter(self, "user", value) + +class AgentMixin(Base): + """Mixin for models that belong to an agent.""" + + __abstract__ = True + + _agent_id: Mapped[UUID] = mapped_column( + SQLUUID(), ForeignKey("agent._id") + ) + + @property + def agent_id(self) -> str: + return _relation_getter(self, "agent") + + @agent_id.setter + def agent_id(self, value: str) -> None: + _relation_setter(self, "agent", value) \ No newline at end of file diff --git a/memgpt/orm/user.py b/memgpt/orm/user.py index 6d576e46db..d0c18bdd6d 100644 --- a/memgpt/orm/user.py +++ b/memgpt/orm/user.py @@ -1,4 +1,4 @@ -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, List from pydantic import EmailStr from sqlalchemy import String from sqlalchemy.orm import Mapped, relationship, mapped_column @@ -17,4 +17,8 @@ class User(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="users") + agents: Mapped[List["Agent"]] = relationship("Agent", + back_populates="users", + secondary="user_agent", + doc="the agents associated with this user.") diff --git a/memgpt/orm/user_agent.py b/memgpt/orm/user_agent.py new file mode 100644 index 0000000000..b29c55aff3 --- /dev/null +++ b/memgpt/orm/user_agent.py @@ -0,0 +1,15 @@ +from sqlachemy.orm import UniqueConstraint + + +from memgpt.orm.base import Base +from memgpt.orm.mixins import UserMixin, AgentMixin + +class UserAgent(Base, UserMixin, AgentMixin): + __tablename__ = 'user_agent' + __table_args__ = ( + UniqueConstraint( + "_agent_id", + "_user_id", + name="unique_agent_user_constraint", + ), + ) \ No newline at end of file From 0f3e8c81865d80e5f719336aca5cd1f8a314a349 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Tue, 25 Jun 2024 11:33:52 -0400 Subject: [PATCH 27/45] basic ORM pattern for most objects --- memgpt/orm/__all__.py | 8 ++++++++ memgpt/orm/agent.py | 5 +++-- memgpt/orm/base.py | 7 ------- memgpt/orm/organization.py | 2 ++ memgpt/orm/token.py | 2 +- memgpt/orm/user.py | 11 ++++++++--- memgpt/orm/user_agent.py | 15 --------------- memgpt/orm/users_agents.py | 16 ++++++++++++++++ memgpt/server/server.py | 5 ++--- tests/conftest.py | 2 +- 10 files changed, 41 insertions(+), 32 deletions(-) create mode 100644 memgpt/orm/__all__.py delete mode 100644 memgpt/orm/user_agent.py create mode 100644 memgpt/orm/users_agents.py diff --git a/memgpt/orm/__all__.py b/memgpt/orm/__all__.py new file mode 100644 index 0000000000..ecb6476188 --- /dev/null +++ b/memgpt/orm/__all__.py @@ -0,0 +1,8 @@ +from memgpt.orm.organization import Organization +from memgpt.orm.user import User +from memgpt.orm.agent import Agent +from memgpt.orm.users_agents import UsersAgents +from memgpt.orm.token import Token + + +from memgpt.orm.base import Base \ No newline at end of file diff --git a/memgpt/orm/agent.py b/memgpt/orm/agent.py index ec3f5e62bf..97f5a3c64d 100644 --- a/memgpt/orm/agent.py +++ b/memgpt/orm/agent.py @@ -1,9 +1,10 @@ from typing import Optional, List, TYPE_CHECKING from sqlalchemy import String -from sqlachemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from memgpt.orm.sqlalchemy_base import SqlalchemyBase from memgpt.orm.mixins import OrganizationMixin +from memgpt.orm.users_agents import UsersAgents if TYPE_CHECKING: from memgpt.orm.organization import Organization from memgpt.orm.user import User @@ -21,5 +22,5 @@ class Agent(SqlalchemyBase, OrganizationMixin): organization: Mapped["Organization"] = relationship("Organization", back_populates="agents") users: Mapped[List["User"]] = relationship("User", back_populates="agents", - secondary="user_agent", + secondary="users_agents", doc="the users associated with this agent.") diff --git a/memgpt/orm/base.py b/memgpt/orm/base.py index 6600cb9668..1577b807e1 100644 --- a/memgpt/orm/base.py +++ b/memgpt/orm/base.py @@ -10,16 +10,9 @@ declared_attr, ) -class AbsoluteBase(DeclarativeBase): - """For the few rare instances where we need a bare table - (like through m2m joins) extending AbsoluteBase ensures - all models inherit from the same DeclarativeBase. - """ - class Base(DeclarativeBase): """absolute base for sqlalchemy classes""" - @declarative_mixin class CommonSqlalchemyMetaMixins(Base): __abstract__ = True diff --git a/memgpt/orm/organization.py b/memgpt/orm/organization.py index d00360c352..3d0c3bae87 100644 --- a/memgpt/orm/organization.py +++ b/memgpt/orm/organization.py @@ -15,6 +15,7 @@ class Organization(SqlalchemyBase): # relationships users: Mapped["User"] = relationship("User", back_populates="organization") + agents: Mapped["Agent"] = relationship("Agent", back_populates="organization") @classmethod def default(cls, db_session:"Session") -> "Organization": @@ -23,3 +24,4 @@ def default(cls, db_session:"Session") -> "Organization": return db_session.query(cls).one().scalar() except NoResultFound: return cls(name="Default Organization").create(db_session) + diff --git a/memgpt/orm/token.py b/memgpt/orm/token.py index d3715b6e5e..7c382e0112 100644 --- a/memgpt/orm/token.py +++ b/memgpt/orm/token.py @@ -1,6 +1,6 @@ from typing import Optional from sqlalchemy import String -from sqlachemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from memgpt.orm.sqlalchemy_base import SqlalchemyBase from memgpt.orm.mixins import UserMixin diff --git a/memgpt/orm/user.py b/memgpt/orm/user.py index d0c18bdd6d..2c64e51171 100644 --- a/memgpt/orm/user.py +++ b/memgpt/orm/user.py @@ -5,6 +5,10 @@ from memgpt.orm.sqlalchemy_base import SqlalchemyBase from memgpt.orm.mixins import OrganizationMixin +from memgpt.orm.users_agents import UsersAgents +if TYPE_CHECKING: + from memgpt.orm.agent import Agent + from memgpt.orm.token import Token class User(SqlalchemyBase, OrganizationMixin): """User ORM class""" @@ -18,7 +22,8 @@ class User(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="users") agents: Mapped[List["Agent"]] = relationship("Agent", - back_populates="users", - secondary="user_agent", - doc="the agents associated with this user.") + secondary="users_agents", + back_populates="users", + doc="the agents associated with this user.") + tokens: Mapped[List["Token"]] = relationship("Token", back_populates="user", doc="the tokens associated with this user.") diff --git a/memgpt/orm/user_agent.py b/memgpt/orm/user_agent.py deleted file mode 100644 index b29c55aff3..0000000000 --- a/memgpt/orm/user_agent.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqlachemy.orm import UniqueConstraint - - -from memgpt.orm.base import Base -from memgpt.orm.mixins import UserMixin, AgentMixin - -class UserAgent(Base, UserMixin, AgentMixin): - __tablename__ = 'user_agent' - __table_args__ = ( - UniqueConstraint( - "_agent_id", - "_user_id", - name="unique_agent_user_constraint", - ), - ) \ No newline at end of file diff --git a/memgpt/orm/users_agents.py b/memgpt/orm/users_agents.py new file mode 100644 index 0000000000..4155005e80 --- /dev/null +++ b/memgpt/orm/users_agents.py @@ -0,0 +1,16 @@ +from uuid import UUID +from sqlalchemy import ForeignKey, UUID as SQLUUID +from sqlalchemy.orm import Mapped, mapped_column + + +from memgpt.orm.base import Base + +class UsersAgents(Base): + __tablename__ = 'users_agents' + + _agent_id: Mapped[UUID] = mapped_column( + SQLUUID(), ForeignKey("agent._id"), primary_key=True + ) + _user_id: Mapped[UUID] = mapped_column( + SQLUUID(), ForeignKey("user._id"), primary_key=True + ) \ No newline at end of file diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 3fcfeedb14..6f27ad9440 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -129,9 +129,8 @@ def __init__( chaining: bool = True, max_chaining_steps: bool = None, default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(), - # default_interface: AgentInterface = CLIInterface(), - # default_persistence_manager_cls: PersistenceManager = LocalStateManager, - # auth_mode: str = "none", # "none, "jwt", "external" + # test hooks + config: Optional["MemGPTConfig"] = None ): """Server process holds in-memory agents that are being run""" diff --git a/tests/conftest.py b/tests/conftest.py index a68b5edfe1..69b163a4c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from memgpt.settings import settings from memgpt.orm.utilities import create_engine -from memgpt.orm.base import Base +from memgpt.orm.__all__ import Base from tests.utils import wipe_memgpt_home from memgpt.data_types import EmbeddingConfig, LLMConfig from memgpt.credentials import MemGPTCredentials From 4e2236917859ec234b4381c5a5fde54128e2a289 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Tue, 25 Jun 2024 15:41:38 -0400 Subject: [PATCH 28/45] presets model started, lots of questions. pushing to this notion doc to clarify https://www.notion.so/Data-Model-Questions-43ef1336483f49c1bf77daddf3f320fa --- memgpt/orm/preset.py | 30 ++++++++++++++++++++++++++++++ memgpt/orm/source.py | 26 ++++++++++++++++++++++++++ memgpt/orm/sources_agents.py | 13 +++++++++++++ memgpt/orm/sources_presets.py | 13 +++++++++++++ memgpt/orm/users_agents.py | 4 ++-- 5 files changed, 84 insertions(+), 2 deletions(-) create mode 100644 memgpt/orm/preset.py create mode 100644 memgpt/orm/source.py create mode 100644 memgpt/orm/sources_agents.py create mode 100644 memgpt/orm/sources_presets.py diff --git a/memgpt/orm/preset.py b/memgpt/orm/preset.py new file mode 100644 index 0000000000..77b8b9e867 --- /dev/null +++ b/memgpt/orm/preset.py @@ -0,0 +1,30 @@ +from sqlalchemy import UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.mixins import OrganizationMixin + +class Preset(SqlalchemyBase, OrganizationMixin): + + __tablename__ = 'preset' + __table_args__ = ( + UniqueConstraint( + "_organization_id", + "name", + name="unique_name_organization", + ), + ) + + name: Mapped[str] = mapped_column(doc="the name of the preset, must be unique within the org", nullable=False) + description: Mapped[str] = mapped_column(nullable=True, doc="a human-readable description of the preset") + + ## TODO: these are unclear - human vs human_name for example, what and why? + system = Column(String) + human = Column(String) + human_name = Column(String, nullable=False) + persona = Column(String) + persona_name = Column(String, nullable=False) + ## TODO: What is this? + preset = Column(String) + + functions_schema = Column(JSON) \ No newline at end of file diff --git a/memgpt/orm/source.py b/memgpt/orm/source.py new file mode 100644 index 0000000000..a08a80b514 --- /dev/null +++ b/memgpt/orm/source.py @@ -0,0 +1,26 @@ +from typing import List, TYPE_CHECKING +from sqlalchemy.orm import relationship, Mapped, mapped_column + + +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.mixins import OrganizationMixin + +if TYPE_CHECKING: + from memgpt.orm.agent import Agent + from memgpt.orm.organization import Organization + from memgpt.orm.preset import Preset + +class Source(OrganizationMixin,SqlalchemyBase): + """A source represents an embedded text passage""" + __tablename__ = 'source' + + name: Mapped[str] = mapped_column(doc="the name of the source, must be unique within the org", nullable=False) + # TODO: feels like embeddings should be a first class object + embedding_dim:Mapped[int] = mapped_column(doc="the max number of dimensions for embedding vectors", nullable=False) + embedding_model:Mapped[str] = mapped_column(doc="the name of the embedding model used to generate the embedding", nullable=False) + description:Mapped[str] = mapped_column(nullable=True, doc="a human-readable description of the source") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") + agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents") + presets: Mapped[List["Preset"]] = relationship("Preset", secondary="presets_sources") \ No newline at end of file diff --git a/memgpt/orm/sources_agents.py b/memgpt/orm/sources_agents.py new file mode 100644 index 0000000000..20f3167860 --- /dev/null +++ b/memgpt/orm/sources_agents.py @@ -0,0 +1,13 @@ +from sqlalchemy import ForeignKey, UUID as SQLUUID +from uuid import UUID +from sqlalchemy.orm import relationship, Mapped, mapped_column + +from memgpt.orm.base import Base + + +class SourcesAgents(Base): + """Agents can have zero to many sources""" + __tablename__ = 'sources_agents' + + _agent_id:Mapped[UUID] = mapped_column(SQLUUID, ForeignKey('agent._id'), primary_key=True) + _source_id:Mapped[UUID] = mapped_column(SQLUUID, ForeignKey('source._id'), primary_key=True) \ No newline at end of file diff --git a/memgpt/orm/sources_presets.py b/memgpt/orm/sources_presets.py new file mode 100644 index 0000000000..11d3296f81 --- /dev/null +++ b/memgpt/orm/sources_presets.py @@ -0,0 +1,13 @@ +from sqlalchemy import ForeignKey, UUID as SQLUUID +from uuid import UUID +from sqlalchemy.orm import relationship, Mapped, mapped_column + +from memgpt.orm.base import Base + + +class SourcesPresets(Base): + """Sources can be used by zero to many Presets""" + __tablename__ = 'sources_presets' + + _preset_id:Mapped[UUID] = mapped_column(SQLUUID, ForeignKey('preset._id'), primary_key=True) + _source_id:Mapped[UUID] = mapped_column(SQLUUID, ForeignKey('source._id'), primary_key=True) \ No newline at end of file diff --git a/memgpt/orm/users_agents.py b/memgpt/orm/users_agents.py index 4155005e80..65a5d8a70c 100644 --- a/memgpt/orm/users_agents.py +++ b/memgpt/orm/users_agents.py @@ -9,8 +9,8 @@ class UsersAgents(Base): __tablename__ = 'users_agents' _agent_id: Mapped[UUID] = mapped_column( - SQLUUID(), ForeignKey("agent._id"), primary_key=True + SQLUUID, ForeignKey("agent._id"), primary_key=True ) _user_id: Mapped[UUID] = mapped_column( - SQLUUID(), ForeignKey("user._id"), primary_key=True + SQLUUID, ForeignKey("user._id"), primary_key=True ) \ No newline at end of file From 4347d3963faaa27d864375cb53edbfd02d2de712 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 26 Jun 2024 13:10:48 -0400 Subject: [PATCH 29/45] pretty sure this is the current model --- memgpt/orm/__all__.py | 9 +++++++- memgpt/orm/agent.py | 7 ++++-- memgpt/orm/enums.py | 7 ++++++ memgpt/orm/memory_templates.py | 42 ++++++++++++++++++++++++++++++++++ memgpt/orm/organization.py | 11 +++++++-- memgpt/orm/preset.py | 34 ++++++++++++++++++--------- memgpt/orm/source.py | 4 ++-- memgpt/orm/tool.py | 32 ++++++++++++++++++++++++++ memgpt/orm/tools_agents.py | 13 +++++++++++ memgpt/orm/tools_presets.py | 13 +++++++++++ 10 files changed, 154 insertions(+), 18 deletions(-) create mode 100644 memgpt/orm/enums.py create mode 100644 memgpt/orm/memory_templates.py create mode 100644 memgpt/orm/tool.py create mode 100644 memgpt/orm/tools_agents.py create mode 100644 memgpt/orm/tools_presets.py diff --git a/memgpt/orm/__all__.py b/memgpt/orm/__all__.py index ecb6476188..3056dc4e26 100644 --- a/memgpt/orm/__all__.py +++ b/memgpt/orm/__all__.py @@ -3,6 +3,13 @@ from memgpt.orm.agent import Agent from memgpt.orm.users_agents import UsersAgents from memgpt.orm.token import Token - +from memgpt.orm.source import Source +from memgpt.orm.tool import Tool +from memgpt.orm.preset import Preset +from memgpt.orm.memory_templates import MemoryTemplate, HumanMemoryTemplate, PersonaMemoryTemplate +from memgpt.orm.sources_agents import SourcesAgents +from memgpt.orm.sources_presets import SourcesPresets +from memgpt.orm.tools_agents import ToolsAgents +from memgpt.orm.tools_presets import ToolsPresets from memgpt.orm.base import Base \ No newline at end of file diff --git a/memgpt/orm/agent.py b/memgpt/orm/agent.py index 97f5a3c64d..e0c761eb1b 100644 --- a/memgpt/orm/agent.py +++ b/memgpt/orm/agent.py @@ -7,7 +7,9 @@ from memgpt.orm.users_agents import UsersAgents if TYPE_CHECKING: from memgpt.orm.organization import Organization + from memgpt.orm.source import Source from memgpt.orm.user import User + from memgpt.orm.tool import Tool class Agent(SqlalchemyBase, OrganizationMixin): __tablename__ = 'agent' @@ -22,5 +24,6 @@ class Agent(SqlalchemyBase, OrganizationMixin): organization: Mapped["Organization"] = relationship("Organization", back_populates="agents") users: Mapped[List["User"]] = relationship("User", back_populates="agents", - secondary="users_agents", - doc="the users associated with this agent.") + secondary="users_agents") + sources: Mapped[List["Source"]] = relationship("Source", secondary="sources_agents") + tools: Mapped[List["Tool"]] = relationship("Tool", secondary="tools_agents") \ No newline at end of file diff --git a/memgpt/orm/enums.py b/memgpt/orm/enums.py new file mode 100644 index 0000000000..8d26f7a8ed --- /dev/null +++ b/memgpt/orm/enums.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class ToolSourceType(str, Enum): + """Defines what a tool was derived from""" + python = "python" + json = "json" \ No newline at end of file diff --git a/memgpt/orm/memory_templates.py b/memgpt/orm/memory_templates.py new file mode 100644 index 0000000000..fb50f0b00b --- /dev/null +++ b/memgpt/orm/memory_templates.py @@ -0,0 +1,42 @@ +from typing import TYPE_CHECKING +from sqlalchemy.orm import mapped_column, Mapped, relationship + +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.mixins import OrganizationMixin +if TYPE_CHECKING: + from memgpt.orm.organization import Organization + +class MemoryTemplate(SqlalchemyBase, OrganizationMixin): + """Memory templates define the structure and starting point for a given memory type.""" + __tablename__ = 'memory_template' + + name:Mapped[str] = mapped_column(doc="the unique name that identifies a memory template") + description:Mapped[str] = mapped_column(doc="a description of the memory template") + type:Mapped[str] = mapped_column(doc="the type of memory template in use") + text:Mapped[str] = mapped_column(doc="the starting memory text provided for the template") + + # relationships + organization:Mapped["Organization"] = relationship("Organization") + + __mapper_args__ = { + "polymorphic_identity": "employee", + "polymorphic_on": "type", + } + +class HumanMemoryTemplate(MemoryTemplate): + """Template for the structured 'human' section of core memory. + Note: will be migrated to dynamic memory templates in the future. + """ + + __mapper_args__ = { + "polymorphic_identity": "human", + } + +class PersonaMemoryTemplate(MemoryTemplate): + """Template for the structured 'persona' section of core memory. + Note: will be migrated to dynamic memory templates in the future. + """ + + __mapper_args__ = { + "polymorphic_identity": "persona", + } \ No newline at end of file diff --git a/memgpt/orm/organization.py b/memgpt/orm/organization.py index 3d0c3bae87..7d1d4736e5 100644 --- a/memgpt/orm/organization.py +++ b/memgpt/orm/organization.py @@ -6,6 +6,10 @@ from memgpt.orm.sqlalchemy_base import SqlalchemyBase if TYPE_CHECKING: from memgpt.orm.user import User + from memgpt.orm.agent import Agent + from memgpt.orm.source import Source + from memgpt.orm.tool import Tool + from memgpt.orm.preset import Preset from sqlalchemy.orm.session import Session class Organization(SqlalchemyBase): @@ -14,8 +18,11 @@ class Organization(SqlalchemyBase): name:Mapped[Optional[str]] = mapped_column(nullable=True, doc="The display name of the organization.") # relationships - users: Mapped["User"] = relationship("User", back_populates="organization") - agents: Mapped["Agent"] = relationship("Agent", back_populates="organization") + users: Mapped["User"] = relationship("User", back_populates="organization", cascade="all, delete-orphan") + agents: Mapped["Agent"] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") + sources: Mapped["Source"] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") + tools: Mapped["Tool"] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") + presets: Mapped["Preset"] = relationship("Preset", back_populates="organization", cascade="all, delete-orphan") @classmethod def default(cls, db_session:"Session") -> "Organization": diff --git a/memgpt/orm/preset.py b/memgpt/orm/preset.py index 77b8b9e867..ca8e0f61eb 100644 --- a/memgpt/orm/preset.py +++ b/memgpt/orm/preset.py @@ -1,11 +1,19 @@ -from sqlalchemy import UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column +from typing import Optional, List, TYPE_CHECKING +from sqlalchemy import UniqueConstraint, JSON # TODO: jsonb for pg +from sqlalchemy.orm import Mapped, mapped_column, relationship from memgpt.orm.sqlalchemy_base import SqlalchemyBase from memgpt.orm.mixins import OrganizationMixin -class Preset(SqlalchemyBase, OrganizationMixin): +if TYPE_CHECKING: + from memgpt.orm.organization import Organization + from memgpt.orm.source import Source + from memgpt.orm.tool import Tool +class Preset(SqlalchemyBase, OrganizationMixin): + """A preset represents a fixed starting point for an Agent, like a template of sorts. + It is similar to OpenAI's concept of an `assistant`_ + """ __tablename__ = 'preset' __table_args__ = ( UniqueConstraint( @@ -19,12 +27,16 @@ class Preset(SqlalchemyBase, OrganizationMixin): description: Mapped[str] = mapped_column(nullable=True, doc="a human-readable description of the preset") ## TODO: these are unclear - human vs human_name for example, what and why? - system = Column(String) - human = Column(String) - human_name = Column(String, nullable=False) - persona = Column(String) - persona_name = Column(String, nullable=False) - ## TODO: What is this? - preset = Column(String) + system:Mapped[Optional[str]] = mapped_column(doc="the current system message for the agent.") + human:Mapped[str] = mapped_column(doc="the current human message for the agent.") + human_name:Mapped[str] = mapped_column(doc="the name of the human message for the agent - DEPRECATED") + persona:Mapped[str] = mapped_column(doc="the current persona message for the agent.") + persona_name:Mapped[str] = mapped_column(doc="the name of the persona message for the agent - DEPRECATED") + # legacy JSON for support in sqlite and postgres. TODO: jsonb for pg + functions_schema:Mapped[dict] = mapped_column(JSON, doc="the schema for the functions in the preset - DEPRECATED") - functions_schema = Column(JSON) \ No newline at end of file + # relationships + organization: Mapped["Organization"] = relationship("Organization", + back_populates="presets") + sources: Mapped[List["Source"]] = relationship("Source", secondary="sources_presets") + tools: Mapped[List["Tool"]] = relationship("Tool", secondary="tools_presets") \ No newline at end of file diff --git a/memgpt/orm/source.py b/memgpt/orm/source.py index a08a80b514..abf3e9ef08 100644 --- a/memgpt/orm/source.py +++ b/memgpt/orm/source.py @@ -22,5 +22,5 @@ class Source(OrganizationMixin,SqlalchemyBase): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") - agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents") - presets: Mapped[List["Preset"]] = relationship("Preset", secondary="presets_sources") \ No newline at end of file + agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources") + presets: Mapped[List["Preset"]] = relationship("Preset", secondary="sources_presets", back_populates="sources") \ No newline at end of file diff --git a/memgpt/orm/tool.py b/memgpt/orm/tool.py new file mode 100644 index 0000000000..9782bfe422 --- /dev/null +++ b/memgpt/orm/tool.py @@ -0,0 +1,32 @@ +from typing import Optional, TYPE_CHECKING, List +from sqlalchemy import String, JSON +from sqlalchemy.orm import Mapped, relationship, mapped_column + +from memgpt.orm.enums import ToolSourceType +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.mixins import OrganizationMixin +from memgpt.orm.users_agents import UsersAgents + + +if TYPE_CHECKING: + from memgpt.orm.agent import Agent + from memgpt.orm.token import Token + +class Tool(SqlalchemyBase, OrganizationMixin): + """Represents an available tool that the LLM can invoke. + + NOTE: polymorphic inheritance makes more sense here as a TODO. We want a superset of tools + that are always available, and a subset scoped to the organization. Alternatively, we could use the apply_access_predicate to build + more granular permissions. + """ + __tablename__ = "tool" + + name:Mapped[Optional[str]] = mapped_column(nullable=True, doc="The display name of the tool.") + # TODO: this needs to be a lookup table to have any value + tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.") + source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json) + source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function if provided.", default=None, nullable=True) + json_schema: Mapped[dict] = mapped_column(JSON, default=lambda : {}, doc="The OAI compatable JSON schema of the function.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="tools") diff --git a/memgpt/orm/tools_agents.py b/memgpt/orm/tools_agents.py new file mode 100644 index 0000000000..1e2bae27a7 --- /dev/null +++ b/memgpt/orm/tools_agents.py @@ -0,0 +1,13 @@ +from sqlalchemy import ForeignKey, UUID as SQLUUID +from uuid import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from memgpt.orm.base import Base + + +class ToolsAgents(Base): + """Agents can have zero to many tools""" + __tablename__ = 'tools_agents' + + _agent_id:Mapped[UUID] = mapped_column(SQLUUID, ForeignKey('agent._id'), primary_key=True) + _tool_id:Mapped[UUID] = mapped_column(SQLUUID, ForeignKey('tool._id'), primary_key=True) \ No newline at end of file diff --git a/memgpt/orm/tools_presets.py b/memgpt/orm/tools_presets.py new file mode 100644 index 0000000000..4edc0dc591 --- /dev/null +++ b/memgpt/orm/tools_presets.py @@ -0,0 +1,13 @@ +from sqlalchemy import ForeignKey, UUID as SQLUUID +from uuid import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from memgpt.orm.base import Base + + +class ToolsPresets(Base): + """Tools can be used by zero to many Presets""" + __tablename__ = 'tools_presets' + + _preset_id:Mapped[UUID] = mapped_column(SQLUUID, ForeignKey('preset._id'), primary_key=True) + _tool_id:Mapped[UUID] = mapped_column(SQLUUID, ForeignKey('tool._id'), primary_key=True) \ No newline at end of file From 90349f085c792c4b127b98093c0e2a83ed5c38d8 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 26 Jun 2024 15:10:53 -0400 Subject: [PATCH 30/45] alembic-managed migrations --- memgpt/alembic.ini | 116 ++++++++++++ memgpt/migrations/README | 1 + memgpt/migrations/env.py | 83 +++++++++ memgpt/migrations/script.py.mako | 26 +++ memgpt/migrations/versions/fcd6c014e6a8_.py | 193 ++++++++++++++++++++ memgpt/settings.py | 22 ++- pyproject.toml | 2 + 7 files changed, 437 insertions(+), 6 deletions(-) create mode 100644 memgpt/alembic.ini create mode 100644 memgpt/migrations/README create mode 100644 memgpt/migrations/env.py create mode 100644 memgpt/migrations/script.py.mako create mode 100644 memgpt/migrations/versions/fcd6c014e6a8_.py diff --git a/memgpt/alembic.ini b/memgpt/alembic.ini new file mode 100644 index 0000000000..c404ec57dc --- /dev/null +++ b/memgpt/alembic.ini @@ -0,0 +1,116 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = %(MEMGPT_DATABASE_URL)s + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/memgpt/migrations/README b/memgpt/migrations/README new file mode 100644 index 0000000000..98e4f9c44e --- /dev/null +++ b/memgpt/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/memgpt/migrations/env.py b/memgpt/migrations/env.py new file mode 100644 index 0000000000..80835623c2 --- /dev/null +++ b/memgpt/migrations/env.py @@ -0,0 +1,83 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +from memgpt.settings import settings +from memgpt.orm.__all__ import Base + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config +section = config.config_ini_section +# set the metadata database url from settings +config.set_section_option(section, "MEMGPT_DATABASE_URL", settings.database_url) +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/memgpt/migrations/script.py.mako b/memgpt/migrations/script.py.mako new file mode 100644 index 0000000000..fbc4b07dce --- /dev/null +++ b/memgpt/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/memgpt/migrations/versions/fcd6c014e6a8_.py b/memgpt/migrations/versions/fcd6c014e6a8_.py new file mode 100644 index 0000000000..10e6a501cc --- /dev/null +++ b/memgpt/migrations/versions/fcd6c014e6a8_.py @@ -0,0 +1,193 @@ +"""empty message + +Revision ID: fcd6c014e6a8 +Revises: +Create Date: 2024-06-26 18:52:23.655166 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'fcd6c014e6a8' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('organization', + sa.Column('name', sa.String(), nullable=True), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.PrimaryKeyConstraint('_id') + ) + op.create_table('agent', + sa.Column('name', sa.String(), nullable=True), + sa.Column('persona', sa.String(), nullable=False), + sa.Column('human', sa.String(), nullable=False), + sa.Column('preset', sa.String(), nullable=False), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.Column('_organization_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), + sa.PrimaryKeyConstraint('_id') + ) + op.create_table('memory_template', + sa.Column('name', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=False), + sa.Column('type', sa.String(), nullable=False), + sa.Column('text', sa.String(), nullable=False), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.Column('_organization_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), + sa.PrimaryKeyConstraint('_id') + ) + op.create_table('preset', + sa.Column('name', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=True), + sa.Column('system', sa.String(), nullable=True), + sa.Column('human', sa.String(), nullable=False), + sa.Column('human_name', sa.String(), nullable=False), + sa.Column('persona', sa.String(), nullable=False), + sa.Column('persona_name', sa.String(), nullable=False), + sa.Column('functions_schema', sa.JSON(), nullable=False), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.Column('_organization_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), + sa.PrimaryKeyConstraint('_id'), + sa.UniqueConstraint('_organization_id', 'name', name='unique_name_organization') + ) + op.create_table('source', + sa.Column('name', sa.String(), nullable=False), + sa.Column('embedding_dim', sa.Integer(), nullable=False), + sa.Column('embedding_model', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=True), + sa.Column('_organization_id', sa.UUID(), nullable=False), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), + sa.PrimaryKeyConstraint('_id') + ) + op.create_table('tool', + sa.Column('name', sa.String(), nullable=True), + sa.Column('tags', sa.JSON(), nullable=False), + sa.Column('source_type', sa.String(), nullable=False), + sa.Column('source_code', sa.String(), nullable=True), + sa.Column('json_schema', sa.JSON(), nullable=False), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.Column('_organization_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), + sa.PrimaryKeyConstraint('_id') + ) + op.create_table('user', + sa.Column('name', sa.String(), nullable=True), + sa.Column('email', sa.String(), nullable=True), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.Column('_organization_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), + sa.PrimaryKeyConstraint('_id') + ) + op.create_table('sources_agents', + sa.Column('_agent_id', sa.UUID(), nullable=False), + sa.Column('_source_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_agent_id'], ['agent._id'], ), + sa.ForeignKeyConstraint(['_source_id'], ['source._id'], ), + sa.PrimaryKeyConstraint('_agent_id', '_source_id') + ) + op.create_table('sources_presets', + sa.Column('_preset_id', sa.UUID(), nullable=False), + sa.Column('_source_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_preset_id'], ['preset._id'], ), + sa.ForeignKeyConstraint(['_source_id'], ['source._id'], ), + sa.PrimaryKeyConstraint('_preset_id', '_source_id') + ) + op.create_table('token', + sa.Column('hash', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.Column('_user_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_user_id'], ['user._id'], ), + sa.PrimaryKeyConstraint('_id') + ) + op.create_table('tools_agents', + sa.Column('_agent_id', sa.UUID(), nullable=False), + sa.Column('_tool_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_agent_id'], ['agent._id'], ), + sa.ForeignKeyConstraint(['_tool_id'], ['tool._id'], ), + sa.PrimaryKeyConstraint('_agent_id', '_tool_id') + ) + op.create_table('tools_presets', + sa.Column('_preset_id', sa.UUID(), nullable=False), + sa.Column('_tool_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_preset_id'], ['preset._id'], ), + sa.ForeignKeyConstraint(['_tool_id'], ['tool._id'], ), + sa.PrimaryKeyConstraint('_preset_id', '_tool_id') + ) + op.create_table('users_agents', + sa.Column('_agent_id', sa.UUID(), nullable=False), + sa.Column('_user_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['_agent_id'], ['agent._id'], ), + sa.ForeignKeyConstraint(['_user_id'], ['user._id'], ), + sa.PrimaryKeyConstraint('_agent_id', '_user_id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('users_agents') + op.drop_table('tools_presets') + op.drop_table('tools_agents') + op.drop_table('token') + op.drop_table('sources_presets') + op.drop_table('sources_agents') + op.drop_table('user') + op.drop_table('tool') + op.drop_table('source') + op.drop_table('preset') + op.drop_table('memory_template') + op.drop_table('agent') + op.drop_table('organization') + # ### end Alembic commands ### diff --git a/memgpt/settings.py b/memgpt/settings.py index 0f742cf4e7..5bd1fc487a 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -1,21 +1,25 @@ from pathlib import Path from typing import Optional +from enum import Enum from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict +class StorageType(str, Enum): + sqlite = "sqlite" + postgres = "postgres" class Settings(BaseSettings): model_config = SettingsConfigDict(env_prefix="memgpt_") - + storage_type: Optional[StorageType] = Field(description="What is the RDBMS type associated with the database url?", default=StorageType.sqlite) memgpt_dir: Optional[Path] = Field(Path.home() / ".memgpt", env="MEMGPT_DIR") debug: Optional[bool] = False server_pass: Optional[str] = None - pg_db: Optional[str] = "memgpt" - pg_user: Optional[str] = "memgpt" - pg_password: Optional[str] = "memgpt" - pg_host: Optional[str] = "localhost" - pg_port: Optional[int] = 5432 cors_origins: Optional[list] = ["http://memgpt.localhost", "http://localhost:8283", "http://localhost:8083"] + pg_db: Optional[str] = None + pg_user: Optional[str] = None + pg_password: Optional[str] = None + pg_host: Optional[str] = None + pg_port: Optional[int] = None _pg_uri: Optional[str] = None # calculated to specify full uri # configurations config_path: Optional[Path] = Path("~/.memgpt/config").expanduser() @@ -28,6 +32,12 @@ class Settings(BaseSettings): # TODO: extract to vendor plugin openai_api_key: Optional[str] = None + @property + def database_url(self) -> str: + if self.storage_type == StorageType.sqlite: + return f"sqlite:///{self.memgpt_dir}/memgpt.db" + return self.pg_uri + @property def pg_uri(self) -> str: if self._pg_uri: diff --git a/pyproject.toml b/pyproject.toml index a3b797747b..1276ae6dfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,8 @@ httpx-sse = "^0.4.0" isort = { version = "^5.13.2", optional = true } llama-index-embeddings-ollama = {version = "^0.1.2", optional = true} protobuf = "3.20.0" +alembic = "^1.13.2" +pyhumps = "^3.8.0" [tool.poetry.extras] local = ["llama-index-embeddings-huggingface"] From b8391a9281f6a2919ff124f004dc8d84d18a50d8 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 26 Jun 2024 16:08:48 -0400 Subject: [PATCH 31/45] migrations now included on startup. we need to add it to every possible entrypoint to be good to go --- memgpt/server/startup.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/memgpt/server/startup.sh b/memgpt/server/startup.sh index f91c669a51..fb59c5a4c7 100755 --- a/memgpt/server/startup.sh +++ b/memgpt/server/startup.sh @@ -1,4 +1,8 @@ #!/bin/sh +set -e +echo "updating database..." +alembic upgrade head +echo "Database updated!" echo "Starting MEMGPT server..." if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then echo "Starting in development mode!" From 9f0b8dc71110e0f353b849ff982f251b2f40d1bb Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 26 Jun 2024 16:29:57 -0400 Subject: [PATCH 32/45] added jobs model --- memgpt/migrations/versions/e1e15ff9ab6e_.py | 151 ++++++++++++++++++++ memgpt/orm/__all__.py | 1 + memgpt/orm/enums.py | 8 +- memgpt/orm/job.py | 23 +++ 4 files changed, 182 insertions(+), 1 deletion(-) create mode 100644 memgpt/migrations/versions/e1e15ff9ab6e_.py create mode 100644 memgpt/orm/job.py diff --git a/memgpt/migrations/versions/e1e15ff9ab6e_.py b/memgpt/migrations/versions/e1e15ff9ab6e_.py new file mode 100644 index 0000000000..ea925032b6 --- /dev/null +++ b/memgpt/migrations/versions/e1e15ff9ab6e_.py @@ -0,0 +1,151 @@ +"""empty message + +Revision ID: e1e15ff9ab6e +Revises: fcd6c014e6a8 +Create Date: 2024-06-26 20:23:47.395414 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'e1e15ff9ab6e' +down_revision: Union[str, None] = 'fcd6c014e6a8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('job', + sa.Column('status', sa.Enum('created', 'running', 'completed', 'failed', name='jobstatus'), nullable=False), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('metadata_', sa.JSON(), nullable=True), + sa.Column('_user_id', sa.UUID(), nullable=False), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(['_user_id'], ['user._id'], ), + sa.PrimaryKeyConstraint('_id') + ) + op.drop_table('agent_source_mapping') + op.drop_table('humanmodel') + op.drop_table('agents') + op.drop_table('toolmodel') + op.drop_table('presets') + op.drop_table('users') + op.drop_table('tokens') + op.drop_table('personamodel') + op.drop_table('jobmodel') + op.drop_table('sources') + op.drop_table('preset_source_mapping') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('preset_source_mapping', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('preset_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('source_id', sa.UUID(), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='preset_source_mapping_pkey') + ) + op.create_table('sources', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True), + sa.Column('embedding_dim', sa.BIGINT(), autoincrement=False, nullable=True), + sa.Column('embedding_model', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('description', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='sources_pkey') + ) + op.create_table('jobmodel', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('status', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('created_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=False), + sa.Column('completed_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=True), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='jobmodel_pkey') + ) + op.create_table('personamodel', + sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='personamodel_pkey') + ) + op.create_table('tokens', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('token', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='tokens_pkey') + ) + op.create_table('users', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('default_agent', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('policies_accepted', sa.BOOLEAN(), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='users_pkey') + ) + op.create_table('presets', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('description', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('system', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('human', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('human_name', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('persona', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('persona_name', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('preset', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True), + sa.Column('functions_schema', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='presets_pkey') + ) + op.create_table('toolmodel', + sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('tags', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.Column('source_type', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('source_code', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('json_schema', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='toolmodel_pkey') + ) + op.create_table('agents', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('persona', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('human', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('preset', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True), + sa.Column('llm_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.Column('state', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='agents_pkey') + ) + op.create_table('humanmodel', + sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='humanmodel_pkey') + ) + op.create_table('agent_source_mapping', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('agent_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('source_id', sa.UUID(), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='agent_source_mapping_pkey') + ) + op.drop_table('job') + # ### end Alembic commands ### diff --git a/memgpt/orm/__all__.py b/memgpt/orm/__all__.py index 3056dc4e26..c2cb353159 100644 --- a/memgpt/orm/__all__.py +++ b/memgpt/orm/__all__.py @@ -11,5 +11,6 @@ from memgpt.orm.sources_presets import SourcesPresets from memgpt.orm.tools_agents import ToolsAgents from memgpt.orm.tools_presets import ToolsPresets +from memgpt.orm.job import Job from memgpt.orm.base import Base \ No newline at end of file diff --git a/memgpt/orm/enums.py b/memgpt/orm/enums.py index 8d26f7a8ed..5fb3ce6584 100644 --- a/memgpt/orm/enums.py +++ b/memgpt/orm/enums.py @@ -4,4 +4,10 @@ class ToolSourceType(str, Enum): """Defines what a tool was derived from""" python = "python" - json = "json" \ No newline at end of file + json = "json" + +class JobStatus(str, Enum): + created = "created" + running = "running" + completed = "completed" + failed = "failed" \ No newline at end of file diff --git a/memgpt/orm/job.py b/memgpt/orm/job.py new file mode 100644 index 0000000000..ecca0c24ff --- /dev/null +++ b/memgpt/orm/job.py @@ -0,0 +1,23 @@ +from typing import Optional, TYPE_CHECKING +from datetime import datetime +from sqlalchemy import JSON +from sqlalchemy.orm import relationship, Mapped, mapped_column +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.enums import JobStatus +from memgpt.orm.mixins import UserMixin + +if TYPE_CHECKING: + from memgpt.orm.user import User + +class Job(UserMixin, SqlalchemyBase): + """Jobs run in the background and are owned by a user. + Typical jobs involve loading and processing sources etc. + """ + __tablename__ = "job" + + status: Mapped[JobStatus] = mapped_column(default=JobStatus.created, doc="The current status of the job.") + completed_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="The unix timestamp of when the job was completed.") + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda : {}, doc="The metadata of the job.") + + # relationships + user: Mapped["User"] = relationship("User", back_populates="jobs") \ No newline at end of file From d11f87ba5c5cc9b47138a33a3e457b7168f04956 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 26 Jun 2024 16:51:54 -0400 Subject: [PATCH 33/45] configs pattern for now. these should be 1st class. also there needs to be helpers like palm to do migrations and such --- .../1ad46f6e4c2d_adding_configs_to_agent.py | 32 +++++++++++++++++++ memgpt/orm/__all__.py | 1 + memgpt/orm/agent.py | 8 +++-- 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 memgpt/migrations/versions/1ad46f6e4c2d_adding_configs_to_agent.py diff --git a/memgpt/migrations/versions/1ad46f6e4c2d_adding_configs_to_agent.py b/memgpt/migrations/versions/1ad46f6e4c2d_adding_configs_to_agent.py new file mode 100644 index 0000000000..efc5d91191 --- /dev/null +++ b/memgpt/migrations/versions/1ad46f6e4c2d_adding_configs_to_agent.py @@ -0,0 +1,32 @@ +"""adding configs to agent + +Revision ID: 1ad46f6e4c2d +Revises: e1e15ff9ab6e +Create Date: 2024-06-26 20:51:04.227418 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1ad46f6e4c2d' +down_revision: Union[str, None] = 'e1e15ff9ab6e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('agent', sa.Column('llm_config', sa.JSON(), nullable=False)) + op.add_column('agent', sa.Column('embedding_config', sa.JSON(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('agent', 'embedding_config') + op.drop_column('agent', 'llm_config') + # ### end Alembic commands ### diff --git a/memgpt/orm/__all__.py b/memgpt/orm/__all__.py index c2cb353159..a8dac16dd2 100644 --- a/memgpt/orm/__all__.py +++ b/memgpt/orm/__all__.py @@ -1,3 +1,4 @@ +"""__all__ acts as manual import management to avoid collisions and circular imports.""" from memgpt.orm.organization import Organization from memgpt.orm.user import User from memgpt.orm.agent import Agent diff --git a/memgpt/orm/agent.py b/memgpt/orm/agent.py index e0c761eb1b..082e24861a 100644 --- a/memgpt/orm/agent.py +++ b/memgpt/orm/agent.py @@ -1,10 +1,11 @@ from typing import Optional, List, TYPE_CHECKING -from sqlalchemy import String +from sqlalchemy import String, JSON from sqlalchemy.orm import Mapped, mapped_column, relationship from memgpt.orm.sqlalchemy_base import SqlalchemyBase from memgpt.orm.mixins import OrganizationMixin -from memgpt.orm.users_agents import UsersAgents +from memgpt.data_types import LLMConfig, EmbeddingConfig + if TYPE_CHECKING: from memgpt.orm.organization import Organization from memgpt.orm.source import Source @@ -20,6 +21,9 @@ class Agent(SqlalchemyBase, OrganizationMixin): human: Mapped[str] = mapped_column(doc="the human text for the agent and the current user, current state.") preset: Mapped[str] = mapped_column(doc="the preset text for the agent, current state.") + llm_config: Mapped[LLMConfig] = mapped_column(JSON, doc="the LLM backend configuration object for this agent.") + embedding_config: Mapped[EmbeddingConfig] = mapped_column(JSON, doc="the embedding configuration object for this agent.") + # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="agents") users: Mapped[List["User"]] = relationship("User", From c4e45b0c27a08305275c5842673563b433b41c9f Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 26 Jun 2024 17:05:41 -0400 Subject: [PATCH 34/45] finally time to start cutting --- memgpt/metadata.py | 540 +++++++++++++++++----------------- memgpt/orm/agent.py | 4 + memgpt/orm/sqlalchemy_base.py | 17 +- memgpt/orm/user.py | 4 +- 4 files changed, 293 insertions(+), 272 deletions(-) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index cfdc7ad762..bbcecf3e27 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -6,25 +6,25 @@ import uuid from typing import List, Optional, Type -from sqlalchemy import ( - BIGINT, - CHAR, - JSON, - Boolean, - Column, - DateTime, - String, - TypeDecorator, - create_engine, - desc, - func, -) -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.exc import InterfaceError, OperationalError -from sqlalchemy.orm import declarative_base, sessionmaker -from sqlalchemy.sql import func - -from memgpt.config import MemGPTConfig +#from sqlalchemy import ( + #BIGINT, + #CHAR, + #JSON, + #Boolean, + #Column, + #DateTime, + #String, + #TypeDecorator, + #create_engine, + #desc, + #func, +#) +#from sqlalchemy.dialects.postgresql import UUID +#from sqlalchemy.exc import InterfaceError +#from sqlalchemy.orm import declarative_base, sessionmaker +#from sqlalchemy.sql import func + +#from memgpt.config import MemGPTConfig from memgpt.data_types import ( AgentState, EmbeddingConfig, @@ -34,290 +34,290 @@ Token, User, ) -from memgpt.models.pydantic_models import ( - HumanModel, - JobModel, - JobStatus, - PersonaModel, - ToolModel, -) +from memgpt.functions.functions import load_all_function_sets +from memgpt.orm.enums import JobStatus +#from memgpt.models.pydantic_models import ( + #HumanModel, + #JobModel, + #JobStatus, + #PersonaModel, + #ToolModel, +#) from memgpt.settings import settings -from memgpt.utils import enforce_types, get_utc_time, printd +from memgpt.utils import printd #enforce_types, get_utc_time, printd -Base = declarative_base() +#Base = declarative_base() # Custom UUID type -class CommonUUID(TypeDecorator): - impl = CHAR - cache_ok = True - - def load_dialect_impl(self, dialect): - if dialect.name == "postgresql": - return dialect.type_descriptor(UUID(as_uuid=True)) - else: - return dialect.type_descriptor(CHAR()) - - def process_bind_param(self, value, dialect): - if dialect.name == "postgresql" or value is None: - return value - else: - return str(value) # Convert UUID to string for SQLite - - def process_result_value(self, value, dialect): - if dialect.name == "postgresql" or value is None: - return value - else: - return uuid.UUID(value) - - -class LLMConfigColumn(TypeDecorator): - """Custom type for storing LLMConfig as JSON""" - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - if value: - return vars(value) - return value - - def process_result_value(self, value, dialect): - if value: - return LLMConfig(**value) - return value - +#class CommonUUID(TypeDecorator): + #impl = CHAR + #cache_ok = True -class EmbeddingConfigColumn(TypeDecorator): - """Custom type for storing EmbeddingConfig as JSON""" + #def load_dialect_impl(self, dialect): + #if dialect.name == "postgresql": + #return dialect.type_descriptor(UUID(as_uuid=True)) + #else: + #return dialect.type_descriptor(CHAR()) - impl = JSON - cache_ok = True + #def process_bind_param(self, value, dialect): + #if dialect.name == "postgresql" or value is None: + #return value + #else: + #return str(value) # Convert UUID to string for SQLite - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) + #def process_result_value(self, value, dialect): + #if dialect.name == "postgresql" or value is None: + #return value + #else: + #return uuid.UUID(value) - def process_bind_param(self, value, dialect): - if value: - return vars(value) - return value - def process_result_value(self, value, dialect): - if value: - return EmbeddingConfig(**value) - return value +#class LLMConfigColumn(TypeDecorator): + #"""Custom type for storing LLMConfig as JSON""" + #impl = JSON + #cache_ok = True -class UserModel(Base): - __tablename__ = "users" - __table_args__ = {"extend_existing": True} + #def load_dialect_impl(self, dialect): + #return dialect.type_descriptor(JSON()) - id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - # name = Column(String, nullable=False) - default_agent = Column(String) + #def process_bind_param(self, value, dialect): + #if value: + #return vars(value) + #return value - policies_accepted = Column(Boolean, nullable=False, default=False) + #def process_result_value(self, value, dialect): + #if value: + #return LLMConfig(**value) + #return value - def __repr__(self) -> str: - return f"" - def to_record(self) -> User: - return User( - id=self.id, - # name=self.name - default_agent=self.default_agent, - policies_accepted=self.policies_accepted, - ) +#class EmbeddingConfigColumn(TypeDecorator): + #"""Custom type for storing EmbeddingConfig as JSON""" + #impl = JSON + #cache_ok = True -class TokenModel(Base): - """Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens).""" + #def load_dialect_impl(self, dialect): + #return dialect.type_descriptor(JSON()) + + #def process_bind_param(self, value, dialect): + #if value: + #return vars(value) + #return value - __tablename__ = "tokens" + #def process_result_value(self, value, dialect): + #if value: + #return EmbeddingConfig(**value) + #return value - id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - # each api key is tied to a user account (that it validates access for) - user_id = Column(CommonUUID, nullable=False) - # the api key - token = Column(String, nullable=False) - # extra (optional) metadata - name = Column(String) - def __repr__(self) -> str: - return f"" +#class UserModel(Base): + #__tablename__ = "users" + #__table_args__ = {"extend_existing": True} - def to_record(self) -> User: - return Token( - id=self.id, - user_id=self.user_id, - token=self.token, - name=self.name, - ) + #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + ## name = Column(String, nullable=False) + #default_agent = Column(String) + + #policies_accepted = Column(Boolean, nullable=False, default=False) + + #def __repr__(self) -> str: + #return f"" + + #def to_record(self) -> User: + #return User( + #id=self.id, + ## name=self.name + #default_agent=self.default_agent, + #policies_accepted=self.policies_accepted, + #) + + +#class TokenModel(Base): + #"""Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens).""" + + #__tablename__ = "tokens" + + #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + ## each api key is tied to a user account (that it validates access for) + #user_id = Column(CommonUUID, nullable=False) + ## the api key + #token = Column(String, nullable=False) + ## extra (optional) metadata + #name = Column(String) + + #def __repr__(self) -> str: + #return f"" + + #def to_record(self) -> User: + #return Token( + #id=self.id, + #user_id=self.user_id, + #token=self.token, + #name=self.name, + #) -def generate_api_key(prefix="sk-", length=51) -> str: - # Generate 'length // 2' bytes because each byte becomes two hex digits. Adjust length for prefix. - actual_length = max(length - len(prefix), 1) // 2 # Ensure at least 1 byte is generated - random_bytes = secrets.token_bytes(actual_length) - new_key = prefix + random_bytes.hex() - return new_key +#def generate_api_key(prefix="sk-", length=51) -> str: + ## Generate 'length // 2' bytes because each byte becomes two hex digits. Adjust length for prefix. + #actual_length = max(length - len(prefix), 1) // 2 # Ensure at least 1 byte is generated + #random_bytes = secrets.token_bytes(actual_length) + #new_key = prefix + random_bytes.hex() + #return new_key -class AgentModel(Base): - """Defines data model for storing Passages (consisting of text, embedding)""" +#class AgentModel(Base): + #"""Defines data model for storing Passages (consisting of text, embedding)""" - __tablename__ = "agents" - __table_args__ = {"extend_existing": True} + #__tablename__ = "agents" + #__table_args__ = {"extend_existing": True} - id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - user_id = Column(CommonUUID, nullable=False) - name = Column(String, nullable=False) - system = Column(String) - created_at = Column(DateTime(timezone=True), server_default=func.now()) + #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + #user_id = Column(CommonUUID, nullable=False) + #name = Column(String, nullable=False) + #persona = Column(String) + #human = Column(String) + #preset = Column(String) + #created_at = Column(DateTime(timezone=True), server_default=func.now()) - # configs - llm_config = Column(LLMConfigColumn) - embedding_config = Column(EmbeddingConfigColumn) + ## configs + #llm_config = Column(LLMConfigColumn) + #embedding_config = Column(EmbeddingConfigColumn) - # state - state = Column(JSON) - _metadata = Column(JSON) + ## state + #state = Column(JSON) - # tools - tools = Column(JSON) + #def __repr__(self) -> str: + #return f"" - def __repr__(self) -> str: - return f"" - - def to_record(self) -> AgentState: - return AgentState( - id=self.id, - user_id=self.user_id, - name=self.name, - created_at=self.created_at, - llm_config=self.llm_config, - embedding_config=self.embedding_config, - state=self.state, - tools=self.tools, - system=self.system, - _metadata=self._metadata, - ) - - -class SourceModel(Base): - """Defines data model for storing Passages (consisting of text, embedding)""" - - __tablename__ = "sources" - __table_args__ = {"extend_existing": True} - - # Assuming passage_id is the primary key - # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - user_id = Column(CommonUUID, nullable=False) - name = Column(String, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - embedding_dim = Column(BIGINT) - embedding_model = Column(String) - description = Column(String) - - # TODO: add num passages - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> Source: - return Source( - id=self.id, - user_id=self.user_id, - name=self.name, - created_at=self.created_at, - embedding_dim=self.embedding_dim, - embedding_model=self.embedding_model, - description=self.description, - ) - - -class AgentSourceMappingModel(Base): - """Stores mapping between agent -> source""" - - __tablename__ = "agent_source_mapping" - - id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - user_id = Column(CommonUUID, nullable=False) - agent_id = Column(CommonUUID, nullable=False) - source_id = Column(CommonUUID, nullable=False) - - def __repr__(self) -> str: - return f"" - - -class PresetSourceMapping(Base): - __tablename__ = "preset_source_mapping" - - id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - user_id = Column(CommonUUID, nullable=False) - preset_id = Column(CommonUUID, nullable=False) - source_id = Column(CommonUUID, nullable=False) - - def __repr__(self) -> str: - return f"" - - -# class PresetFunctionMapping(Base): -# __tablename__ = "preset_function_mapping" -# -# id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) -# user_id = Column(CommonUUID, nullable=False) -# preset_id = Column(CommonUUID, nullable=False) -# #function_id = Column(CommonUUID, nullable=False) -# function = Column(String, nullable=False) # TODO: convert to ID eventually -# -# def __repr__(self) -> str: -# return f"" - - -class PresetModel(Base): - """Defines data model for storing Preset objects""" - - __tablename__ = "presets" - __table_args__ = {"extend_existing": True} - - id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - user_id = Column(CommonUUID, nullable=False) - name = Column(String, nullable=False) - description = Column(String) - system = Column(String) - human = Column(String) - human_name = Column(String, nullable=False) - persona = Column(String) - persona_name = Column(String, nullable=False) - preset = Column(String) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - - functions_schema = Column(JSON) - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> Preset: - return Preset( - id=self.id, - user_id=self.user_id, - name=self.name, - description=self.description, - system=self.system, - human=self.human, - persona=self.persona, - human_name=self.human_name, - persona_name=self.persona_name, - preset=self.preset, - created_at=self.created_at, - functions_schema=self.functions_schema, - ) + #def to_record(self) -> AgentState: + #return AgentState( + #id=self.id, + #user_id=self.user_id, + #name=self.name, + #persona=self.persona, + #human=self.human, + #preset=self.preset, + #created_at=self.created_at, + #llm_config=self.llm_config, + #embedding_config=self.embedding_config, + #state=self.state, + #) + + +#class SourceModel(Base): + #"""Defines data model for storing Passages (consisting of text, embedding)""" + + #__tablename__ = "sources" + #__table_args__ = {"extend_existing": True} + + ## Assuming passage_id is the primary key + ## id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + #user_id = Column(CommonUUID, nullable=False) + #name = Column(String, nullable=False) + #created_at = Column(DateTime(timezone=True), server_default=func.now()) + #embedding_dim = Column(BIGINT) + #embedding_model = Column(String) + #description = Column(String) + + ## TODO: add num passages + + #def __repr__(self) -> str: + #return f"" + + #def to_record(self) -> Source: + #return Source( + #id=self.id, + #user_id=self.user_id, + #name=self.name, + #created_at=self.created_at, + #embedding_dim=self.embedding_dim, + #embedding_model=self.embedding_model, + #description=self.description, + #) + + +#class AgentSourceMappingModel(Base): + #"""Stores mapping between agent -> source""" + + #__tablename__ = "agent_source_mapping" + + #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + #user_id = Column(CommonUUID, nullable=False) + #agent_id = Column(CommonUUID, nullable=False) + #source_id = Column(CommonUUID, nullable=False) + + #def __repr__(self) -> str: + #return f"" + + +#class PresetSourceMapping(Base): + #__tablename__ = "preset_source_mapping" + + #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + #user_id = Column(CommonUUID, nullable=False) + #preset_id = Column(CommonUUID, nullable=False) + #source_id = Column(CommonUUID, nullable=False) + + #def __repr__(self) -> str: + #return f"" + + +## class PresetFunctionMapping(Base): +## __tablename__ = "preset_function_mapping" +## +## id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) +## user_id = Column(CommonUUID, nullable=False) +## preset_id = Column(CommonUUID, nullable=False) +## #function_id = Column(CommonUUID, nullable=False) +## function = Column(String, nullable=False) # TODO: convert to ID eventually +## +## def __repr__(self) -> str: +## return f"" + + +#class PresetModel(Base): + #"""Defines data model for storing Preset objects""" + + #__tablename__ = "presets" + #__table_args__ = {"extend_existing": True} + + #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + #user_id = Column(CommonUUID, nullable=False) + #name = Column(String, nullable=False) + #description = Column(String) + #system = Column(String) + #human = Column(String) + #human_name = Column(String, nullable=False) + #persona = Column(String) + #persona_name = Column(String, nullable=False) + #preset = Column(String) + #created_at = Column(DateTime(timezone=True), server_default=func.now()) + + #functions_schema = Column(JSON) + + #def __repr__(self) -> str: + #return f"" + + #def to_record(self) -> Preset: + #return Preset( + #id=self.id, + #user_id=self.user_id, + #name=self.name, + #description=self.description, + #system=self.system, + #human=self.human, + #persona=self.persona, + #human_name=self.human_name, + #persona_name=self.persona_name, + #preset=self.preset, + #created_at=self.created_at, + #functions_schema=self.functions_schema, + #) class MetadataStore: diff --git a/memgpt/orm/agent.py b/memgpt/orm/agent.py index 082e24861a..e66ee6a2e9 100644 --- a/memgpt/orm/agent.py +++ b/memgpt/orm/agent.py @@ -17,6 +17,10 @@ class Agent(SqlalchemyBase, OrganizationMixin): name:Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="a human-readable identifier for an agent, non-unique.") persona: Mapped[str] = mapped_column(doc="the persona text for the agent, current state.") + # TODO: reconcile this with persona,human etc AND make this structured via pydantic! + # TODO: these are vague and need to be more specific and explained. WTF is state vs _metadata? + state: Mapped[dict] = mapped_column(JSON, doc="the state of the agent.") + _metadata: Mapped[dict] = mapped_column(JSON, doc="metadata for the agent.") # todo: this doesn't allign with 1:M agents to users! human: Mapped[str] = mapped_column(doc="the human text for the agent and the current user, current state.") preset: Mapped[str] = mapped_column(doc="the preset text for the agent, current state.") diff --git a/memgpt/orm/sqlalchemy_base.py b/memgpt/orm/sqlalchemy_base.py index b4629b7c3d..4021ee3283 100644 --- a/memgpt/orm/sqlalchemy_base.py +++ b/memgpt/orm/sqlalchemy_base.py @@ -11,6 +11,7 @@ from memgpt.orm.errors import NoResultFound if TYPE_CHECKING: + from pydantic import BaseModel from sqlalchemy.orm import Session from sqlalchemy import Select from memgpt.orm.user import User @@ -43,6 +44,7 @@ def id(self, value: str) -> None: prefix == self.__prefix__ ), f"{prefix} is not a valid id prefix for {self.__class__.__name__}" self._id = UUID(id_) + @classmethod def list(cls, db_session: "Session") -> list[Type["Base"]]: with db_session as session: @@ -100,4 +102,17 @@ def apply_access_predicate( ) if not org_uid: raise ValueError("object %s has no organization accessor", actor) - return query.where(cls._organization_id == org_uid) \ No newline at end of file + return query.where(cls._organization_id == org_uid) + + @property + def __pydantic_model__(self) -> Type["BaseModel"]: + raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.") + + def to_pydantic(self) -> Type["BaseModel"]: + """converts to the basic pydantic model counterpart""" + return self.__pydantic_model__.model_validate(self) + + def to_record(self) -> Type["BaseModel"]: + """Deprecated accessor for to_pydantic""" + logger.warning("to_record is deprecated, use to_pydantic instead.") + return self.to_pydantic() \ No newline at end of file diff --git a/memgpt/orm/user.py b/memgpt/orm/user.py index 2c64e51171..aaefac9a37 100644 --- a/memgpt/orm/user.py +++ b/memgpt/orm/user.py @@ -5,7 +5,8 @@ from memgpt.orm.sqlalchemy_base import SqlalchemyBase from memgpt.orm.mixins import OrganizationMixin -from memgpt.orm.users_agents import UsersAgents +from memgpt.data_types import User as PydanticUser + if TYPE_CHECKING: from memgpt.orm.agent import Agent from memgpt.orm.token import Token @@ -13,6 +14,7 @@ class User(SqlalchemyBase, OrganizationMixin): """User ORM class""" __tablename__ = "user" + __pydantic_model__ = PydanticUser name:Mapped[Optional[str]] = mapped_column(nullable=True, doc="The display name of the user.") email:Mapped[Optional[EmailStr]] = mapped_column(String, From 1c2ae946c4ceaba737d11bb5c4a240066c7ad257 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Thu, 27 Jun 2024 22:42:06 -0400 Subject: [PATCH 35/45] pg_uri is now the _only_ db setting. it will always have the correct scheme. the settings.backend object is self-contained, so no more external double-setting --- memgpt/metadata.py | 65 +++++----------------------------- memgpt/orm/utilities.py | 23 ++++-------- memgpt/settings.py | 77 +++++++++++++++++------------------------ tests/conftest.py | 16 +++++++-- 4 files changed, 61 insertions(+), 120 deletions(-) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index bbcecf3e27..9b74d15185 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -5,7 +5,7 @@ import traceback import uuid from typing import List, Optional, Type - +from memgpt.orm.utilities import get_db_session #from sqlalchemy import ( #BIGINT, #CHAR, @@ -321,65 +321,16 @@ class MetadataStore: - uri: Optional[str] = None + db_session: "Session" = None - def __init__(self, config: MemGPTConfig): - # TODO: get DB URI or path - if config.metadata_storage_type == "postgres": - # construct URI from enviornment variables - self.uri = settings.pg_uri if settings.pg_uri else config.metadata_storage_uri + def __init__(self): + self.db_session = get_db_session() - elif config.metadata_storage_type == "sqlite": - path = os.path.join(config.metadata_storage_path, "sqlite.db") - self.uri = f"sqlite:///{path}" - else: - raise ValueError(f"Invalid metadata storage type: {config.metadata_storage_type}") - - # Ensure valid URI - assert self.uri, "Database URI is not provided or is invalid." - - # Check if tables need to be created - self.engine = create_engine(self.uri) - try: - Base.metadata.create_all( - self.engine, - tables=[ - UserModel.__table__, - AgentModel.__table__, - SourceModel.__table__, - AgentSourceMappingModel.__table__, - TokenModel.__table__, - PresetModel.__table__, - PresetSourceMapping.__table__, - HumanModel.__table__, - PersonaModel.__table__, - ToolModel.__table__, - JobModel.__table__, - ], - ) - except (InterfaceError, OperationalError) as e: - traceback.print_exc() - if config.metadata_storage_type == "postgres": - raise ValueError( - f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. " - + "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). " - + "\npostgres detected: Make sure the postgres database is running (https://memgpt.readme.io/docs/storage#postgres)." - ) - elif config.metadata_storage_type == "sqlite": - raise ValueError( - f"{str(e)}\n\nMemGPT failed to connect to the database at URI '{self.uri}'. " - + "Please make sure you configured your storage backend correctly (https://memgpt.readme.io/docs/storage). " - + "\nsqlite detected: Make sure that the sqlite.db file exists at the URI." - ) - else: - raise e - except: - raise - self.session_maker = sessionmaker(bind=self.engine) - - @enforce_types - def create_api_key(self, user_id: uuid.UUID, name: Optional[str] = None) -> Token: + def create_api_key(self, + user_id: uuid.UUID, + name: Optional[str] = None) -> Token: """Create an API key for a user""" + # TODO: next - create token for user new_api_key = generate_api_key() with self.session_maker() as session: if session.query(TokenModel).filter(TokenModel.token == new_api_key).count() > 0: diff --git a/memgpt/orm/utilities.py b/memgpt/orm/utilities.py index a0486de9f6..a549d74b2f 100644 --- a/memgpt/orm/utilities.py +++ b/memgpt/orm/utilities.py @@ -4,33 +4,24 @@ from sqlalchemy.orm import sessionmaker -from memgpt.settings import settings +from memgpt.settings import settings, BackendConfiguration if TYPE_CHECKING: from sqlalchemy.engine import Engine def create_engine( - storage_type: Optional[str] = None, - database: Optional[str] = None, + backend_configuration: Optional["BackendConfiguration"] = None ) -> "Engine": """creates an engine for the storage_type designated by settings Args: - storage_type: test hook to inject storage_type, you should not be setting this - database: test hook to inject database, you should not be setting this + backend_configuration: a BackendConfiguration object - this is a test hook, you should NOT be using this for application code! + Returns: a sqlalchemy engine """ - storage_type = storage_type or settings.storage_type - match storage_type: - case "postgres": - url_parts = list(urlsplit(settings.pg_uri)) - PATH_PARAM = 2 # avoid the magic number! - url_parts[PATH_PARAM] = f"/{database}" if database else url_parts.path - return sqlalchemy_create_engine(urlunsplit(url_parts)) - case "sqlite_chroma": - return sqlalchemy_create_engine(f"sqlite:///{database}") - case _: - raise ValueError(f"Unsupported storage_type: {storage_type}") + backend = backend_configuration or settings.backend + return sqlalchemy_create_engine(backend.database_uri) + def get_db_session() -> "Generator": diff --git a/memgpt/settings.py b/memgpt/settings.py index 5bd1fc487a..ad6736549c 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -1,26 +1,46 @@ from pathlib import Path -from typing import Optional +from urlparse import urlsplit, urlunsplit +from typing import Optional, Literal from enum import Enum -from pydantic import Field +from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict -class StorageType(str, Enum): - sqlite = "sqlite" - postgres = "postgres" +# this is the driver we use +POSTGRES_SCHEME="postgresql+pg8000" + +class BackendConfiguration(BaseModel): + name: Literal["postgres", "sqlite_chroma"] + database_uri: str + class Settings(BaseSettings): model_config = SettingsConfigDict(env_prefix="memgpt_") - storage_type: Optional[StorageType] = Field(description="What is the RDBMS type associated with the database url?", default=StorageType.sqlite) memgpt_dir: Optional[Path] = Field(Path.home() / ".memgpt", env="MEMGPT_DIR") debug: Optional[bool] = False server_pass: Optional[str] = None cors_origins: Optional[list] = ["http://memgpt.localhost", "http://localhost:8283", "http://localhost:8083"] - pg_db: Optional[str] = None - pg_user: Optional[str] = None - pg_password: Optional[str] = None - pg_host: Optional[str] = None - pg_port: Optional[int] = None - _pg_uri: Optional[str] = None # calculated to specify full uri + + # backend database settings + pg_uri: Optional[str] = Field(None, description="if set backend will use postgresql. Othewise default to sqlite.") + + @property + def backend(self) -> BackendConfiguration: + """Return an adjusted BackendConfiguration. + Note: defaults to sqlite-chroma if pg_uri is not set. + """ + if self.pg_uri: + return BackendConfiguration(name="postgres", database_uri=self._correct_pg_uri(self.pg_uri)) + return BackendConfiguration(name="sqlite-chroma", database_uri=f"sqlite:///{self.memgpt_dir}/memgpt.db") + + def _correct_pg_uri(self) -> str: + """It is awkward to have users set a scheme for the uri (because why should they know anything about what drivers we use?) + So here we check (and correct) the provided uri to use the scheme we implement. + """ + url_parts = list(urlsplit(settings.pg_uri)) + SCHEME = 0 + url_parts[SCHEME] = POSTGRES_SCHEME + return urlunsplit(url_parts) + # configurations config_path: Optional[Path] = Path("~/.memgpt/config").expanduser() @@ -32,38 +52,5 @@ class Settings(BaseSettings): # TODO: extract to vendor plugin openai_api_key: Optional[str] = None - @property - def database_url(self) -> str: - if self.storage_type == StorageType.sqlite: - return f"sqlite:///{self.memgpt_dir}/memgpt.db" - return self.pg_uri - - @property - def pg_uri(self) -> str: - if self._pg_uri: - return self._pg_uri - elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port: - return f"postgresql+pg8000://{self.pg_user}:{self.pg_password}@{self.pg_host}:{self.pg_port}/{self.pg_db}" - else: - return f"postgresql+pg8000://memgpt:memgpt@localhost:5432/memgpt" - - # add this property to avoid being returned the default - # reference: https://github.com/cpacker/MemGPT/issues/1362 - @property - def memgpt_pg_uri_no_default(self) -> str: - if self.pg_uri: - return self.pg_uri - elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port: - return f"postgresql+pg8000://{self.pg_user}:{self.pg_password}@{self.pg_host}:{self.pg_port}/{self.pg_db}" - else: - return None - - @pg_uri.setter - def pg_uri(self, value: str): - self._pg_uri = value - - - - # singleton settings = Settings() diff --git a/tests/conftest.py b/tests/conftest.py index 69b163a4c3..b9ba0712d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ from typing import TYPE_CHECKING +from urllib.parse import urlsplit, urlunsplit import pytest from sqlalchemy import text from sqlalchemy.orm import sessionmaker -from memgpt.settings import settings +from memgpt.settings import settings, BackendConfiguration from memgpt.orm.utilities import create_engine from memgpt.orm.__all__ import Base from tests.utils import wipe_memgpt_home @@ -115,7 +116,18 @@ def db_session(request) -> "Session": } } adapter = adapter_test_configurations[request.param] - engine = create_engine(storage_type=request.param, database=adapter["database"]) + # update the db uri to reflect the test function and param + match request.param: + case "sqlite_chroma": + database_uri = f"sqlite:///{adapter['database']}" + case "postgres": + url_parts = list(urlsplit(settings.postgres_uri)) + PATH_PARAM = 2 + url_parts[PATH_PARAM] = f"/{adapter['database']}" + database_uri = urlunsplit(url_parts) + backend = BackendConfiguration(name=request.param, + database_uri=database_uri) + engine = create_engine(backend) with engine.begin() as connection: for statement in adapter["statements"]: From 2c3ca0d153d47fd2db138aa8fb4112369fba878c Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 28 Jun 2024 11:31:12 -0400 Subject: [PATCH 36/45] chewing through all the redundant crud methods in metadata to ideally stub everything over to ORM models. --- memgpt/metadata.py | 213 ++++++++++++++++++---------------- memgpt/orm/sqlalchemy_base.py | 34 +++++- memgpt/orm/token.py | 36 +++++- 3 files changed, 171 insertions(+), 112 deletions(-) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 9b74d15185..7e9005218e 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -1,11 +1,23 @@ """ Metadata store for user/agent/data_source information""" - +from typing import TYPE_CHECKING +import inspect as python_inspect import os import secrets import traceback import uuid from typing import List, Optional, Type +from humps import pascalize + +from memgpt.log import get_logger from memgpt.orm.utilities import get_db_session +from memgpt.orm.token import Token +from memgpt.orm.agent import Agent + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +logger = get_logger(__name__) + #from sqlalchemy import ( #BIGINT, #CHAR, @@ -328,92 +340,94 @@ def __init__(self): def create_api_key(self, user_id: uuid.UUID, - name: Optional[str] = None) -> Token: - """Create an API key for a user""" - # TODO: next - create token for user - new_api_key = generate_api_key() - with self.session_maker() as session: - if session.query(TokenModel).filter(TokenModel.token == new_api_key).count() > 0: - # NOTE duplicate API keys / tokens should never happen, but if it does don't allow it - raise ValueError(f"Token {new_api_key} already exists") - # TODO store the API keys as hashed - token = Token(user_id=user_id, token=new_api_key, name=name) - session.add(TokenModel(**vars(token))) - session.commit() - return self.get_api_key(api_key=new_api_key) + name: Optional[str] = None, + actor: Optional["User"] = None) -> str: + """Create an API key for a user + Args: + user_id: the user raw id as a UUID (legacy accessor) + name: the name of the token + actor: the user creating the API key, does not need to be the same as the user_id. will default to the user_id if not provided. + Returns: + api_key: the generated API key string starting with 'sk-' + """ + token = Token( + _user_id=actor._id or user_id, + name=name + ).create(self.db_session) + return token.api_key + + def delete_api_key(self, + api_key: str, + actor: Optional["User"]=None) -> None: + """(soft) Delete an API key from the database + Args: + api_key: the API key to delete + actor: the user deleting the API key. TODO this will not be optional in the future! + Raises: + NotFoundError: if the API key does not exist or the user does not have access to it. + """ + #TODO: this is a temporary shim. the long-term solution (next PR) will be to look up the token ID partial, check access, and soft delete. + logger.info(f"User %s is deleting API key %s", actor.id, api_key) + Token.get_by_api_key(api_key).delete(self.db_session) + + def get_api_key(self, + api_key: str, + actor: Optional["User"] = None) -> Optional[Token]: + """legacy token lookup. + Note: auth should remove this completely - there is no reason to look up a token without a user context. + """ + return Token.get_by_api_key(self.db_session, api_key).to_record() + + def get_all_api_keys_for_user(self, + user_id: uuid.UUID) -> List[Token]: + """""" + user = User.read(self.db_session, user_id) + return [r.to_record() for r in user.tokens] - @enforce_types - def delete_api_key(self, api_key: str): - """Delete an API key from the database""" - with self.session_maker() as session: - session.query(TokenModel).filter(TokenModel.token == api_key).delete() - session.commit() - - @enforce_types - def get_api_key(self, api_key: str) -> Optional[Token]: - with self.session_maker() as session: - results = session.query(TokenModel).filter(TokenModel.token == api_key).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result - return results[0].to_record() - - @enforce_types - def get_all_api_keys_for_user(self, user_id: uuid.UUID) -> List[Token]: - with self.session_maker() as session: - results = session.query(TokenModel).filter(TokenModel.user_id == user_id).all() - tokens = [r.to_record() for r in results] - return tokens - - @enforce_types def get_user_from_api_key(self, api_key: str) -> Optional[User]: """Get the user associated with a given API key""" - token = self.get_api_key(api_key=api_key) - if token is None: - raise ValueError(f"Provided token does not exist") - else: - return self.get_user(user_id=token.user_id) + return Token.get_by_api_key(self.db_session, api_key).user.to_record() - @enforce_types def create_agent(self, agent: AgentState): - # insert into agent table - # make sure agent.name does not already exist for user user_id - assert agent.state is not None, "Agent state must be provided" - assert len(list(agent.state.keys())) > 0, "Agent state must not be empty" - with self.session_maker() as session: - if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0: - raise ValueError(f"Agent with name {agent.name} already exists") - session.add(AgentModel(**vars(agent))) - session.commit() - - @enforce_types - def create_source(self, source: Source, exists_ok=False): - # make sure source.name does not already exist for user - with self.session_maker() as session: - if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0: - if not exists_ok: - raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}") - else: - session.update(SourceModel(**vars(source))) - else: - session.add(SourceModel(**vars(source))) - session.commit() - - @enforce_types - def create_user(self, user: User): - with self.session_maker() as session: - if session.query(UserModel).filter(UserModel.id == user.id).count() > 0: - raise ValueError(f"User with id {user.id} already exists") - session.add(UserModel(**vars(user))) - session.commit() - - @enforce_types - def create_preset(self, preset: Preset): - with self.session_maker() as session: - if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0: - raise ValueError(f"User with id {preset.id} already exists") - session.add(PresetModel(**vars(preset))) - session.commit() + "here is one example longhand to demonstrate the meta pattern" + return Agent.create(self.db_session, agent.model_dump(exclude_none=True)) + + def __getattr__(self, name, *args, **kwargs): + """temporary metaprogramming to clean up all the getters and setters here""" + if name.startswith("create_"): + _, raw_model_name = name.split("_",1) + model = globals().get(pascalize(raw_model_name)) # gross, but nessary for now + splatted_pydantic = args[0].model_dump(exclude_none=True) + return model.create(self.db_session, splatted_pydantic) + + #@enforce_types + #def create_source(self, source: Source, exists_ok=False): + ## make sure source.name does not already exist for user + #with self.session_maker() as session: + #if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0: + #if not exists_ok: + #raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}") + #else: + #session.update(SourceModel(**vars(source))) + #else: + #session.add(SourceModel(**vars(source))) + #session.commit() + + #@enforce_types + #def create_user(self, user: User): + #with self.session_maker() as session: + #if session.query(UserModel).filter(UserModel.id == user.id).count() > 0: + #raise ValueError(f"User with id {user.id} already exists") + #session.add(UserModel(**vars(user))) + #session.commit() + + #@enforce_types + #def create_preset(self, preset: Preset): + #with self.session_maker() as session: + #if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0: + #raise ValueError(f"User with id {preset.id} already exists") + #session.add(PresetModel(**vars(preset))) + #session.commit() @enforce_types def get_preset( @@ -442,30 +456,23 @@ def get_preset( # session.add(PresetFunctionMapping(user_id=user_id, preset_id=preset_id, function=function)) # session.commit() - @enforce_types - def set_preset_sources(self, preset_id: uuid.UUID, sources: List[uuid.UUID]): - preset = self.get_preset(preset_id) - if preset is None: - raise ValueError(f"Preset with id {preset_id} does not exist") - user_id = preset.user_id - with self.session_maker() as session: - for source_id in sources: - session.add(PresetSourceMapping(user_id=user_id, preset_id=preset_id, source_id=source_id)) - session.commit() + def set_preset_sources(self, + preset_id: uuid.UUID, + sources: List[uuid.UUID], + actor: Optional["User"] = None) -> None: + """Legacy assign sources to a preset. This should be a normal relationship collection in the future. + Args: + preset_id: the preset raw UUID to assign sources to, legacy support + sources: the source raw UUID for each source to assign to the preset, legacy support + actor: the user making the assignment. TODO: this will not be optional in the future! + """ + preset = Preset.read(self.db_session, preset_id) + preset.sources = [Source.read(self.db_session, source_id) for source_id in sources] + preset.update(self.db_session) - # @enforce_types - # def get_preset_functions(self, preset_id: uuid.UUID) -> List[str]: - # with self.session_maker() as session: - # results = session.query(PresetFunctionMapping).filter(PresetFunctionMapping.preset_id == preset_id).all() - # return [r.function for r in results] - - @enforce_types def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]: - with self.session_maker() as session: - results = session.query(PresetSourceMapping).filter(PresetSourceMapping.preset_id == preset_id).all() - return [r.source_id for r in results] + return [s._id for s in Preset.read(self.db_session, preset_id).sources] - @enforce_types def update_agent(self, agent: AgentState): with self.session_maker() as session: session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent)) diff --git a/memgpt/orm/sqlalchemy_base.py b/memgpt/orm/sqlalchemy_base.py index 4021ee3283..19318574b5 100644 --- a/memgpt/orm/sqlalchemy_base.py +++ b/memgpt/orm/sqlalchemy_base.py @@ -7,7 +7,7 @@ mapped_column ) from memgpt.log import get_logger -from memgpt.orm.base import CommonSqlalchemyMetaMixins +from memgpt.orm.base import Base, CommonSqlalchemyMetaMixins from memgpt.orm.errors import NoResultFound if TYPE_CHECKING: @@ -19,13 +19,15 @@ logger = get_logger(__name__) -class SqlalchemyBase(CommonSqlalchemyMetaMixins): +class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): __abstract__ = True __order_by_default__ = "created_at" _id: Mapped[UUID] = mapped_column(SQLUUID(), primary_key=True, default=uuid4) + deleted: Mapped[bool] = mapped_column(bool, default=False, doc="Is this record deleted? Used for universal soft deletes.") + @property def __prefix__(self) -> str: return depascalize(self.__class__.__name__) @@ -52,11 +54,31 @@ def list(cls, db_session: "Session") -> list[Type["Base"]]: @classmethod def read( - cls, db_session: "Session", identifier: Union[str, UUID], **kwargs + cls, + db_session: "Session", + identifier: Union[str, UUID], + actor: Optional["User"] = None, + access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], + **kwargs ) -> Type["SqlalchemyBase"]: - del kwargs + """The primary accessor for an ORM record. + Args: + db_session: the database session to use when retrieving the record + identifier: the identifier of the record to read, can be the id string or the UUID object for backwards compatibility + actor: if specified, results will be scoped only to records the user is able to access + access: if actor is specified, records will be filtered to the minimum permission level for the actor + kwargs: additional arguments to pass to the read, used for more complex objects + Returns: + The matching object + Raises: + NoResultFound: if the object is not found + """ + del kwargs # arity for more complex reads identifier = cls.to_uid(identifier) - if found := db_session.get(cls, identifier): + query = select(cls).where(cls._id == identifier) + if actor: + query = cls.apply_access_predicate(query, actor, access) + if found := db_session.execute(query).scalar(): return found raise NoResultFound(f"{cls.__name__} with id {identifier} not found") @@ -102,7 +124,7 @@ def apply_access_predicate( ) if not org_uid: raise ValueError("object %s has no organization accessor", actor) - return query.where(cls._organization_id == org_uid) + return query.where(cls._organization_id == org_uid, cls.deleted == False) @property def __pydantic_model__(self) -> Type["BaseModel"]: diff --git a/memgpt/orm/token.py b/memgpt/orm/token.py index 7c382e0112..fbc11865f6 100644 --- a/memgpt/orm/token.py +++ b/memgpt/orm/token.py @@ -1,15 +1,45 @@ -from typing import Optional +from uuid import UUID, uuid4 +from typing import Optional, TYPE_CHECKING from sqlalchemy import String -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship, hybrid_property +from memgpt.data_types import Token as PydanticToken from memgpt.orm.sqlalchemy_base import SqlalchemyBase from memgpt.orm.mixins import UserMixin +from memgpt.log import get_logger +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +logger = get_logger(__name__) class Token(SqlalchemyBase, UserMixin): __tablename__ = 'token' + __pydantic_model__ = PydanticToken + _temporary_shim_api_key: Mapped[Optional[str]] = mapped_column(String, + default=lambda: "sk-" + str(uuid4()), + doc="a temporary shim to get the ORM launched without refactoring downstream") hash:Mapped[str] = mapped_column(String, doc="the secured one-way hash of the token") name:Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="a name to identify the token") - user: Mapped["User"] = relationship("User", back_populates="tokens") \ No newline at end of file + user: Mapped["User"] = relationship("User", back_populates="tokens") + + @hybrid_property + def api_key(self) -> Optional[str]: + """enforce read-only on temporary shim api key""" + logger.warning("Token.api_key is a temporary shim to get the ORM launched. It is unsecure, dangerous, and will be replaced by token.authenticate() in the next PR!") + return self._temporary_shim_api_key + + @classmethod + def factory(cls, user_id: "UUID", name: Optional[str] = None) -> "Token": + """Note: this is a temporary shim to get the ORM launched. It will immediately be replaced with proper + secure token generation! + """ + return cls(user_id=user_id, + name=name) + + @classmethod + def get_by_api_key(cls, db_session:"Session", api_key:str) -> "Token": + """temporary lookup (insecure! replace!) to get a token by the plain text user api key""" + return db_session.query(cls).filter(cls._temporary_shim_api_key == api_key, cls.deleted==False).one() \ No newline at end of file From e440f248c1a43f899b8cfba6c529236c099e5f96 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 28 Jun 2024 12:58:38 -0400 Subject: [PATCH 37/45] The big cutover to ORM begins! What's happening here: 1. the metadata.py file is being updated to use the ORM 2. conflicting models are being sunset and/or quarantined for this PR 3. CRUD accessors stay in metadatastore but are now managed behind the scenes by the ORM This is going to break a lot of things (which is goodTo get unbroken: 1. update the tests to no longer be aware of the backend configs 2. update the code to same 3. remove all the SQLModel and deprecated backend code 4. document (loom) how the ORM works, how to create migrations, how to traverse the ORM tree etc etc. Strategy here should be to merge this into a long-running branch and start CI against it, then keep pulling main into it until we're ready for a major release (this will be a major). Configs will be extremely thin after this PR. We should be set up to move docker dev to a single stack and docker quickstart to a single image. --- memgpt/metadata.py | 653 ++++----------------------------------------- 1 file changed, 59 insertions(+), 594 deletions(-) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 7e9005218e..e168c1f73b 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -1,338 +1,34 @@ """ Metadata store for user/agent/data_source information""" from typing import TYPE_CHECKING -import inspect as python_inspect -import os -import secrets -import traceback import uuid -from typing import List, Optional, Type +from typing import List, Optional from humps import pascalize from memgpt.log import get_logger from memgpt.orm.utilities import get_db_session from memgpt.orm.token import Token from memgpt.orm.agent import Agent +from memgpt.orm.job import Job if TYPE_CHECKING: from sqlalchemy.orm import Session logger = get_logger(__name__) -#from sqlalchemy import ( - #BIGINT, - #CHAR, - #JSON, - #Boolean, - #Column, - #DateTime, - #String, - #TypeDecorator, - #create_engine, - #desc, - #func, -#) -#from sqlalchemy.dialects.postgresql import UUID -#from sqlalchemy.exc import InterfaceError -#from sqlalchemy.orm import declarative_base, sessionmaker -#from sqlalchemy.sql import func - -#from memgpt.config import MemGPTConfig from memgpt.data_types import ( AgentState, - EmbeddingConfig, - LLMConfig, Preset, Source, Token, User, ) -from memgpt.functions.functions import load_all_function_sets from memgpt.orm.enums import JobStatus -#from memgpt.models.pydantic_models import ( - #HumanModel, - #JobModel, - #JobStatus, - #PersonaModel, - #ToolModel, -#) -from memgpt.settings import settings -from memgpt.utils import printd #enforce_types, get_utc_time, printd - - -#Base = declarative_base() - - -# Custom UUID type -#class CommonUUID(TypeDecorator): - #impl = CHAR - #cache_ok = True - - #def load_dialect_impl(self, dialect): - #if dialect.name == "postgresql": - #return dialect.type_descriptor(UUID(as_uuid=True)) - #else: - #return dialect.type_descriptor(CHAR()) - - #def process_bind_param(self, value, dialect): - #if dialect.name == "postgresql" or value is None: - #return value - #else: - #return str(value) # Convert UUID to string for SQLite - - #def process_result_value(self, value, dialect): - #if dialect.name == "postgresql" or value is None: - #return value - #else: - #return uuid.UUID(value) - - -#class LLMConfigColumn(TypeDecorator): - #"""Custom type for storing LLMConfig as JSON""" - - #impl = JSON - #cache_ok = True - - #def load_dialect_impl(self, dialect): - #return dialect.type_descriptor(JSON()) - - #def process_bind_param(self, value, dialect): - #if value: - #return vars(value) - #return value - - #def process_result_value(self, value, dialect): - #if value: - #return LLMConfig(**value) - #return value - - -#class EmbeddingConfigColumn(TypeDecorator): - #"""Custom type for storing EmbeddingConfig as JSON""" - - #impl = JSON - #cache_ok = True - - #def load_dialect_impl(self, dialect): - #return dialect.type_descriptor(JSON()) - - #def process_bind_param(self, value, dialect): - #if value: - #return vars(value) - #return value - - #def process_result_value(self, value, dialect): - #if value: - #return EmbeddingConfig(**value) - #return value - - -#class UserModel(Base): - #__tablename__ = "users" - #__table_args__ = {"extend_existing": True} - - #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - ## name = Column(String, nullable=False) - #default_agent = Column(String) - - #policies_accepted = Column(Boolean, nullable=False, default=False) - - #def __repr__(self) -> str: - #return f"" - - #def to_record(self) -> User: - #return User( - #id=self.id, - ## name=self.name - #default_agent=self.default_agent, - #policies_accepted=self.policies_accepted, - #) - - -#class TokenModel(Base): - #"""Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens).""" - - #__tablename__ = "tokens" - - #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - ## each api key is tied to a user account (that it validates access for) - #user_id = Column(CommonUUID, nullable=False) - ## the api key - #token = Column(String, nullable=False) - ## extra (optional) metadata - #name = Column(String) - - #def __repr__(self) -> str: - #return f"" - - #def to_record(self) -> User: - #return Token( - #id=self.id, - #user_id=self.user_id, - #token=self.token, - #name=self.name, - #) - - -#def generate_api_key(prefix="sk-", length=51) -> str: - ## Generate 'length // 2' bytes because each byte becomes two hex digits. Adjust length for prefix. - #actual_length = max(length - len(prefix), 1) // 2 # Ensure at least 1 byte is generated - #random_bytes = secrets.token_bytes(actual_length) - #new_key = prefix + random_bytes.hex() - #return new_key - - -#class AgentModel(Base): - #"""Defines data model for storing Passages (consisting of text, embedding)""" - - #__tablename__ = "agents" - #__table_args__ = {"extend_existing": True} - - #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - #user_id = Column(CommonUUID, nullable=False) - #name = Column(String, nullable=False) - #persona = Column(String) - #human = Column(String) - #preset = Column(String) - #created_at = Column(DateTime(timezone=True), server_default=func.now()) - - ## configs - #llm_config = Column(LLMConfigColumn) - #embedding_config = Column(EmbeddingConfigColumn) - - ## state - #state = Column(JSON) - - #def __repr__(self) -> str: - #return f"" - - #def to_record(self) -> AgentState: - #return AgentState( - #id=self.id, - #user_id=self.user_id, - #name=self.name, - #persona=self.persona, - #human=self.human, - #preset=self.preset, - #created_at=self.created_at, - #llm_config=self.llm_config, - #embedding_config=self.embedding_config, - #state=self.state, - #) - - -#class SourceModel(Base): - #"""Defines data model for storing Passages (consisting of text, embedding)""" - - #__tablename__ = "sources" - #__table_args__ = {"extend_existing": True} - - ## Assuming passage_id is the primary key - ## id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - #user_id = Column(CommonUUID, nullable=False) - #name = Column(String, nullable=False) - #created_at = Column(DateTime(timezone=True), server_default=func.now()) - #embedding_dim = Column(BIGINT) - #embedding_model = Column(String) - #description = Column(String) - - ## TODO: add num passages - - #def __repr__(self) -> str: - #return f"" - - #def to_record(self) -> Source: - #return Source( - #id=self.id, - #user_id=self.user_id, - #name=self.name, - #created_at=self.created_at, - #embedding_dim=self.embedding_dim, - #embedding_model=self.embedding_model, - #description=self.description, - #) - - -#class AgentSourceMappingModel(Base): - #"""Stores mapping between agent -> source""" - - #__tablename__ = "agent_source_mapping" - - #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - #user_id = Column(CommonUUID, nullable=False) - #agent_id = Column(CommonUUID, nullable=False) - #source_id = Column(CommonUUID, nullable=False) - - #def __repr__(self) -> str: - #return f"" - - -#class PresetSourceMapping(Base): - #__tablename__ = "preset_source_mapping" - - #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - #user_id = Column(CommonUUID, nullable=False) - #preset_id = Column(CommonUUID, nullable=False) - #source_id = Column(CommonUUID, nullable=False) - - #def __repr__(self) -> str: - #return f"" - - -## class PresetFunctionMapping(Base): -## __tablename__ = "preset_function_mapping" -## -## id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) -## user_id = Column(CommonUUID, nullable=False) -## preset_id = Column(CommonUUID, nullable=False) -## #function_id = Column(CommonUUID, nullable=False) -## function = Column(String, nullable=False) # TODO: convert to ID eventually -## -## def __repr__(self) -> str: -## return f"" - - -#class PresetModel(Base): - #"""Defines data model for storing Preset objects""" - - #__tablename__ = "presets" - #__table_args__ = {"extend_existing": True} - - #id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - #user_id = Column(CommonUUID, nullable=False) - #name = Column(String, nullable=False) - #description = Column(String) - #system = Column(String) - #human = Column(String) - #human_name = Column(String, nullable=False) - #persona = Column(String) - #persona_name = Column(String, nullable=False) - #preset = Column(String) - #created_at = Column(DateTime(timezone=True), server_default=func.now()) - - #functions_schema = Column(JSON) - - #def __repr__(self) -> str: - #return f"" - - #def to_record(self) -> Preset: - #return Preset( - #id=self.id, - #user_id=self.user_id, - #name=self.name, - #description=self.description, - #system=self.system, - #human=self.human, - #persona=self.persona, - #human_name=self.human_name, - #persona_name=self.persona_name, - #preset=self.preset, - #created_at=self.created_at, - #functions_schema=self.functions_schema, - #) - class MetadataStore: + """Metadatastore acts as a bridge between the ORM and the rest of the application. Ideally it will be removed in coming PRs and + Allow requests to handle sessions atomically (this is how FastAPI really wants things to work, and will drastically reduce the + mucking of the ORM layer). For now, all CRUD methods are invoked here instead of the ORM layer directly. + """ db_session: "Session" = None def __init__(self): @@ -393,68 +89,43 @@ def create_agent(self, agent: AgentState): return Agent.create(self.db_session, agent.model_dump(exclude_none=True)) def __getattr__(self, name, *args, **kwargs): - """temporary metaprogramming to clean up all the getters and setters here""" - if name.startswith("create_"): - _, raw_model_name = name.split("_",1) - model = globals().get(pascalize(raw_model_name)) # gross, but nessary for now - splatted_pydantic = args[0].model_dump(exclude_none=True) - return model.create(self.db_session, splatted_pydantic) - - #@enforce_types - #def create_source(self, source: Source, exists_ok=False): - ## make sure source.name does not already exist for user - #with self.session_maker() as session: - #if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0: - #if not exists_ok: - #raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}") - #else: - #session.update(SourceModel(**vars(source))) - #else: - #session.add(SourceModel(**vars(source))) - #session.commit() - - #@enforce_types - #def create_user(self, user: User): - #with self.session_maker() as session: - #if session.query(UserModel).filter(UserModel.id == user.id).count() > 0: - #raise ValueError(f"User with id {user.id} already exists") - #session.add(UserModel(**vars(user))) - #session.commit() - - #@enforce_types - #def create_preset(self, preset: Preset): - #with self.session_maker() as session: - #if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0: - #raise ValueError(f"User with id {preset.id} already exists") - #session.add(PresetModel(**vars(preset))) - #session.commit() + """temporary metaprogramming to clean up all the getters and setters here. + + __getattr__ is always the last-ditch effort, so you can override it by declaring any method (ie `get_hamburger`) to handle the call instead. + """ + action, raw_model_name = name.split("_",1) + Model = globals().get(pascalize(raw_model_name)) # gross, but nessary for now + match action: + case "get": + # this has no support for scoping, but we won't keep this pattern long + return Model.read(self.db_session, args[0]).to_record() + case "create": + splatted_pydantic = args[0].model_dump(exclude_none=True) + return Model.create(self.db_session, splatted_pydantic).to_record() + case "update": + instance = Model.read(self.db_session, args[0].id) + splatted_pydantic = args[0].model_dump(exclude_none=True, exclude=["id"]) + for k,v in splatted_pydantic.items(): + setattr(instance, k, v) + instance.update(self.db_session) + return instance.to_record() + case "delete": + instance = Model.read(self.db_session, args[0]) + instance.delete(self.db_session) + case "list": + # TODO: this has no scoping, no pagination, and no filtering. it's a placeholder. + return [r.to_record() for r in Model.list(self.db_session)] - @enforce_types def get_preset( self, preset_id: Optional[uuid.UUID] = None, name: Optional[str] = None, user_id: Optional[uuid.UUID] = None ) -> Optional[Preset]: - with self.session_maker() as session: - if preset_id: - results = session.query(PresetModel).filter(PresetModel.id == preset_id).all() - elif name and user_id: - results = session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).all() - else: - raise ValueError("Must provide either preset_id or (preset_name and user_id)") - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - - # @enforce_types - # def set_preset_functions(self, preset_id: uuid.UUID, functions: List[str]): - # preset = self.get_preset(preset_id) - # if preset is None: - # raise ValueError(f"Preset with id {preset_id} does not exist") - # user_id = preset.user_id - # with self.session_maker() as session: - # for function in functions: - # session.add(PresetFunctionMapping(user_id=user_id, preset_id=preset_id, function=function)) - # session.commit() + assert preset_id or (name and user_id), "Must provide either preset_id or (preset_name and user_id)" + + #TODO: pivot this to org scope - get by id or by name within org of actor + if preset_id: + return Preset.read(self.db_session, preset_id).to_record() + # Implement actor lookup + return Preset.read(self.db_session, name=name, actor=user_id).to_record() def set_preset_sources(self, preset_id: uuid.UUID, @@ -473,23 +144,7 @@ def set_preset_sources(self, def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]: return [s._id for s in Preset.read(self.db_session, preset_id).sources] - def update_agent(self, agent: AgentState): - with self.session_maker() as session: - session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent)) - session.commit() - - @enforce_types - def update_user(self, user: User): - with self.session_maker() as session: - session.query(UserModel).filter(UserModel.id == user.id).update(vars(user)) - session.commit() - - @enforce_types - def update_source(self, source: Source): - with self.session_maker() as session: - session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source)) - session.commit() - + ## TODO: update these to get rid of SQLModel completely! this is up next. @enforce_types def update_human(self, human: HumanModel): with self.session_maker() as session: @@ -504,192 +159,26 @@ def update_persona(self, persona: PersonaModel): session.commit() session.refresh(persona) - @enforce_types - def update_tool(self, tool: ToolModel): - with self.session_maker() as session: - session.add(tool) - session.commit() - session.refresh(tool) - - @enforce_types - def delete_agent(self, agent_id: uuid.UUID): - with self.session_maker() as session: - - # delete agents - session.query(AgentModel).filter(AgentModel.id == agent_id).delete() - - # delete mappings - session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).delete() - - session.commit() - - @enforce_types - def delete_source(self, source_id: uuid.UUID): - with self.session_maker() as session: - # delete from sources table - session.query(SourceModel).filter(SourceModel.id == source_id).delete() - - # delete any mappings - session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete() - - session.commit() - - @enforce_types - def delete_user(self, user_id: uuid.UUID): - with self.session_maker() as session: - # delete from users table - session.query(UserModel).filter(UserModel.id == user_id).delete() - - # delete associated agents - session.query(AgentModel).filter(AgentModel.user_id == user_id).delete() - - # delete associated sources - session.query(SourceModel).filter(SourceModel.user_id == user_id).delete() - - # delete associated mappings - session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete() - - session.commit() - - @enforce_types - def list_presets(self, user_id: uuid.UUID) -> List[Preset]: - with self.session_maker() as session: - results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all() - return [r.to_record() for r in results] - - @enforce_types - # def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools - def list_tools(self, user_id: Optional[uuid.UUID] = None) -> List[ToolModel]: - with self.session_maker() as session: - results = session.query(ToolModel).filter(ToolModel.user_id == None).all() - if user_id: - results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all() - return results - - @enforce_types - def list_agents(self, user_id: uuid.UUID) -> List[AgentState]: - with self.session_maker() as session: - results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all() - return [r.to_record() for r in results] - - @enforce_types - def list_sources(self, user_id: uuid.UUID) -> List[Source]: - with self.session_maker() as session: - results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all() - return [r.to_record() for r in results] - - @enforce_types - def get_agent( - self, agent_id: Optional[uuid.UUID] = None, agent_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None - ) -> Optional[AgentState]: - with self.session_maker() as session: - if agent_id: - results = session.query(AgentModel).filter(AgentModel.id == agent_id).all() - else: - assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name" - results = session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all() - - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result - return results[0].to_record() - - @enforce_types - def get_user(self, user_id: uuid.UUID) -> Optional[User]: - with self.session_maker() as session: - results = session.query(UserModel).filter(UserModel.id == user_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - - @enforce_types def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]): - with self.session_maker() as session: - query = session.query(UserModel).order_by(desc(UserModel.id)) - if cursor: - query = query.filter(UserModel.id < cursor) - results = query.limit(limit).all() - if not results: - return None, [] - user_records = [r.to_record() for r in results] - next_cursor = user_records[-1].id - assert isinstance(next_cursor, uuid.UUID) - - return next_cursor, user_records - - @enforce_types - def get_source( - self, source_id: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None - ) -> Optional[Source]: - with self.session_maker() as session: - if source_id: - results = session.query(SourceModel).filter(SourceModel.id == source_id).all() - else: - assert user_id is not None and source_name is not None - results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - - @enforce_types - def get_tool(self, tool_name: str, user_id: Optional[uuid.UUID] = None) -> Optional[ToolModel]: - # TODO: add user_id when tools can eventually be added by users - with self.session_maker() as session: - results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all() - if user_id: - results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all() - - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0] + del limit # TODO: implement pagination as part of predicate + return None , [u.to_record() for u in User.list(self.db_session)] # agent source metadata - @enforce_types - def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID): - with self.session_maker() as session: - session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id)) - session.commit() + def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID) -> None: + agent = Agent.read(self.db_session, agent_id) + source = Source.read(self.db_session, source_id) + agent.sources.append(source) - @enforce_types def list_attached_sources(self, agent_id: uuid.UUID) -> List[uuid.UUID]: - with self.session_maker() as session: - results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all() - - source_ids = [] - # make sure source exists - for r in results: - source = self.get_source(source_id=r.source_id) - if source: - source_ids.append(r.source_id) - else: - printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.") - return source_ids + return [s._id for s in Agent.read(self.db_session, agent_id).sources] - @enforce_types def list_attached_agents(self, source_id: uuid.UUID) -> List[uuid.UUID]: - with self.session_maker() as session: - results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all() - - agent_ids = [] - # make sure agent exists - for r in results: - agent = self.get_agent(agent_id=r.agent_id) - if agent: - agent_ids.append(r.agent_id) - else: - printd(f"Warning: agent {r.agent_id} does not exist but exists in mapping database. This should never happen.") - return agent_ids + return [a._id for a in Source.read(self.db_session, source_id).agents] - @enforce_types - def detach_source(self, agent_id: uuid.UUID, source_id: uuid.UUID): - with self.session_maker() as session: - session.query(AgentSourceMappingModel).filter( - AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id - ).delete() - session.commit() + def detach_source(self, agent_id: uuid.UUID, source_id: uuid.UUID) -> None: + agent = Agent.read(self.db_session, agent_id) + source = Source.read(self.db_session, source_id) + agent.sources.remove(source) @enforce_types def add_human(self, human: HumanModel): @@ -707,22 +196,14 @@ def add_persona(self, persona: PersonaModel): session.add(persona) session.commit() - @enforce_types - def add_preset(self, preset: PresetModel): # TODO: remove - with self.session_maker() as session: - session.add(preset) - session.commit() + def add_preset(self, preset: PresetModel) -> "PresetModel": + return self.create_preset(preset) - @enforce_types def add_tool(self, tool: ToolModel): - with self.session_maker() as session: - if self.get_tool(tool.name, tool.user_id): - raise ValueError(f"Tool with name {tool.name} already exists for user_id {tool.user_id}") - session.add(tool) - session.commit() + return self.create_tool(tool) - @enforce_types def get_human(self, name: str, user_id: uuid.UUID) -> Optional[HumanModel]: + # TODO: What? why does a getter take a 1:m id? with self.session_maker() as session: results = session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).all() if len(results) == 0: @@ -790,22 +271,6 @@ def create_job(self, job: JobModel): session.expunge_all() def update_job_status(self, job_id: uuid.UUID, status: JobStatus): - with self.session_maker() as session: - session.query(JobModel).filter(JobModel.id == job_id).update({"status": status}) - if status == JobStatus.COMPLETED: - session.query(JobModel).filter(JobModel.id == job_id).update({"completed_at": get_utc_time()}) - session.commit() - - def update_job(self, job: JobModel): - with self.session_maker() as session: - session.add(job) - session.commit() - session.refresh(job) - - def get_job(self, job_id: uuid.UUID) -> Optional[JobModel]: - with self.session_maker() as session: - results = session.query(JobModel).filter(JobModel.id == job_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0] + job = Job.read(self.db_session, job_id) + job.status = status + job.update(self.db_session) From 7f9a7a3181fcd405f772be5cf1294023b42df42a Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 3 Jul 2024 11:42:04 -0400 Subject: [PATCH 38/45] hacking out SQLModel --- memgpt/models/pydantic_models.py | 55 +++++++++++++------------------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index c7d4ce97c0..1a22a2918a 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -1,5 +1,5 @@ # tool imports -import uuid +from uuid import UUID from datetime import datetime from enum import Enum from typing import Dict, List, Optional @@ -19,6 +19,15 @@ class MemGPTUsageStatistics(BaseModel): total_tokens: int step_count: int +class PersistedBase(BaseModel): + """shared elements that all models coming from the ORM will support""" + id: str = Field(description="The unique identifier of the object prefixed with the object type (Stripe pattern).") + uuid: UUID = Field(description="The unique identifier of the object stored as a raw uuid (for legacy support).") + deleted: bool = Field(default=False, description="Is this record deleted? Used for universal soft deletes.") + created_at: datetime = Field(description="The unix timestamp of when the object was created.") + updated_at: datetime = Field(description="The unix timestamp of when the object was last updated.") + created_by_id: Optional[str] = Field(description="The unique identifier of the user who created the object.") + last_updated_by_id: Optional[str] = Field(description="The unique identifier of the user who last updated the object.") class LLMConfigModel(BaseModel): model: Optional[str] = "gpt-4" @@ -38,6 +47,16 @@ class EmbeddingConfigModel(BaseModel): embedding_dim: Optional[int] = 1536 embedding_chunk_size: Optional[int] = 300 +class OrganizationSummary(PersistedBase): + """An Organization interface with minimal references, good when only the link is needed""" + name: str = Field(..., description="The name of the organization.") + +class UserSummary(PersistedBase): + """A User interface with minimal references, good when only the link is needed""" + name: Optional[str] = Field(default=None, description="The name of the user.") + email: Optional[str] = Field(default=None, description="The email of the user.") + organization: Optional[OrganizationSummary] = Field(None, description="The organization this user belongs to.") + class PresetModel(BaseModel): name: str = Field(..., description="The name of the preset.") @@ -53,10 +72,8 @@ class PresetModel(BaseModel): functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") -class ToolModel(SQLModel, table=True): - # TODO move into database +class ToolModel(PersistedBase): name: str = Field(..., description="The name of the function.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True) tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.") source_type: Optional[str] = Field(None, description="The type of the source code.") source_code: Optional[str] = Field(..., description="The source code of the function.") @@ -64,37 +81,11 @@ class ToolModel(SQLModel, table=True): json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.") - # optional: user_id (user-specific tools) - user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.") - - # Needed for Column(JSON) - class Config: - arbitrary_types_allowed = True - - -class AgentToolMap(SQLModel, table=True): - # mapping between agents and tools - agent_id: uuid.UUID = Field(..., description="The unique identifier of the agent.") - tool_id: uuid.UUID = Field(..., description="The unique identifier of the tool.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the agent-tool map.", primary_key=True) + organization: Optional[OrganizationSummary] = Field(None, description="The organization this function belongs to.") - -class PresetToolMap(SQLModel, table=True): - # mapping between presets and tools - preset_id: uuid.UUID = Field(..., description="The unique identifier of the preset.") - tool_id: uuid.UUID = Field(..., description="The unique identifier of the tool.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset-tool map.", primary_key=True) - - -class AgentStateModel(BaseModel): - id: uuid.UUID = Field(..., description="The unique identifier of the agent.") +class AgentStateModel(PersistedBase): name: str = Field(..., description="The name of the agent.") description: Optional[str] = Field(None, description="The description of the agent.") - user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.") - - # timestamps - # created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the agent was created.") - created_at: int = Field(..., description="The unix timestamp of when the agent was created.") # preset information tools: List[str] = Field(..., description="The tools used by the agent.") From 585a3cb5bc2b0ad13cd65d5666e585fb5126b116 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Wed, 3 Jul 2024 16:30:18 -0400 Subject: [PATCH 39/45] last of the sqlmodel models --- memgpt/models/pydantic_models.py | 107 ++++++++++++------------------- memgpt/orm/document.py | 25 ++++++++ memgpt/orm/passage.py | 27 ++++++++ 3 files changed, 92 insertions(+), 67 deletions(-) create mode 100644 memgpt/orm/document.py create mode 100644 memgpt/orm/passage.py diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 1a22a2918a..eddb0002c0 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -4,12 +4,10 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import JSON, Column -from sqlalchemy_utils import ChoiceType -from sqlmodel import Field, SQLModel +from pydantic import BaseModel, Field, ConfigDict from memgpt.settings import settings +from memgpt.orm.enums import JobStatus from memgpt.utils import get_human_text, get_persona_text, get_utc_time @@ -23,13 +21,14 @@ class PersistedBase(BaseModel): """shared elements that all models coming from the ORM will support""" id: str = Field(description="The unique identifier of the object prefixed with the object type (Stripe pattern).") uuid: UUID = Field(description="The unique identifier of the object stored as a raw uuid (for legacy support).") - deleted: bool = Field(default=False, description="Is this record deleted? Used for universal soft deletes.") + deleted: Optional[bool] = Field(default=False, description="Is this record deleted? Used for universal soft deletes.") created_at: datetime = Field(description="The unix timestamp of when the object was created.") updated_at: datetime = Field(description="The unix timestamp of when the object was last updated.") created_by_id: Optional[str] = Field(description="The unique identifier of the user who created the object.") last_updated_by_id: Optional[str] = Field(description="The unique identifier of the user who last updated the object.") class LLMConfigModel(BaseModel): + # TODO: 🤮 don't default to a vendor! bug city! model: Optional[str] = "gpt-4" model_endpoint_type: Optional[str] = "openai" model_endpoint: Optional[str] = "https://api.openai.com/v1" @@ -57,29 +56,26 @@ class UserSummary(PersistedBase): email: Optional[str] = Field(default=None, description="The email of the user.") organization: Optional[OrganizationSummary] = Field(None, description="The organization this user belongs to.") - -class PresetModel(BaseModel): - name: str = Field(..., description="The name of the preset.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.") - user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.") +class PresetModel(PersistedBase): + name: str = Field(description="The name of the preset.") description: Optional[str] = Field(None, description="The description of the preset.") - created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.") system: str = Field(..., description="The system prompt of the preset.") + # TODO: these should never default if the ORM manages defaults persona: str = Field(default=get_persona_text(settings.persona), description="The persona of the preset.") persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.") human: str = Field(default=get_human_text(settings.human), description="The human of the preset.") human_name: Optional[str] = Field(None, description="The name of the human of the preset.") - functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") - + functions_schema: List[dict] = Field(..., description="The functions schema of the preset.") + organization: Optional[OrganizationSummary] = Field(None, description="The organization this Preset belongs to.") class ToolModel(PersistedBase): name: str = Field(..., description="The name of the function.") - tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.") + tags: List[str] = Field(description="Metadata tags.") source_type: Optional[str] = Field(None, description="The type of the source code.") - source_code: Optional[str] = Field(..., description="The source code of the function.") + source_code: Optional[str] = Field(None, description="The source code of the function.") module: Optional[str] = Field(None, description="The module of the function.") - json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.") + json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.") organization: Optional[OrganizationSummary] = Field(None, description="The organization this function belongs to.") @@ -90,7 +86,6 @@ class AgentStateModel(PersistedBase): # preset information tools: List[str] = Field(..., description="The tools used by the agent.") system: str = Field(..., description="The system prompt used by the agent.") - # functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.") # llm information llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.") @@ -98,78 +93,56 @@ class AgentStateModel(PersistedBase): # agent state state: Optional[Dict] = Field(None, description="The state of the agent.") - metadata: Optional[Dict] = Field(None, description="The metadata of the agent.") + metadata: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") class CoreMemory(BaseModel): human: str = Field(..., description="Human element of the core memory.") persona: str = Field(..., description="Persona element of the core memory.") +class MemorySection(PersistedBase): + """the common base for the legacy memory sections. + This is going away in favor of MemoryModule dynamic sections. + memgpt/memory.py + """ + text: Optional[str] = Field(default=get_human_text(settings.human), description="The content to be added to this section of core memory.") + name: str = Field(..., description="The name of the memory section.") + organization: Optional[OrganizationSummary] = Field(None, description="The organization this memory belongs to.") -class HumanModel(SQLModel, table=True): - text: str = Field(default=get_human_text(settings.human), description="The human text.") - name: str = Field(..., description="The name of the human.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True) - user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.") +class HumanMemory(MemorySection): + """Specifically for human, legacy""" -class PersonaModel(SQLModel, table=True): - text: str = Field(default=get_persona_text(settings.persona), description="The persona text.") - name: str = Field(..., description="The name of the persona.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True) - user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.") +class PersonaModel(MemorySection): + """Specifically for persona, legacy""" -class SourceModel(SQLModel, table=True): +class SourceModel(PersistedBase): name: str = Field(..., description="The name of the source.") description: Optional[str] = Field(None, description="The description of the source.") - user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.") - created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the source was created.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True) - description: Optional[str] = Field(None, description="The description of the source.") - # embedding info - # embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.") embedding_config: Optional[EmbeddingConfigModel] = Field( - None, sa_column=Column(JSON), description="The embedding configuration used by the passage." + None, description="The embedding configuration used by the passage." ) # NOTE: .metadata is a reserved attribute on SQLModel - metadata_: Optional[dict] = Field(None, sa_column=Column(JSON), description="Metadata associated with the source.") - + metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") -class JobStatus(str, Enum): - created = "created" - running = "running" - completed = "completed" - failed = "failed" - - -class JobModel(SQLModel, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the job.", primary_key=True) - # status: str = Field(default="created", description="The status of the job.") - status: JobStatus = Field(default=JobStatus.created, description="The status of the job.", sa_column=Column(ChoiceType(JobStatus))) - created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.") +class JobModel(PersistedBase): + status: JobStatus = Field(default=JobStatus.created, description="The status of the job.") completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.") - user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the job.") - metadata_: Optional[dict] = Field({}, sa_column=Column(JSON), description="The metadata of the job.") + user: UserSummary = Field(description="The user associated with the job.") + metadata_: Optional[dict] = Field({}, description="The metadata of the job.") - -class PassageModel(BaseModel): - user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the passage.") - agent_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the agent associated with the passage.") +class PassageModel(PersistedBase): text: str = Field(..., description="The text of the passage.") embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.") embedding_config: Optional[EmbeddingConfigModel] = Field( - None, sa_column=Column(JSON), description="The embedding configuration used by the passage." + None, description="The embedding configuration used by the passage." ) - data_source: Optional[str] = Field(None, description="The data source of the passage.") - doc_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the document associated with the passage.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the passage.", primary_key=True) - metadata: Optional[Dict] = Field({}, description="The metadata of the passage.") - + document: "DocumentModel" = Field(description="The document associated with the passage.") + metadata_: Optional[dict] = Field({}, description="The metadata of the passage.") -class DocumentModel(BaseModel): - user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the document.") - text: str = Field(..., description="The text of the document.") +class DocumentModel(PersistedBase): + organization: OrganizationSummary = Field(description="The organization this document belongs to.") + text: str = Field(..., description="The full text of the document.") data_source: str = Field(..., description="The data source of the document.") - id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True) - metadata: Optional[Dict] = Field({}, description="The metadata of the document.") + metadata_: Optional[Dict] = Field({}, description="The metadata of the document.") diff --git a/memgpt/orm/document.py b/memgpt/orm/document.py new file mode 100644 index 0000000000..c0194cba33 --- /dev/null +++ b/memgpt/orm/document.py @@ -0,0 +1,25 @@ +from typing import Optional, TYPE_CHECKING, List +from sqlalchemy import JSON +from sqlalchemy.orm import relationship, Mapped, mapped_column + +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.mixins import OrganizationMixin + +if TYPE_CHECKING: + from memgpt.orm.organization import Organization + from memgpt.orm.passage import Passage + + +class Document(OrganizationMixin, SqlalchemyBase): + """Represents a file or distinct, complete body of information. + """ + __tablename__ = "document" + + text: Mapped[str] = mapped_column(doc="The full text for the document.") + data_source: Optional[str] = mapped_column(nullable=True, doc="Human readable description of where the passage came from.") + metadata_: Optional[dict] = mapped_column(JSON, default_factory=lambda: {}, doc="additional information about the passage.") + + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="documents") + passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="document") \ No newline at end of file diff --git a/memgpt/orm/passage.py b/memgpt/orm/passage.py new file mode 100644 index 0000000000..89b439a980 --- /dev/null +++ b/memgpt/orm/passage.py @@ -0,0 +1,27 @@ +from typing import Optional, TYPE_CHECKING, List +from datetime import datetime +from sqlalchemy import JSON +from sqlalchemy.orm import relationship, Mapped, mapped_column + +from memgpt.models.pydantic_models import EmbeddingConfigModel +from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.orm.mixins import DocumentMixin + +if TYPE_CHECKING: + from memgpt.orm.document import Document + +class Passage(DocumentMixin, SqlalchemyBase): + """A segment of text from a document. + """ + __tablename__ = "passage" + + text: Mapped[str] = mapped_column(doc="The text of the passage.") + embedding: Optional[List[float]] = mapped_column(JSON, doc="The embedding of the passage.", nullable=True) + embedding_config: Optional["EmbeddingConfigModel"] = mapped_column(JSON, doc="The embedding configuration used by the passage.", + nullable=True) + data_source: Optional[str] = mapped_column(nullable=True, doc="Human readable description of where the passage came from.") + metadata_: Optional[dict] = mapped_column(JSON, default_factory=lambda: {}, doc="additional information about the passage.") + + + # relationships + document: Mapped["Document"] = relationship("Document", back_populates="passages") \ No newline at end of file From ec0b328a160351273c23341a348a0aa664d459e3 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 5 Jul 2024 09:47:46 -0400 Subject: [PATCH 40/45] stripped all the SQLModel, pydantic schemas vs dataclasses vs arb classes still a mess, but I think we can get around that to get things working --- memgpt/metadata.py | 128 +++++++------------------------ memgpt/models/pydantic_models.py | 14 ++-- memgpt/orm/organization.py | 11 ++- 3 files changed, 46 insertions(+), 107 deletions(-) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index e168c1f73b..b206e05777 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -9,6 +9,9 @@ from memgpt.orm.token import Token from memgpt.orm.agent import Agent from memgpt.orm.job import Job +from memgpt.orm.preset import Preset +from memgpt.orm.memory_templates import HumanModelTemplate, PersonaModelTemplate + if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -22,6 +25,10 @@ Token, User, ) +from memgpt.models.pydantic_models import ( + HumanModel, + PersonaModel, +) from memgpt.orm.enums import JobStatus class MetadataStore: @@ -96,6 +103,8 @@ def __getattr__(self, name, *args, **kwargs): action, raw_model_name = name.split("_",1) Model = globals().get(pascalize(raw_model_name)) # gross, but nessary for now match action: + case "add": + return self.getattr("_".join(["create",raw_model_name])) case "get": # this has no support for scoping, but we won't keep this pattern long return Model.read(self.db_session, args[0]).to_record() @@ -110,9 +119,19 @@ def __getattr__(self, name, *args, **kwargs): instance.update(self.db_session) return instance.to_record() case "delete": + # hacky temp. look up the org for the user, get all the plural (related set) for that org and delete by name + if user_uuid := (args[1] if len(args) > 1 else None): + org = User.read(user_uuid).organization + related_set = getattr(org, (raw_model_name + "s")) + related_set.filter(name=name).scalar().delete() + return instance = Model.read(self.db_session, args[0]) instance.delete(self.db_session) case "list": + # hacky temp. look up the org for the user, get all the plural (related set) for that org + if user_uuid := (args[1] if len(args) > 1 else None): + org = User.read(user_uuid).organization + return [r.to_record() for r in getattr(org, (raw_model_name + "s"))] # TODO: this has no scoping, no pagination, and no filtering. it's a placeholder. return [r.to_record() for r in Model.list(self.db_session)] @@ -144,20 +163,13 @@ def set_preset_sources(self, def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]: return [s._id for s in Preset.read(self.db_session, preset_id).sources] - ## TODO: update these to get rid of SQLModel completely! this is up next. - @enforce_types - def update_human(self, human: HumanModel): - with self.session_maker() as session: - session.add(human) - session.commit() - session.refresh(human) + def update_human(self, human: HumanModel) -> "HumanModel": + sql_human = HumanModelTemplate(**human.model_dump(exclude_none=True)).create(self.db_session) + return sql_human.to_record() - @enforce_types - def update_persona(self, persona: PersonaModel): - with self.session_maker() as session: - session.add(persona) - session.commit() - session.refresh(persona) + def update_persona(self, persona: PersonaModel) -> "PersonaModel": + sql_persona = PersonaModelTemplate(**persona.model_dump(exclude_none=True)).create(self.db_session) + return sql_persona.to_record() def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]): del limit # TODO: implement pagination as part of predicate @@ -180,95 +192,13 @@ def detach_source(self, agent_id: uuid.UUID, source_id: uuid.UUID) -> None: source = Source.read(self.db_session, source_id) agent.sources.remove(source) - @enforce_types - def add_human(self, human: HumanModel): - with self.session_maker() as session: - if self.get_human(human.name, human.user_id): - raise ValueError(f"Human with name {human.name} already exists for user_id {human.user_id}") - session.add(human) - session.commit() - - @enforce_types - def add_persona(self, persona: PersonaModel): - with self.session_maker() as session: - if self.get_persona(persona.name, persona.user_id): - raise ValueError(f"Persona with name {persona.name} already exists for user_id {persona.user_id}") - session.add(persona) - session.commit() - - def add_preset(self, preset: PresetModel) -> "PresetModel": - return self.create_preset(preset) - - def add_tool(self, tool: ToolModel): - return self.create_tool(tool) - def get_human(self, name: str, user_id: uuid.UUID) -> Optional[HumanModel]: - # TODO: What? why does a getter take a 1:m id? - with self.session_maker() as session: - results = session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0] + org = User.read(self.db_session, user_id) + return org.human_memory_templates.filter(name=name).scalar() - @enforce_types def get_persona(self, name: str, user_id: uuid.UUID) -> Optional[PersonaModel]: - with self.session_maker() as session: - results = session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0] - - @enforce_types - def list_personas(self, user_id: uuid.UUID) -> List[PersonaModel]: - with self.session_maker() as session: - results = session.query(PersonaModel).filter(PersonaModel.user_id == user_id).all() - return results - - @enforce_types - def list_humans(self, user_id: uuid.UUID) -> List[HumanModel]: - with self.session_maker() as session: - # if user_id matches provided user_id or if user_id is None - results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all() - return results - - @enforce_types - def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]: - with self.session_maker() as session: - results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all() - return results - - @enforce_types - def delete_human(self, name: str, user_id: uuid.UUID): - with self.session_maker() as session: - session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).delete() - session.commit() - - @enforce_types - def delete_persona(self, name: str, user_id: uuid.UUID): - with self.session_maker() as session: - session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete() - session.commit() - - @enforce_types - def delete_preset(self, name: str, user_id: uuid.UUID): - with self.session_maker() as session: - session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete() - session.commit() - - @enforce_types - def delete_tool(self, name: str, user_id: uuid.UUID): - with self.session_maker() as session: - session.query(ToolModel).filter(ToolModel.name == name).filter(ToolModel.user_id == user_id).delete() - session.commit() - - # job related functions - def create_job(self, job: JobModel): - with self.session_maker() as session: - session.add(job) - session.commit() - session.expunge_all() + org = User.read(self.db_session, user_id) + return org.human_memory_templates.filter(name=name).scalar() def update_job_status(self, job_id: uuid.UUID, status: JobStatus): job = Job.read(self.db_session, job_id) diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index eddb0002c0..8e5fefdf2f 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -2,7 +2,7 @@ from uuid import UUID from datetime import datetime from enum import Enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Literal from pydantic import BaseModel, Field, ConfigDict @@ -100,22 +100,24 @@ class CoreMemory(BaseModel): human: str = Field(..., description="Human element of the core memory.") persona: str = Field(..., description="Persona element of the core memory.") -class MemorySection(PersistedBase): +class MemoryTemplate(PersistedBase): """the common base for the legacy memory sections. This is going away in favor of MemoryModule dynamic sections. memgpt/memory.py """ text: Optional[str] = Field(default=get_human_text(settings.human), description="The content to be added to this section of core memory.") + type: Literal["human", "persona"] = Field(..., description="The type of memory section.") name: str = Field(..., description="The name of the memory section.") organization: Optional[OrganizationSummary] = Field(None, description="The organization this memory belongs to.") -class HumanMemory(MemorySection): +class HumanModel(MemoryTemplate): """Specifically for human, legacy""" + type: Literal["human"] = "human" -class PersonaModel(MemorySection): +class PersonaModel(MemoryTemplate): """Specifically for persona, legacy""" - + type: Literal["persona"] = "persona" class SourceModel(PersistedBase): name: str = Field(..., description="The name of the source.") @@ -145,4 +147,4 @@ class DocumentModel(PersistedBase): organization: OrganizationSummary = Field(description="The organization this document belongs to.") text: str = Field(..., description="The full text of the document.") data_source: str = Field(..., description="The data source of the document.") - metadata_: Optional[Dict] = Field({}, description="The metadata of the document.") + metadata_: Optional[Dict] = Field({}, description="The metadata of the document.") \ No newline at end of file diff --git a/memgpt/orm/organization.py b/memgpt/orm/organization.py index 7d1d4736e5..ca042a7855 100644 --- a/memgpt/orm/organization.py +++ b/memgpt/orm/organization.py @@ -1,20 +1,25 @@ -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, List from pydantic import EmailStr from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Mapped, relationship, mapped_column from memgpt.orm.sqlalchemy_base import SqlalchemyBase +from memgpt.models.pydantic_models import OrganizationSummary if TYPE_CHECKING: from memgpt.orm.user import User from memgpt.orm.agent import Agent from memgpt.orm.source import Source from memgpt.orm.tool import Tool from memgpt.orm.preset import Preset - from sqlalchemy.orm.session import Session + from memgpt.orm.memory_templates import HumanMemoryTemplate, PersonaMemoryTemplate + from sqlalchemy.orm import Session + class Organization(SqlalchemyBase): """The highest level of the object tree. All Entities belong to one and only one Organization.""" __tablename__ = "organization" + __pydantic_model__ = OrganizationSummary + name:Mapped[Optional[str]] = mapped_column(nullable=True, doc="The display name of the organization.") # relationships @@ -23,6 +28,8 @@ class Organization(SqlalchemyBase): sources: Mapped["Source"] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") tools: Mapped["Tool"] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") presets: Mapped["Preset"] = relationship("Preset", back_populates="organization", cascade="all, delete-orphan") + personas: Mapped["PersonaMemoryTemplate"] = relationship("PersonaMemoryTemplate", back_populates="organization", cascade="all, delete-orphan") + humans: Mapped["HumanMemoryTemplate"] = relationship("HumanMemoryTemplate", back_populates="organization", cascade="all, delete-orphan") @classmethod def default(cls, db_session:"Session") -> "Organization": From ed0971820724ddddef6d6aba4fa7c4d3bb4f1f2c Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 5 Jul 2024 10:45:35 -0400 Subject: [PATCH 41/45] cleanup --- memgpt/client/client.py | 8 ++++---- memgpt/metadata.py | 6 +++--- memgpt/migrations/env.py | 2 +- memgpt/orm/sqlalchemy_base.py | 4 ++-- memgpt/orm/token.py | 3 ++- memgpt/settings.py | 6 +++--- 6 files changed, 15 insertions(+), 14 deletions(-) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 09b9a35925..3c93fec842 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -260,11 +260,11 @@ def create_agent( embedding_config: Optional[EmbeddingConfig] = None, llm_config: Optional[LLMConfig] = None, # memory - memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)), + memory: BaseMemory = ChatMemory(human=get_human_text(settings.human), persona=get_human_text(settings.persona)), # tools tools: Optional[List[str]] = None, include_base_tools: Optional[bool] = True, - metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, + metadata: Optional[Dict] = {"human:": settings.human, "persona": settings.persona}, ) -> AgentState: """ Create an agent @@ -724,12 +724,12 @@ def create_agent( embedding_config: Optional[EmbeddingConfig] = None, llm_config: Optional[LLMConfig] = None, # memory - memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)), + memory: BaseMemory = ChatMemory(human=get_human_text(settings.human), persona=get_human_text(settings.persona)), # tools tools: Optional[List[str]] = None, include_base_tools: Optional[bool] = True, # metadata - metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, + metadata: Optional[Dict] = {"human:": settings.human, "persona": settings.persona}, ) -> AgentState: if name and self.agent_exists(agent_name=name): raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})") diff --git a/memgpt/metadata.py b/memgpt/metadata.py index b206e05777..6013b87018 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -10,7 +10,7 @@ from memgpt.orm.agent import Agent from memgpt.orm.job import Job from memgpt.orm.preset import Preset -from memgpt.orm.memory_templates import HumanModelTemplate, PersonaModelTemplate +from memgpt.orm.memory_templates import HumanMemoryTemplate, PersonaMemoryTemplate if TYPE_CHECKING: @@ -164,11 +164,11 @@ def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]: return [s._id for s in Preset.read(self.db_session, preset_id).sources] def update_human(self, human: HumanModel) -> "HumanModel": - sql_human = HumanModelTemplate(**human.model_dump(exclude_none=True)).create(self.db_session) + sql_human = HumanMemoryTemplate(**human.model_dump(exclude_none=True)).create(self.db_session) return sql_human.to_record() def update_persona(self, persona: PersonaModel) -> "PersonaModel": - sql_persona = PersonaModelTemplate(**persona.model_dump(exclude_none=True)).create(self.db_session) + sql_persona = PersonaMemoryTemplate(**persona.model_dump(exclude_none=True)).create(self.db_session) return sql_persona.to_record() def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]): diff --git a/memgpt/migrations/env.py b/memgpt/migrations/env.py index 80835623c2..fbfee47442 100644 --- a/memgpt/migrations/env.py +++ b/memgpt/migrations/env.py @@ -13,7 +13,7 @@ config = context.config section = config.config_ini_section # set the metadata database url from settings -config.set_section_option(section, "MEMGPT_DATABASE_URL", settings.database_url) +config.set_section_option(section, "MEMGPT_DATABASE_URL", settings.backend.database_uri) # Interpret the config file for Python logging. # This line sets up loggers basically. if config.config_file_name is not None: diff --git a/memgpt/orm/sqlalchemy_base.py b/memgpt/orm/sqlalchemy_base.py index 19318574b5..7c0809268a 100644 --- a/memgpt/orm/sqlalchemy_base.py +++ b/memgpt/orm/sqlalchemy_base.py @@ -1,7 +1,7 @@ from uuid import uuid4, UUID from typing import Optional, TYPE_CHECKING,Type, Union, List, Literal from humps import depascalize -from sqlalchemy import select, UUID as SQLUUID +from sqlalchemy import select, UUID as SQLUUID, Boolean from sqlalchemy.orm import ( Mapped, mapped_column @@ -26,7 +26,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): _id: Mapped[UUID] = mapped_column(SQLUUID(), primary_key=True, default=uuid4) - deleted: Mapped[bool] = mapped_column(bool, default=False, doc="Is this record deleted? Used for universal soft deletes.") + deleted: Mapped[bool] = mapped_column(Boolean, default=False, doc="Is this record deleted? Used for universal soft deletes.") @property def __prefix__(self) -> str: diff --git a/memgpt/orm/token.py b/memgpt/orm/token.py index fbc11865f6..6bdadc6b8f 100644 --- a/memgpt/orm/token.py +++ b/memgpt/orm/token.py @@ -1,7 +1,8 @@ from uuid import UUID, uuid4 from typing import Optional, TYPE_CHECKING from sqlalchemy import String -from sqlalchemy.orm import Mapped, mapped_column, relationship, hybrid_property +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.ext.hybrid import hybrid_property from memgpt.data_types import Token as PydanticToken from memgpt.orm.sqlalchemy_base import SqlalchemyBase diff --git a/memgpt/settings.py b/memgpt/settings.py index ad6736549c..9019b0744a 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -1,5 +1,5 @@ from pathlib import Path -from urlparse import urlsplit, urlunsplit +from urllib.parse import urlsplit, urlunsplit from typing import Optional, Literal from enum import Enum from pydantic import BaseModel, Field @@ -26,11 +26,11 @@ class Settings(BaseSettings): @property def backend(self) -> BackendConfiguration: """Return an adjusted BackendConfiguration. - Note: defaults to sqlite-chroma if pg_uri is not set. + Note: defaults to sqlite_chroma if pg_uri is not set. """ if self.pg_uri: return BackendConfiguration(name="postgres", database_uri=self._correct_pg_uri(self.pg_uri)) - return BackendConfiguration(name="sqlite-chroma", database_uri=f"sqlite:///{self.memgpt_dir}/memgpt.db") + return BackendConfiguration(name="sqlite_chroma", database_uri=f"sqlite:///{self.memgpt_dir}/memgpt.db") def _correct_pg_uri(self) -> str: """It is awkward to have users set a scheme for the uri (because why should they know anything about what drivers we use?) From fe38a1f583ad2f97109890550a1b74876090da1c Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Fri, 5 Jul 2024 15:14:24 -0400 Subject: [PATCH 42/45] migrations now all the way up on both sqlite and pg --- development.compose.yml | 1 + .../1ad46f6e4c2d_adding_configs_to_agent.py | 32 ---- ...{fcd6c014e6a8_.py => 8ab5757fa7a1_init.py} | 121 ++++++++++---- memgpt/migrations/versions/e1e15ff9ab6e_.py | 151 ------------------ memgpt/orm/__all__.py | 2 + memgpt/orm/document.py | 4 +- memgpt/orm/mixins.py | 20 ++- memgpt/orm/passage.py | 8 +- memgpt/settings.py | 5 +- 9 files changed, 123 insertions(+), 221 deletions(-) delete mode 100644 memgpt/migrations/versions/1ad46f6e4c2d_adding_configs_to_agent.py rename memgpt/migrations/versions/{fcd6c014e6a8_.py => 8ab5757fa7a1_init.py} (63%) delete mode 100644 memgpt/migrations/versions/e1e15ff9ab6e_.py diff --git a/development.compose.yml b/development.compose.yml index 324722f156..216c724b8c 100644 --- a/development.compose.yml +++ b/development.compose.yml @@ -15,6 +15,7 @@ services: # no value syntax to not set the env at all if it is not set in .env environment: - MEMGPT_SERVER_PASS=test_server_token + - MEMGPT_PG_URI=postgresql://memgpt:memgpt@memgpt-db:5432/memgpt - WATCHFILES_FORCE_POLLING=true volumes: diff --git a/memgpt/migrations/versions/1ad46f6e4c2d_adding_configs_to_agent.py b/memgpt/migrations/versions/1ad46f6e4c2d_adding_configs_to_agent.py deleted file mode 100644 index efc5d91191..0000000000 --- a/memgpt/migrations/versions/1ad46f6e4c2d_adding_configs_to_agent.py +++ /dev/null @@ -1,32 +0,0 @@ -"""adding configs to agent - -Revision ID: 1ad46f6e4c2d -Revises: e1e15ff9ab6e -Create Date: 2024-06-26 20:51:04.227418 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = '1ad46f6e4c2d' -down_revision: Union[str, None] = 'e1e15ff9ab6e' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column('agent', sa.Column('llm_config', sa.JSON(), nullable=False)) - op.add_column('agent', sa.Column('embedding_config', sa.JSON(), nullable=False)) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('agent', 'embedding_config') - op.drop_column('agent', 'llm_config') - # ### end Alembic commands ### diff --git a/memgpt/migrations/versions/fcd6c014e6a8_.py b/memgpt/migrations/versions/8ab5757fa7a1_init.py similarity index 63% rename from memgpt/migrations/versions/fcd6c014e6a8_.py rename to memgpt/migrations/versions/8ab5757fa7a1_init.py index 10e6a501cc..95fb472210 100644 --- a/memgpt/migrations/versions/fcd6c014e6a8_.py +++ b/memgpt/migrations/versions/8ab5757fa7a1_init.py @@ -1,8 +1,8 @@ -"""empty message +"""init -Revision ID: fcd6c014e6a8 -Revises: -Create Date: 2024-06-26 18:52:23.655166 +Revision ID: 8ab5757fa7a1 +Revises: +Create Date: 2024-07-05 18:58:31.038011 """ from typing import Sequence, Union @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. -revision: str = 'fcd6c014e6a8' +revision: str = '8ab5757fa7a1' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -23,9 +23,10 @@ def upgrade() -> None: op.create_table('organization', sa.Column('name', sa.String(), nullable=True), sa.Column('_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column('_created_by_id', sa.UUID(), nullable=True), sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), sa.PrimaryKeyConstraint('_id') @@ -33,27 +34,48 @@ def upgrade() -> None: op.create_table('agent', sa.Column('name', sa.String(), nullable=True), sa.Column('persona', sa.String(), nullable=False), + sa.Column('state', sa.JSON(), nullable=False), + sa.Column('_metadata', sa.JSON(), nullable=False), sa.Column('human', sa.String(), nullable=False), sa.Column('preset', sa.String(), nullable=False), + sa.Column('llm_config', sa.JSON(), nullable=False), + sa.Column('embedding_config', sa.JSON(), nullable=False), sa.Column('_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column('_created_by_id', sa.UUID(), nullable=True), sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), sa.Column('_organization_id', sa.UUID(), nullable=False), sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), sa.PrimaryKeyConstraint('_id') ) + op.create_table('document', + sa.Column('text', sa.String(), nullable=False), + sa.Column('data_source', sa.String(), nullable=True), + sa.Column('metadata_', sa.JSON(), nullable=True), + sa.Column('_organization_id', sa.UUID(), nullable=False), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), + sa.PrimaryKeyConstraint('_id') + ) op.create_table('memory_template', sa.Column('name', sa.String(), nullable=False), sa.Column('description', sa.String(), nullable=False), sa.Column('type', sa.String(), nullable=False), sa.Column('text', sa.String(), nullable=False), sa.Column('_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column('_created_by_id', sa.UUID(), nullable=True), sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), sa.Column('_organization_id', sa.UUID(), nullable=False), @@ -70,9 +92,10 @@ def upgrade() -> None: sa.Column('persona_name', sa.String(), nullable=False), sa.Column('functions_schema', sa.JSON(), nullable=False), sa.Column('_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column('_created_by_id', sa.UUID(), nullable=True), sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), sa.Column('_organization_id', sa.UUID(), nullable=False), @@ -87,9 +110,10 @@ def upgrade() -> None: sa.Column('description', sa.String(), nullable=True), sa.Column('_organization_id', sa.UUID(), nullable=False), sa.Column('_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column('_created_by_id', sa.UUID(), nullable=True), sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), @@ -102,9 +126,10 @@ def upgrade() -> None: sa.Column('source_code', sa.String(), nullable=True), sa.Column('json_schema', sa.JSON(), nullable=False), sa.Column('_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column('_created_by_id', sa.UUID(), nullable=True), sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), sa.Column('_organization_id', sa.UUID(), nullable=False), @@ -115,15 +140,48 @@ def upgrade() -> None: sa.Column('name', sa.String(), nullable=True), sa.Column('email', sa.String(), nullable=True), sa.Column('_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column('_created_by_id', sa.UUID(), nullable=True), sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), sa.Column('_organization_id', sa.UUID(), nullable=False), sa.ForeignKeyConstraint(['_organization_id'], ['organization._id'], ), sa.PrimaryKeyConstraint('_id') ) + op.create_table('job', + sa.Column('status', sa.Enum('created', 'running', 'completed', 'failed', name='jobstatus'), nullable=False), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('metadata_', sa.JSON(), nullable=True), + sa.Column('_user_id', sa.UUID(), nullable=False), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(['_user_id'], ['user._id'], ), + sa.PrimaryKeyConstraint('_id') + ) + op.create_table('passage', + sa.Column('text', sa.String(), nullable=False), + sa.Column('embedding', sa.JSON(), nullable=True), + sa.Column('embedding_config', sa.JSON(), nullable=True), + sa.Column('data_source', sa.String(), nullable=True), + sa.Column('metadata_', sa.JSON(), nullable=True), + sa.Column('_document_id', sa.UUID(), nullable=False), + sa.Column('_id', sa.UUID(), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), + sa.Column('_created_by_id', sa.UUID(), nullable=True), + sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(['_document_id'], ['document._id'], ), + sa.PrimaryKeyConstraint('_id') + ) op.create_table('sources_agents', sa.Column('_agent_id', sa.UUID(), nullable=False), sa.Column('_source_id', sa.UUID(), nullable=False), @@ -139,12 +197,14 @@ def upgrade() -> None: sa.PrimaryKeyConstraint('_preset_id', '_source_id') ) op.create_table('token', + sa.Column('_temporary_shim_api_key', sa.String(), nullable=True), sa.Column('hash', sa.String(), nullable=False), sa.Column('name', sa.String(), nullable=True), sa.Column('_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('(FALSE)'), nullable=False), + sa.Column('deleted', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), sa.Column('_created_by_id', sa.UUID(), nullable=True), sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), sa.Column('_user_id', sa.UUID(), nullable=False), @@ -183,11 +243,14 @@ def downgrade() -> None: op.drop_table('token') op.drop_table('sources_presets') op.drop_table('sources_agents') + op.drop_table('passage') + op.drop_table('job') op.drop_table('user') op.drop_table('tool') op.drop_table('source') op.drop_table('preset') op.drop_table('memory_template') + op.drop_table('document') op.drop_table('agent') op.drop_table('organization') # ### end Alembic commands ### diff --git a/memgpt/migrations/versions/e1e15ff9ab6e_.py b/memgpt/migrations/versions/e1e15ff9ab6e_.py deleted file mode 100644 index ea925032b6..0000000000 --- a/memgpt/migrations/versions/e1e15ff9ab6e_.py +++ /dev/null @@ -1,151 +0,0 @@ -"""empty message - -Revision ID: e1e15ff9ab6e -Revises: fcd6c014e6a8 -Create Date: 2024-06-26 20:23:47.395414 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision: str = 'e1e15ff9ab6e' -down_revision: Union[str, None] = 'fcd6c014e6a8' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('job', - sa.Column('status', sa.Enum('created', 'running', 'completed', 'failed', name='jobstatus'), nullable=False), - sa.Column('completed_at', sa.DateTime(), nullable=True), - sa.Column('metadata_', sa.JSON(), nullable=True), - sa.Column('_user_id', sa.UUID(), nullable=False), - sa.Column('_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), - sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), - sa.Column('_created_by_id', sa.UUID(), nullable=True), - sa.Column('_last_updated_by_id', sa.UUID(), nullable=True), - sa.ForeignKeyConstraint(['_user_id'], ['user._id'], ), - sa.PrimaryKeyConstraint('_id') - ) - op.drop_table('agent_source_mapping') - op.drop_table('humanmodel') - op.drop_table('agents') - op.drop_table('toolmodel') - op.drop_table('presets') - op.drop_table('users') - op.drop_table('tokens') - op.drop_table('personamodel') - op.drop_table('jobmodel') - op.drop_table('sources') - op.drop_table('preset_source_mapping') - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('preset_source_mapping', - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('preset_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('source_id', sa.UUID(), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('id', name='preset_source_mapping_pkey') - ) - op.create_table('sources', - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True), - sa.Column('embedding_dim', sa.BIGINT(), autoincrement=False, nullable=True), - sa.Column('embedding_model', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('description', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='sources_pkey') - ) - op.create_table('jobmodel', - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('status', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('created_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=False), - sa.Column('completed_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=True), - sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='jobmodel_pkey') - ) - op.create_table('personamodel', - sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='personamodel_pkey') - ) - op.create_table('tokens', - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('token', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='tokens_pkey') - ) - op.create_table('users', - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('default_agent', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('policies_accepted', sa.BOOLEAN(), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('id', name='users_pkey') - ) - op.create_table('presets', - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('description', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('system', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('human', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('human_name', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('persona', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('persona_name', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('preset', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True), - sa.Column('functions_schema', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='presets_pkey') - ) - op.create_table('toolmodel', - sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('tags', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.Column('source_type', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('source_code', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('json_schema', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='toolmodel_pkey') - ) - op.create_table('agents', - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('persona', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('human', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('preset', sa.VARCHAR(), autoincrement=False, nullable=True), - sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True), - sa.Column('llm_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.Column('state', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='agents_pkey') - ) - op.create_table('humanmodel', - sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('id', name='humanmodel_pkey') - ) - op.create_table('agent_source_mapping', - sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('agent_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('source_id', sa.UUID(), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('id', name='agent_source_mapping_pkey') - ) - op.drop_table('job') - # ### end Alembic commands ### diff --git a/memgpt/orm/__all__.py b/memgpt/orm/__all__.py index a8dac16dd2..3621e27600 100644 --- a/memgpt/orm/__all__.py +++ b/memgpt/orm/__all__.py @@ -7,6 +7,8 @@ from memgpt.orm.source import Source from memgpt.orm.tool import Tool from memgpt.orm.preset import Preset +from memgpt.orm.document import Document +from memgpt.orm.passage import Passage from memgpt.orm.memory_templates import MemoryTemplate, HumanMemoryTemplate, PersonaMemoryTemplate from memgpt.orm.sources_agents import SourcesAgents from memgpt.orm.sources_presets import SourcesPresets diff --git a/memgpt/orm/document.py b/memgpt/orm/document.py index c0194cba33..c2051daeb5 100644 --- a/memgpt/orm/document.py +++ b/memgpt/orm/document.py @@ -16,8 +16,8 @@ class Document(OrganizationMixin, SqlalchemyBase): __tablename__ = "document" text: Mapped[str] = mapped_column(doc="The full text for the document.") - data_source: Optional[str] = mapped_column(nullable=True, doc="Human readable description of where the passage came from.") - metadata_: Optional[dict] = mapped_column(JSON, default_factory=lambda: {}, doc="additional information about the passage.") + data_source: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Human readable description of where the passage came from.") + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="additional information about the passage.") # relationships diff --git a/memgpt/orm/mixins.py b/memgpt/orm/mixins.py index d22c633e2a..68d815749e 100644 --- a/memgpt/orm/mixins.py +++ b/memgpt/orm/mixins.py @@ -92,4 +92,22 @@ def agent_id(self) -> str: @agent_id.setter def agent_id(self, value: str) -> None: - _relation_setter(self, "agent", value) \ No newline at end of file + _relation_setter(self, "agent", value) + + +class DocumentMixin(Base): + """Mixin for models that belong to a document.""" + + __abstract__ = True + + _document_id: Mapped[UUID] = mapped_column( + SQLUUID(), ForeignKey("document._id") + ) + + @property + def document_id(self) -> str: + return _relation_getter(self, "document") + + @document_id.setter + def document_id(self, value: str) -> None: + _relation_setter(self, "document", value) \ No newline at end of file diff --git a/memgpt/orm/passage.py b/memgpt/orm/passage.py index 89b439a980..d30525cb44 100644 --- a/memgpt/orm/passage.py +++ b/memgpt/orm/passage.py @@ -16,11 +16,11 @@ class Passage(DocumentMixin, SqlalchemyBase): __tablename__ = "passage" text: Mapped[str] = mapped_column(doc="The text of the passage.") - embedding: Optional[List[float]] = mapped_column(JSON, doc="The embedding of the passage.", nullable=True) - embedding_config: Optional["EmbeddingConfigModel"] = mapped_column(JSON, doc="The embedding configuration used by the passage.", + embedding: Mapped[Optional[List[float]]] = mapped_column(JSON, doc="The embedding of the passage.", nullable=True) + embedding_config: Mapped[Optional["EmbeddingConfigModel"]] = mapped_column(JSON, doc="The embedding configuration used by the passage.", nullable=True) - data_source: Optional[str] = mapped_column(nullable=True, doc="Human readable description of where the passage came from.") - metadata_: Optional[dict] = mapped_column(JSON, default_factory=lambda: {}, doc="additional information about the passage.") + data_source: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Human readable description of where the passage came from.") + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="additional information about the passage.") # relationships diff --git a/memgpt/settings.py b/memgpt/settings.py index 9019b0744a..08a820cda1 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -32,11 +32,12 @@ def backend(self) -> BackendConfiguration: return BackendConfiguration(name="postgres", database_uri=self._correct_pg_uri(self.pg_uri)) return BackendConfiguration(name="sqlite_chroma", database_uri=f"sqlite:///{self.memgpt_dir}/memgpt.db") - def _correct_pg_uri(self) -> str: + @classmethod + def _correct_pg_uri(cls, uri:str) -> str: """It is awkward to have users set a scheme for the uri (because why should they know anything about what drivers we use?) So here we check (and correct) the provided uri to use the scheme we implement. """ - url_parts = list(urlsplit(settings.pg_uri)) + url_parts = list(urlsplit(uri)) SCHEME = 0 url_parts[SCHEME] = POSTGRES_SCHEME return urlunsplit(url_parts) From 3fb39314bbd37cb2b2ff76e3e6b2d58a1e06165f Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Mon, 8 Jul 2024 11:21:24 -0400 Subject: [PATCH 43/45] minimal orm test passes --- memgpt/orm/organization.py | 2 ++ memgpt/orm/user.py | 2 ++ tests/conftest.py | 2 +- tests/test_tools.py | 3 +-- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/memgpt/orm/organization.py b/memgpt/orm/organization.py index ca042a7855..365cf105f1 100644 --- a/memgpt/orm/organization.py +++ b/memgpt/orm/organization.py @@ -11,6 +11,7 @@ from memgpt.orm.source import Source from memgpt.orm.tool import Tool from memgpt.orm.preset import Preset + from memgpt.orm.document import Document from memgpt.orm.memory_templates import HumanMemoryTemplate, PersonaMemoryTemplate from sqlalchemy.orm import Session @@ -30,6 +31,7 @@ class Organization(SqlalchemyBase): presets: Mapped["Preset"] = relationship("Preset", back_populates="organization", cascade="all, delete-orphan") personas: Mapped["PersonaMemoryTemplate"] = relationship("PersonaMemoryTemplate", back_populates="organization", cascade="all, delete-orphan") humans: Mapped["HumanMemoryTemplate"] = relationship("HumanMemoryTemplate", back_populates="organization", cascade="all, delete-orphan") + documents: Mapped["Document"] = relationship("Document", back_populates="organization", cascade="all, delete-orphan") @classmethod def default(cls, db_session:"Session") -> "Organization": diff --git a/memgpt/orm/user.py b/memgpt/orm/user.py index aaefac9a37..0adccb09de 100644 --- a/memgpt/orm/user.py +++ b/memgpt/orm/user.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from memgpt.orm.agent import Agent from memgpt.orm.token import Token + from memgpt.orm.job import Job class User(SqlalchemyBase, OrganizationMixin): """User ORM class""" @@ -28,4 +29,5 @@ class User(SqlalchemyBase, OrganizationMixin): back_populates="users", doc="the agents associated with this user.") tokens: Mapped[List["Token"]] = relationship("Token", back_populates="user", doc="the tokens associated with this user.") + jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.") diff --git a/tests/conftest.py b/tests/conftest.py index b9ba0712d8..4f28a7b790 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -121,7 +121,7 @@ def db_session(request) -> "Session": case "sqlite_chroma": database_uri = f"sqlite:///{adapter['database']}" case "postgres": - url_parts = list(urlsplit(settings.postgres_uri)) + url_parts = list(urlsplit(settings.backend.database_uri)) PATH_PARAM = 2 url_parts[PATH_PARAM] = f"/{adapter['database']}" database_uri = urlunsplit(url_parts) diff --git a/tests/test_tools.py b/tests/test_tools.py index b0ac1b6266..a3428b734e 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -9,7 +9,6 @@ from memgpt import Admin, create_client from memgpt.agent import Agent from memgpt.config import MemGPTConfig -from memgpt.constants import DEFAULT_PRESET from memgpt.credentials import MemGPTCredentials from memgpt.memory import ChatMemory from memgpt.settings import settings @@ -17,7 +16,7 @@ test_agent_name = f"test_client_{str(uuid.uuid4())}" # test_preset_name = "test_preset" -test_preset_name = DEFAULT_PRESET +test_preset_name = settings.preset test_agent_state = None client = None From 37a04a2e666a0e1b55f7f86ffc13b0cacdec216f Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Mon, 8 Jul 2024 13:20:45 -0400 Subject: [PATCH 44/45] Conftest ALMOST has a scoped test app so we can stop spinning up servers in tests! But... 1. need to move the db_session all the way up to the request (where it belongs). 2. dep inject that thing at request time! 3. dep override it in conftest! --- memgpt/client/client.py | 111 +++++++++++++++++++++-------------- tests/conftest.py | 6 ++ tests/test_client.py | 127 +++++++++++++++++++++------------------- 3 files changed, 139 insertions(+), 105 deletions(-) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 3c93fec842..a0186cdef4 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -1,11 +1,11 @@ import datetime import time import uuid -from typing import Dict, List, Optional, Tuple, Union - -import requests +from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING +import httpx from memgpt.config import MemGPTConfig +from memgpt.log import get_logger from memgpt.constants import BASE_TOOLS from memgpt.settings import settings from memgpt.data_sources.connectors import DataConnector @@ -54,8 +54,22 @@ from memgpt.server.server import SyncServer from memgpt.utils import get_human_text - -def create_client(base_url: Optional[str] = None, token: Optional[str] = None, config: Optional[MemGPTConfig] = None): +if TYPE_CHECKING: + from httpx import ASGITransport, WSGITransport + +logger = get_logger(__name__) + +def create_client(base_url: Optional[str] = None, + token: Optional[str] = None, + config: Optional[MemGPTConfig] = None, + app: Optional[str] = None): + """factory method to create either a local or rest api enabled client. + # TODO: link to docs on the difference between the two. + base_url: str if provided, the url to the rest api server + token: str if provided, the token to authenticate to the rest api server + config: MemGPTConfig if provided, the configuration settings to use for the local client + app: str if provided an ASGI compliant application to use instead of an actual http call. used for testing hook. + """ if base_url is None: return LocalClient(config=config) else: @@ -228,17 +242,24 @@ def __init__( base_url: str, token: str, debug: bool = False, + app: Optional[Union["WSGITransport","ASGITransport"]] = None, ): super().__init__(debug=debug) - self.base_url = base_url - self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"} + httpx_client_args = { + "headers": {"accept": "application/json", "authorization": f"Bearer {token}"}, + "base_url": base_url, + } + if app: + logger.warning("Using supplied WSGI or ASGI app for RESTClient") + httpx_client_args["app"] = app + self.httpx_client = self.httpx_client.Client(**httpx_client_args) def list_agents(self): - response = requests.get(f"{self.base_url}/api/agents", headers=self.headers) + response = self.httpx_client.get("/api/agents") return ListAgentsResponse(**response.json()) def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: - response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/config", headers=self.headers) + response = self.httpx_client.get(f"/api/agents/{str(agent_id)}/config") if response.status_code == 404: # not found error return False @@ -248,7 +269,7 @@ def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] raise ValueError(f"Failed to check if agent exists: {response.text}") def get_tool(self, tool_name: str): - response = requests.get(f"{self.base_url}/api/tools/{tool_name}", headers=self.headers) + response = self.httpx_client.get(f"/api/tools/{tool_name}") if response.status_code != 200: raise ValueError(f"Failed to get tool: {response.text}") return ToolModel(**response.json()) @@ -305,7 +326,7 @@ def create_agent( "metadata": metadata, } } - response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers) + response = self.httpx_client.post("/api/agents", json=payload) if response.status_code != 200: raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}") response_obj = CreateAgentResponse(**response.json()) @@ -343,25 +364,25 @@ def get_agent_response_to_state(self, response: Union[GetAgentResponse, CreateAg return agent_state def rename_agent(self, agent_id: uuid.UUID, new_name: str): - response = requests.patch(f"{self.base_url}/api/agents/{str(agent_id)}/rename", json={"agent_name": new_name}, headers=self.headers) + response = self.httpx_client.patch("/api/agents/{str(agent_id)}/rename", json={"agent_name": new_name}) assert response.status_code == 200, f"Failed to rename agent: {response.text}" response_obj = GetAgentResponse(**response.json()) return self.get_agent_response_to_state(response_obj) def delete_agent(self, agent_id: uuid.UUID): """Delete the agent.""" - response = requests.delete(f"{self.base_url}/api/agents/{str(agent_id)}", headers=self.headers) + response = self.httpx_client.delete("/api/agents/{str(agent_id)}") assert response.status_code == 200, f"Failed to delete agent: {response.text}" def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> AgentState: - response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/config", headers=self.headers) + response = self.httpx_client.get("/api/agents/{str(agent_id)}/config") assert response.status_code == 200, f"Failed to get agent: {response.text}" response_obj = GetAgentResponse(**response.json()) return self.get_agent_response_to_state(response_obj) def get_preset(self, name: str) -> PresetModel: # TODO: remove - response = requests.get(f"{self.base_url}/api/presets/{name}", headers=self.headers) + response = self.httpx_client.get("/api/presets/{name}") assert response.status_code == 200, f"Failed to get preset: {response.text}" return PresetModel(**response.json()) @@ -417,25 +438,25 @@ def create_preset( human_name=human_name, functions_schema=schema, ) - response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers) + response = self.httpx_client.post("/api/presets", json=payload.model_dump()) assert response.status_code == 200, f"Failed to create preset: {response.text}" return CreatePresetResponse(**response.json()).preset def delete_preset(self, preset_id: uuid.UUID): - response = requests.delete(f"{self.base_url}/api/presets/{str(preset_id)}", headers=self.headers) + response = self.httpx_client.delete("/api/presets/{str(preset_id)}") assert response.status_code == 200, f"Failed to delete preset: {response.text}" def list_presets(self) -> List[PresetModel]: - response = requests.get(f"{self.base_url}/api/presets", headers=self.headers) + response = self.httpx_client.get("/api/presets") return ListPresetsResponse(**response.json()).presets # memory def get_agent_memory(self, agent_id: uuid.UUID) -> GetAgentMemoryResponse: - response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory", headers=self.headers) + response = self.httpx_client.get("/api/agents/{agent_id}/memory") return GetAgentMemoryResponse(**response.json()) def update_agent_core_memory(self, agent_id: str, new_memory_contents: Dict) -> UpdateAgentMemoryResponse: - response = requests.post(f"{self.base_url}/api/agents/{agent_id}/memory", json=new_memory_contents, headers=self.headers) + response = self.httpx_client.post("/api/agents/{agent_id}/memory", json=new_memory_contents) return UpdateAgentMemoryResponse(**response.json()) # agent interactions @@ -444,7 +465,7 @@ def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[L return self.send_message(agent_id, message, role="user") def run_command(self, agent_id: str, command: str) -> Union[str, None]: - response = requests.post(f"{self.base_url}/api/agents/{str(agent_id)}/command", json={"command": command}, headers=self.headers) + response = self.httpx_client.post("/api/agents/{str(agent_id)}/command", json={"command": command}) return CommandResponse(**response.json()) def save(self): @@ -461,18 +482,18 @@ def get_agent_archival_memory( params["before"] = str(before) if after: params["after"] = str(after) - response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/archival", params=params, headers=self.headers) + response = self.httpx_client.get("/api/agents/{str(agent_id)}/archival", params=params) assert response.status_code == 200, f"Failed to get archival memory: {response.text}" return GetAgentArchivalMemoryResponse(**response.json()) def insert_archival_memory(self, agent_id: uuid.UUID, memory: str) -> GetAgentArchivalMemoryResponse: - response = requests.post(f"{self.base_url}/api/agents/{agent_id}/archival", json={"content": memory}, headers=self.headers) + response = self.httpx_client.post("/api/agents/{agent_id}/archival", json={"content": memory}) if response.status_code != 200: raise ValueError(f"Failed to insert archival memory: {response.text}") return InsertAgentArchivalMemoryResponse(**response.json()) def delete_archival_memory(self, agent_id: uuid.UUID, memory_id: uuid.UUID): - response = requests.delete(f"{self.base_url}/api/agents/{agent_id}/archival?id={memory_id}", headers=self.headers) + response = self.httpx_client.delete("/api/agents/{agent_id}/archival?id={memory_id}") assert response.status_code == 200, f"Failed to delete archival memory: {response.text}" # messages (recall memory) @@ -481,14 +502,14 @@ def get_messages( self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 ) -> GetAgentMessagesResponse: params = {"before": before, "after": after, "limit": limit} - response = requests.get(f"{self.base_url}/api/agents/{agent_id}/messages-cursor", params=params, headers=self.headers) + response = self.httpx_client.get("/api/agents/{agent_id}/messages-cursor", params=params) if response.status_code != 200: raise ValueError(f"Failed to get messages: {response.text}") return GetAgentMessagesResponse(**response.json()) def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False) -> UserMessageResponse: data = {"message": message, "role": role, "stream": stream} - response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=data, headers=self.headers) + response = self.httpx_client.post("/api/agents/{agent_id}/messages", json=data) if response.status_code != 200: raise ValueError(f"Failed to send message: {response.text}") return UserMessageResponse(**response.json()) @@ -496,29 +517,29 @@ def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Opt # humans / personas def list_humans(self) -> ListHumansResponse: - response = requests.get(f"{self.base_url}/api/humans", headers=self.headers) + response = self.httpx_client.get("/api/humans") return ListHumansResponse(**response.json()) def create_human(self, name: str, human: str) -> HumanModel: data = {"name": name, "text": human} - response = requests.post(f"{self.base_url}/api/humans", json=data, headers=self.headers) + response = self.httpx_client.post("/api/humans", json=data) if response.status_code != 200: raise ValueError(f"Failed to create human: {response.text}") return HumanModel(**response.json()) def list_personas(self) -> ListPersonasResponse: - response = requests.get(f"{self.base_url}/api/personas", headers=self.headers) + response = self.httpx_client.get("/api/personas") return ListPersonasResponse(**response.json()) def create_persona(self, name: str, persona: str) -> PersonaModel: data = {"name": name, "text": persona} - response = requests.post(f"{self.base_url}/api/personas", json=data, headers=self.headers) + response = self.httpx_client.post("/api/personas", json=data) if response.status_code != 200: raise ValueError(f"Failed to create persona: {response.text}") return PersonaModel(**response.json()) def get_persona(self, name: str) -> PersonaModel: - response = requests.get(f"{self.base_url}/api/personas/{name}", headers=self.headers) + response = self.httpx_client.get("/api/personas/{name}") if response.status_code == 404: return None elif response.status_code != 200: @@ -526,7 +547,7 @@ def get_persona(self, name: str) -> PersonaModel: return PersonaModel(**response.json()) def get_human(self, name: str) -> HumanModel: - response = requests.get(f"{self.base_url}/api/humans/{name}", headers=self.headers) + response = self.httpx_client.get("/api/humans/{name}") if response.status_code == 404: return None elif response.status_code != 200: @@ -537,17 +558,17 @@ def get_human(self, name: str) -> HumanModel: def list_sources(self): """List loaded sources""" - response = requests.get(f"{self.base_url}/api/sources", headers=self.headers) + response = self.httpx_client.get("/api/sources") response_json = response.json() return ListSourcesResponse(**response_json) def delete_source(self, source_id: uuid.UUID): """Delete a source and associated data (including attached to agents)""" - response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers) + response = self.httpx_client.delete("/api/sources/{str(source_id)}") assert response.status_code == 200, f"Failed to delete source: {response.text}" def get_job_status(self, job_id: uuid.UUID): - response = requests.get(f"{self.base_url}/api/sources/status/{str(job_id)}", headers=self.headers) + response = self.httpx_client.get("/api/sources/status/{str(job_id)}") return JobModel(**response.json()) def load_file_into_source(self, filename: str, source_id: uuid.UUID, blocking=True): @@ -555,7 +576,7 @@ def load_file_into_source(self, filename: str, source_id: uuid.UUID, blocking=Tr files = {"file": open(filename, "rb")} # create job - response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers) + response = self.httpx_client.post("/api/sources/{source_id}/upload", files=files) if response.status_code != 200: raise ValueError(f"Failed to upload file to source: {response.text}") @@ -574,7 +595,7 @@ def load_file_into_source(self, filename: str, source_id: uuid.UUID, blocking=Tr def create_source(self, name: str) -> Source: """Create a new source""" payload = {"name": name} - response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers) + response = self.httpx_client.post("/api/sources", json=payload) response_json = response.json() response_obj = SourceModel(**response_json) return Source( @@ -589,23 +610,23 @@ def create_source(self, name: str) -> Source: def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID): """Attach a source to an agent""" params = {"agent_id": agent_id} - response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers) + response = self.httpx_client.post("/api/sources/{source_id}/attach", params=params) assert response.status_code == 200, f"Failed to attach source to agent: {response.text}" def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID): """Detach a source from an agent""" params = {"agent_id": str(agent_id)} - response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers) + response = self.httpx_client.post("/api/sources/{source_id}/detach", params=params) assert response.status_code == 200, f"Failed to detach source from agent: {response.text}" # server configuration commands def list_models(self) -> ListModelsResponse: - response = requests.get(f"{self.base_url}/api/models", headers=self.headers) + response = self.httpx_client.get("/api/models") return ListModelsResponse(**response.json()) def get_config(self) -> ConfigResponse: - response = requests.get(f"{self.base_url}/api/config", headers=self.headers) + response = self.httpx_client.get("/api/config") return ConfigResponse(**response.json()) # tools @@ -644,25 +665,25 @@ def create_tool( raise ValueError(f"Failed to create tool: {e}, invalid input {data}") # make REST request - response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers) + response = self.httpx_client.post("/api/tools", json=data) if response.status_code != 200: raise ValueError(f"Failed to create tool: {response.text}") return ToolModel(**response.json()) def list_tools(self) -> ListToolsResponse: - response = requests.get(f"{self.base_url}/api/tools", headers=self.headers) + response = self.httpx_client.get("/api/tools") if response.status_code != 200: raise ValueError(f"Failed to list tools: {response.text}") return ListToolsResponse(**response.json()).tools def delete_tool(self, name: str): - response = requests.delete(f"{self.base_url}/api/tools/{name}", headers=self.headers) + response = self.httpx_client.delete(f"/api/tools/{name}") if response.status_code != 200: raise ValueError(f"Failed to delete tool: {response.text}") return response.json() def get_tool(self, name: str): - response = requests.get(f"{self.base_url}/api/tools/{name}", headers=self.headers) + response = self.httpx_client.get(f"/api/tools/{name}") if response.status_code == 404: return None elif response.status_code != 200: diff --git a/tests/conftest.py b/tests/conftest.py index 4f28a7b790..35f4a66a60 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from memgpt.data_types import EmbeddingConfig, LLMConfig from memgpt.credentials import MemGPTCredentials from memgpt.server.server import SyncServer +from memgpt.server.rest_api import app from tests.config import TestMGPTConfig @@ -136,3 +137,8 @@ def db_session(request) -> "Session": Base.metadata.create_all(bind=connection) with sessionmaker(bind=engine)() as session: yield session + +@pytest.fixture +def test_app(db_session): + """a per-test-function db scoped version of the rest api app""" + diff --git a/tests/test_client.py b/tests/test_client.py index f1af0bfdba..b518863d19 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,8 +2,10 @@ import threading import time import uuid +import httpx import pytest +from faker import Faker from dotenv import load_dotenv from memgpt.settings import settings @@ -11,6 +13,8 @@ from memgpt.credentials import MemGPTCredentials from memgpt.data_types import Preset # TODO move to PresetModel from memgpt.settings import settings +from memgpt.orm.user import User +from memgpt.orm.token import Token from tests.utils import create_config test_agent_name = f"test_client_{str(uuid.uuid4())}" @@ -25,85 +29,75 @@ # admin credentials test_server_token = "test_server_token" - +faker = Faker() def _reset_config(): - # Use os.getenv with a fallback to os.environ.get - db_url = settings.pg_uri - - if os.getenv("OPENAI_API_KEY"): - create_config("openai") - credentials = MemGPTCredentials( - openai_key=os.getenv("OPENAI_API_KEY"), - ) - else: # hosted - create_config("memgpt_hosted") - credentials = MemGPTCredentials() - - config = MemGPTConfig.load() - - # set to use postgres - config.archival_storage_uri = db_url - config.recall_storage_uri = db_url - config.metadata_storage_uri = db_url - config.archival_storage_type = "postgres" - config.recall_storage_type = "postgres" - config.metadata_storage_type = "postgres" - config.save() - credentials.save() - print("_reset_config :: ", config.config_path) + pass +# ## Use os.getenv with a fallback to os.environ.get + #db_url = settings.pg_uri + + #if os.getenv("OPENAI_API_KEY"): + #create_config("openai") + #credentials = MemGPTCredentials( + #openai_key=os.getenv("OPENAI_API_KEY"), + #) + #else: # hosted + #create_config("memgpt_hosted") + #credentials = MemGPTCredentials() + + #config = MemGPTConfig.load() + + ## set to use postgres + #config.archival_storage_uri = db_url + #config.recall_storage_uri = db_url + #config.metadata_storage_uri = db_url + #config.archival_storage_type = "postgres" + #config.recall_storage_type = "postgres" + #config.metadata_storage_type = "postgres" + #config.save() + #credentials.save() + #print("_reset_config :: ", config.config_path) def run_server(): - load_dotenv() + pass + #load_dotenv() - _reset_config() + #_reset_config() - from memgpt.server.rest_api.server import start_server + #from memgpt.server.rest_api.server import start_server - print("Starting server...") - start_server(debug=True) + #print("Starting server...") + #start_server(debug=True) # Fixture to create clients with different configurations @pytest.fixture( params=[{"server": True}, {"server": False}], # whether to use REST API server - scope="module", ) -def client(request): +def client(request, db_session, test_app): if request.param["server"]: - # get URL from enviornment - server_url = os.getenv("MEMGPT_SERVER_URL") - if server_url is None: - # run server in thread - # NOTE: must set MEMGPT_SERVER_PASS enviornment variable - server_url = "http://localhost:8083" - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - # create user via admin client - admin = Admin(server_url, test_server_token) - response = admin.create_user(test_user_id) # Adjust as per your client's method - token = response.api_key - + # since we are not TESTING the admin client, we don't want to USE the admin client here. + # create the user directly + requesting_user = User.create(db_session) + api_token = Token(user=requesting_user, name="test_client_api_token").create(db_session) + token = api_token.api_key + client_args = { + "base_url": settings.server_url, + "token": token, + "debug": True, + "app": test_app + } else: # use local client (no server) - token = None - server_url = None - - client = create_client(base_url=server_url, token=token) # This yields control back to the test function - try: - yield client - finally: - # cleanup user - if server_url: - admin.delete_user(test_user_id) # Adjust as per your client's method - + client_args = { + "token": None, + "server_url": None + } + yield create_client(**client_args) # Fixture for test agent -@pytest.fixture(scope="module") +@pytest.fixture def agent(client): agent_state = client.create_agent(name=test_agent_name) print("AGENT ID", agent_state.id) @@ -112,6 +106,19 @@ def agent(client): # delete agent client.delete_agent(agent_state.id) +class TestClientAgent: + """CRUD for agents via the client""" + + def test_create_agent(self, client): + expected_agent_name = faker.name() + assert not client.agent_exists(name=expected_agent_name) + created_agent_state = client.create_agent(name=expected_agent_name) + assert client.agent_exists(name=expected_agent_name) + assert created_agent_state.name == expected_agent_name + + def test_rename_agent(): + new_name = faker.name() + def test_agent(client, agent): _reset_config() From 7f4711e1de79ab0ae73d3c786971d5f509e3fe00 Mon Sep 17 00:00:00 2001 From: Ethan Knox Date: Tue, 9 Jul 2024 18:50:12 -0400 Subject: [PATCH 45/45] breakpoint to rebase --- memgpt/metadata.py | 6 +++++- memgpt/server/rest_api/auth/index.py | 8 ++++++-- memgpt/server/server.py | 5 +++-- tests/conftest.py | 3 ++- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 6013b87018..8cd9abdd85 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -38,7 +38,11 @@ class MetadataStore: """ db_session: "Session" = None - def __init__(self): + def __init__(self, db_session: Optional["Session"] = None): + """ + Args: + db_session: the database session to use. + """ self.db_session = get_db_session() def create_api_key(self, diff --git a/memgpt/server/rest_api/auth/index.py b/memgpt/server/rest_api/auth/index.py index 6be07d888e..4f9f05ace8 100644 --- a/memgpt/server/rest_api/auth/index.py +++ b/memgpt/server/rest_api/auth/index.py @@ -1,9 +1,11 @@ from uuid import UUID +from typing import Annotated, Generator -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel, Field from memgpt.server.rest_api.interface import QueuingInterface +from memgpt.orm.utilities import get_db_session from memgpt.server.server import SyncServer router = APIRouter() @@ -20,7 +22,9 @@ class AuthRequest(BaseModel): def setup_auth_router(server: SyncServer, interface: QueuingInterface, password: str) -> APIRouter: @router.post("/auth", tags=["auth"], response_model=AuthResponse) - def authenticate_user(request: AuthRequest) -> AuthResponse: + def authenticate_user(request: AuthRequest, + db_session: Annotated["Generator", Depends(get_db_session)], + ) -> AuthResponse: """ Authenticates the user and sends response with User related data. diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 6f27ad9440..8960be05c3 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -130,7 +130,8 @@ def __init__( max_chaining_steps: bool = None, default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(), # test hooks - config: Optional["MemGPTConfig"] = None + config: Optional["MemGPTConfig"] = None, + db_session: Optional["Session"] = None, ): """Server process holds in-memory agents that are being run""" @@ -179,7 +180,7 @@ def __init__( assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config) # Initialize the metadata store - self.ms = MetadataStore(self.config) + self.ms = MetadataStore(self.config, db_session=db_session) # pre-fill database (users, presets, humans, personas) # TODO: figure out how to handle default users (server is technically multi-user) diff --git a/tests/conftest.py b/tests/conftest.py index 35f4a66a60..620a61aade 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -141,4 +141,5 @@ def db_session(request) -> "Session": @pytest.fixture def test_app(db_session): """a per-test-function db scoped version of the rest api app""" - + app.dependency_overrides[sessionmaker] = lambda: db_session + return app