diff --git a/memgpt/server/rest_api/sources/index.py b/memgpt/server/rest_api/sources/index.py index b38365b397..51d8156774 100644 --- a/memgpt/server/rest_api/sources/index.py +++ b/memgpt/server/rest_api/sources/index.py @@ -141,10 +141,36 @@ async def create_source( interface.clear() try: # TODO: don't use Source and just use SourceModel once pydantic migration is complete - source = server.create_source(name=request.name, user_id=user_id) + source = server.create_source(name=request.name, user_id=user_id, description=request.description) return SourceModel( name=source.name, - description=None, # TODO: actually store descriptions + description=source.description, + user_id=source.user_id, + id=source.id, + embedding_config=server.server_embedding_config, + created_at=source.created_at.timestamp(), + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + + @router.post("/sources/{source_id}", tags=["sources"], response_model=SourceModel) + async def update_source( + source_id: uuid.UUID, + request: CreateSourceRequest = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + """ + Update the name or documentation of an existing data source. + """ + interface.clear() + try: + # TODO: don't use Source and just use SourceModel once pydantic migration is complete + source = server.update_source(source_id=source_id, name=request.name, user_id=user_id, description=request.description) + return SourceModel( + name=source.name, + description=source.description, user_id=source.user_id, id=source.id, embedding_config=server.server_embedding_config, diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 377546df76..3104866614 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1375,11 +1375,12 @@ def create_api_key_for_user(self, user_id: uuid.UUID) -> Token: token = self.ms.create_api_key(user_id=user_id) return token - def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields + def create_source(self, name: str, user_id: uuid.UUID, description: str = None) -> Source: # TODO: add other fields """Create a new data source""" source = Source( name=name, user_id=user_id, + description=description, embedding_model=self.config.default_embedding_config.embedding_model, embedding_dim=self.config.default_embedding_config.embedding_dim, ) @@ -1387,6 +1388,19 @@ def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add o assert self.ms.get_source(source_name=name, user_id=user_id) is not None, f"Failed to create source {name}" return source + def update_source(self, source_id: uuid.UUID, name: str, user_id: uuid.UUID, description: str = None) -> Source: + """Updates a data source""" + source = Source( + id=source_id, + name=name, + user_id=user_id, + description=description, + embedding_model=self.config.default_embedding_config.embedding_model, + embedding_dim=self.config.default_embedding_config.embedding_dim, + ) + self.ms.update_source(source) + return source + def delete_source(self, source_id: uuid.UUID, user_id: uuid.UUID): """Delete a data source""" source = self.ms.get_source(source_id=source_id, user_id=user_id) @@ -1475,7 +1489,7 @@ def list_all_sources(self, user_id: uuid.UUID) -> List[SourceModel]: sources = [ SourceModel( name=source.name, - description=None, # TODO: actually store descriptions + description=source.description, user_id=source.user_id, id=source.id, embedding_config=self.server_embedding_config,