diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index f3eec97..efd4cbb 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -1 +1,2 @@ from .graph_from_segmentation import graph_from_segmentation +from .graph_to_nx import graph_to_nx diff --git a/src/motile_toolbox/candidate_graph/graph_to_nx.py b/src/motile_toolbox/candidate_graph/graph_to_nx.py new file mode 100644 index 0000000..ac722a3 --- /dev/null +++ b/src/motile_toolbox/candidate_graph/graph_to_nx.py @@ -0,0 +1,21 @@ +import networkx as nx +from motile import TrackGraph + + +def graph_to_nx(graph: TrackGraph) -> nx.DiGraph: + """Convert a motile TrackGraph into a networkx DiGraph. + + Args: + graph (TrackGraph): TrackGraph to be converted to networkx + + Returns: + nx.DiGraph: Directed networkx graph with same nodes, edges, and attributes. + """ + nx_graph = nx.DiGraph() + nodes_list = list(graph.nodes.items()) + nx_graph.add_nodes_from(nodes_list) + edges_list = [ + (edge_id[0], edge_id[1], data) for edge_id, data in graph.edges.items() + ] + nx_graph.add_edges_from(edges_list) + return nx_graph diff --git a/tests/test_candidate_graph/test_graph_to_nx.py b/tests/test_candidate_graph/test_graph_to_nx.py new file mode 100644 index 0000000..b52d864 --- /dev/null +++ b/tests/test_candidate_graph/test_graph_to_nx.py @@ -0,0 +1,30 @@ +import networkx as nx +import pytest +from motile import TrackGraph +from motile_toolbox.candidate_graph import graph_to_nx +from networkx.utils import graphs_equal + + +@pytest.fixture +def graph_3d(): + graph = nx.DiGraph() + nodes = [ + ("0_1", {"z": 50, "y": 50, "x": 50, "t": 0, "segmentation_id": 1}), + ("1_1", {"z": 20, "y": 50, "x": 80, "t": 1, "segmentation_id": 1}), + ("1_2", {"z": 60, "y": 50, "x": 45, "t": 1, "segmentation_id": 2}), + ] + edges = [ + # math.dist([50, 50], [20, 80]) + ("0_1", "1_1", {"distance": 42.43}), + # math.dist([50, 50], [60, 45]) + ("0_1", "1_2", {"distance": 11.18}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def test_graph_to_nx(graph_3d: nx.DiGraph): + track_graph = TrackGraph(nx_graph=graph_3d, frame_attribute="t") + nx_graph = graph_to_nx(track_graph) + assert graphs_equal(graph_3d, nx_graph)