Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: deletes on the fly embeddings and uses edge collections #436

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 6 additions & 32 deletions cognee/modules/graph/cognee_graph/CogneeGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
import heapq
from graphistry import edges
import asyncio


class CogneeGraph(CogneeAbstractGraph):
Expand Down Expand Up @@ -127,51 +127,25 @@ async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
else:
print(f"Node with id {node_id} not found in the graph.")

async def map_vector_distances_to_graph_edges(
self, vector_engine, query
) -> None: # :TODO: When we calculate edge embeddings in vector db change this similarly to node mapping
async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None:
try:
# Step 1: Generate the query embedding
query_vector = await vector_engine.embed_data([query])
query_vector = query_vector[0]
if query_vector is None or len(query_vector) == 0:
raise ValueError("Failed to generate query embedding.")

# Step 2: Collect all unique relationship types
unique_relationship_types = set()
for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
if relationship_type:
unique_relationship_types.add(relationship_type)

# Step 3: Embed all unique relationship types
unique_relationship_types = list(unique_relationship_types)
relationship_type_embeddings = await vector_engine.embed_data(unique_relationship_types)

# Step 4: Map relationship types to their embeddings and calculate distances
embedding_map = {}
for relationship_type, embedding in zip(
unique_relationship_types, relationship_type_embeddings
):
edge_vector = np.array(embedding)

# Calculate cosine similarity
similarity = np.dot(query_vector, edge_vector) / (
np.linalg.norm(query_vector) * np.linalg.norm(edge_vector)
)
distance = 1 - similarity
edge_distances = await vector_engine.get_distance_from_collection_elements(
"edge_type_relationship_name", query_text=query
)
Comment on lines +130 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance error handling and logging

The method has several areas that could benefit from improved error handling:

  1. Query vector validation could raise a custom exception
  2. Collection name should be configurable
  3. Consider using proper logging instead of print statements

Consider this improvement:

 async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None:
+    EDGE_TYPE_COLLECTION = "edge_type_relationship_name"  # Move to config
     try:
         query_vector = await vector_engine.embed_data([query])
         query_vector = query_vector[0]
         if query_vector is None or len(query_vector) == 0:
-            raise ValueError("Failed to generate query embedding.")
+            raise InvalidValueError("Failed to generate query embedding: empty or null vector")

         edge_distances = await vector_engine.get_distance_from_collection_elements(
-            "edge_type_relationship_name", query_text=query
+            EDGE_TYPE_COLLECTION, query_text=query
         )
+        if not edge_distances:
+            raise InvalidValueError(f"No distances retrieved from {EDGE_TYPE_COLLECTION}")

Committable suggestion skipped: line range outside the PR's diff.


# Round the distance to 4 decimal places and store it
embedding_map[relationship_type] = round(distance, 4)
embedding_map = {result.payload["text"]: result.score for result in edge_distances}

# Step 4: Assign precomputed distances to edges
for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
if not relationship_type or relationship_type not in embedding_map:
print(f"Edge {edge} has an unknown or missing relationship type.")
continue

# Assign the precomputed distance
edge.attributes["vector_distance"] = embedding_map[relationship_type]
Comment on lines +141 to 149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve edge distance mapping robustness

The current implementation has potential issues:

  1. Silent failures with print statements
  2. No validation of embedding_map values
  3. No handling of edge cases (empty results, invalid scores)

Consider this improvement:

-        embedding_map = {result.payload["text"]: result.score for result in edge_distances}
+        embedding_map = {}
+        for result in edge_distances:
+            if "text" not in result.payload:
+                raise InvalidValueError(f"Missing 'text' in payload: {result.payload}")
+            if not isinstance(result.score, (int, float)):
+                raise InvalidValueError(f"Invalid score type: {type(result.score)}")
+            embedding_map[result.payload["text"]] = result.score

         for edge in self.edges:
             relationship_type = edge.attributes.get("relationship_type")
             if not relationship_type or relationship_type not in embedding_map:
-                print(f"Edge {edge} has an unknown or missing relationship type.")
+                logging.warning("Edge %s has an unknown or missing relationship type", edge)
                 continue

             edge.attributes["vector_distance"] = embedding_map[relationship_type]

Committable suggestion skipped: line range outside the PR's diff.


except Exception as ex:
Expand Down
25 changes: 1 addition & 24 deletions cognee/modules/retrieval/brute_force_triplet_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,6 @@ async def brute_force_triplet_search(
return retrieved_results


def delete_duplicated_vector_db_elements(
collections, results
): #:TODO: This is just for now to fix vector db duplicates
results_dict = {}
for collection, results in zip(collections, results):
seen_ids = set()
unique_results = []
for result in results:
if result.id not in seen_ids:
unique_results.append(result)
seen_ids.add(result.id)
else:
print(f"Duplicate found in collection '{collection}': {result.id}")
results_dict[collection] = unique_results

return results_dict


async def brute_force_search(
query: str, user: User, top_k: int, collections: List[str] = None
) -> list:
Expand Down Expand Up @@ -125,10 +107,7 @@ async def brute_force_search(
]
)

############################################# :TODO: Change when vector db does not contain duplicates
node_distances = delete_duplicated_vector_db_elements(collections, results)
# node_distances = {collection: result for collection, result in zip(collections, results)}
##############################################
node_distances = {collection: result for collection, result in zip(collections, results)}

memory_fragment = CogneeGraph()

Expand All @@ -140,14 +119,12 @@ async def brute_force_search(

await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)

#:TODO: Change when vectordb contains edge embeddings
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)

results = await memory_fragment.calculate_top_triplet_importances(k=top_k)

send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)

#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
return results

except Exception as e:
Expand Down
Loading