Skip to content

Commit

Permalink
all but one test passes for storageconnector
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Woodson committed Sep 3, 2024
1 parent 611efca commit a3103e1
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 34 deletions.
42 changes: 20 additions & 22 deletions memgpt/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None
def get_filters(self, filters: Optional[Dict] = {}):
filter_conditions = {**self.filters, **(filters or {})}
all_filters = [getattr(self.SQLModel, key) == value for key, value in filter_conditions.items() if hasattr(self.SQLModel, key)]
breakpoint()
return all_filters

def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0):
Expand Down Expand Up @@ -65,7 +64,7 @@ def get_all_cursor(

# generate query
with self.db_session as session:
query = session.query(self.SQLModel).filter(*filters)
query = select(self.SQLModel).filter(*filters).limit(limit)
# query = query.order_by(asc(self.SQLModel.id))

# records are sorted by the order_by field first, and then by the ID if two fields are the same
Expand All @@ -87,7 +86,7 @@ def get_all_cursor(
query = query.filter(or_(sort_exp, and_(getattr(self.SQLModel, order_by) == before_value, self.SQLModel.id < before)))

# get records
db_record_chunk = query.limit(limit).all()
db_record_chunk = session.execute(query).scalars()
if not db_record_chunk:
return (None, [])
records = [record.to_pydantic() for record in db_record_chunk]
Expand All @@ -103,10 +102,10 @@ def get_all(self, filters: Optional[Dict] = {}, limit=None):
query = select(self.SQLModel).filter(*filters)
if limit:
query = query.limit(limit)
breakpoint()
db_records = session.execute(query).all()
db_records = session.execute(query).scalars()

return [record.to_pydantic() for record in db_records]

return [record.to_pydantic() for record in db_records]

def get(self, id: str):
try:
Expand Down Expand Up @@ -159,13 +158,10 @@ def insert_many(self, records, exists_ok=True, show_progress=False):
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
filters = self.get_filters(filters)
with self.db_session as session:
results = session.scalars(
select(self.SQLModel).filter(*filters).order_by(self.SQLModel.embedding.l2_distance(query_vec)).limit(top_k)
).all()

# Convert the results into Pydantic objects
records = [result.to_pydantic() for result in results]
return records
query = select(self.SQLModel).filter(*filters).order_by(self.SQLModel.embedding.l2_distance(query_vec)).limit(top_k)
results = session.execute(query).scalars()

return [result.to_pydantic() for result in results]

def update(self, record: MemGPTBase):
"""Updates a record in the database based on the provided Pydantic Record object."""
Expand All @@ -181,7 +177,7 @@ def query_date(self, start_date, end_date, limit=None, offset=0):
filters = self.get_filters({})
with self.db_session as session:
query = (
session.query(self.SQLModel)
select(self.SQLModel)
.filter(*filters)
.filter(self.SQLModel.created_at >= start_date)
.filter(self.SQLModel.created_at <= end_date)
Expand All @@ -191,15 +187,15 @@ def query_date(self, start_date, end_date, limit=None, offset=0):
)
if limit:
query = query.limit(limit)
results = query.all()
return [result.to_pydantic() for result in results]
results = session.execute(query).scalars()
return [result.to_pydantic() for result in results]

def query_text(self, query, limit=None, offset=0):
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
filters = self.get_filters({})
with self.db_session as session:
query = (
session.query(self.SQLModel)
select(self.SQLModel)
.filter(*filters)
.filter(func.lower(self.SQLModel.text).contains(func.lower(query)))
.filter(self.SQLModel.role != "system")
Expand All @@ -208,8 +204,9 @@ def query_text(self, query, limit=None, offset=0):
)
if limit:
query = query.limit(limit)
results = query.all()
return [result.to_pydantic() for result in results]
results = session.execute(query).scalars()

return [result.to_pydantic() for result in results]


def delete(self, filters: Optional[Dict] = {}):
Expand Down Expand Up @@ -244,7 +241,7 @@ def query_date(self, start_date, end_date, limit=None, offset=0):
_end_date = self.str_to_datetime(end_date) if isinstance(end_date, str) else end_date
with self.db_session as session:
query = (
session.query(self.SQLModel)
select(self.SQLModel)
.filter(*filters)
.filter(self.SQLModel.created_at >= _start_date)
.filter(self.SQLModel.created_at <= _end_date)
Expand All @@ -254,8 +251,9 @@ def query_date(self, start_date, end_date, limit=None, offset=0):
)
if limit:
query = query.limit(limit)
results = query.all()
return [result.to_pydantic() for result in results]
results = session.execute(query).scalars()

return [result.to_pydantic() for result in results]


class SQLLiteStorageConnector(SQLStorageConnector):
Expand Down
6 changes: 3 additions & 3 deletions memgpt/agent_store/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from memgpt.config import MemGPTConfig

from memgpt.orm.base import Base as SQLBase
from memgpt.orm.sqlalchemy_base import SqlalchemyBase as SQLBase
from memgpt.orm.message import Message as SQLMessage
from memgpt.orm.passage import Passage as SQLPassage
from memgpt.orm.document import Document as SQLDocument
Expand All @@ -19,7 +19,7 @@

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from memgpt.orm.base import Base as SQLBase
from memgpt.orm.sqlalchemy_base import SqlalchemyBase as SQLBase



Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
if self.table_type == TableType.ARCHIVAL_MEMORY or self.table_type == TableType.RECALL_MEMORY:
# agent-specific table
assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
self.filters = {"user_id": self.user_id, "_agent_id": self.agent_id}
self.filters = {"user_id": self.user_id, "_agent_id": SQLBase.to_uid(self.agent_id, True)}
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS:
# setup base filters for user-specific tables
self.filters = {"user_id": self.user_id}
Expand Down
2 changes: 1 addition & 1 deletion memgpt/orm/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def list(cls, *, db_session: "Session", **kwargs) -> List[Type["SqlalchemyBase"]
return list(session.execute(query).scalars())

@classmethod
def to_uid(cls, identifier, indifferent:Optional[bool] = False) -> "UUID":
def to_uid(cls, identifier, indifferent: Optional[bool] = False) -> "UUID":
"""converts the id into a uuid object
Args:
indifferent: if True, will not enforce the prefix check
Expand Down
4 changes: 0 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,5 @@ def patch_local_db_calls(test_get_db_session: Callable):
"memgpt.server.rest_api.app.get_db_session",
test_get_db_session,
),
patch(
"memgpt.server.rest_api.utils.get_db_session",
test_get_db_session,
),
):
yield
4 changes: 0 additions & 4 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import pytest
from faker import Faker

from memgpt.settings import settings
from memgpt import Admin, create_client
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.enums import JobStatus
from memgpt.schemas.message import Message
from memgpt.schemas.usage import MemGPTUsageStatistics

Expand Down

0 comments on commit a3103e1

Please sign in to comment.