diff --git a/cognee/__init__.py b/cognee/__init__.py index 241dd1ad..bafa972b 100644 --- a/cognee/__init__.py +++ b/cognee/__init__.py @@ -5,7 +5,9 @@ from .api.v1.prune import prune from .api.v1.search import SearchType, get_search_history, search from .api.v1.visualize import visualize_graph -from .shared.utils import create_cognee_style_network_with_logo +from cognee.infrastructure.visualization.cognee_network_visualization import ( + create_cognee_style_network_with_logo, +) # Pipelines from .modules import pipelines diff --git a/cognee/api/v1/visualize/visualize.py b/cognee/api/v1/visualize/visualize.py index fd0cb4ce..f9b51d46 100644 --- a/cognee/api/v1/visualize/visualize.py +++ b/cognee/api/v1/visualize/visualize.py @@ -1,4 +1,6 @@ -from cognee.shared.utils import create_cognee_style_network_with_logo +from cognee.infrastructure.visualization.cognee_network_visualization import ( + create_cognee_style_network_with_logo, +) from cognee.infrastructure.databases.graph import get_graph_engine import logging diff --git a/cognee/infrastructure/visualization/cognee_network_visualization.py b/cognee/infrastructure/visualization/cognee_network_visualization.py new file mode 100644 index 00000000..19b24952 --- /dev/null +++ b/cognee/infrastructure/visualization/cognee_network_visualization.py @@ -0,0 +1,180 @@ +import networkx as nx +import json +import os + + +async def create_cognee_style_network_with_logo(graph_data): + nodes_data, edges_data = graph_data + + G = nx.DiGraph() + + nodes_list = [] + color_map = { + "Entity": "#f47710", + "EntityType": "#6510f4", + "DocumentChunk": "#801212", + "default": "#D3D3D3", + } + + for node_id, node_info in nodes_data: + node_info = node_info.copy() + node_info["id"] = str(node_id) + node_info["color"] = color_map.get(node_info.get("pydantic_type", "default"), "#D3D3D3") + node_info["name"] = node_info.get("name", str(node_id)) + del node_info[ + "updated_at" + ] #:TODO: We should decide what properties to show on the nodes and edges, we dont necessarily need all. + del node_info["created_at"] + nodes_list.append(node_info) + G.add_node(node_id, **node_info) + + edge_labels = {} + links_list = [] + for source, target, relation, edge_info in edges_data: + source = str(source) + target = str(target) + G.add_edge(source, target) + edge_labels[(source, target)] = relation + links_list.append({"source": source, "target": target, "relation": relation}) + + html_template = """ + + + + + + + + + + + + + + + + + """ + + html_content = html_template.replace("{nodes}", json.dumps(nodes_list)) + html_content = html_content.replace("{links}", json.dumps(links_list)) + + home_dir = os.path.expanduser("~") + output_file = os.path.join(home_dir, "graph_visualization.html") + + with open(output_file, "w") as f: + f.write(html_content) + + print(f"Graph visualization saved as {output_file}") + + return html_content diff --git a/cognee/tests/unit/processing/utils/utils_test.py b/cognee/tests/unit/processing/utils/utils_test.py index bd7e0e52..0731cf92 100644 --- a/cognee/tests/unit/processing/utils/utils_test.py +++ b/cognee/tests/unit/processing/utils/utils_test.py @@ -7,6 +7,9 @@ from uuid import uuid4 from datetime import datetime, timezone from cognee.shared.exceptions import IngestionError +from cognee.infrastructure.visualization.cognee_network_visualization import ( + create_cognee_style_network_with_logo, +) from cognee.shared.utils import ( get_anonymous_id, @@ -14,7 +17,6 @@ get_file_content_hash, prepare_edges, prepare_nodes, - create_cognee_style_network_with_logo, graph_to_tuple, )