Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relabel segmentation based on tracking results #3

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/motile_toolbox/candidate_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .graph_attributes import EdgeAttr, NodeAttr
from .graph_from_segmentation import graph_from_segmentation
from .graph_to_nx import graph_to_nx
19 changes: 19 additions & 0 deletions src/motile_toolbox/candidate_graph/graph_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from enum import Enum


class NodeAttr(Enum):
"""Node attributes that can be added to candidate graph using the toolbox.
Note: Motile can flexibly support any custom attributes. The toolbox provides
implementations of commonly used ones, listed here.
"""

SEG_ID = "segmentation_id"


class EdgeAttr(Enum):
"""Edge attributes that can be added to candidate graph using the toolbox.
Note: Motile can flexibly support any custom attributes. The toolbox provides
implementations of commonly used ones, listed here.
"""

DISTANCE = "distance"
34 changes: 10 additions & 24 deletions src/motile_toolbox/candidate_graph/graph_from_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from skimage.measure import regionprops
from tqdm import tqdm

from .graph_attributes import EdgeAttr, NodeAttr

logger = logging.getLogger(__name__)


Expand All @@ -33,7 +35,7 @@ def _get_location(

def nodes_from_segmentation(
segmentation: np.ndarray,
attributes: tuple[str, ...] | list[str] = ("segmentation_id",),
attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
Expand Down Expand Up @@ -73,8 +75,8 @@ def nodes_from_segmentation(
attrs = {
frame_key: t,
}
if "segmentation_id" in attributes:
attrs["segmentation_id"] = regionprop.label
if NodeAttr.SEG_ID in attributes:
attrs[NodeAttr.SEG_ID.value] = regionprop.label
centroid = regionprop.centroid # [z,] y, x
for label, value in zip(position_keys, centroid):
attrs[label] = value
Expand All @@ -88,7 +90,7 @@ def nodes_from_segmentation(
def add_cand_edges(
cand_graph: nx.DiGraph,
max_edge_distance: float,
attributes: tuple[str, ...] | list[str] = ("distance",),
attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
node_frame_dict: None | dict[int, list[Any]] = None,
Expand Down Expand Up @@ -133,17 +135,17 @@ def add_cand_edges(
for next_id, next_loc in zip(next_nodes, next_locs):
dist = math.dist(next_loc, loc)
attrs = {}
if "distance" in attributes:
attrs["distance"] = dist
if EdgeAttr.DISTANCE in attributes:
attrs[EdgeAttr.DISTANCE.value] = dist
if dist <= max_edge_distance:
cand_graph.add_edge(node, next_id, **attrs)


def graph_from_segmentation(
segmentation: np.ndarray,
max_edge_distance: float,
node_attributes: tuple[str, ...] | list[str] = ("segmentation_id",),
edge_attributes: tuple[str, ...] | list[str] = ("distance",),
node_attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,),
edge_attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
):
Expand Down Expand Up @@ -181,22 +183,6 @@ def graph_from_segmentation(
arguments, or if the number of position keys provided does not match the
number of position dimensions.
"""
valid_edge_attributes = [
"distance",
]
for attr in edge_attributes:
if attr not in valid_edge_attributes:
raise ValueError(
f"Invalid attribute {attr} (supported attrs: {valid_edge_attributes})"
)
valid_node_attributes = [
"segmentation_id",
]
for attr in node_attributes:
if attr not in valid_node_attributes:
raise ValueError(
f"Invalid attribute {attr} (supported attrs: {valid_node_attributes})"
)
if len(position_keys) != segmentation.ndim - 1:
raise ValueError(
f"Position labels {position_keys} does not match number of spatial dims "
Expand Down
1 change: 1 addition & 0 deletions src/motile_toolbox/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .saving_utils import relabel_segmentation
41 changes: 41 additions & 0 deletions src/motile_toolbox/utils/saving_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import networkx as nx
import numpy as np

from motile_toolbox.candidate_graph import NodeAttr


def relabel_segmentation(
solution_nx_graph: nx.DiGraph,
segmentation: np.ndarray,
frame_key="t",
) -> np.ndarray:
"""Relabel a segmentation based on tracking results so that nodes in same
track share the same id. IDs do change at division.

Args:
solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use
for relabeling. Nodes not in graph will be removed from seg. Original
segmentation ids have to be stored in the graph so we can map them back.
segmentation (np.ndarray): Original segmentation with labels ids that correspond
to segmentation id in graph.
frame_key (str, optional): Time frame key in networkx graph. Defaults to "t".

Returns:
np.ndarray: Relabeled segmentation array where nodes in same track share same
id.
"""
tracked_masks = np.zeros_like(segmentation)
id_counter = 1
parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1]
soln_copy = solution_nx_graph.copy()
for parent_node in parent_nodes:
out_edges = solution_nx_graph.out_edges(parent_node)
soln_copy.remove_edges_from(out_edges)
for node_set in nx.weakly_connected_components(soln_copy):
for node in node_set:
time_frame = solution_nx_graph.nodes[node][frame_key]
previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value]
previous_seg_mask = segmentation[time_frame] == previous_seg_id
tracked_masks[time_frame][previous_seg_mask] = id_counter
id_counter += 1
return tracked_masks
1 change: 1 addition & 0 deletions src/motile_toolbox/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .napari_utils import to_napari_tracks_layer
87 changes: 87 additions & 0 deletions src/motile_toolbox/visualization/napari_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import networkx as nx
import numpy as np


def assign_tracklet_ids(graph: nx.DiGraph) -> nx.DiGraph:
"""Add a tracklet_id attribute to a graph by removing division edges,
assigning one id to each connected component.
Designed as a helper for visualizing the graph in the napari Tracks layer.

Args:
graph (nx.DiGraph): A networkx graph with a tracking solution

Returns:
nx.DiGraph: The same graph with the tracklet_id assigned. Probably
occurrs in place but returned just to be clear.
"""
graph_copy = graph.copy()

parents = [node for node, degree in graph.out_degree() if degree >= 2]
intertrack_edges = []

# Remove all intertrack edges from a copy of the original graph
for parent in parents:
daughters = [child for p, child in graph.out_edges(parent)]
for daughter in daughters:
graph_copy.remove_edge(parent, daughter)
intertrack_edges.append((parent, daughter))

track_id = 0
for tracklet in nx.weakly_connected_components(graph_copy):
nx.set_node_attributes(
graph, {node: {"tracklet_id": track_id} for node in tracklet}
)
track_id += 1
return graph, intertrack_edges


def to_napari_tracks_layer(
graph, frame_key="t", location_keys=("y", "x"), properties=()
):
"""Function to take a networkx graph and return the data needed to add to
a napari tracks layer.

Args:
graph (nx.DiGraph): _description_
frame_key (str, optional): Key in graph attributes containing time frame.
Defaults to "t".
location_keys (tuple, optional): Keys in graph node attributes containing
location. Should be in order: (Z), Y, X. Defaults to ("y", "x").
properties (tuple, optional): Keys in graph node attributes to add
to the visualization layer. Defaults to (). NOTE: not working now :(

Returns:
data : array (N, D+1)
Coordinates for N points in D+1 dimensions. ID,T,(Z),Y,X. The first
axis is the integer ID of the track. D is either 3 or 4 for planar
or volumetric timeseries respectively.
properties : dict {str: array (N,)}
Properties for each point. Each property should be an array of length N,
where N is the number of points.
graph : dict {int: list}
Graph representing associations between tracks. Dictionary defines the
mapping between a track ID and the parents of the track. This can be
one (the track has one parent, and the parent has >=1 child) in the
case of track splitting, or more than one (the track has multiple
parents, but only one child) in the case of track merging.
"""
napari_data = np.zeros((graph.number_of_nodes(), len(location_keys) + 2))
napari_properties = {prop: np.zeros(graph.number_of_nodes()) for prop in properties}
napari_edges = {}
graph, intertrack_edges = assign_tracklet_ids(graph)
for index, node in enumerate(graph.nodes(data=True)):
node_id, data = node
location = [data[loc_key] for loc_key in location_keys]
napari_data[index] = [data["tracklet_id"], data[frame_key], *location]
for prop in properties:
if prop in data:
napari_properties[prop][index] = data[prop]
napari_edges = {}
for parent, child in intertrack_edges:
parent_track_id = graph.nodes[parent]["tracklet_id"]
child_track_id = graph.nodes[child]["tracklet_id"]
if child_track_id in napari_edges:
napari_edges[child_track_id].append(parent_track_id)
else:
napari_edges[child_track_id] = [parent_track_id]
return napari_data, napari_properties, napari_edges
3 changes: 2 additions & 1 deletion tests/test_candidate_graph/test_graph_from_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import networkx as nx
import numpy as np
import pytest
from motile_toolbox.candidate_graph import NodeAttr
from motile_toolbox.candidate_graph.graph_from_segmentation import (
add_cand_edges,
graph_from_segmentation,
Expand Down Expand Up @@ -132,7 +133,7 @@ def test_nodes_from_segmentation_3d(segmentation_3d):
# test with 3D segmentation
node_graph, node_frame_dict = nodes_from_segmentation(
segmentation=segmentation_3d,
attributes=["segmentation_id"],
attributes=[NodeAttr.SEG_ID],
position_keys=("pos_z", "pos_y", "pos_x"),
)
assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"])
Expand Down
61 changes: 61 additions & 0 deletions tests/test_utils/test_saving_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import networkx as nx
import numpy as np
import pytest
from motile_toolbox.utils import relabel_segmentation
from numpy.testing import assert_array_equal
from skimage.draw import disk


@pytest.fixture
def segmentation_2d():
frame_shape = (100, 100)
total_shape = (2, *frame_shape)
segmentation = np.zeros(total_shape, dtype="int32")
# make frame with one cell in center with label 1
rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100))
segmentation[0][rr, cc] = 1

# make frame with two cells
# first cell centered at (20, 80) with label 2
# second cell centered at (60, 45) with label 3
rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape)
segmentation[1][rr, cc] = 2
rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape)
segmentation[1][rr, cc] = 3

return segmentation


@pytest.fixture
def graph_2d():
graph = nx.DiGraph()
nodes = [
("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}),
("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 2}),
]
edges = [
("0_1", "1_1", {"distance": 42.43}),
]
graph.add_nodes_from(nodes)
graph.add_edges_from(edges)
return graph


def test_relabel_segmentation(segmentation_2d, graph_2d):
frame_shape = (100, 100)
expected = np.zeros(segmentation_2d.shape, dtype="int32")
# make frame with one cell in center with label 1
rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100))
expected[0][rr, cc] = 1

# make frame with cell centered at (20, 80) with label 1
rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape)
expected[1][rr, cc] = 1

relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d)
print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}")
print(f"Nonzero expected: {np.count_nonzero(expected)}")
print(f"Max relabeled: {np.max(relabeled_seg)}")
print(f"Max expected: {np.max(expected)}")

assert_array_equal(relabeled_seg, expected)
Loading