From 6841c83566dac3354d2fa127aa3c6fc9713e0586 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:18:55 +0100 Subject: [PATCH] fix: fixes cognify duplicated edges and resets the methods to an older version --- .../graph/utils/get_graph_from_model.py | 149 ++++++++---------- .../utils/get_model_instance_from_graph.py | 44 ++---- 2 files changed, 83 insertions(+), 110 deletions(-) diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 770e63d0..29137ddc 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -1,16 +1,8 @@ from datetime import datetime, timezone - from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model - -def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None): - - if not added_nodes: - added_nodes = {} - if not added_edges: - added_edges = {} - +def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}): nodes = [] edges = [] @@ -20,94 +12,87 @@ def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=No for field_name, field_value in data_point: if field_name == "_metadata": continue - elif isinstance(field_value, DataPoint): + + if isinstance(field_value, DataPoint): excluded_properties.add(field_name) - nodes, edges, added_nodes, added_edges = add_nodes_and_edges( - data_point, - field_name, - field_value, - nodes, - edges, - added_nodes, - added_edges, - ) - - elif ( - isinstance(field_value, list) - and len(field_value) > 0 - and isinstance(field_value[0], DataPoint) - ): + + property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) + + for node in property_nodes: + if str(node.id) not in added_nodes: + nodes.append(node) + added_nodes[str(node.id)] = True + + for edge in property_edges: + edge_key = str(edge[0]) + str(edge[1]) + edge[2] + + if str(edge_key) not in added_edges: + edges.append(edge) + added_edges[str(edge_key)] = True + + for property_node in get_own_properties(property_nodes, property_edges): + edge_key = str(data_point.id) + str(property_node.id) + field_name + + if str(edge_key) not in added_edges: + edges.append((data_point.id, property_node.id, field_name, { + "source_node_id": data_point.id, + "target_node_id": property_node.id, + "relationship_name": field_name, + "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + })) + added_edges[str(edge_key)] = True + continue + + if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): excluded_properties.add(field_name) for item in field_value: - n_edges_before = len(edges) - nodes, edges, added_nodes, added_edges = add_nodes_and_edges( - data_point, field_name, item, nodes, edges, added_nodes, added_edges - ) - edges = edges[:n_edges_before] + [ - (*edge[:3], {**edge[3], "metadata": {"type": "list"}}) - for edge in edges[n_edges_before:] - ] - else: - data_point_properties[field_name] = field_value + property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) + + for node in property_nodes: + if str(node.id) not in added_nodes: + nodes.append(node) + added_nodes[str(node.id)] = True + + for edge in property_edges: + edge_key = str(edge[0]) + str(edge[1]) + edge[2] + + if str(edge_key) not in added_edges: + edges.append(edge) + added_edges[edge_key] = True + + for property_node in get_own_properties(property_nodes, property_edges): + edge_key = str(data_point.id) + str(property_node.id) + field_name + + if str(edge_key) not in added_edges: + edges.append((data_point.id, property_node.id, field_name, { + "source_node_id": data_point.id, + "target_node_id": property_node.id, + "relationship_name": field_name, + "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + "metadata": { + "type": "list" + }, + })) + added_edges[edge_key] = True + continue + + data_point_properties[field_name] = field_value SimpleDataPointModel = copy_model( type(data_point), - include_fields={ + include_fields = { "_metadata": (dict, data_point._metadata), }, - exclude_fields=excluded_properties, + exclude_fields = excluded_properties, ) - nodes.append(SimpleDataPointModel(**data_point_properties)) + if include_root: + nodes.append(SimpleDataPointModel(**data_point_properties)) return nodes, edges -def add_nodes_and_edges( - data_point, field_name, field_value, nodes, edges, added_nodes, added_edges -): - - property_nodes, property_edges = get_graph_from_model( - field_value, dict(added_nodes), dict(added_edges) - ) - - for node in property_nodes: - if str(node.id) not in added_nodes: - nodes.append(node) - added_nodes[str(node.id)] = True - - for edge in property_edges: - edge_key = str(edge[0]) + str(edge[1]) + edge[2] - - if str(edge_key) not in added_edges: - edges.append(edge) - added_edges[str(edge_key)] = True - - for property_node in get_own_properties(property_nodes, property_edges): - edge_key = str(data_point.id) + str(property_node.id) + field_name - - if str(edge_key) not in added_edges: - edges.append( - ( - data_point.id, - property_node.id, - field_name, - { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime( - "%Y-%m-%d %H:%M:%S" - ), - }, - ) - ) - added_edges[str(edge_key)] = True - - return (nodes, edges, added_nodes, added_edges) - - def get_own_properties(property_nodes, property_edges): own_properties = [] diff --git a/cognee/modules/graph/utils/get_model_instance_from_graph.py b/cognee/modules/graph/utils/get_model_instance_from_graph.py index 16658d74..82cdfa15 100644 --- a/cognee/modules/graph/utils/get_model_instance_from_graph.py +++ b/cognee/modules/graph/utils/get_model_instance_from_graph.py @@ -1,41 +1,29 @@ -from typing import Callable - from pydantic_core import PydanticUndefined - from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import copy_model -def get_model_instance_from_graph( - nodes: list[DataPoint], - edges: list[tuple[str, str, str, dict[str, str]]], - entity_id: str, -): - node_map = {node.id: node for node in nodes} +def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str): + node_map = {} - for source_node_id, target_node_id, edge_label, edge_properties in edges: - source_node = node_map[source_node_id] - target_node = node_map[target_node_id] + for node in nodes: + node_map[node.id] = node + + for edge in edges: + source_node = node_map[edge[0]] + target_node = node_map[edge[1]] + edge_label = edge[2] + edge_properties = edge[3] if len(edge) == 4 else {} edge_metadata = edge_properties.get("metadata", {}) - edge_type = edge_metadata.get("type", "default") + edge_type = edge_metadata.get("type") if edge_type == "list": - NewModel = copy_model( - type(source_node), - {edge_label: (list[type(target_node)], PydanticUndefined)}, - ) - source_node_dict = source_node.model_dump() - source_node_edge_label_values = source_node_dict.get(edge_label, []) - source_node_dict[edge_label] = source_node_edge_label_values + [target_node] - - node_map[source_node_id] = NewModel(**source_node_dict) + NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) }) + + node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] }) else: - NewModel = copy_model( - type(source_node), {edge_label: (type(target_node), PydanticUndefined)} - ) + NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) }) - node_map[target_node_id] = NewModel( - **source_node.model_dump(), **{edge_label: target_node} - ) + node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node }) return node_map[entity_id]