Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds cognee node and edge embeddings for graphiti graph #437

Merged
merged 10 commits into from
Jan 16, 2025
9 changes: 9 additions & 0 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,15 @@ def serialize_properties(self, properties=dict()):

return serialized_properties

async def get_model_independent_graph_data(self):
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
nodes = await self.query(query_nodes)

query_edges = "MATCH (n)-[r]->(m) RETURN collect([n, r, m]) AS elements"
edges = await self.query(query_edges)

return (nodes, edges)

async def get_graph_data(self):
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"

Expand Down
12 changes: 12 additions & 0 deletions cognee/tasks/temporal_awareness/graphiti_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from cognee.infrastructure.engine import DataPoint
from typing import ClassVar, Optional


class GraphitiNode(DataPoint):
__tablename__ = "graphitinode"
content: Optional[str] = None
name: Optional[str] = None
summary: Optional[str] = None
pydantic_type: str = "GraphitiNode"

_metadata: dict = {"index_fields": ["name", "summary", "content"], "type": "GraphitiNode"}
84 changes: 84 additions & 0 deletions cognee/tasks/temporal_awareness/index_graphiti_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
from collections import Counter

from cognee.tasks.temporal_awareness.graphiti_model import GraphitiNode
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType


async def index_and_transform_graphiti_nodes_and_edges():
try:
created_indexes = {}
index_points = {}

vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e

await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
await graph_engine.query("""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
r.target_node_id = target.id,
r.relationship_name = type(r) RETURN r""")
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")

nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()

for node_data in nodes_data[0]["nodes"]:
graphiti_node = GraphitiNode(
**{key: node_data[key] for key in ("content", "name", "summary") if key in node_data},
id=node_data.get("uuid"),
)
hajdul88 marked this conversation as resolved.
Show resolved Hide resolved

data_point_type = type(graphiti_node)

for field_name in graphiti_node._metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []

if getattr(graphiti_node, field_name, None) is not None:
indexed_data_point = graphiti_node.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)

for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)

edge_types = Counter(
edge[1][1]
for edge in edges_data[0]["elements"]
if isinstance(edge, list) and len(edge) == 3
)

for text, count in edge_types.items():
edge = EdgeType(relationship_name=text, number_of_edges=count)
data_point_type = type(edge)

for field_name in edge._metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []

indexed_data_point = edge.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)

for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)

hajdul88 marked this conversation as resolved.
Show resolved Hide resolved
return None
53 changes: 46 additions & 7 deletions examples/python/graphiti_example.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import asyncio

import cognee
from cognee.api.v1.search import SearchType
import logging
from cognee.modules.pipelines import Task, run_tasks
from cognee.tasks.temporal_awareness import (
build_graph_with_temporal_awareness,
search_graph_with_temporal_awareness,
from cognee.shared.utils import setup_logging
from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.tasks.temporal_awareness.index_graphiti_objects import (
index_and_transform_graphiti_nodes_and_edges,
)
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client

text_list = [
"Kamala Harris is the Attorney General of California. She was previously "
Expand All @@ -16,18 +24,49 @@


async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await create_relational_db_and_tables()
hajdul88 marked this conversation as resolved.
Show resolved Hide resolved

for text in text_list:
await cognee.add(text)
hajdul88 marked this conversation as resolved.
Show resolved Hide resolved

tasks = [
Task(build_graph_with_temporal_awareness, text_list=text_list),
Task(
search_graph_with_temporal_awareness, query="Who was the California Attorney General?"
),
]

pipeline = run_tasks(tasks)

async for result in pipeline:
print(result)

await index_and_transform_graphiti_nodes_and_edges()

query = "When was Kamala Harris in office?"
triplets = await brute_force_triplet_search(
query=query,
top_k=3,
collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"],
)

args = {
"question": query,
"context": retrieved_edges_to_string(triplets),
}

user_prompt = render_prompt("graph_context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question_restricted.txt")

llm_client = get_llm_client()
computed_answer = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)

hajdul88 marked this conversation as resolved.
Show resolved Hide resolved
print(computed_answer)


if __name__ == "__main__":
setup_logging(logging.INFO)
asyncio.run(main())
Loading