Skip to content

Commit

Permalink
Merge pull request #2 from funkelab/graph_to_nx
Browse files Browse the repository at this point in the history
Add function to turn TrackGraph to networkx graph
  • Loading branch information
cmalinmayor authored Mar 11, 2024
2 parents af405b4 + 3f97ee0 commit b1d2909
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/motile_toolbox/candidate_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .graph_from_segmentation import graph_from_segmentation
from .graph_to_nx import graph_to_nx
21 changes: 21 additions & 0 deletions src/motile_toolbox/candidate_graph/graph_to_nx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import networkx as nx
from motile import TrackGraph


def graph_to_nx(graph: TrackGraph) -> nx.DiGraph:
"""Convert a motile TrackGraph into a networkx DiGraph.
Args:
graph (TrackGraph): TrackGraph to be converted to networkx
Returns:
nx.DiGraph: Directed networkx graph with same nodes, edges, and attributes.
"""
nx_graph = nx.DiGraph()
nodes_list = list(graph.nodes.items())
nx_graph.add_nodes_from(nodes_list)
edges_list = [
(edge_id[0], edge_id[1], data) for edge_id, data in graph.edges.items()
]
nx_graph.add_edges_from(edges_list)
return nx_graph
30 changes: 30 additions & 0 deletions tests/test_candidate_graph/test_graph_to_nx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import networkx as nx
import pytest
from motile import TrackGraph
from motile_toolbox.candidate_graph import graph_to_nx
from networkx.utils import graphs_equal


@pytest.fixture
def graph_3d():
graph = nx.DiGraph()
nodes = [
("0_1", {"z": 50, "y": 50, "x": 50, "t": 0, "segmentation_id": 1}),
("1_1", {"z": 20, "y": 50, "x": 80, "t": 1, "segmentation_id": 1}),
("1_2", {"z": 60, "y": 50, "x": 45, "t": 1, "segmentation_id": 2}),
]
edges = [
# math.dist([50, 50], [20, 80])
("0_1", "1_1", {"distance": 42.43}),
# math.dist([50, 50], [60, 45])
("0_1", "1_2", {"distance": 11.18}),
]
graph.add_nodes_from(nodes)
graph.add_edges_from(edges)
return graph


def test_graph_to_nx(graph_3d: nx.DiGraph):
track_graph = TrackGraph(nx_graph=graph_3d, frame_attribute="t")
nx_graph = graph_to_nx(track_graph)
assert graphs_equal(graph_3d, nx_graph)

0 comments on commit b1d2909

Please sign in to comment.