diff --git a/.env.template b/.env.template index acdac04c..77c845d4 100644 --- a/.env.template +++ b/.env.template @@ -7,24 +7,26 @@ GRAPHISTRY_PASSWORD= SENTRY_REPORTING_URL= -GRAPH_DATABASE_PROVIDER="neo4j" # or "networkx" +# "neo4j" or "networkx" +GRAPH_DATABASE_PROVIDER="neo4j" # Not needed if using networkx GRAPH_DATABASE_URL= GRAPH_DATABASE_USERNAME= GRAPH_DATABASE_PASSWORD= -VECTOR_DB_PROVIDER="qdrant" # or "weaviate" or "lancedb" -# Not needed if using "lancedb" +# "qdrant", "pgvector", "weaviate" or "lancedb" +VECTOR_DB_PROVIDER="qdrant" +# Not needed if using "lancedb" or "pgvector" VECTOR_DB_URL= VECTOR_DB_KEY= -# Database provider -DB_PROVIDER="sqlite" # or "postgres" +# Relational Database provider "sqlite" or "postgres" +DB_PROVIDER="sqlite" # Database name DB_NAME=cognee_db -# Postgres specific parameters (Only if Postgres is run) +# Postgres specific parameters (Only if Postgres or PGVector is used) DB_HOST=127.0.0.1 DB_PORT=5432 DB_USERNAME=cognee diff --git a/.github/workflows/test_pgvector.yml b/.github/workflows/test_pgvector.yml new file mode 100644 index 00000000..913d249e --- /dev/null +++ b/.github/workflows/test_pgvector.yml @@ -0,0 +1,67 @@ +name: test | pgvector + +on: + pull_request: + branches: + - main + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + RUNTIME__LOG_LEVEL: ERROR + +jobs: + get_docs_changes: + name: docs changes + uses: ./.github/workflows/get_docs_changes.yml + + run_pgvector_integration_test: + name: test + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + runs-on: ubuntu-latest + defaults: + run: + shell: bash + services: + postgres: + image: pgvector/pgvector:pg17 + env: + POSTGRES_USER: cognee + POSTGRES_PASSWORD: cognee + POSTGRES_DB: cognee_db + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + steps: + - name: Check out + uses: actions/checkout@master + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.11.x' + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Install dependencies + run: poetry install -E postgres --no-interaction + + - name: Run default PGVector + env: + ENV: 'dev' + LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: poetry run python ./cognee/tests/test_pgvector.py diff --git a/README.md b/README.md index 76cd833b..350de80c 100644 --- a/README.md +++ b/README.md @@ -190,11 +190,11 @@ Cognee supports a variety of tools and services for different operations: - **Local Setup**: By default, LanceDB runs locally with NetworkX and OpenAI. -- **Vector Stores**: Cognee supports Qdrant and Weaviate for vector storage. +- **Vector Stores**: Cognee supports LanceDB, Qdrant, PGVector and Weaviate for vector storage. - **Language Models (LLMs)**: You can use either Anyscale or Ollama as your LLM provider. -- **Graph Stores**: In addition to LanceDB, Neo4j is also supported for graph storage. +- **Graph Stores**: In addition to NetworkX, Neo4j is also supported for graph storage. - **User management**: Create individual user graphs and manage permissions diff --git a/cognee/api/client.py b/cognee/api/client.py index b8d56bc5..4b41f057 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -374,7 +374,7 @@ class LLMConfigDTO(InDTO): api_key: str class VectorDBConfigDTO(InDTO): - provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]] + provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]] url: str api_key: str diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 85f0688a..10430ed8 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -8,15 +8,18 @@ from cognee.modules.ingestion import get_matched_datasets, save_data_to_file from cognee.shared.utils import send_telemetry from cognee.base_config import get_base_config -from cognee.infrastructure.databases.relational import get_relational_engine, create_db_and_tables +from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.users.methods import get_default_user from cognee.tasks.ingestion import get_dlt_destination from cognee.modules.users.permissions.methods import give_permission_on_document from cognee.modules.users.models import User from cognee.modules.data.methods import create_dataset +from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables +from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = "main_dataset", user: User = None): - await create_db_and_tables() + await create_relational_db_and_tables() + await create_pgvector_db_and_tables() if isinstance(data, str): if "data://" in data: diff --git a/cognee/api/v1/add/add_v2.py b/cognee/api/v1/add/add_v2.py index 291ec5f4..4d43dd65 100644 --- a/cognee/api/v1/add/add_v2.py +++ b/cognee/api/v1/add/add_v2.py @@ -3,10 +3,12 @@ from cognee.modules.users.methods import get_default_user from cognee.modules.pipelines import run_tasks, Task from cognee.tasks.ingestion import save_data_to_storage, ingest_data -from cognee.infrastructure.databases.relational import create_db_and_tables +from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables +from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables async def add(data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_name: str = "main_dataset", user: User = None): - await create_db_and_tables() + await create_relational_db_and_tables() + await create_pgvector_db_and_tables() if user is None: user = await get_default_user() diff --git a/cognee/api/v1/config/config.py b/cognee/api/v1/config/config.py index 225f6781..2f4167b7 100644 --- a/cognee/api/v1/config/config.py +++ b/cognee/api/v1/config/config.py @@ -95,6 +95,30 @@ def set_vector_db_provider(vector_db_provider: str): vector_db_config = get_vectordb_config() vector_db_config.vector_db_provider = vector_db_provider + @staticmethod + def set_relational_db_config(config_dict: dict): + """ + Updates the relational db config with values from config_dict. + """ + relational_db_config = get_relational_config() + for key, value in config_dict.items(): + if hasattr(relational_db_config, key): + object.__setattr__(relational_db_config, key, value) + else: + raise AttributeError(f"'{key}' is not a valid attribute of the config.") + + @staticmethod + def set_vector_db_config(config_dict: dict): + """ + Updates the vector db config with values from config_dict. + """ + vector_db_config = get_vectordb_config() + for key, value in config_dict.items(): + if hasattr(vector_db_config, key): + object.__setattr__(vector_db_config, key, value) + else: + raise AttributeError(f"'{key}' is not a valid attribute of the config.") + @staticmethod def set_vector_db_key(db_key: str): vector_db_config = get_vectordb_config() diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 36302bce..81a828bd 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -119,6 +119,8 @@ async def delete_database(self): self.db_path = None else: async with self.engine.begin() as connection: + # Load the schema information into the MetaData object + await connection.run_sync(Base.metadata.reflect) for table in Base.metadata.sorted_tables: drop_table_query = text(f"DROP TABLE IF EXISTS {table.name} CASCADE") await connection.execute(drop_table_query) diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 19859ae4..f0cbfcd5 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,5 +1,7 @@ from typing import Dict +from ..relational.config import get_relational_config + class VectorConfig(Dict): vector_db_url: str vector_db_key: str @@ -26,6 +28,25 @@ def create_vector_engine(config: VectorConfig, embedding_engine): api_key = config["vector_db_key"], embedding_engine = embedding_engine ) + elif config["vector_db_provider"] == "pgvector": + from .pgvector.PGVectorAdapter import PGVectorAdapter + + # Get configuration for postgres database + relational_config = get_relational_config() + db_username = relational_config.db_username + db_password = relational_config.db_password + db_host = relational_config.db_host + db_port = relational_config.db_port + db_name = relational_config.db_name + + connection_string: str = ( + f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" + ) + + return PGVectorAdapter(connection_string, + config["vector_db_key"], + embedding_engine + ) else: from .lancedb.LanceDBAdapter import LanceDBAdapter diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 3bb47fcc..40463448 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -152,7 +152,7 @@ async def batch_search( ): query_vectors = await self.embedding_engine.embed_text(query_texts) - return asyncio.gather( + return await asyncio.gather( *[self.search( collection_name = collection_name, query_vector = query_vector, diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py new file mode 100644 index 00000000..b13346cf --- /dev/null +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -0,0 +1,222 @@ +import asyncio +from pgvector.sqlalchemy import Vector +from typing import List, Optional, get_type_hints +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy import JSON, Column, Table, select, delete +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker + +from .serialize_datetime import serialize_datetime +from ..models.ScoredResult import ScoredResult +from ..vector_db_interface import VectorDBInterface, DataPoint +from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter +from ...relational.ModelBase import Base + + +class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): + + def __init__( + self, + connection_string: str, + api_key: Optional[str], + embedding_engine: EmbeddingEngine, + ): + self.api_key = api_key + self.embedding_engine = embedding_engine + self.db_uri: str = connection_string + + self.engine = create_async_engine(self.db_uri) + self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) + + async def embed_data(self, data: list[str]) -> list[list[float]]: + return await self.embedding_engine.embed_text(data) + + async def has_collection(self, collection_name: str) -> bool: + async with self.engine.begin() as connection: + # Load the schema information into the MetaData object + await connection.run_sync(Base.metadata.reflect) + + if collection_name in Base.metadata.tables: + return True + else: + return False + + async def create_collection(self, collection_name: str, payload_schema=None): + data_point_types = get_type_hints(DataPoint) + vector_size = self.embedding_engine.get_vector_size() + + if not await self.has_collection(collection_name): + + class PGVectorDataPoint(Base): + __tablename__ = collection_name + __table_args__ = {"extend_existing": True} + # PGVector requires one column to be the primary key + primary_key: Mapped[int] = mapped_column( + primary_key=True, autoincrement=True + ) + id: Mapped[data_point_types["id"]] + payload = Column(JSON) + vector = Column(Vector(vector_size)) + + def __init__(self, id, payload, vector): + self.id = id + self.payload = payload + self.vector = vector + + async with self.engine.begin() as connection: + if len(Base.metadata.tables.keys()) > 0: + await connection.run_sync( + Base.metadata.create_all, tables=[PGVectorDataPoint.__table__] + ) + + async def create_data_points( + self, collection_name: str, data_points: List[DataPoint] + ): + async with self.get_async_session() as session: + if not await self.has_collection(collection_name): + await self.create_collection( + collection_name=collection_name, + payload_schema=type(data_points[0].payload), + ) + + data_vectors = await self.embed_data( + [data_point.get_embeddable_data() for data_point in data_points] + ) + + vector_size = self.embedding_engine.get_vector_size() + + class PGVectorDataPoint(Base): + __tablename__ = collection_name + __table_args__ = {"extend_existing": True} + # PGVector requires one column to be the primary key + primary_key: Mapped[int] = mapped_column( + primary_key=True, autoincrement=True + ) + id: Mapped[type(data_points[0].id)] + payload = Column(JSON) + vector = Column(Vector(vector_size)) + + def __init__(self, id, payload, vector): + self.id = id + self.payload = payload + self.vector = vector + + pgvector_data_points = [ + PGVectorDataPoint( + id=data_point.id, + vector=data_vectors[data_index], + payload=serialize_datetime(data_point.payload.dict()), + ) + for (data_index, data_point) in enumerate(data_points) + ] + + session.add_all(pgvector_data_points) + await session.commit() + + async def get_table(self, collection_name: str) -> Table: + """ + Dynamically loads a table using the given collection name + with an async engine. + """ + async with self.engine.begin() as connection: + # Load the schema information into the MetaData object + await connection.run_sync(Base.metadata.reflect) + if collection_name in Base.metadata.tables: + return Base.metadata.tables[collection_name] + else: + raise ValueError(f"Table '{collection_name}' not found.") + + async def retrieve(self, collection_name: str, data_point_ids: List[str]): + async with self.get_async_session() as session: + # Get PGVectorDataPoint Table from database + PGVectorDataPoint = await self.get_table(collection_name) + + results = await session.execute( + select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids)) + ) + results = results.all() + + return [ + ScoredResult(id=result.id, payload=result.payload, score=0) + for result in results + ] + + async def search( + self, + collection_name: str, + query_text: Optional[str] = None, + query_vector: Optional[List[float]] = None, + limit: int = 5, + with_vector: bool = False, + ) -> List[ScoredResult]: + if query_text is None and query_vector is None: + raise ValueError("One of query_text or query_vector must be provided!") + + if query_text and not query_vector: + query_vector = (await self.embedding_engine.embed_text([query_text]))[0] + + # Use async session to connect to the database + async with self.get_async_session() as session: + # Get PGVectorDataPoint Table from database + PGVectorDataPoint = await self.get_table(collection_name) + + # Find closest vectors to query_vector + closest_items = await session.execute( + select( + PGVectorDataPoint, + PGVectorDataPoint.c.vector.cosine_distance(query_vector).label( + "similarity" + ), + ) + .order_by("similarity") + .limit(limit) + ) + + vector_list = [] + # Extract distances and find min/max for normalization + for vector in closest_items: + # TODO: Add normalization of similarity score + vector_list.append(vector) + + # Create and return ScoredResult objects + return [ + ScoredResult( + id=str(row.id), payload=row.payload, score=row.similarity + ) + for row in vector_list + ] + + async def batch_search( + self, + collection_name: str, + query_texts: List[str], + limit: int = None, + with_vectors: bool = False, + ): + query_vectors = await self.embedding_engine.embed_text(query_texts) + + return await asyncio.gather( + *[ + self.search( + collection_name=collection_name, + query_vector=query_vector, + limit=limit, + with_vector=with_vectors, + ) + for query_vector in query_vectors + ] + ) + + async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + async with self.get_async_session() as session: + # Get PGVectorDataPoint Table from database + PGVectorDataPoint = await self.get_table(collection_name) + results = await session.execute( + delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids)) + ) + await session.commit() + return results + + async def prune(self): + # Clean up the database if it was set up as temporary + await self.delete_database() diff --git a/cognee/infrastructure/databases/vector/pgvector/__init__.py b/cognee/infrastructure/databases/vector/pgvector/__init__.py new file mode 100644 index 00000000..130246a3 --- /dev/null +++ b/cognee/infrastructure/databases/vector/pgvector/__init__.py @@ -0,0 +1,2 @@ +from .PGVectorAdapter import PGVectorAdapter +from .create_db_and_tables import create_db_and_tables \ No newline at end of file diff --git a/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py b/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py new file mode 100644 index 00000000..99d53d69 --- /dev/null +++ b/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py @@ -0,0 +1,14 @@ +from ...relational.ModelBase import Base +from ..get_vector_engine import get_vector_engine, get_vectordb_config +from sqlalchemy import text + +async def create_db_and_tables(): + vector_config = get_vectordb_config() + vector_engine = get_vector_engine() + + if vector_config.vector_db_provider == "pgvector": + vector_engine.create_database() + async with vector_engine.engine.begin() as connection: + await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + + diff --git a/cognee/infrastructure/databases/vector/pgvector/serialize_datetime.py b/cognee/infrastructure/databases/vector/pgvector/serialize_datetime.py new file mode 100644 index 00000000..9cb979e2 --- /dev/null +++ b/cognee/infrastructure/databases/vector/pgvector/serialize_datetime.py @@ -0,0 +1,12 @@ +from datetime import datetime + +def serialize_datetime(data): + """Recursively convert datetime objects in dictionaries/lists to ISO format.""" + if isinstance(data, dict): + return {key: serialize_datetime(value) for key, value in data.items()} + elif isinstance(data, list): + return [serialize_datetime(item) for item in data] + elif isinstance(data, datetime): + return data.isoformat() # Convert datetime to ISO 8601 string + else: + return data \ No newline at end of file diff --git a/cognee/modules/settings/get_settings.py b/cognee/modules/settings/get_settings.py index fccbc316..95f2f592 100644 --- a/cognee/modules/settings/get_settings.py +++ b/cognee/modules/settings/get_settings.py @@ -41,6 +41,9 @@ def get_settings() -> SettingsDict: }, { "value": "lancedb", "label": "LanceDB", + }, { + "value": "pgvector", + "label": "PGVector", }] vector_config = get_vectordb_config() diff --git a/cognee/modules/settings/save_vector_db_config.py b/cognee/modules/settings/save_vector_db_config.py index 1a5895bc..1e0b683e 100644 --- a/cognee/modules/settings/save_vector_db_config.py +++ b/cognee/modules/settings/save_vector_db_config.py @@ -5,7 +5,7 @@ class VectorDBConfig(BaseModel): url: str api_key: str - provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]] + provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]] async def save_vector_db_config(vector_db_config: VectorDBConfig): vector_config = get_vectordb_config() diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py new file mode 100644 index 00000000..02d292d6 --- /dev/null +++ b/cognee/tests/test_pgvector.py @@ -0,0 +1,93 @@ +import os +import logging +import pathlib +import cognee +from cognee.api.v1.search import SearchType + +logging.basicConfig(level=logging.DEBUG) + + +async def main(): + cognee.config.set_vector_db_config( + { + "vector_db_url": "", + "vector_db_key": "", + "vector_db_provider": "pgvector" + } + ) + cognee.config.set_relational_db_config( + { + "db_path": "", + "db_name": "cognee_db", + "db_host": "127.0.0.1", + "db_port": "5432", + "db_username": "cognee", + "db_password": "cognee", + "db_provider": "postgres", + } + ) + + data_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_pgvector") + ).resolve() + ) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_pgvector") + ).resolve() + ) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + dataset_name = "cs_explanations" + + explanation_file_path = os.path.join( + pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt" + ) + await cognee.add([explanation_file_path], dataset_name) + + text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena. + At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states. + Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible. + The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly. + Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate. + In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited. + """ + + await cognee.add([text], dataset_name) + + await cognee.cognify([dataset_name]) + + from cognee.infrastructure.databases.vector import get_vector_engine + + vector_engine = get_vector_engine() + random_node = (await vector_engine.search("entities", "AI"))[0] + random_node_name = random_node.payload["name"] + + search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted sentences are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search(SearchType.CHUNKS, query=random_node_name) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted chunks are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search(SearchType.SUMMARIES, query=random_node_name) + assert len(search_results) != 0, "Query related summaries don't exist." + print("\n\nExtracted summaries are:\n") + for result in search_results: + print(f"{result}\n") + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/docker-compose.yml b/docker-compose.yml index 2ef05170..426b178a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -62,7 +62,7 @@ services: - cognee-network postgres: - image: postgres:latest + image: pgvector/pgvector:pg17 container_name: postgres environment: POSTGRES_USER: cognee diff --git a/poetry.lock b/poetry.lock index 03dcc023..acd56e02 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4656,6 +4656,20 @@ files = [ [package.dependencies] ptyprocess = ">=0.5" +[[package]] +name = "pgvector" +version = "0.3.5" +description = "pgvector support for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pgvector-0.3.5-py3-none-any.whl", hash = "sha256:56cca90392e596ea18873c593ec858a1984a77d16d1f82b8d0c180e79ef1018f"}, + {file = "pgvector-0.3.5.tar.gz", hash = "sha256:e876c9ee382c4c2f7ee57691a4c4015d688c7222e47448ce310ded03ecfafe2f"}, +] + +[package.dependencies] +numpy = "*" + [[package]] name = "pillow" version = "10.4.0" @@ -4918,7 +4932,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "psycopg2" version = "2.9.10" description = "psycopg2 - Python-PostgreSQL Database Adapter" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "psycopg2-2.9.10-cp310-cp310-win32.whl", hash = "sha256:5df2b672140f95adb453af93a7d669d7a7bf0a56bcd26f1502329166f4a61716"}, @@ -7745,10 +7759,11 @@ cli = [] filesystem = [] neo4j = ["neo4j"] notebook = ["overrides"] +postgres = ["psycopg2"] qdrant = ["qdrant-client"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "4cba654100a455c8691dd3d4e1b588f00bbb2acca89168954037017b3a6aced9" +content-hash = "70a0072dce8de95d64b862f9a9df48aaec84c8d8515ae018fce4426a0dcacf88" diff --git a/pyproject.toml b/pyproject.toml index ab686cb8..22074959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,8 +70,8 @@ sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"} fastapi-users = { version = "*", extras = ["sqlalchemy"] } asyncpg = "^0.29.0" alembic = "^1.13.3" -psycopg2 = "^2.9.10" - +pgvector = "^0.3.5" +psycopg2 = {version = "^2.9.10", optional = true} [tool.poetry.extras] filesystem = ["s3fs", "botocore"] @@ -79,6 +79,7 @@ cli = ["pipdeptree", "cron-descriptor"] weaviate = ["weaviate-client"] qdrant = ["qdrant-client"] neo4j = ["neo4j"] +postgres = ["psycopg2"] notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] [tool.poetry.group.dev.dependencies] @@ -104,7 +105,6 @@ diskcache = "^5.6.3" pandas = "2.0.3" tabulate = "^0.9.0" - [tool.ruff] # https://beta.ruff.rs/docs/ line-length = 100 ignore = ["F401"]