From bcaabe3c120a5203c4b2ca6a904041518c2e7bff Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 11 Mar 2024 20:09:19 -0400 Subject: [PATCH 1/5] Enumerate node and edge attributes --- .../candidate_graph/__init__.py | 1 + .../candidate_graph/graph_attributes.py | 19 +++++++++++ .../graph_from_segmentation.py | 34 ++++++------------- .../test_loading_utils.py | 0 4 files changed, 30 insertions(+), 24 deletions(-) create mode 100644 src/motile_toolbox/candidate_graph/graph_attributes.py rename tests/{utils => test_utils}/test_loading_utils.py (100%) diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index efd4cbb..d06fd4e 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -1,2 +1,3 @@ +from .graph_attributes import EdgeAttr, NodeAttr from .graph_from_segmentation import graph_from_segmentation from .graph_to_nx import graph_to_nx diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py new file mode 100644 index 0000000..767e023 --- /dev/null +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class NodeAttr(Enum): + """Node attributes that can be added to candidate graph using the toolbox. + Note: Motile can flexibly support any custom attributes. The toolbox provides + implementations of commonly used ones, listed here. + """ + + SEG_ID = "segmentation_id" + + +class EdgeAttr(Enum): + """Edge attributes that can be added to candidate graph using the toolbox. + Note: Motile can flexibly support any custom attributes. The toolbox provides + implementations of commonly used ones, listed here. + """ + + DISTANCE = "distance" diff --git a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py index 96c0879..42fe480 100644 --- a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py +++ b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py @@ -7,6 +7,8 @@ from skimage.measure import regionprops from tqdm import tqdm +from .graph_attributes import EdgeAttr, NodeAttr + logger = logging.getLogger(__name__) @@ -33,7 +35,7 @@ def _get_location( def nodes_from_segmentation( segmentation: np.ndarray, - attributes: tuple[str, ...] | list[str] = ("segmentation_id",), + attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), position_keys: tuple[str, ...] | list[str] = ("y", "x"), frame_key: str = "t", ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: @@ -73,8 +75,8 @@ def nodes_from_segmentation( attrs = { frame_key: t, } - if "segmentation_id" in attributes: - attrs["segmentation_id"] = regionprop.label + if NodeAttr.SEG_ID in attributes: + attrs[NodeAttr.SEG_ID.value] = regionprop.label centroid = regionprop.centroid # [z,] y, x for label, value in zip(position_keys, centroid): attrs[label] = value @@ -88,7 +90,7 @@ def nodes_from_segmentation( def add_cand_edges( cand_graph: nx.DiGraph, max_edge_distance: float, - attributes: tuple[str, ...] | list[str] = ("distance",), + attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,), position_keys: tuple[str, ...] | list[str] = ("y", "x"), frame_key: str = "t", node_frame_dict: None | dict[int, list[Any]] = None, @@ -133,8 +135,8 @@ def add_cand_edges( for next_id, next_loc in zip(next_nodes, next_locs): dist = math.dist(next_loc, loc) attrs = {} - if "distance" in attributes: - attrs["distance"] = dist + if EdgeAttr.DISTANCE in attributes: + attrs[EdgeAttr.DISTANCE.value] = dist if dist <= max_edge_distance: cand_graph.add_edge(node, next_id, **attrs) @@ -142,8 +144,8 @@ def add_cand_edges( def graph_from_segmentation( segmentation: np.ndarray, max_edge_distance: float, - node_attributes: tuple[str, ...] | list[str] = ("segmentation_id",), - edge_attributes: tuple[str, ...] | list[str] = ("distance",), + node_attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), + edge_attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,), position_keys: tuple[str, ...] | list[str] = ("y", "x"), frame_key: str = "t", ): @@ -181,22 +183,6 @@ def graph_from_segmentation( arguments, or if the number of position keys provided does not match the number of position dimensions. """ - valid_edge_attributes = [ - "distance", - ] - for attr in edge_attributes: - if attr not in valid_edge_attributes: - raise ValueError( - f"Invalid attribute {attr} (supported attrs: {valid_edge_attributes})" - ) - valid_node_attributes = [ - "segmentation_id", - ] - for attr in node_attributes: - if attr not in valid_node_attributes: - raise ValueError( - f"Invalid attribute {attr} (supported attrs: {valid_node_attributes})" - ) if len(position_keys) != segmentation.ndim - 1: raise ValueError( f"Position labels {position_keys} does not match number of spatial dims " diff --git a/tests/utils/test_loading_utils.py b/tests/test_utils/test_loading_utils.py similarity index 100% rename from tests/utils/test_loading_utils.py rename to tests/test_utils/test_loading_utils.py From 6af0109cc0a48cbb6f216819f7df688b7e6c81e3 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 11 Mar 2024 20:09:38 -0400 Subject: [PATCH 2/5] Add napari tracks layer helper function --- src/motile_toolbox/visualization/__init__.py | 1 + .../visualization/napari_utils.py | 87 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 src/motile_toolbox/visualization/__init__.py create mode 100644 src/motile_toolbox/visualization/napari_utils.py diff --git a/src/motile_toolbox/visualization/__init__.py b/src/motile_toolbox/visualization/__init__.py new file mode 100644 index 0000000..cb720fa --- /dev/null +++ b/src/motile_toolbox/visualization/__init__.py @@ -0,0 +1 @@ +from .napari_utils import to_napari_tracks_layer diff --git a/src/motile_toolbox/visualization/napari_utils.py b/src/motile_toolbox/visualization/napari_utils.py new file mode 100644 index 0000000..74859b2 --- /dev/null +++ b/src/motile_toolbox/visualization/napari_utils.py @@ -0,0 +1,87 @@ +import networkx as nx +import numpy as np + + +def assign_tracklet_ids(graph: nx.DiGraph) -> nx.DiGraph: + """Add a tracklet_id attribute to a graph by removing division edges, + assigning one id to each connected component. + Designed as a helper for visualizing the graph in the napari Tracks layer. + + Args: + graph (nx.DiGraph): A networkx graph with a tracking solution + + Returns: + nx.DiGraph: The same graph with the tracklet_id assigned. Probably + occurrs in place but returned just to be clear. + """ + graph_copy = graph.copy() + + parents = [node for node, degree in graph.out_degree() if degree >= 2] + intertrack_edges = [] + + # Remove all intertrack edges from a copy of the original graph + for parent in parents: + daughters = [child for p, child in graph.out_edges(parent)] + for daughter in daughters: + graph_copy.remove_edge(parent, daughter) + intertrack_edges.append((parent, daughter)) + + track_id = 0 + for tracklet in nx.weakly_connected_components(graph_copy): + nx.set_node_attributes( + graph, {node: {"tracklet_id": track_id} for node in tracklet} + ) + track_id += 1 + return graph, intertrack_edges + + +def to_napari_tracks_layer( + graph, frame_key="t", location_keys=("y", "x"), properties=() +): + """Function to take a networkx graph and return the data needed to add to + a napari tracks layer. + + Args: + graph (nx.DiGraph): _description_ + frame_key (str, optional): Key in graph attributes containing time frame. + Defaults to "t". + location_keys (tuple, optional): Keys in graph node attributes containing + location. Should be in order: (Z), Y, X. Defaults to ("y", "x"). + properties (tuple, optional): Keys in graph node attributes to add + to the visualization layer. Defaults to (). NOTE: not working now :( + + Returns: + data : array (N, D+1) + Coordinates for N points in D+1 dimensions. ID,T,(Z),Y,X. The first + axis is the integer ID of the track. D is either 3 or 4 for planar + or volumetric timeseries respectively. + properties : dict {str: array (N,)} + Properties for each point. Each property should be an array of length N, + where N is the number of points. + graph : dict {int: list} + Graph representing associations between tracks. Dictionary defines the + mapping between a track ID and the parents of the track. This can be + one (the track has one parent, and the parent has >=1 child) in the + case of track splitting, or more than one (the track has multiple + parents, but only one child) in the case of track merging. + """ + napari_data = np.zeros((graph.number_of_nodes(), len(location_keys) + 2)) + napari_properties = {prop: np.zeros(graph.number_of_nodes()) for prop in properties} + napari_edges = {} + graph, intertrack_edges = assign_tracklet_ids(graph) + for index, node in enumerate(graph.nodes(data=True)): + node_id, data = node + location = [data[loc_key] for loc_key in location_keys] + napari_data[index] = [data["tracklet_id"], data[frame_key], *location] + for prop in properties: + if prop in data: + napari_properties[prop][index] = data[prop] + napari_edges = {} + for parent, child in intertrack_edges: + parent_track_id = graph.nodes[parent]["tracklet_id"] + child_track_id = graph.nodes[child]["tracklet_id"] + if child_track_id in napari_edges: + napari_edges[child_track_id].append(parent_track_id) + else: + napari_edges[child_track_id] = [parent_track_id] + return napari_data, napari_properties, napari_edges From 05c0a8e06958a141441b1924b38a56581d9880fd Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 11 Mar 2024 20:10:03 -0400 Subject: [PATCH 3/5] Add function to relabel segmentation from tracking solution --- src/motile_toolbox/utils/__init__.py | 1 + src/motile_toolbox/utils/saving_utils.py | 40 ++++++++++++++++ tests/test_utils/test_saving_utils.py | 61 ++++++++++++++++++++++++ 3 files changed, 102 insertions(+) create mode 100644 src/motile_toolbox/utils/saving_utils.py create mode 100644 tests/test_utils/test_saving_utils.py diff --git a/src/motile_toolbox/utils/__init__.py b/src/motile_toolbox/utils/__init__.py index e69de29..eebaf83 100644 --- a/src/motile_toolbox/utils/__init__.py +++ b/src/motile_toolbox/utils/__init__.py @@ -0,0 +1 @@ +from .saving_utils import relabel_segmentation diff --git a/src/motile_toolbox/utils/saving_utils.py b/src/motile_toolbox/utils/saving_utils.py new file mode 100644 index 0000000..7419bf3 --- /dev/null +++ b/src/motile_toolbox/utils/saving_utils.py @@ -0,0 +1,40 @@ +import networkx as nx +import numpy as np + +from motile_toolbox.candidate_graph import NodeAttr + + +def relabel_segmentation( + solution_nx_graph: nx.DiGraph, + segmentation: np.array, + frame_key="t", +) -> np.array: + """Relabel a segmentation based on tracking results so that nodes in same + track share the same id. IDs do change at division. + + Args: + solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use + for relabeling. Nodes not in graph will be removed from seg. Original + segmentation ids have to be stored in the graph so we can map them back. + segmentation (np.array): Original segmentation with labels ids that correspond + to segmentation id in graph. + frame_key (str, optional): Time frame key in networkx graph. Defaults to "t". + + Returns: + np.array: Relabeled segmentation array where nodes in same track share same id. + """ + tracked_masks = np.zeros_like(segmentation) + id_counter = 1 + parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] + soln_copy = solution_nx_graph.copy() + for parent_node in parent_nodes: + out_edges = solution_nx_graph.out_edges(parent_node) + soln_copy.remove_edges_from(out_edges) + for node_set in nx.weakly_connected_components(soln_copy): + for node in node_set: + time_frame = solution_nx_graph.nodes[node][frame_key] + previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value] + previous_seg_mask = segmentation[time_frame] == previous_seg_id + tracked_masks[time_frame][previous_seg_mask] = id_counter + id_counter += 1 + return tracked_masks diff --git a/tests/test_utils/test_saving_utils.py b/tests/test_utils/test_saving_utils.py new file mode 100644 index 0000000..c4ff2ac --- /dev/null +++ b/tests/test_utils/test_saving_utils.py @@ -0,0 +1,61 @@ +import networkx as nx +import numpy as np +import pytest +from motile_toolbox.utils import relabel_segmentation +from numpy.testing import assert_array_equal +from skimage.draw import disk + + +@pytest.fixture +def segmentation_2d(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 2 + # second cell centered at (60, 45) with label 3 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 2 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 3 + + return segmentation + + +@pytest.fixture +def graph_2d(): + graph = nx.DiGraph() + nodes = [ + ("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}), + ("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 2}), + ] + edges = [ + ("0_1", "1_1", {"distance": 42.43}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def test_relabel_segmentation(segmentation_2d, graph_2d): + frame_shape = (100, 100) + expected = np.zeros(segmentation_2d.shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + expected[0][rr, cc] = 1 + + # make frame with cell centered at (20, 80) with label 1 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + expected[1][rr, cc] = 1 + + relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d) + print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}") + print(f"Nonzero expected: {np.count_nonzero(expected)}") + print(f"Max relabeled: {np.max(relabeled_seg)}") + print(f"Max expected: {np.max(expected)}") + + assert_array_equal(relabeled_seg, expected) From d5cfc41f206b394436d97d807b0f0ad60eda42e3 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 13 Mar 2024 16:17:06 -0400 Subject: [PATCH 4/5] Use ndarray for type annotations --- src/motile_toolbox/utils/saving_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/motile_toolbox/utils/saving_utils.py b/src/motile_toolbox/utils/saving_utils.py index 7419bf3..6bd98f2 100644 --- a/src/motile_toolbox/utils/saving_utils.py +++ b/src/motile_toolbox/utils/saving_utils.py @@ -6,9 +6,9 @@ def relabel_segmentation( solution_nx_graph: nx.DiGraph, - segmentation: np.array, + segmentation: np.ndarray, frame_key="t", -) -> np.array: +) -> np.ndarray: """Relabel a segmentation based on tracking results so that nodes in same track share the same id. IDs do change at division. @@ -16,12 +16,13 @@ def relabel_segmentation( solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use for relabeling. Nodes not in graph will be removed from seg. Original segmentation ids have to be stored in the graph so we can map them back. - segmentation (np.array): Original segmentation with labels ids that correspond + segmentation (np.ndarray): Original segmentation with labels ids that correspond to segmentation id in graph. frame_key (str, optional): Time frame key in networkx graph. Defaults to "t". Returns: - np.array: Relabeled segmentation array where nodes in same track share same id. + np.ndarray: Relabeled segmentation array where nodes in same track share same + id. """ tracked_masks = np.zeros_like(segmentation) id_counter = 1 From 8fcf48a3b4c4f0426bd248a77fb272e51e9accac Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 13 Mar 2024 16:21:13 -0400 Subject: [PATCH 5/5] Use NodeAttr in tests --- tests/test_candidate_graph/test_graph_from_segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_graph_from_segmentation.py index 022b5fb..55d5741 100644 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ b/tests/test_candidate_graph/test_graph_from_segmentation.py @@ -3,6 +3,7 @@ import networkx as nx import numpy as np import pytest +from motile_toolbox.candidate_graph import NodeAttr from motile_toolbox.candidate_graph.graph_from_segmentation import ( add_cand_edges, graph_from_segmentation, @@ -132,7 +133,7 @@ def test_nodes_from_segmentation_3d(segmentation_3d): # test with 3D segmentation node_graph, node_frame_dict = nodes_from_segmentation( segmentation=segmentation_3d, - attributes=["segmentation_id"], + attributes=[NodeAttr.SEG_ID], position_keys=("pos_z", "pos_y", "pos_x"), ) assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"])