-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add function to turn TrackGraph to networkx graph
- Loading branch information
1 parent
af405b4
commit 3f97ee0
Showing
3 changed files
with
52 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .graph_from_segmentation import graph_from_segmentation | ||
from .graph_to_nx import graph_to_nx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import networkx as nx | ||
from motile import TrackGraph | ||
|
||
|
||
def graph_to_nx(graph: TrackGraph) -> nx.DiGraph: | ||
"""Convert a motile TrackGraph into a networkx DiGraph. | ||
Args: | ||
graph (TrackGraph): TrackGraph to be converted to networkx | ||
Returns: | ||
nx.DiGraph: Directed networkx graph with same nodes, edges, and attributes. | ||
""" | ||
nx_graph = nx.DiGraph() | ||
nodes_list = list(graph.nodes.items()) | ||
nx_graph.add_nodes_from(nodes_list) | ||
edges_list = [ | ||
(edge_id[0], edge_id[1], data) for edge_id, data in graph.edges.items() | ||
] | ||
nx_graph.add_edges_from(edges_list) | ||
return nx_graph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import networkx as nx | ||
import pytest | ||
from motile import TrackGraph | ||
from motile_toolbox.candidate_graph import graph_to_nx | ||
from networkx.utils import graphs_equal | ||
|
||
|
||
@pytest.fixture | ||
def graph_3d(): | ||
graph = nx.DiGraph() | ||
nodes = [ | ||
("0_1", {"z": 50, "y": 50, "x": 50, "t": 0, "segmentation_id": 1}), | ||
("1_1", {"z": 20, "y": 50, "x": 80, "t": 1, "segmentation_id": 1}), | ||
("1_2", {"z": 60, "y": 50, "x": 45, "t": 1, "segmentation_id": 2}), | ||
] | ||
edges = [ | ||
# math.dist([50, 50], [20, 80]) | ||
("0_1", "1_1", {"distance": 42.43}), | ||
# math.dist([50, 50], [60, 45]) | ||
("0_1", "1_2", {"distance": 11.18}), | ||
] | ||
graph.add_nodes_from(nodes) | ||
graph.add_edges_from(edges) | ||
return graph | ||
|
||
|
||
def test_graph_to_nx(graph_3d: nx.DiGraph): | ||
track_graph = TrackGraph(nx_graph=graph_3d, frame_attribute="t") | ||
nx_graph = graph_to_nx(track_graph) | ||
assert graphs_equal(graph_3d, nx_graph) |