Skip to content

Commit

Permalink
Merge pull request #196 from topoteretes/feat/COG-553-graph-memory-pr…
Browse files Browse the repository at this point in the history
…ojection

Feat/cog 553 graph memory projection
  • Loading branch information
hajdul88 authored Nov 14, 2024
2 parents 7a72aa4 + 867e18d commit c100709
Show file tree
Hide file tree
Showing 11 changed files with 710 additions and 7 deletions.
5 changes: 5 additions & 0 deletions cognee/infrastructure/databases/graph/graph_db_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,8 @@ async def add_edges(
async def delete_graph(
self,
): raise NotImplementedError

@abstractmethod
async def get_graph_data(
self
): raise NotImplementedError
3 changes: 0 additions & 3 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ def __init__(
max_connection_lifetime = 120
)

async def close(self) -> None:
await self.driver.close()

@asynccontextmanager
async def get_session(self) -> AsyncSession:
async with self.driver.session() as session:
Expand Down
16 changes: 12 additions & 4 deletions cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,18 @@ def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> Lance
for (data_point_index, data_point) in enumerate(data_points)
]

await collection.merge_insert("id") \
.when_matched_update_all() \
.when_not_matched_insert_all() \
.execute(lance_data_points)
# TODO: This enables us to work with pydantic version but shouldn't
# stay like this, existing rows should be updated

await collection.delete("id IS NOT NULL")

original_size = await collection.count_rows()
await collection.add(lance_data_points)
new_size = await collection.count_rows()

if new_size <= original_size:
raise ValueError(
"LanceDB create_datapoints error: data points did not get added.")


async def retrieve(self, collection_name: str, data_point_ids: list[str]):
Expand Down
1 change: 1 addition & 0 deletions cognee/modules/chunking/TextChunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, document, get_text: callable, chunk_size: int = 1024):
self.get_text = get_text

def read(self):
self.paragraph_chunks = []
for content_text in self.get_text():
for chunk_data in chunk_by_paragraph(
content_text,
Expand Down
35 changes: 35 additions & 0 deletions cognee/modules/graph/cognee_graph/CogneeAbstractGraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Union
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface

class CogneeAbstractGraph(ABC):
"""
Abstract base class for representing a graph structure.
"""

@abstractmethod
def add_node(self, node: Node) -> None:
"""Add a node to the graph."""
pass

@abstractmethod
def add_edge(self, edge: Edge) -> None:
"""Add an edge to the graph."""
pass

@abstractmethod
def get_node(self, node_id: str) -> Node:
"""Retrieve a node by its ID."""
pass

@abstractmethod
def get_edges(self, node_id: str) -> List[Edge]:
"""Retrieve edges connected to a specific node."""
pass

@abstractmethod
async def project_graph_from_db(self, adapter: GraphDBInterface, directed: bool, dimension: int) -> None:
"""Project the graph structure from a database using the provided adapter."""
pass
91 changes: 91 additions & 0 deletions cognee/modules/graph/cognee_graph/CogneeGraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import List, Dict, Union

from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
from cognee.infrastructure.databases.graph import get_graph_engine

class CogneeGraph(CogneeAbstractGraph):
"""
Concrete implementation of the AbstractGraph class for Cognee.
This class provides the functionality to manage nodes and edges,
and project a graph from a database using adapters.
"""

nodes: Dict[str, Node]
edges: List[Edge]
directed: bool

def __init__(self, directed: bool = True):
self.nodes = {}
self.edges = []
self.directed = directed

def add_node(self, node: Node) -> None:
if node.id not in self.nodes:
self.nodes[node.id] = node
else:
raise ValueError(f"Node with id {node.id} already exists.")

def add_edge(self, edge: Edge) -> None:
if edge not in self.edges:
self.edges.append(edge)
edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge)
else:
raise ValueError(f"Edge {edge} already exists in the graph.")

def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None)

def get_edges(self, node_id: str) -> List[Edge]:
node = self.get_node(node_id)
if node:
return node.skeleton_edges
else:
raise ValueError(f"Node with id {node_id} does not exist.")

async def project_graph_from_db(self,
adapter: Union[GraphDBInterface],
node_properties_to_project: List[str],
edge_properties_to_project: List[str],
directed = True,
node_dimension = 1,
edge_dimension = 1) -> None:

if node_dimension < 1 or edge_dimension < 1:
raise ValueError("Dimensions must be positive integers")

try:
nodes_data, edges_data = await adapter.get_graph_data()

if not nodes_data:
raise ValueError("No node data retrieved from the database.")
if not edges_data:
raise ValueError("No edge data retrieved from the database.")

for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))

for source_id, target_id, relationship_type, properties in edges_data:
source_node = self.get_node(str(source_id))
target_node = self.get_node(str(target_id))
if source_node and target_node:
edge_attributes = {key: properties.get(key) for key in edge_properties_to_project}
edge_attributes['relationship_type'] = relationship_type

edge = Edge(source_node, target_node, attributes=edge_attributes, directed=directed, dimension=edge_dimension)
self.add_edge(edge)

source_node.add_skeleton_edge(edge)
target_node.add_skeleton_edge(edge)

else:
raise ValueError(f"Edge references nonexistent nodes: {source_id} -> {target_id}")

except (ValueError, TypeError) as e:
print(f"Error projecting graph: {e}")
except Exception as ex:
print(f"Unexpected error: {ex}")
114 changes: 114 additions & 0 deletions cognee/modules/graph/cognee_graph/CogneeGraphElements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
from typing import List, Dict, Optional, Any

class Node:
"""
Represents a node in a graph.
Attributes:
id (str): A unique identifier for the node.
attributes (Dict[str, Any]): A dictionary of attributes associated with the node.
neighbors (List[Node]): Represents the original nodes
skeleton_edges (List[Edge]): Represents the original edges
"""
id: str
attributes: Dict[str, Any]
skeleton_neighbours: List["Node"]
skeleton_edges: List["Edge"]
status: np.ndarray

def __init__(self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1):
if dimension <= 0:
raise ValueError("Dimension must be a positive integer")
self.id = node_id
self.attributes = attributes if attributes is not None else {}
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)

def add_skeleton_neighbor(self, neighbor: "Node") -> None:
if neighbor not in self.skeleton_neighbours:
self.skeleton_neighbours.append(neighbor)

def remove_skeleton_neighbor(self, neighbor: "Node") -> None:
if neighbor in self.skeleton_neighbours:
self.skeleton_neighbours.remove(neighbor)

def add_skeleton_edge(self, edge: "Edge") -> None:
if edge not in self.skeleton_edges:
self.skeleton_edges.append(edge)
# Add neighbor
if edge.node1 == self:
self.add_skeleton_neighbor(edge.node2)
elif edge.node2 == self:
self.add_skeleton_neighbor(edge.node1)

def remove_skeleton_edge(self, edge: "Edge") -> None:
if edge in self.skeleton_edges:
self.skeleton_edges.remove(edge)
# Remove neighbor if no other edge connects them
neighbor = edge.node2 if edge.node1 == self else edge.node1
if all(e.node1 != neighbor and e.node2 != neighbor for e in self.skeleton_edges):
self.remove_skeleton_neighbor(neighbor)

def is_node_alive_in_dimension(self, dimension: int) -> bool:
if dimension < 0 or dimension >= len(self.status):
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1

def __repr__(self) -> str:
return f"Node({self.id}, attributes={self.attributes})"

def __hash__(self) -> int:
return hash(self.id)

def __eq__(self, other: "Node") -> bool:
return isinstance(other, Node) and self.id == other.id


class Edge:
"""
Represents an edge in a graph, connecting two nodes.
Attributes:
node1 (Node): The starting node of the edge.
node2 (Node): The ending node of the edge.
attributes (Dict[str, Any]): A dictionary of attributes associated with the edge.
directed (bool): A flag indicating whether the edge is directed or undirected.
"""

node1: "Node"
node2: "Node"
attributes: Dict[str, Any]
directed: bool
status: np.ndarray

def __init__(self, node1: "Node", node2: "Node", attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1):
if dimension <= 0:
raise ValueError("Dimensions must be a positive integer.")
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
self.directed = directed
self.status = np.ones(dimension, dtype=int)

def is_edge_alive_in_dimension(self, dimension: int) -> bool:
if dimension < 0 or dimension >= len(self.status):
raise ValueError(f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1

def __repr__(self) -> str:
direction = "->" if self.directed else "--"
return f"Edge({self.node1.id} {direction} {self.node2.id}, attributes={self.attributes})"

def __hash__(self) -> int:
if self.directed:
return hash((self.node1, self.node2))
else:
return hash(frozenset({self.node1, self.node2}))

def __eq__(self, other: "Edge") -> bool:
if not isinstance(other, Edge):
return False
if self.directed:
return self.node1 == other.node1 and self.node2 == other.node2
else:
return {self.node1, self.node2} == {other.node1, other.node2}
Empty file.
Loading

0 comments on commit c100709

Please sign in to comment.