Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds cognee node and edge embeddings for graphiti graph #437

Merged
merged 10 commits into from
Jan 16, 2025
4 changes: 4 additions & 0 deletions cognee/infrastructure/databases/graph/graph_db_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ async def delete_graph(
):
raise NotImplementedError

@abstractmethod
async def get_model_independent_graph_data(self):
raise NotImplementedError

hajdul88 marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
async def get_graph_data(self):
raise NotImplementedError
9 changes: 9 additions & 0 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,15 @@ def serialize_properties(self, properties=dict()):

return serialized_properties

async def get_model_independent_graph_data(self):
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
nodes = await self.query(query_nodes)

query_edges = "MATCH ()-[r]->() RETURN collect(r) AS relationships"
edges = await self.query(query_edges)

return (nodes, edges)

hajdul88 marked this conversation as resolved.
Show resolved Hide resolved
async def get_graph_data(self):
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"

Expand Down
71 changes: 71 additions & 0 deletions cognee/tasks/storage/index_graph_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,77 @@
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType
from cognee.tasks.temporal_awareness.graphiti_model import GraphitiNode


async def index_graphiti_nodes_and_edges():
hajdul88 marked this conversation as resolved.
Show resolved Hide resolved
try:
created_indexes = {}
index_points = {}

vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e

nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()

for node_data in nodes_data[0]["nodes"]:
graphiti_node = GraphitiNode(
**{key: node_data[key] for key in ("content", "name", "summary") if key in node_data}
)

data_point_type = type(graphiti_node)

for field_name in graphiti_node._metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []

if getattr(graphiti_node, field_name, None) is not None:
indexed_data_point = graphiti_node.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)

hajdul88 marked this conversation as resolved.
Show resolved Hide resolved
for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)

edge_types = Counter(
edge[1]
for edge in edges_data[0]["relationships"]
if isinstance(edge, tuple) and len(edge) == 3
)

for text, count in edge_types.items():
edge = EdgeType(relationship_name=text, number_of_edges=count)
data_point_type = type(edge)

for field_name in edge._metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []

indexed_data_point = edge.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)

for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)

return None


async def index_graph_edges():
Expand Down
12 changes: 12 additions & 0 deletions cognee/tasks/temporal_awareness/graphiti_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from cognee.infrastructure.engine import DataPoint
from typing import ClassVar, Optional


class GraphitiNode(DataPoint):
__tablename__ = "graphitinode"
content: Optional[str] = None
name: Optional[str] = None
summary: Optional[str] = None
pydantic_type: str = "GraphitiNode"

_metadata: dict = {"index_fields": ["name", "summary", "content"], "type": "GraphitiNode"}
16 changes: 13 additions & 3 deletions examples/python/graphiti_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
build_graph_with_temporal_awareness,
search_graph_with_temporal_awareness,
)
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.tasks.storage.index_graph_edges import index_graphiti_nodes_and_edges

text_list = [
"Kamala Harris is the Attorney General of California. She was previously "
Expand All @@ -16,18 +20,24 @@


async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await create_relational_db_and_tables()
hajdul88 marked this conversation as resolved.
Show resolved Hide resolved

for text in text_list:
await cognee.add(text)
hajdul88 marked this conversation as resolved.
Show resolved Hide resolved

tasks = [
Task(build_graph_with_temporal_awareness, text_list=text_list),
Task(
search_graph_with_temporal_awareness, query="Who was the California Attorney General?"
),
]

pipeline = run_tasks(tasks)

async for result in pipeline:
print(result)

await index_graphiti_nodes_and_edges()
hajdul88 marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
asyncio.run(main())
Loading