Skip to content

Commit

Permalink
fix: Resolve issue regrading not having Vector column type defined wh…
Browse files Browse the repository at this point in the history
…en using vector search

Issue happens when search is called in a session without previously adding data or creating tables as an import of Vector column type was missing

Fix
  • Loading branch information
dexters1 committed Dec 12, 2024
1 parent 92ecd8a commit 599e1d4
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,6 @@ class IndexSchema(DataPoint):
"index_fields": ["text"]
}

def singleton(class_):
# Note: Using this singleton as a decorator to a class removes
# the option to use class methods for that class
instances = {}

def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]

return getinstance

@singleton
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):

def __init__(
Expand All @@ -51,6 +38,11 @@ def __init__(
self.engine = create_async_engine(self.db_uri)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)

# Has to be imported at class level
# Functions reading tables from database need to know what a Vector column type is
from pgvector.sqlalchemy import Vector
self.Vector = Vector

async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)

Expand All @@ -70,7 +62,6 @@ async def create_collection(self, collection_name: str, payload_schema=None):

if not await self.has_collection(collection_name):

from pgvector.sqlalchemy import Vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
Expand All @@ -80,7 +71,7 @@ class PGVectorDataPoint(Base):
)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(Vector(vector_size))
vector = Column(self.Vector(vector_size))

def __init__(self, id, payload, vector):
self.id = id
Expand Down Expand Up @@ -108,7 +99,6 @@ async def create_data_points(

vector_size = self.embedding_engine.get_vector_size()

from pgvector.sqlalchemy import Vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
Expand All @@ -118,7 +108,7 @@ class PGVectorDataPoint(Base):
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
vector = Column(self.Vector(vector_size))

def __init__(self, id, payload, vector):
self.id = id
Expand Down

0 comments on commit 599e1d4

Please sign in to comment.