Skip to content

Commit

Permalink
fix: enable source desc and allowing editing source name and desc (#1599
Browse files Browse the repository at this point in the history
)
  • Loading branch information
sarahwooders authored Aug 1, 2024
2 parents ebf59f9 + 5702939 commit 1f0e8fc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
30 changes: 28 additions & 2 deletions memgpt/server/rest_api/sources/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,36 @@ async def create_source(
interface.clear()
try:
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.create_source(name=request.name, user_id=user_id)
source = server.create_source(name=request.name, user_id=user_id, description=request.description)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")

@router.post("/sources/{source_id}", tags=["sources"], response_model=SourceModel)
async def update_source(
source_id: uuid.UUID,
request: CreateSourceRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Update the name or documentation of an existing data source.
"""
interface.clear()
try:
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.update_source(source_id=source_id, name=request.name, user_id=user_id, description=request.description)
return SourceModel(
name=source.name,
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
Expand Down
18 changes: 16 additions & 2 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,18 +1375,32 @@ def create_api_key_for_user(self, user_id: uuid.UUID) -> Token:
token = self.ms.create_api_key(user_id=user_id)
return token

def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
def create_source(self, name: str, user_id: uuid.UUID, description: str = None) -> Source: # TODO: add other fields
"""Create a new data source"""
source = Source(
name=name,
user_id=user_id,
description=description,
embedding_model=self.config.default_embedding_config.embedding_model,
embedding_dim=self.config.default_embedding_config.embedding_dim,
)
self.ms.create_source(source)
assert self.ms.get_source(source_name=name, user_id=user_id) is not None, f"Failed to create source {name}"
return source

def update_source(self, source_id: uuid.UUID, name: str, user_id: uuid.UUID, description: str = None) -> Source:
"""Updates a data source"""
source = Source(
id=source_id,
name=name,
user_id=user_id,
description=description,
embedding_model=self.config.default_embedding_config.embedding_model,
embedding_dim=self.config.default_embedding_config.embedding_dim,
)
self.ms.update_source(source)
return source

def delete_source(self, source_id: uuid.UUID, user_id: uuid.UUID):
"""Delete a data source"""
source = self.ms.get_source(source_id=source_id, user_id=user_id)
Expand Down Expand Up @@ -1475,7 +1489,7 @@ def list_all_sources(self, user_id: uuid.UUID) -> List[SourceModel]:
sources = [
SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=self.server_embedding_config,
Expand Down

0 comments on commit 1f0e8fc

Please sign in to comment.