Skip to content

Commit

Permalink
Merge pull request #242 from topoteretes/main-cognify-fix
Browse files Browse the repository at this point in the history
fix: fixes cognify duplicated edges and resets the methods to an olde…
  • Loading branch information
Vasilije1990 authored Dec 2, 2024
2 parents 1c47870 + 6841c83 commit 42ab601
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 110 deletions.
149 changes: 67 additions & 82 deletions cognee/modules/graph/utils/get_graph_from_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
from datetime import datetime, timezone

from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model


def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None):

if not added_nodes:
added_nodes = {}
if not added_edges:
added_edges = {}

def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
nodes = []
edges = []

Expand All @@ -20,94 +12,87 @@ def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=No
for field_name, field_value in data_point:
if field_name == "_metadata":
continue
elif isinstance(field_value, DataPoint):

if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
)

elif (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
):

property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue

if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)

for item in field_value:
n_edges_before = len(edges)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point, field_name, item, nodes, edges, added_nodes, added_edges
)
edges = edges[:n_edges_before] + [
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
for edge in edges[n_edges_before:]
]
else:
data_point_properties[field_name] = field_value
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue

data_point_properties[field_name] = field_value

SimpleDataPointModel = copy_model(
type(data_point),
include_fields={
include_fields = {
"_metadata": (dict, data_point._metadata),
},
exclude_fields=excluded_properties,
exclude_fields = excluded_properties,
)

nodes.append(SimpleDataPointModel(**data_point_properties))
if include_root:
nodes.append(SimpleDataPointModel(**data_point_properties))

return nodes, edges


def add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
):

property_nodes, property_edges = get_graph_from_model(
field_value, dict(added_nodes), dict(added_edges)
)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
property_node.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
},
)
)
added_edges[str(edge_key)] = True

return (nodes, edges, added_nodes, added_edges)


def get_own_properties(property_nodes, property_edges):
own_properties = []

Expand Down
44 changes: 16 additions & 28 deletions cognee/modules/graph/utils/get_model_instance_from_graph.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,29 @@
from typing import Callable

from pydantic_core import PydanticUndefined

from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model


def get_model_instance_from_graph(
nodes: list[DataPoint],
edges: list[tuple[str, str, str, dict[str, str]]],
entity_id: str,
):
node_map = {node.id: node for node in nodes}
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}

for source_node_id, target_node_id, edge_label, edge_properties in edges:
source_node = node_map[source_node_id]
target_node = node_map[target_node_id]
for node in nodes:
node_map[node.id] = node

for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type", "default")
edge_type = edge_metadata.get("type")

if edge_type == "list":
NewModel = copy_model(
type(source_node),
{edge_label: (list[type(target_node)], PydanticUndefined)},
)
source_node_dict = source_node.model_dump()
source_node_edge_label_values = source_node_dict.get(edge_label, [])
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]

node_map[source_node_id] = NewModel(**source_node_dict)
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })

node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
else:
NewModel = copy_model(
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
)
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })

node_map[target_node_id] = NewModel(
**source_node.model_dump(), **{edge_label: target_node}
)
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })

return node_map[entity_id]

0 comments on commit 42ab601

Please sign in to comment.