From 07fcce7970408afb9b28d777b8295d39cb9ec619 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 13 Jan 2025 11:46:46 +0100 Subject: [PATCH] feat: deletes on the fly embeddings as uses edge collections --- .../modules/graph/cognee_graph/CogneeGraph.py | 38 +++---------------- .../retrieval/brute_force_triplet_search.py | 25 +----------- 2 files changed, 7 insertions(+), 56 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 279a73b19..491f83b5a 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -8,7 +8,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph import heapq -from graphistry import edges +import asyncio class CogneeGraph(CogneeAbstractGraph): @@ -127,51 +127,25 @@ async def map_vector_distances_to_graph_nodes(self, node_distances) -> None: else: print(f"Node with id {node_id} not found in the graph.") - async def map_vector_distances_to_graph_edges( - self, vector_engine, query - ) -> None: # :TODO: When we calculate edge embeddings in vector db change this similarly to node mapping + async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None: try: - # Step 1: Generate the query embedding query_vector = await vector_engine.embed_data([query]) query_vector = query_vector[0] if query_vector is None or len(query_vector) == 0: raise ValueError("Failed to generate query embedding.") - # Step 2: Collect all unique relationship types - unique_relationship_types = set() - for edge in self.edges: - relationship_type = edge.attributes.get("relationship_type") - if relationship_type: - unique_relationship_types.add(relationship_type) - - # Step 3: Embed all unique relationship types - unique_relationship_types = list(unique_relationship_types) - relationship_type_embeddings = await vector_engine.embed_data(unique_relationship_types) - - # Step 4: Map relationship types to their embeddings and calculate distances - embedding_map = {} - for relationship_type, embedding in zip( - unique_relationship_types, relationship_type_embeddings - ): - edge_vector = np.array(embedding) - - # Calculate cosine similarity - similarity = np.dot(query_vector, edge_vector) / ( - np.linalg.norm(query_vector) * np.linalg.norm(edge_vector) - ) - distance = 1 - similarity + edge_distances = await vector_engine.get_distance_from_collection_elements( + "edge_type_relationship_name", query_text=query + ) - # Round the distance to 4 decimal places and store it - embedding_map[relationship_type] = round(distance, 4) + embedding_map = {result.payload["text"]: result.score for result in edge_distances} - # Step 4: Assign precomputed distances to edges for edge in self.edges: relationship_type = edge.attributes.get("relationship_type") if not relationship_type or relationship_type not in embedding_map: print(f"Edge {edge} has an unknown or missing relationship type.") continue - # Assign the precomputed distance edge.attributes["vector_distance"] = embedding_map[relationship_type] except Exception as ex: diff --git a/cognee/modules/retrieval/brute_force_triplet_search.py b/cognee/modules/retrieval/brute_force_triplet_search.py index 9c778505d..c27e90766 100644 --- a/cognee/modules/retrieval/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/brute_force_triplet_search.py @@ -62,24 +62,6 @@ async def brute_force_triplet_search( return retrieved_results -def delete_duplicated_vector_db_elements( - collections, results -): #:TODO: This is just for now to fix vector db duplicates - results_dict = {} - for collection, results in zip(collections, results): - seen_ids = set() - unique_results = [] - for result in results: - if result.id not in seen_ids: - unique_results.append(result) - seen_ids.add(result.id) - else: - print(f"Duplicate found in collection '{collection}': {result.id}") - results_dict[collection] = unique_results - - return results_dict - - async def brute_force_search( query: str, user: User, top_k: int, collections: List[str] = None ) -> list: @@ -125,10 +107,7 @@ async def brute_force_search( ] ) - ############################################# :TODO: Change when vector db does not contain duplicates - node_distances = delete_duplicated_vector_db_elements(collections, results) - # node_distances = {collection: result for collection, result in zip(collections, results)} - ############################################## + node_distances = {collection: result for collection, result in zip(collections, results)} memory_fragment = CogneeGraph() @@ -140,14 +119,12 @@ async def brute_force_search( await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) - #:TODO: Change when vectordb contains edge embeddings await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query) results = await memory_fragment.calculate_top_triplet_importances(k=top_k) send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id) - #:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db return results except Exception as e: