-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from funkelab/save_labels
Relabel segmentation based on tracking results
- Loading branch information
Showing
10 changed files
with
223 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .graph_attributes import EdgeAttr, NodeAttr | ||
from .graph_from_segmentation import graph_from_segmentation | ||
from .graph_to_nx import graph_to_nx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from enum import Enum | ||
|
||
|
||
class NodeAttr(Enum): | ||
"""Node attributes that can be added to candidate graph using the toolbox. | ||
Note: Motile can flexibly support any custom attributes. The toolbox provides | ||
implementations of commonly used ones, listed here. | ||
""" | ||
|
||
SEG_ID = "segmentation_id" | ||
|
||
|
||
class EdgeAttr(Enum): | ||
"""Edge attributes that can be added to candidate graph using the toolbox. | ||
Note: Motile can flexibly support any custom attributes. The toolbox provides | ||
implementations of commonly used ones, listed here. | ||
""" | ||
|
||
DISTANCE = "distance" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .saving_utils import relabel_segmentation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import networkx as nx | ||
import numpy as np | ||
|
||
from motile_toolbox.candidate_graph import NodeAttr | ||
|
||
|
||
def relabel_segmentation( | ||
solution_nx_graph: nx.DiGraph, | ||
segmentation: np.ndarray, | ||
frame_key="t", | ||
) -> np.ndarray: | ||
"""Relabel a segmentation based on tracking results so that nodes in same | ||
track share the same id. IDs do change at division. | ||
Args: | ||
solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use | ||
for relabeling. Nodes not in graph will be removed from seg. Original | ||
segmentation ids have to be stored in the graph so we can map them back. | ||
segmentation (np.ndarray): Original segmentation with labels ids that correspond | ||
to segmentation id in graph. | ||
frame_key (str, optional): Time frame key in networkx graph. Defaults to "t". | ||
Returns: | ||
np.ndarray: Relabeled segmentation array where nodes in same track share same | ||
id. | ||
""" | ||
tracked_masks = np.zeros_like(segmentation) | ||
id_counter = 1 | ||
parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] | ||
soln_copy = solution_nx_graph.copy() | ||
for parent_node in parent_nodes: | ||
out_edges = solution_nx_graph.out_edges(parent_node) | ||
soln_copy.remove_edges_from(out_edges) | ||
for node_set in nx.weakly_connected_components(soln_copy): | ||
for node in node_set: | ||
time_frame = solution_nx_graph.nodes[node][frame_key] | ||
previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value] | ||
previous_seg_mask = segmentation[time_frame] == previous_seg_id | ||
tracked_masks[time_frame][previous_seg_mask] = id_counter | ||
id_counter += 1 | ||
return tracked_masks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .napari_utils import to_napari_tracks_layer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import networkx as nx | ||
import numpy as np | ||
|
||
|
||
def assign_tracklet_ids(graph: nx.DiGraph) -> nx.DiGraph: | ||
"""Add a tracklet_id attribute to a graph by removing division edges, | ||
assigning one id to each connected component. | ||
Designed as a helper for visualizing the graph in the napari Tracks layer. | ||
Args: | ||
graph (nx.DiGraph): A networkx graph with a tracking solution | ||
Returns: | ||
nx.DiGraph: The same graph with the tracklet_id assigned. Probably | ||
occurrs in place but returned just to be clear. | ||
""" | ||
graph_copy = graph.copy() | ||
|
||
parents = [node for node, degree in graph.out_degree() if degree >= 2] | ||
intertrack_edges = [] | ||
|
||
# Remove all intertrack edges from a copy of the original graph | ||
for parent in parents: | ||
daughters = [child for p, child in graph.out_edges(parent)] | ||
for daughter in daughters: | ||
graph_copy.remove_edge(parent, daughter) | ||
intertrack_edges.append((parent, daughter)) | ||
|
||
track_id = 0 | ||
for tracklet in nx.weakly_connected_components(graph_copy): | ||
nx.set_node_attributes( | ||
graph, {node: {"tracklet_id": track_id} for node in tracklet} | ||
) | ||
track_id += 1 | ||
return graph, intertrack_edges | ||
|
||
|
||
def to_napari_tracks_layer( | ||
graph, frame_key="t", location_keys=("y", "x"), properties=() | ||
): | ||
"""Function to take a networkx graph and return the data needed to add to | ||
a napari tracks layer. | ||
Args: | ||
graph (nx.DiGraph): _description_ | ||
frame_key (str, optional): Key in graph attributes containing time frame. | ||
Defaults to "t". | ||
location_keys (tuple, optional): Keys in graph node attributes containing | ||
location. Should be in order: (Z), Y, X. Defaults to ("y", "x"). | ||
properties (tuple, optional): Keys in graph node attributes to add | ||
to the visualization layer. Defaults to (). NOTE: not working now :( | ||
Returns: | ||
data : array (N, D+1) | ||
Coordinates for N points in D+1 dimensions. ID,T,(Z),Y,X. The first | ||
axis is the integer ID of the track. D is either 3 or 4 for planar | ||
or volumetric timeseries respectively. | ||
properties : dict {str: array (N,)} | ||
Properties for each point. Each property should be an array of length N, | ||
where N is the number of points. | ||
graph : dict {int: list} | ||
Graph representing associations between tracks. Dictionary defines the | ||
mapping between a track ID and the parents of the track. This can be | ||
one (the track has one parent, and the parent has >=1 child) in the | ||
case of track splitting, or more than one (the track has multiple | ||
parents, but only one child) in the case of track merging. | ||
""" | ||
napari_data = np.zeros((graph.number_of_nodes(), len(location_keys) + 2)) | ||
napari_properties = {prop: np.zeros(graph.number_of_nodes()) for prop in properties} | ||
napari_edges = {} | ||
graph, intertrack_edges = assign_tracklet_ids(graph) | ||
for index, node in enumerate(graph.nodes(data=True)): | ||
node_id, data = node | ||
location = [data[loc_key] for loc_key in location_keys] | ||
napari_data[index] = [data["tracklet_id"], data[frame_key], *location] | ||
for prop in properties: | ||
if prop in data: | ||
napari_properties[prop][index] = data[prop] | ||
napari_edges = {} | ||
for parent, child in intertrack_edges: | ||
parent_track_id = graph.nodes[parent]["tracklet_id"] | ||
child_track_id = graph.nodes[child]["tracklet_id"] | ||
if child_track_id in napari_edges: | ||
napari_edges[child_track_id].append(parent_track_id) | ||
else: | ||
napari_edges[child_track_id] = [parent_track_id] | ||
return napari_data, napari_properties, napari_edges |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import networkx as nx | ||
import numpy as np | ||
import pytest | ||
from motile_toolbox.utils import relabel_segmentation | ||
from numpy.testing import assert_array_equal | ||
from skimage.draw import disk | ||
|
||
|
||
@pytest.fixture | ||
def segmentation_2d(): | ||
frame_shape = (100, 100) | ||
total_shape = (2, *frame_shape) | ||
segmentation = np.zeros(total_shape, dtype="int32") | ||
# make frame with one cell in center with label 1 | ||
rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) | ||
segmentation[0][rr, cc] = 1 | ||
|
||
# make frame with two cells | ||
# first cell centered at (20, 80) with label 2 | ||
# second cell centered at (60, 45) with label 3 | ||
rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) | ||
segmentation[1][rr, cc] = 2 | ||
rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) | ||
segmentation[1][rr, cc] = 3 | ||
|
||
return segmentation | ||
|
||
|
||
@pytest.fixture | ||
def graph_2d(): | ||
graph = nx.DiGraph() | ||
nodes = [ | ||
("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}), | ||
("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 2}), | ||
] | ||
edges = [ | ||
("0_1", "1_1", {"distance": 42.43}), | ||
] | ||
graph.add_nodes_from(nodes) | ||
graph.add_edges_from(edges) | ||
return graph | ||
|
||
|
||
def test_relabel_segmentation(segmentation_2d, graph_2d): | ||
frame_shape = (100, 100) | ||
expected = np.zeros(segmentation_2d.shape, dtype="int32") | ||
# make frame with one cell in center with label 1 | ||
rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) | ||
expected[0][rr, cc] = 1 | ||
|
||
# make frame with cell centered at (20, 80) with label 1 | ||
rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) | ||
expected[1][rr, cc] = 1 | ||
|
||
relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d) | ||
print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}") | ||
print(f"Nonzero expected: {np.count_nonzero(expected)}") | ||
print(f"Max relabeled: {np.max(relabeled_seg)}") | ||
print(f"Max expected: {np.max(expected)}") | ||
|
||
assert_array_equal(relabeled_seg, expected) |