Skip to content

Commit

Permalink
Merge branch 'dev' into feat/cog-958-run-eval-on-paramset
Browse files Browse the repository at this point in the history
  • Loading branch information
alekszievr authored Jan 16, 2025
2 parents d4ca141 + 1c4a605 commit e30675d
Show file tree
Hide file tree
Showing 5 changed files with 412 additions and 7 deletions.
9 changes: 9 additions & 0 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,15 @@ def serialize_properties(self, properties=dict()):

return serialized_properties

async def get_model_independent_graph_data(self):
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
nodes = await self.query(query_nodes)

query_edges = "MATCH (n)-[r]->(m) RETURN collect([n, r, m]) AS elements"
edges = await self.query(query_edges)

return (nodes, edges)

async def get_graph_data(self):
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"

Expand Down
12 changes: 12 additions & 0 deletions cognee/tasks/temporal_awareness/graphiti_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from cognee.infrastructure.engine import DataPoint
from typing import ClassVar, Optional


class GraphitiNode(DataPoint):
__tablename__ = "graphitinode"
content: Optional[str] = None
name: Optional[str] = None
summary: Optional[str] = None
pydantic_type: str = "GraphitiNode"

_metadata: dict = {"index_fields": ["name", "summary", "content"], "type": "GraphitiNode"}
84 changes: 84 additions & 0 deletions cognee/tasks/temporal_awareness/index_graphiti_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
from collections import Counter

from cognee.tasks.temporal_awareness.graphiti_model import GraphitiNode
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType


async def index_and_transform_graphiti_nodes_and_edges():
try:
created_indexes = {}
index_points = {}

vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e

await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
await graph_engine.query("""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
r.target_node_id = target.id,
r.relationship_name = type(r) RETURN r""")
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")

nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()

for node_data in nodes_data[0]["nodes"]:
graphiti_node = GraphitiNode(
**{key: node_data[key] for key in ("content", "name", "summary") if key in node_data},
id=node_data.get("uuid"),
)

data_point_type = type(graphiti_node)

for field_name in graphiti_node._metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []

if getattr(graphiti_node, field_name, None) is not None:
indexed_data_point = graphiti_node.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)

for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)

edge_types = Counter(
edge[1][1]
for edge in edges_data[0]["elements"]
if isinstance(edge, list) and len(edge) == 3
)

for text, count in edge_types.items():
edge = EdgeType(relationship_name=text, number_of_edges=count)
data_point_type = type(edge)

for field_name in edge._metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []

indexed_data_point = edge.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)

for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)

return None
53 changes: 46 additions & 7 deletions examples/python/graphiti_example.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import asyncio

import cognee
from cognee.api.v1.search import SearchType
import logging
from cognee.modules.pipelines import Task, run_tasks
from cognee.tasks.temporal_awareness import (
build_graph_with_temporal_awareness,
search_graph_with_temporal_awareness,
from cognee.shared.utils import setup_logging
from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.tasks.temporal_awareness.index_graphiti_objects import (
index_and_transform_graphiti_nodes_and_edges,
)
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client

text_list = [
"Kamala Harris is the Attorney General of California. She was previously "
Expand All @@ -16,18 +24,49 @@


async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await create_relational_db_and_tables()

for text in text_list:
await cognee.add(text)

tasks = [
Task(build_graph_with_temporal_awareness, text_list=text_list),
Task(
search_graph_with_temporal_awareness, query="Who was the California Attorney General?"
),
]

pipeline = run_tasks(tasks)

async for result in pipeline:
print(result)

await index_and_transform_graphiti_nodes_and_edges()

query = "When was Kamala Harris in office?"
triplets = await brute_force_triplet_search(
query=query,
top_k=3,
collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"],
)

args = {
"question": query,
"context": retrieved_edges_to_string(triplets),
}

user_prompt = render_prompt("graph_context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question_restricted.txt")

llm_client = get_llm_client()
computed_answer = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)

print(computed_answer)


if __name__ == "__main__":
setup_logging(logging.ERROR)
asyncio.run(main())
Loading

0 comments on commit e30675d

Please sign in to comment.