From 3d0a134c96248a582db3b6503d104bf7c9296dda Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 11:08:45 -0400 Subject: [PATCH 01/20] Reorganize code for multi hypothesis graphs --- .../candidate_graph/graph_attributes.py | 11 ++- .../graph_from_segmentation.py | 86 ++----------------- src/motile_toolbox/candidate_graph/iou.py | 36 ++++++++ .../candidate_graph/multi_seg_graph.py | 67 +++++++++++++++ src/motile_toolbox/candidate_graph/utils.py | 31 +++++++ 5 files changed, 150 insertions(+), 81 deletions(-) create mode 100644 src/motile_toolbox/candidate_graph/iou.py create mode 100644 src/motile_toolbox/candidate_graph/multi_seg_graph.py create mode 100644 src/motile_toolbox/candidate_graph/utils.py diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py index e6a9d49..eef4d07 100644 --- a/src/motile_toolbox/candidate_graph/graph_attributes.py +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -6,8 +6,9 @@ class NodeAttr(Enum): Note: Motile can flexibly support any custom attributes. The toolbox provides implementations of commonly used ones, listed here. """ - - SEG_ID = "segmentation_id" + POS = "position" + SEG_ID = "segmentation_id" # TODO: Seg? + SEG_HYPOTHESIS = "seg_hypothesis" class EdgeAttr(Enum): @@ -18,3 +19,9 @@ class EdgeAttr(Enum): DISTANCE = "distance" IOU = "iou" + + +def add_iou(cand_graph, segmentation) -> None: + # TODO: implement + pass + diff --git a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py index 638d003..4e42bb6 100644 --- a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py +++ b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py @@ -7,35 +7,14 @@ from skimage.measure import regionprops from tqdm import tqdm -from .graph_attributes import EdgeAttr, NodeAttr +from .graph_attributes import EdgeAttr, NodeAttr, add_iou +from .utils import _get_location, _get_node_id logger = logging.getLogger(__name__) -def _get_location( - node_data: dict[str, Any], position_keys: tuple[str, ...] | list[str] -) -> list[Any]: - """Convenience function to get the location of a networkx node when each dimension - is stored in a different attribute. - - Args: - node_data (dict[str, Any]): Dictionary of attributes of a networkx node. - Assumes the provided position keys are in the dictionary. - position_keys (tuple[str, ...] | list[str], optional): Keys to use to get - location information from node_data (assumes they are present in node_data). - Defaults to ("z", "y", "x"). - - Returns: - list: _description_ - Raises: - KeyError if position keys not in node_data - """ - return [node_data[k] for k in position_keys] - - def nodes_from_segmentation( segmentation: np.ndarray, - attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), position_keys: tuple[str, ...] | list[str] = ("y", "x"), frame_key: str = "t", ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: @@ -71,12 +50,11 @@ def nodes_from_segmentation( nodes_in_frame = [] props = regionprops(segmentation[t]) for regionprop in props: - node_id = f"{t}_{regionprop.label}" + node_id = _get_node_id(t, regionprop.label) attrs = { frame_key: t, } - if NodeAttr.SEG_ID in attributes: - attrs[NodeAttr.SEG_ID.value] = regionprop.label + attrs[NodeAttr.SEG_ID.value] = regionprop.label centroid = regionprop.centroid # [z,] y, x for label, value in zip(position_keys, centroid): attrs[label] = value @@ -112,11 +90,9 @@ def _compute_node_frame_dict( def add_cand_edges( cand_graph: nx.DiGraph, max_edge_distance: float, - attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,), position_keys: tuple[str, ...] | list[str] = ("y", "x"), frame_key: str = "t", node_frame_dict: None | dict[int, list[Any]] = None, - segmentation: None | np.ndarray = None, ) -> None: """Add candidate edges to a candidate graph by connecting all nodes in adjacent frames that are closer than max_edge_distance. Also adds attributes to the edges. @@ -152,63 +128,15 @@ def add_cand_edges( _get_location(cand_graph.nodes[n], position_keys=position_keys) for n in next_nodes ] - if EdgeAttr.IOU in attributes: - if segmentation is None: - raise ValueError("Can't compute IOU without segmentation.") - ious = compute_ious(segmentation[frame], segmentation[frame + 1]) for node in node_frame_dict[frame]: loc = _get_location(cand_graph.nodes[node], position_keys=position_keys) for next_id, next_loc in zip(next_nodes, next_locs): dist = math.dist(next_loc, loc) if dist <= max_edge_distance: - attrs = {} - if EdgeAttr.DISTANCE in attributes: - attrs[EdgeAttr.DISTANCE.value] = dist - if EdgeAttr.IOU in attributes: - node_seg_id = cand_graph.nodes[node][NodeAttr.SEG_ID.value] - next_seg_id = cand_graph.nodes[next_id][NodeAttr.SEG_ID.value] - attrs[EdgeAttr.IOU.value] = ious.get(node_seg_id, {}).get( - next_seg_id, 0 - ) + attrs = {EdgeAttr.DISTANCE.value: dist} cand_graph.add_edge(node, next_id, **attrs) -def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]: - """Compute label IOUs between two label arrays of the same shape. Ignores background - (label 0). - - Args: - frame1 (np.ndarray): Array with integer labels - frame2 (np.ndarray): Array with integer labels - - Returns: - dict[int, dict[int, float]]: Dictionary from labels in frame 1 to labels in - frame 2 to iou values. Nodes that have no overlap are not included. - """ - frame1 = frame1.flatten() - frame2 = frame2.flatten() - # get indices where both are not zero (ignore background) - # this speeds up computation significantly - non_zero_indices = np.logical_and(frame1, frame2) - flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]]) - - values, counts = np.unique(flattened_stacked, axis=1, return_counts=True) - frame1_values, frame1_counts = np.unique(frame1, return_counts=True) - frame1_label_sizes = dict(zip(frame1_values, frame1_counts)) - frame2_values, frame2_counts = np.unique(frame2, return_counts=True) - frame2_label_sizes = dict(zip(frame2_values, frame2_counts)) - iou_dict: dict[int, dict[int, float]] = {} - for index in range(values.shape[1]): - pair = values[:, index] - intersection = counts[index] - id1, id2 = pair - union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection - if id1 not in iou_dict: - iou_dict[id1] = {} - iou_dict[id1][id2] = intersection / union - return iou_dict - - def graph_from_segmentation( segmentation: np.ndarray, max_edge_distance: float, @@ -267,11 +195,11 @@ def graph_from_segmentation( add_cand_edges( cand_graph, max_edge_distance=max_edge_distance, - attributes=edge_attributes, position_keys=position_keys, node_frame_dict=node_frame_dict, - segmentation=segmentation, ) + if EdgeAttr.IOU in edge_attributes: + add_iou(cand_graph, segmentation) logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") return cand_graph diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py new file mode 100644 index 0000000..f013200 --- /dev/null +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -0,0 +1,36 @@ +import numpy as np + +def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]: + """Compute label IOUs between two label arrays of the same shape. Ignores background + (label 0). + + Args: + frame1 (np.ndarray): Array with integer labels + frame2 (np.ndarray): Array with integer labels + + Returns: + dict[int, dict[int, float]]: Dictionary from labels in frame 1 to labels in + frame 2 to iou values. Nodes that have no overlap are not included. + """ + frame1 = frame1.flatten() + frame2 = frame2.flatten() + # get indices where both are not zero (ignore background) + # this speeds up computation significantly + non_zero_indices = np.logical_and(frame1, frame2) + flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]]) + + values, counts = np.unique(flattened_stacked, axis=1, return_counts=True) + frame1_values, frame1_counts = np.unique(frame1, return_counts=True) + frame1_label_sizes = dict(zip(frame1_values, frame1_counts)) + frame2_values, frame2_counts = np.unique(frame2, return_counts=True) + frame2_label_sizes = dict(zip(frame2_values, frame2_counts)) + iou_dict: dict[int, dict[int, float]] = {} + for index in range(values.shape[1]): + pair = values[:, index] + intersection = counts[index] + id1, id2 = pair + union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection + if id1 not in iou_dict: + iou_dict[id1] = {} + iou_dict[id1][id2] = intersection / union + return iou_dict \ No newline at end of file diff --git a/src/motile_toolbox/candidate_graph/multi_seg_graph.py b/src/motile_toolbox/candidate_graph/multi_seg_graph.py new file mode 100644 index 0000000..baac64f --- /dev/null +++ b/src/motile_toolbox/candidate_graph/multi_seg_graph.py @@ -0,0 +1,67 @@ +import numpy as np +import networkx as nx +from typing import Any + +from .graph_attributes import EdgeAttr, NodeAttr, add_iou +from .graph_from_segmentation import nodes_from_segmentation, add_cand_edges + + +def compute_multi_seg_graph(segmentations: list[np.ndarray]) -> tuple[nx.DiGraph, list[set]]: + """Create a candidate graph from multi hypothesis segmentations. This is not + tailored for agglomeration approaches with hierarchical merge graphs, it simply + creates a conflict set for any nodes that overlap in the same time frame. + + Args: + segmentations (list[np.ndarray]): + + Returns: + nx.DiGraph: _description_ + """ + # for each segmentation, get nodes using same method as graph_from_segmentation + # add them all to one big graph + cand_graph, frame_dict = nodes_from_multi_segmentation(segmentations) # TODO: other args + + # Compute conflict sets between segmentations + # can use same method as IOU (without the U) to compute conflict sets + conflicts = [] + for time, segs in enumerate(segmentations): + conflicts.append(compute_conflict_sets(segs, time)) + + # add edges with same method as before, with slightly different implementation + add_cand_edges(cand_graph) # TODO: other args + if EdgeAttr.IOU in edge_attributes: + # TODO: cross product when calling (need to re-organize add_iou to not assume stuff) + add_iou(cand_graph, segmentation) + + return cand_graph + + + + + + +def nodes_from_multi_segmentation( + segmentations: list[np.ndarray], + attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), + position_keys: tuple[str, ...] | list[str] = ("y", "x"), + frame_key: str = "t", +) -> tuple[nx.DiGraph, dict[int, list[Any]]]: + multi_hypo_node_graph = nx.DiGraph() + multi_frame_dict = {} + for layer_id, segmentation in enumerate(segmentations): + node_graph, frame_dict = nodes_from_segmentation(segmentation, layer_id) + # TODO: pass attributes, etc. + # TODO: add multi segmentation attribute to nodes_from_segmentation + # (use in node id and add to attributes) + multi_hypo_node_graph.update(node_graph) + multi_frame_dict.update(frame_dict) + # TODO: Make sure there is no node-id collision + + return multi_hypo_node_graph, multi_frame_dict + + + +def compute_conflict_sets(segmenations: np.ndarray, time: int) -> list[set]: + """ Segmentations in one frame only. Return list of sets of node ids that conflict.""" + # This will look a lot like the IOU code + pass diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py new file mode 100644 index 0000000..36b161a --- /dev/null +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -0,0 +1,31 @@ +from typing import Any + + +def _get_location( + node_data: dict[str, Any], position_keys: tuple[str, ...] | list[str] +) -> list[Any]: + + # TODO: Remove this function by storing positions in one attribute called position + """Convenience function to get the location of a networkx node when each dimension + is stored in a different attribute. + + Args: + node_data (dict[str, Any]): Dictionary of attributes of a networkx node. + Assumes the provided position keys are in the dictionary. + position_keys (tuple[str, ...] | list[str], optional): Keys to use to get + location information from node_data (assumes they are present in node_data). + Defaults to ("z", "y", "x"). + + Returns: + list: _description_ + Raises: + KeyError if position keys not in node_data + """ + return [node_data[k] for k in position_keys] + +def _get_node_id(time: int, label_id: int, hypothesis_id: int | None) -> str: + + if hypothesis_id: + return f"{time}_{hypothesis_id}_{label_id}" + else: + return f"{time}_{label_id}" \ No newline at end of file From 3fa4757bcb15f24118f1dffc483c18ec2c30c3ea Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 11:09:27 -0400 Subject: [PATCH 02/20] Ruff formatting --- .../candidate_graph/graph_attributes.py | 1 - src/motile_toolbox/candidate_graph/iou.py | 3 ++- .../candidate_graph/multi_seg_graph.py | 19 ++++++++++--------- src/motile_toolbox/candidate_graph/utils.py | 6 +++--- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py index eef4d07..7045c0a 100644 --- a/src/motile_toolbox/candidate_graph/graph_attributes.py +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -24,4 +24,3 @@ class EdgeAttr(Enum): def add_iou(cand_graph, segmentation) -> None: # TODO: implement pass - diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index f013200..e605bf2 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -1,5 +1,6 @@ import numpy as np + def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]: """Compute label IOUs between two label arrays of the same shape. Ignores background (label 0). @@ -33,4 +34,4 @@ def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, if id1 not in iou_dict: iou_dict[id1] = {} iou_dict[id1][id2] = intersection / union - return iou_dict \ No newline at end of file + return iou_dict diff --git a/src/motile_toolbox/candidate_graph/multi_seg_graph.py b/src/motile_toolbox/candidate_graph/multi_seg_graph.py index baac64f..eb13085 100644 --- a/src/motile_toolbox/candidate_graph/multi_seg_graph.py +++ b/src/motile_toolbox/candidate_graph/multi_seg_graph.py @@ -1,18 +1,19 @@ -import numpy as np -import networkx as nx from typing import Any +import networkx as nx +import numpy as np + from .graph_attributes import EdgeAttr, NodeAttr, add_iou -from .graph_from_segmentation import nodes_from_segmentation, add_cand_edges +from .graph_from_segmentation import add_cand_edges, nodes_from_segmentation def compute_multi_seg_graph(segmentations: list[np.ndarray]) -> tuple[nx.DiGraph, list[set]]: - """Create a candidate graph from multi hypothesis segmentations. This is not - tailored for agglomeration approaches with hierarchical merge graphs, it simply + """Create a candidate graph from multi hypothesis segmentations. This is not + tailored for agglomeration approaches with hierarchical merge graphs, it simply creates a conflict set for any nodes that overlap in the same time frame. Args: - segmentations (list[np.ndarray]): + segmentations (list[np.ndarray]): Returns: nx.DiGraph: _description_ @@ -34,8 +35,8 @@ def compute_multi_seg_graph(segmentations: list[np.ndarray]) -> tuple[nx.DiGraph add_iou(cand_graph, segmentation) return cand_graph - - + + @@ -62,6 +63,6 @@ def nodes_from_multi_segmentation( def compute_conflict_sets(segmenations: np.ndarray, time: int) -> list[set]: - """ Segmentations in one frame only. Return list of sets of node ids that conflict.""" + """Segmentations in one frame only. Return list of sets of node ids that conflict.""" # This will look a lot like the IOU code pass diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index 36b161a..6d0d656 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -4,7 +4,7 @@ def _get_location( node_data: dict[str, Any], position_keys: tuple[str, ...] | list[str] ) -> list[Any]: - + # TODO: Remove this function by storing positions in one attribute called position """Convenience function to get the location of a networkx node when each dimension is stored in a different attribute. @@ -24,8 +24,8 @@ def _get_location( return [node_data[k] for k in position_keys] def _get_node_id(time: int, label_id: int, hypothesis_id: int | None) -> str: - + if hypothesis_id: return f"{time}_{hypothesis_id}_{label_id}" else: - return f"{time}_{label_id}" \ No newline at end of file + return f"{time}_{label_id}" From c748922f40f2e65bb52929b59b058836c21ea0fe Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 12:05:42 -0400 Subject: [PATCH 03/20] Enforce standard attribute names in code and tests --- .../candidate_graph/graph_attributes.py | 12 +- .../graph_from_segmentation.py | 97 ++++----------- src/motile_toolbox/candidate_graph/iou.py | 25 ++++ src/motile_toolbox/candidate_graph/utils.py | 31 ----- src/motile_toolbox/utils/saving_utils.py | 3 +- .../test_graph_from_segmentation.py | 112 ++++-------------- tests/test_candidate_graph/test_iou.py | 82 +++++++++++++ tests/test_utils/test_saving_utils.py | 7 +- 8 files changed, 164 insertions(+), 205 deletions(-) delete mode 100644 src/motile_toolbox/candidate_graph/utils.py create mode 100644 tests/test_candidate_graph/test_iou.py diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py index 7045c0a..d3c6894 100644 --- a/src/motile_toolbox/candidate_graph/graph_attributes.py +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -6,9 +6,10 @@ class NodeAttr(Enum): Note: Motile can flexibly support any custom attributes. The toolbox provides implementations of commonly used ones, listed here. """ - POS = "position" - SEG_ID = "segmentation_id" # TODO: Seg? - SEG_HYPOTHESIS = "seg_hypothesis" + POS = "pos" + TIME = "time" + SEG_ID = "seg_id" + SEG_HYPOTHESIS = "seg_hypo" class EdgeAttr(Enum): @@ -19,8 +20,3 @@ class EdgeAttr(Enum): DISTANCE = "distance" IOU = "iou" - - -def add_iou(cand_graph, segmentation) -> None: - # TODO: implement - pass diff --git a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py index 4e42bb6..88a950f 100644 --- a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py +++ b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py @@ -7,16 +7,22 @@ from skimage.measure import regionprops from tqdm import tqdm -from .graph_attributes import EdgeAttr, NodeAttr, add_iou -from .utils import _get_location, _get_node_id +from .graph_attributes import EdgeAttr, NodeAttr +from .iou import add_iou logger = logging.getLogger(__name__) +def _get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> str: + + if hypothesis_id: + return f"{time}_{hypothesis_id}_{label_id}" + else: + return f"{time}_{label_id}" + + def nodes_from_segmentation( segmentation: np.ndarray, - position_keys: tuple[str, ...] | list[str] = ("y", "x"), - frame_key: str = "t", ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: """Extract candidate nodes from a segmentation. Also computes specified attributes. Returns a networkx graph with only nodes, and also a dictionary from frames to @@ -25,18 +31,7 @@ def nodes_from_segmentation( Args: segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels (0 is background, all pixels with value 1 belong to one cell, etc.). The - time dimension is first, followed by two or three position dimensions. If - the position dims are not (y, x), use `position_keys` to specify the names - of the dimensions. - attributes (tuple[str, ...] | list[str] , optional): Set of attributes to - compute and add to graph nodes. Valid attributes are: "segmentation_id". - Defaults to ("segmentation_id",). - position_keys (tuple[str, ...]| list[str] , optional): What to label the - position dimensions in the candidate graph. The order of the names - corresponds to the order of the dimensions in `segmentation`. Defaults to - ("y", "x"). - frame_key (str, optional): What to label the time dimension in the candidate - graph. Defaults to 't'. + time dimension is first, followed by two or three position dimensions. Returns: tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, @@ -52,12 +47,11 @@ def nodes_from_segmentation( for regionprop in props: node_id = _get_node_id(t, regionprop.label) attrs = { - frame_key: t, + NodeAttr.TIME.value: t, } attrs[NodeAttr.SEG_ID.value] = regionprop.label centroid = regionprop.centroid # [z,] y, x - for label, value in zip(position_keys, centroid): - attrs[label] = value + attrs[NodeAttr.POS.value] = centroid cand_graph.add_node(node_id, **attrs) nodes_in_frame.append(node_id) if nodes_in_frame: @@ -66,21 +60,19 @@ def nodes_from_segmentation( def _compute_node_frame_dict( - cand_graph: nx.DiGraph, frame_key: str = "t" + cand_graph: nx.DiGraph ) -> dict[int, list[Any]]: """Compute dictionary from time frames to node ids for candidate graph. Args: cand_graph (nx.DiGraph): A networkx graph - frame_key (str, optional): Attribute key that holds the time frame of each - node in cand_graph. Defaults to "t". Returns: dict[int, list[Any]]: A mapping from time frames to lists of node ids. """ node_frame_dict: dict[int, list[Any]] = {} for node, data in cand_graph.nodes(data=True): - t = data[frame_key] + t = data[NodeAttr.TIME.value] if t not in node_frame_dict: node_frame_dict[t] = [] node_frame_dict[t].append(node) @@ -90,8 +82,6 @@ def _compute_node_frame_dict( def add_cand_edges( cand_graph: nx.DiGraph, max_edge_distance: float, - position_keys: tuple[str, ...] | list[str] = ("y", "x"), - frame_key: str = "t", node_frame_dict: None | dict[int, list[Any]] = None, ) -> None: """Add candidate edges to a candidate graph by connecting all nodes in adjacent @@ -103,33 +93,22 @@ def add_cand_edges( max_edge_distance (float): Maximum distance that objects can travel between frames. All nodes within this distance in adjacent frames will by connected with a candidate edge. - attributes (tuple[EdgeAttr, ...], optional): Set of attributes to compute and - add to graph. Defaults to (EdgeAttr.DISTANCE,). - position_keys (tuple[str, ...], optional): What the position dimensions of nodes - in the candidate graph are labeled. Defaults to ("y", "x"). - frame_key (str, optional): The label of the time dimension in the candidate - graph. Defaults to "t". node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames to node ids. If not provided, it will be computed from cand_graph. Defaults to None. - segmentation (np.ndarray, optional): The segmentation array for optionally - computing attributes such as IOU. Defaults to None. """ print("Extracting candidate edges") if not node_frame_dict: - node_frame_dict = _compute_node_frame_dict(cand_graph, frame_key=frame_key) + node_frame_dict = _compute_node_frame_dict(cand_graph) frames = sorted(node_frame_dict.keys()) for frame in tqdm(frames): if frame + 1 not in node_frame_dict: continue next_nodes = node_frame_dict[frame + 1] - next_locs = [ - _get_location(cand_graph.nodes[n], position_keys=position_keys) - for n in next_nodes - ] + next_locs = [cand_graph.nodes[n][NodeAttr.POS.value] for n in next_nodes] for node in node_frame_dict[frame]: - loc = _get_location(cand_graph.nodes[node], position_keys=position_keys) + loc = cand_graph.nodes[node][NodeAttr.POS.value] for next_id, next_loc in zip(next_nodes, next_locs): dist = math.dist(next_loc, loc) if dist <= max_edge_distance: @@ -140,11 +119,8 @@ def add_cand_edges( def graph_from_segmentation( segmentation: np.ndarray, max_edge_distance: float, - node_attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), - edge_attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,), - position_keys: tuple[str, ...] | list[str] = ("y", "x"), - frame_key: str = "t", -): + iou: bool = False, +) -> nx.DiGraph: """Construct a candidate graph from a segmentation array. Nodes are placed at the centroid of each segmentation and edges are added for all nodes in adjacent frames within max_edge_distance. The specified attributes are computed during construction. @@ -159,47 +135,24 @@ def graph_from_segmentation( max_edge_distance (float): Maximum distance that objects can travel between frames. All nodes within this distance in adjacent frames will by connected with a candidate edge. - node_attributes (tuple[str, ...] | list[str], optional): Set of attributes to - compute and add to nodes in graph. Valid attributes are: "segmentation_id". - Defaults to ("segmentation_id",). - edge_attributes (tuple[str, ...] | list[str], optional): Set of attributes to - compute and add to edges in graph. Valid attributes are: "distance". - Defaults to ("distance",). - position_keys (tuple[str, ...], optional): What to label the position dimensions - in the candidate graph. The order of the names corresponds to the order of - the dimensions in `segmentation`. Defaults to ("y", "x"). - frame_key (str, optional): What to label the time dimension in the candidate - graph. Defaults to 't'. + iou (bool, optional): Whether to include IOU on the candidate graph. + Defaults to False. Returns: nx.DiGraph: A candidate graph that can be passed to the motile solver. - - Raises: - ValueError: if unsupported attribute strings are passed in to the attributes - arguments, or if the number of position keys provided does not match the - number of position dimensions. """ - if len(position_keys) != segmentation.ndim - 1: - raise ValueError( - f"Position labels {position_keys} does not match number of spatial dims " - f"({segmentation.ndim - 1})" - ) # add nodes - - cand_graph, node_frame_dict = nodes_from_segmentation( - segmentation, node_attributes, position_keys=position_keys, frame_key=frame_key - ) + cand_graph, node_frame_dict = nodes_from_segmentation(segmentation) logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") # add edges add_cand_edges( cand_graph, max_edge_distance=max_edge_distance, - position_keys=position_keys, node_frame_dict=node_frame_dict, ) - if EdgeAttr.IOU in edge_attributes: - add_iou(cand_graph, segmentation) + if iou: + add_iou(cand_graph, segmentation, node_frame_dict) logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") return cand_graph diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index e605bf2..599e636 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -1,4 +1,8 @@ +import networkx as nx import numpy as np +from tqdm import tqdm + +from .graph_attributes import EdgeAttr, NodeAttr def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]: @@ -35,3 +39,24 @@ def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, iou_dict[id1] = {} iou_dict[id1][id2] = intersection / union return iou_dict + + +def add_iou(cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict) -> None: + """Add IOU to the candidate graph. + + Args: + cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated + segmentation (np.ndarray): segmentation that was used to create cand_graph + """ + frames = sorted(node_frame_dict.keys()) + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + ious = compute_ious(segmentation[frame], segmentation[frame + 1]) + next_nodes = node_frame_dict[frame + 1] + for node_id in node_frame_dict[frame]: + node_seg_id = cand_graph.nodes[node_id][NodeAttr.SEG_ID.value] + for next_id in next_nodes: + next_seg_id = cand_graph.nodes[next_id][NodeAttr.SEG_ID.value] + iou = ious.get(node_seg_id, {}).get( next_seg_id, 0) + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou \ No newline at end of file diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py deleted file mode 100644 index 6d0d656..0000000 --- a/src/motile_toolbox/candidate_graph/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Any - - -def _get_location( - node_data: dict[str, Any], position_keys: tuple[str, ...] | list[str] -) -> list[Any]: - - # TODO: Remove this function by storing positions in one attribute called position - """Convenience function to get the location of a networkx node when each dimension - is stored in a different attribute. - - Args: - node_data (dict[str, Any]): Dictionary of attributes of a networkx node. - Assumes the provided position keys are in the dictionary. - position_keys (tuple[str, ...] | list[str], optional): Keys to use to get - location information from node_data (assumes they are present in node_data). - Defaults to ("z", "y", "x"). - - Returns: - list: _description_ - Raises: - KeyError if position keys not in node_data - """ - return [node_data[k] for k in position_keys] - -def _get_node_id(time: int, label_id: int, hypothesis_id: int | None) -> str: - - if hypothesis_id: - return f"{time}_{hypothesis_id}_{label_id}" - else: - return f"{time}_{label_id}" diff --git a/src/motile_toolbox/utils/saving_utils.py b/src/motile_toolbox/utils/saving_utils.py index 6bd98f2..868c755 100644 --- a/src/motile_toolbox/utils/saving_utils.py +++ b/src/motile_toolbox/utils/saving_utils.py @@ -7,7 +7,6 @@ def relabel_segmentation( solution_nx_graph: nx.DiGraph, segmentation: np.ndarray, - frame_key="t", ) -> np.ndarray: """Relabel a segmentation based on tracking results so that nodes in same track share the same id. IDs do change at division. @@ -33,7 +32,7 @@ def relabel_segmentation( soln_copy.remove_edges_from(out_edges) for node_set in nx.weakly_connected_components(soln_copy): for node in node_set: - time_frame = solution_nx_graph.nodes[node][frame_key] + time_frame = solution_nx_graph.nodes[node][NodeAttr.TIME.value] previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value] previous_seg_mask = segmentation[time_frame] == previous_seg_id tracked_masks[time_frame][previous_seg_mask] = id_counter diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_graph_from_segmentation.py index 2b5b7db..b235f06 100644 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ b/tests/test_candidate_graph/test_graph_from_segmentation.py @@ -6,7 +6,8 @@ from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr from motile_toolbox.candidate_graph.graph_from_segmentation import ( add_cand_edges, - compute_ious, + _compute_node_frame_dict, + _get_node_id, graph_from_segmentation, nodes_from_segmentation, ) @@ -37,13 +38,13 @@ def segmentation_2d(): def graph_2d(): graph = nx.DiGraph() nodes = [ - ("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}), - ("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 1}), - ("1_2", {"y": 60, "x": 45, "t": 1, "segmentation_id": 2}), + ("0_1", {NodeAttr.POS.value: (50, 50), NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1}), + ("1_1", {NodeAttr.POS.value: (20, 80), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 1}), + ("1_2", {NodeAttr.POS.value: (60, 45), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2}), ] edges = [ - ("0_1", "1_1", {"distance": 42.43, "iou": 0.0}), - ("0_1", "1_2", {"distance": 11.18, "iou": 0.395}), + ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43, EdgeAttr.IOU.value: 0.0}), + ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18, EdgeAttr.IOU.value: 0.395}), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) @@ -82,15 +83,15 @@ def segmentation_3d(): def graph_3d(): graph = nx.DiGraph() nodes = [ - ("0_1", {"z": 50, "y": 50, "x": 50, "t": 0, "segmentation_id": 1}), - ("1_1", {"z": 20, "y": 50, "x": 80, "t": 1, "segmentation_id": 1}), - ("1_2", {"z": 60, "y": 50, "x": 45, "t": 1, "segmentation_id": 2}), + ("0_1", {NodeAttr.POS.value: (50, 50, 50), NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1}), + ("1_1", {NodeAttr.POS.value: (20, 50, 80), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 1}), + ("1_2", {NodeAttr.POS.value: (60, 50, 45), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2}), ] edges = [ # math.dist([50, 50], [20, 80]) - ("0_1", "1_1", {"distance": 42.43}), + ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43}), # math.dist([50, 50], [60, 45]) - ("0_1", "1_2", {"distance": 11.18}), + ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18}), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) @@ -113,36 +114,23 @@ def test_nodes_from_segmentation_2d(segmentation_2d): segmentation=segmentation_2d, ) assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert node_graph.nodes["1_1"]["segmentation_id"] == 1 - assert node_graph.nodes["1_1"]["t"] == 1 - assert node_graph.nodes["1_1"]["y"] == 20 - assert node_graph.nodes["1_1"]["x"] == 80 + assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 80) assert node_frame_dict[0] == ["0_1"] assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) - # remove attrs - node_graph, _ = nodes_from_segmentation( - segmentation=segmentation_2d, - attributes=[], - ) - assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert "segmentation_id" not in node_graph.nodes["0_1"] - def test_nodes_from_segmentation_3d(segmentation_3d): # test with 3D segmentation node_graph, node_frame_dict = nodes_from_segmentation( segmentation=segmentation_3d, - attributes=[NodeAttr.SEG_ID], - position_keys=("pos_z", "pos_y", "pos_x"), ) assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert node_graph.nodes["1_1"]["segmentation_id"] == 1 - assert node_graph.nodes["1_1"]["t"] == 1 - assert node_graph.nodes["1_1"]["pos_z"] == 20 - assert node_graph.nodes["1_1"]["pos_y"] == 50 - assert node_graph.nodes["1_1"]["pos_x"] == 80 + assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 50, 80) assert node_frame_dict[0] == ["0_1"] assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) @@ -162,47 +150,26 @@ def test_add_cand_edges_2d(graph_2d): def test_add_cand_edges_3d(graph_3d): cand_graph = nx.create_empty_copy(graph_3d) - add_cand_edges(cand_graph, max_edge_distance=15, position_keys=("z", "y", "x")) + add_cand_edges(cand_graph, max_edge_distance=15) graph_3d.remove_edge("0_1", "1_1") assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) for edge in cand_graph.edges: assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] -# graph_from_segmentation -def test_graph_from_segmentation_invalid(): - # test invalid attributes - with pytest.raises(ValueError): - graph_from_segmentation( - np.zeros((3, 10, 10, 10), dtype="int32"), - 10, - edge_attributes=["invalid"], - ) - with pytest.raises(ValueError): - graph_from_segmentation( - np.zeros((3, 10, 10, 10), dtype="int32"), - 10, - node_attributes=["invalid"], - ) - - with pytest.raises(ValueError): - graph_from_segmentation( - np.zeros((3, 10, 10), dtype="int32"), 100, position_keys=["z", "y", "x"] - ) - - def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): # test with 2D segmentation cand_graph = graph_from_segmentation( segmentation=segmentation_2d, max_edge_distance=100, - edge_attributes=[EdgeAttr.DISTANCE, EdgeAttr.IOU], + iou=True, ) assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes)) assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) for node in cand_graph.nodes: assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) for edge in cand_graph.edges: + print(cand_graph.edges[edge]) assert ( pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) == graph_2d.edges[edge][EdgeAttr.DISTANCE.value] @@ -219,7 +186,7 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): ) assert Counter(list(cand_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) assert Counter(list(cand_graph.edges)) == Counter([("0_1", "1_2")]) - assert cand_graph.edges[("0_1", "1_2")]["distance"] == pytest.approx( + assert cand_graph.edges[("0_1", "1_2")][EdgeAttr.DISTANCE.value] == pytest.approx( 11.18, abs=0.01 ) @@ -229,43 +196,10 @@ def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): cand_graph = graph_from_segmentation( segmentation=segmentation_3d, max_edge_distance=100, - position_keys=("z", "y", "x"), ) assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes)) assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) for node in cand_graph.nodes: assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) for edge in cand_graph.edges: - assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] - - -def test_compute_ious_2d(segmentation_2d): - ious = compute_ious(segmentation_2d[0], segmentation_2d[1]) - expected = {1: {2: 555.46 / 1408.0}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) - - ious = compute_ious(segmentation_2d[1], segmentation_2d[1]) - expected = {1: {1: 1.0}, 2: {2: 1.0}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) - assert ious[2].keys() == expected[2].keys() - assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) - - -def test_compute_ious_3d(segmentation_3d): - ious = compute_ious(segmentation_3d[0], segmentation_3d[1]) - expected = {1: {2: 0.30}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) - - ious = compute_ious(segmentation_3d[1], segmentation_3d[1]) - expected = {1: {1: 1.0}, 2: {2: 1.0}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) - assert ious[2].keys() == expected[2].keys() - assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] \ No newline at end of file diff --git a/tests/test_candidate_graph/test_iou.py b/tests/test_candidate_graph/test_iou.py new file mode 100644 index 0000000..10c1ab2 --- /dev/null +++ b/tests/test_candidate_graph/test_iou.py @@ -0,0 +1,82 @@ +from motile_toolbox.candidate_graph.iou import compute_ious +import pytest +import numpy as np +from skimage.draw import disk + +@pytest.fixture +def segmentation_2d(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 1 + # second cell centered at (60, 45) with label 2 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 1 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 2 + + return segmentation + +def sphere(center, radius, shape): + assert len(center) == len(shape) + indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index + distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) + mask = distance <= radius + return mask + + +@pytest.fixture +def segmentation_3d(): + frame_shape = (100, 100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) + segmentation[0][mask] = 1 + + # make frame with two cells + # first cell centered at (20, 50, 80) with label 1 + # second cell centered at (60, 50, 45) with label 2 + mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) + segmentation[1][mask] = 1 + mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) + segmentation[1][mask] = 2 + + return segmentation + + +def test_compute_ious_2d(segmentation_2d): + ious = compute_ious(segmentation_2d[0], segmentation_2d[1]) + expected = {1: {2: 555.46 / 1408.0}} + assert ious.keys() == expected.keys() + assert ious[1].keys() == expected[1].keys() + assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) + + ious = compute_ious(segmentation_2d[1], segmentation_2d[1]) + expected = {1: {1: 1.0}, 2: {2: 1.0}} + assert ious.keys() == expected.keys() + assert ious[1].keys() == expected[1].keys() + assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) + assert ious[2].keys() == expected[2].keys() + assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) + + +def test_compute_ious_3d(segmentation_3d): + ious = compute_ious(segmentation_3d[0], segmentation_3d[1]) + expected = {1: {2: 0.30}} + assert ious.keys() == expected.keys() + assert ious[1].keys() == expected[1].keys() + assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) + + ious = compute_ious(segmentation_3d[1], segmentation_3d[1]) + expected = {1: {1: 1.0}, 2: {2: 1.0}} + assert ious.keys() == expected.keys() + assert ious[1].keys() == expected[1].keys() + assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) + assert ious[2].keys() == expected[2].keys() + assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) \ No newline at end of file diff --git a/tests/test_utils/test_saving_utils.py b/tests/test_utils/test_saving_utils.py index c4ff2ac..4f6f0fb 100644 --- a/tests/test_utils/test_saving_utils.py +++ b/tests/test_utils/test_saving_utils.py @@ -4,6 +4,7 @@ from motile_toolbox.utils import relabel_segmentation from numpy.testing import assert_array_equal from skimage.draw import disk +from motile_toolbox.candidate_graph.graph_attributes import NodeAttr, EdgeAttr @pytest.fixture @@ -30,11 +31,11 @@ def segmentation_2d(): def graph_2d(): graph = nx.DiGraph() nodes = [ - ("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}), - ("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 2}), + ("0_1", {"y": 50, "x": 50, NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1}), + ("1_1", {"y": 20, "x": 80, NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2}), ] edges = [ - ("0_1", "1_1", {"distance": 42.43}), + ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43}), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) From ccae0b74cbd8f6a49df5239f46fe524c7c84dd36 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 13:04:32 -0400 Subject: [PATCH 04/20] Move pytest fixtures to conftest.py --- tests/conftest.py | 131 ++++++++++++++++++ .../test_graph_from_segmentation.py | 89 +----------- .../test_candidate_graph/test_graph_to_nx.py | 22 +-- tests/test_candidate_graph/test_iou.py | 52 +------ tests/test_utils/test_saving_utils.py | 39 +----- 5 files changed, 136 insertions(+), 197 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..4b200d2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,131 @@ +import networkx as nx +import numpy as np +import pytest +from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr +from skimage.draw import disk + + +@pytest.fixture +def segmentation_2d(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 1 + # second cell centered at (60, 45) with label 2 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 1 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 2 + + return segmentation + + +@pytest.fixture +def graph_2d(): + graph = nx.DiGraph() + nodes = [ + ( + "0_1", + { + NodeAttr.POS.value: (50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_1", + { + NodeAttr.POS.value: (20, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_2", + { + NodeAttr.POS.value: (60, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 2, + }, + ), + ] + edges = [ + ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43, EdgeAttr.IOU.value: 0.0}), + ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18, EdgeAttr.IOU.value: 0.395}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def sphere(center, radius, shape): + assert len(center) == len(shape) + indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index + distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) + mask = distance <= radius + return mask + + +@pytest.fixture +def segmentation_3d(): + frame_shape = (100, 100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) + segmentation[0][mask] = 1 + + # make frame with two cells + # first cell centered at (20, 50, 80) with label 1 + # second cell centered at (60, 50, 45) with label 2 + mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) + segmentation[1][mask] = 1 + mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) + segmentation[1][mask] = 2 + + return segmentation + + +@pytest.fixture +def graph_3d(): + graph = nx.DiGraph() + nodes = [ + ( + "0_1", + { + NodeAttr.POS.value: (50, 50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_1", + { + NodeAttr.POS.value: (20, 50, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_2", + { + NodeAttr.POS.value: (60, 50, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 2, + }, + ), + ] + edges = [ + # math.dist([50, 50], [20, 80]) + ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43}), + # math.dist([50, 50], [60, 45]) + ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_graph_from_segmentation.py index b235f06..25aa299 100644 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ b/tests/test_candidate_graph/test_graph_from_segmentation.py @@ -6,96 +6,9 @@ from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr from motile_toolbox.candidate_graph.graph_from_segmentation import ( add_cand_edges, - _compute_node_frame_dict, - _get_node_id, graph_from_segmentation, nodes_from_segmentation, ) -from skimage.draw import disk - - -@pytest.fixture -def segmentation_2d(): - frame_shape = (100, 100) - total_shape = (2, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) - segmentation[0][rr, cc] = 1 - - # make frame with two cells - # first cell centered at (20, 80) with label 1 - # second cell centered at (60, 45) with label 2 - rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) - segmentation[1][rr, cc] = 1 - rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) - segmentation[1][rr, cc] = 2 - - return segmentation - - -@pytest.fixture -def graph_2d(): - graph = nx.DiGraph() - nodes = [ - ("0_1", {NodeAttr.POS.value: (50, 50), NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1}), - ("1_1", {NodeAttr.POS.value: (20, 80), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 1}), - ("1_2", {NodeAttr.POS.value: (60, 45), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2}), - ] - edges = [ - ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43, EdgeAttr.IOU.value: 0.0}), - ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18, EdgeAttr.IOU.value: 0.395}), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - -def sphere(center, radius, shape): - assert len(center) == len(shape) - indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index - distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) - mask = distance <= radius - return mask - - -@pytest.fixture -def segmentation_3d(): - frame_shape = (100, 100, 100) - total_shape = (2, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) - segmentation[0][mask] = 1 - - # make frame with two cells - # first cell centered at (20, 50, 80) with label 1 - # second cell centered at (60, 50, 45) with label 2 - mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) - segmentation[1][mask] = 1 - mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) - segmentation[1][mask] = 2 - - return segmentation - - -@pytest.fixture -def graph_3d(): - graph = nx.DiGraph() - nodes = [ - ("0_1", {NodeAttr.POS.value: (50, 50, 50), NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1}), - ("1_1", {NodeAttr.POS.value: (20, 50, 80), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 1}), - ("1_2", {NodeAttr.POS.value: (60, 50, 45), NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2}), - ] - edges = [ - # math.dist([50, 50], [20, 80]) - ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43}), - # math.dist([50, 50], [60, 45]) - ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18}), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph # nodes_from_segmentation @@ -202,4 +115,4 @@ def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): for node in cand_graph.nodes: assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) for edge in cand_graph.edges: - assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] \ No newline at end of file + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] diff --git a/tests/test_candidate_graph/test_graph_to_nx.py b/tests/test_candidate_graph/test_graph_to_nx.py index b52d864..3d08e89 100644 --- a/tests/test_candidate_graph/test_graph_to_nx.py +++ b/tests/test_candidate_graph/test_graph_to_nx.py @@ -1,30 +1,10 @@ import networkx as nx -import pytest from motile import TrackGraph from motile_toolbox.candidate_graph import graph_to_nx from networkx.utils import graphs_equal -@pytest.fixture -def graph_3d(): - graph = nx.DiGraph() - nodes = [ - ("0_1", {"z": 50, "y": 50, "x": 50, "t": 0, "segmentation_id": 1}), - ("1_1", {"z": 20, "y": 50, "x": 80, "t": 1, "segmentation_id": 1}), - ("1_2", {"z": 60, "y": 50, "x": 45, "t": 1, "segmentation_id": 2}), - ] - edges = [ - # math.dist([50, 50], [20, 80]) - ("0_1", "1_1", {"distance": 42.43}), - # math.dist([50, 50], [60, 45]) - ("0_1", "1_2", {"distance": 11.18}), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - def test_graph_to_nx(graph_3d: nx.DiGraph): - track_graph = TrackGraph(nx_graph=graph_3d, frame_attribute="t") + track_graph = TrackGraph(nx_graph=graph_3d, frame_attribute="time") nx_graph = graph_to_nx(track_graph) assert graphs_equal(graph_3d, nx_graph) diff --git a/tests/test_candidate_graph/test_iou.py b/tests/test_candidate_graph/test_iou.py index 10c1ab2..a9dcb42 100644 --- a/tests/test_candidate_graph/test_iou.py +++ b/tests/test_candidate_graph/test_iou.py @@ -1,53 +1,5 @@ -from motile_toolbox.candidate_graph.iou import compute_ious import pytest -import numpy as np -from skimage.draw import disk - -@pytest.fixture -def segmentation_2d(): - frame_shape = (100, 100) - total_shape = (2, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) - segmentation[0][rr, cc] = 1 - - # make frame with two cells - # first cell centered at (20, 80) with label 1 - # second cell centered at (60, 45) with label 2 - rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) - segmentation[1][rr, cc] = 1 - rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) - segmentation[1][rr, cc] = 2 - - return segmentation - -def sphere(center, radius, shape): - assert len(center) == len(shape) - indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index - distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) - mask = distance <= radius - return mask - - -@pytest.fixture -def segmentation_3d(): - frame_shape = (100, 100, 100) - total_shape = (2, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) - segmentation[0][mask] = 1 - - # make frame with two cells - # first cell centered at (20, 50, 80) with label 1 - # second cell centered at (60, 50, 45) with label 2 - mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) - segmentation[1][mask] = 1 - mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) - segmentation[1][mask] = 2 - - return segmentation +from motile_toolbox.candidate_graph.iou import compute_ious def test_compute_ious_2d(segmentation_2d): @@ -79,4 +31,4 @@ def test_compute_ious_3d(segmentation_3d): assert ious[1].keys() == expected[1].keys() assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) assert ious[2].keys() == expected[2].keys() - assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) \ No newline at end of file + assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) diff --git a/tests/test_utils/test_saving_utils.py b/tests/test_utils/test_saving_utils.py index 4f6f0fb..57d796f 100644 --- a/tests/test_utils/test_saving_utils.py +++ b/tests/test_utils/test_saving_utils.py @@ -1,45 +1,7 @@ -import networkx as nx import numpy as np -import pytest from motile_toolbox.utils import relabel_segmentation from numpy.testing import assert_array_equal from skimage.draw import disk -from motile_toolbox.candidate_graph.graph_attributes import NodeAttr, EdgeAttr - - -@pytest.fixture -def segmentation_2d(): - frame_shape = (100, 100) - total_shape = (2, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) - segmentation[0][rr, cc] = 1 - - # make frame with two cells - # first cell centered at (20, 80) with label 2 - # second cell centered at (60, 45) with label 3 - rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) - segmentation[1][rr, cc] = 2 - rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) - segmentation[1][rr, cc] = 3 - - return segmentation - - -@pytest.fixture -def graph_2d(): - graph = nx.DiGraph() - nodes = [ - ("0_1", {"y": 50, "x": 50, NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1}), - ("1_1", {"y": 20, "x": 80, NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2}), - ] - edges = [ - ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43}), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph def test_relabel_segmentation(segmentation_2d, graph_2d): @@ -53,6 +15,7 @@ def test_relabel_segmentation(segmentation_2d, graph_2d): rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) expected[1][rr, cc] = 1 + graph_2d.remove_node("1_2") relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d) print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}") print(f"Nonzero expected: {np.count_nonzero(expected)}") From d8a5f02ec228b16e04a2695f6b8997b5198b943a Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 13:19:02 -0400 Subject: [PATCH 05/20] Add hypothesis id to nodes_from_segmentation --- .../candidate_graph/graph_attributes.py | 3 ++- .../candidate_graph/graph_from_segmentation.py | 17 ++++++++++------- .../test_graph_from_segmentation.py | 15 +++++++++++++++ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py index d3c6894..478c2b3 100644 --- a/src/motile_toolbox/candidate_graph/graph_attributes.py +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -6,10 +6,11 @@ class NodeAttr(Enum): Note: Motile can flexibly support any custom attributes. The toolbox provides implementations of commonly used ones, listed here. """ + POS = "pos" TIME = "time" SEG_ID = "seg_id" - SEG_HYPOTHESIS = "seg_hypo" + SEG_HYPO = "seg_hypo" class EdgeAttr(Enum): diff --git a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py index 88a950f..10100ec 100644 --- a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py +++ b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py @@ -14,15 +14,15 @@ def _get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> str: - - if hypothesis_id: + if hypothesis_id is not None: + print(hypothesis_id) return f"{time}_{hypothesis_id}_{label_id}" else: return f"{time}_{label_id}" def nodes_from_segmentation( - segmentation: np.ndarray, + segmentation: np.ndarray, hypo_id: int | None = None ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: """Extract candidate nodes from a segmentation. Also computes specified attributes. Returns a networkx graph with only nodes, and also a dictionary from frames to @@ -32,6 +32,9 @@ def nodes_from_segmentation( segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels (0 is background, all pixels with value 1 belong to one cell, etc.). The time dimension is first, followed by two or three position dimensions. + hypo_id (int | None, optional): An id to identify which layer of the multi- + hypothesis segmentation this is. Used to create node id, and is added + to each node if not None. Defaults to None. Returns: tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, @@ -45,11 +48,13 @@ def nodes_from_segmentation( nodes_in_frame = [] props = regionprops(segmentation[t]) for regionprop in props: - node_id = _get_node_id(t, regionprop.label) + node_id = _get_node_id(t, regionprop.label, hypothesis_id=hypo_id) attrs = { NodeAttr.TIME.value: t, } attrs[NodeAttr.SEG_ID.value] = regionprop.label + if hypo_id is not None: + attrs[NodeAttr.SEG_HYPO.value] = hypo_id centroid = regionprop.centroid # [z,] y, x attrs[NodeAttr.POS.value] = centroid cand_graph.add_node(node_id, **attrs) @@ -59,9 +64,7 @@ def nodes_from_segmentation( return cand_graph, node_frame_dict -def _compute_node_frame_dict( - cand_graph: nx.DiGraph -) -> dict[int, list[Any]]: +def _compute_node_frame_dict(cand_graph: nx.DiGraph) -> dict[int, list[Any]]: """Compute dictionary from time frames to node ids for candidate graph. Args: diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_graph_from_segmentation.py index 25aa299..96923e2 100644 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ b/tests/test_candidate_graph/test_graph_from_segmentation.py @@ -35,6 +35,21 @@ def test_nodes_from_segmentation_2d(segmentation_2d): assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) +def test_nodes_from_segmentation_2d_hypo(segmentation_2d): + # test with 2D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_2d, hypo_id=0 + ) + assert Counter(list(node_graph.nodes)) == Counter(["0_0_1", "1_0_1", "1_0_2"]) + assert node_graph.nodes["1_0_1"][NodeAttr.SEG_ID.value] == 1 + assert node_graph.nodes["1_0_1"][NodeAttr.SEG_HYPO.value] == 0 + assert node_graph.nodes["1_0_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_0_1"][NodeAttr.POS.value] == (20, 80) + + assert node_frame_dict[0] == ["0_0_1"] + assert Counter(node_frame_dict[1]) == Counter(["1_0_1", "1_0_2"]) + + def test_nodes_from_segmentation_3d(segmentation_3d): # test with 3D segmentation node_graph, node_frame_dict = nodes_from_segmentation( From 8c401b88b0b9b3b1f25cfdb94c9f6ac929992134 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 1 Apr 2024 13:49:28 -0400 Subject: [PATCH 06/20] Add multihypothesis test fixtures --- tests/conftest.py | 216 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 4b200d2..9c7f55c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,40 @@ def segmentation_2d(): return segmentation +@pytest.fixture +def multi_hypothesis_segmentation_2d(): + """ + Creates a multi-hypothesis version of the `segmentation_2d` fixture defined above. + + """ + frame_shape = (100, 100) + total_shape = (2, 2, *frame_shape) # 2 time points, 2 hypotheses layers, H, W + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr0, cc0 = disk(center=(50, 50), radius=20, shape=frame_shape) + rr1, cc1 = disk(center=(45, 45), radius=15, shape=frame_shape) + + segmentation[0, 0][rr0, cc0] = 1 + segmentation[0, 1][rr1, cc1] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 1 + rr0, cc0 = disk(center=(20, 80), radius=10, shape=frame_shape) + rr1, cc1 = disk(center=(15, 75), radius=15, shape=frame_shape) + + segmentation[1, 0][rr0, cc0] = 1 + segmentation[1, 1][rr1, cc1] = 1 + + # second cell centered at (60, 45) with label 2 + rr0, cc0 = disk(center=(60, 45), radius=15, shape=frame_shape) + rr1, cc1 = disk(center=(55, 40), radius=20, shape=frame_shape) + + segmentation[1, 0][rr0, cc0] = 2 + segmentation[1, 1][rr1, cc1] = 2 + + return segmentation + + @pytest.fixture def graph_2d(): graph = nx.DiGraph() @@ -63,6 +97,93 @@ def graph_2d(): return graph +@pytest.fixture +def multi_hypothesis_graph_2d(): + graph = nx.DiGraph() + nodes = [ + ( + "0_0_1", + { + NodeAttr.POS.value: (50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "0_1_1", + { + NodeAttr.POS.value: (45, 45), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_0_1", + { + NodeAttr.POS.value: (20, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_1_1", + { + NodeAttr.POS.value: (15, 75), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_0_2", + { + NodeAttr.POS.value: (60, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_ID.value: 2, + }, + ), + ( + "1_1_2", + { + NodeAttr.POS.value: (55, 40), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_ID.value: 2, + }, + ), + ] + + edges = [ + ("0_0_1", "1_0_1", {EdgeAttr.DISTANCE.value: 42.426, EdgeAttr.IOU.value: 0.0}), + ("0_0_1", "1_1_1", {EdgeAttr.DISTANCE.value: 43.011, EdgeAttr.IOU.value: 0.0}), + ( + "0_0_1", + "1_0_2", + {EdgeAttr.DISTANCE.value: 11.180, EdgeAttr.IOU.value: 0.3931}, + ), + ( + "0_0_1", + "1_1_2", + {EdgeAttr.DISTANCE.value: 11.180, EdgeAttr.IOU.value: 0.4768}, + ), + ("0_1_1", "1_0_1", {EdgeAttr.DISTANCE.value: 43.011, EdgeAttr.IOU.value: 0.0}), + ("0_1_1", "1_1_1", {EdgeAttr.DISTANCE.value: 42.426, EdgeAttr.IOU.value: 0.0}), + ("0_1_1", "1_0_2", {EdgeAttr.DISTANCE.value: 15.0, EdgeAttr.IOU.value: 0.2402}), + ( + "0_1_1", + "1_1_2", + {EdgeAttr.DISTANCE.value: 11.180, EdgeAttr.IOU.value: 0.3931}, + ), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + def sphere(center, radius, shape): assert len(center) == len(shape) indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index @@ -91,6 +212,38 @@ def segmentation_3d(): return segmentation +@pytest.fixture +def multi_hypothesis_segmentation_3d(): + """ + Creates a multi-hypothesis version of the `segmentation_3d` fixture defined above. + + """ + frame_shape = (100, 100, 100) + total_shape = (2, 2, *frame_shape) # 2 time points, 2 hypotheses + segmentation = np.zeros(total_shape, dtype="int32") + # make first frame with one cell in center with label 1 + mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) + segmentation[0, 0][mask] = 1 + mask = sphere(center=(45, 50, 55), radius=20, shape=frame_shape) + segmentation[0, 1][mask] = 1 + + # make second frame, first hypothesis with two cells + # first cell centered at (20, 50, 80) with label 1 + # second cell centered at (60, 50, 45) with label 2 + mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) + segmentation[1, 0][mask] = 1 + mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) + segmentation[1, 0][mask] = 2 + + # make second frame, second hypothesis with one cell + # first cell centered at (15, 50, 70) with label 1 + # second cell centered at (55, 55, 45) with label 2 + mask = sphere(center=(15, 50, 70), radius=10, shape=frame_shape) + segmentation[1, 1][mask] = 1 + + return segmentation + + @pytest.fixture def graph_3d(): graph = nx.DiGraph() @@ -129,3 +282,66 @@ def graph_3d(): graph.add_nodes_from(nodes) graph.add_edges_from(edges) return graph + + +@pytest.fixture +def multi_hypothesis_graph_3d(): + graph = nx.DiGraph() + nodes = [ + ( + "0_0_1", + { + NodeAttr.POS.value: (50, 50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "0_1_1", + { + NodeAttr.POS.value: (45, 50, 55), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_0_1", + { + NodeAttr.POS.value: (20, 50, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_0_2", + { + NodeAttr.POS.value: (60, 50, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_ID.value: 2, + }, + ), + ( + "1_1_1", + { + NodeAttr.POS.value: (15, 50, 70), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ] + edges = [ + ("0_0_1", "1_0_1", {EdgeAttr.DISTANCE.value: 42.4264}), + ("0_0_1", "1_0_2", {EdgeAttr.DISTANCE.value: 11.1803}), + ("0_1_1", "1_0_1", {EdgeAttr.DISTANCE.value: 35.3553}), + ("0_1_1", "1_0_2", {EdgeAttr.DISTANCE.value: 18.0277}), + ("0_0_1", "1_1_1", {EdgeAttr.DISTANCE.value: 40.3112}), + ("0_1_1", "1_1_1", {EdgeAttr.DISTANCE.value: 33.5410}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph From 8d12a6ac0ba2945cada99a56a43f30b6f65d4b37 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 14:11:21 -0400 Subject: [PATCH 07/20] Implement multi hypothesis candidate graph computation --- src/motile_toolbox/candidate_graph/iou.py | 41 ++++++- .../candidate_graph/multi_seg_graph.py | 101 +++++++++++------- 2 files changed, 99 insertions(+), 43 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index 599e636..ceb6431 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -1,8 +1,11 @@ +from itertools import combinations + import networkx as nx import numpy as np from tqdm import tqdm from .graph_attributes import EdgeAttr, NodeAttr +from .graph_from_segmentation import _get_node_id def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]: @@ -58,5 +61,39 @@ def add_iou(cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict) - node_seg_id = cand_graph.nodes[node_id][NodeAttr.SEG_ID.value] for next_id in next_nodes: next_seg_id = cand_graph.nodes[next_id][NodeAttr.SEG_ID.value] - iou = ious.get(node_seg_id, {}).get( next_seg_id, 0) - cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou \ No newline at end of file + iou = ious.get(node_seg_id, {}).get(next_seg_id, 0) + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou + + +def add_multihypo_iou( + cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict +) -> None: + """Add IOU to the candidate graph for multi-hypothesis segmentations. + + Args: + cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated + segmentation (np.ndarray): Multiple hypothesis segmentation. Dimensions + are (t, h, [z], y, x), where h is the number of hypotheses. + """ + frames = sorted(node_frame_dict.keys()) + num_hypotheses = segmentation.shape[1] + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + # construct dictionary of ious between node_ids in frame 1 and frame 2 + ious: dict[str, dict[str, float]] = {} + for hypo1, hypo2 in combinations(range(num_hypotheses), 2): + hypo_ious = compute_ious( + segmentation[frame][hypo1], segmentation[frame + 1][hypo2] + ) + for segid, intersecting_labels in hypo_ious.items(): + node_id = _get_node_id(frame, segid, hypo1) + ious[node_id] = {} + for segid2, iou in intersecting_labels.items(): + next_id = _get_node_id(frame + 1, segid2, hypo2) + ious[node_id][next_id] = iou + next_nodes = node_frame_dict[frame + 1] + for node_id in node_frame_dict[frame]: + for next_id in next_nodes: + iou = ious.get(node_id, {}).get(next_id, 0) + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou diff --git a/src/motile_toolbox/candidate_graph/multi_seg_graph.py b/src/motile_toolbox/candidate_graph/multi_seg_graph.py index eb13085..4635044 100644 --- a/src/motile_toolbox/candidate_graph/multi_seg_graph.py +++ b/src/motile_toolbox/candidate_graph/multi_seg_graph.py @@ -1,68 +1,87 @@ -from typing import Any + +from itertools import combinations import networkx as nx import numpy as np -from .graph_attributes import EdgeAttr, NodeAttr, add_iou -from .graph_from_segmentation import add_cand_edges, nodes_from_segmentation +from .graph_from_segmentation import ( + _get_node_id, + add_cand_edges, + nodes_from_segmentation, +) +from .iou import add_multihypo_iou -def compute_multi_seg_graph(segmentations: list[np.ndarray]) -> tuple[nx.DiGraph, list[set]]: - """Create a candidate graph from multi hypothesis segmentations. This is not +def compute_multi_seg_graph( + segmentation: np.ndarray, + max_edge_distance: float, + iou: bool = False, +) -> tuple[nx.DiGraph, list[set]]: + """Create a candidate graph from multi hypothesis segmentation. This is not tailored for agglomeration approaches with hierarchical merge graphs, it simply creates a conflict set for any nodes that overlap in the same time frame. Args: - segmentations (list[np.ndarray]): + segmentations (np.ndarray): Multiple hypothesis segmentation. Dimensions + are (t, h, [z], y, x), where h is the number of hypotheses. Returns: nx.DiGraph: _description_ """ # for each segmentation, get nodes using same method as graph_from_segmentation # add them all to one big graph - cand_graph, frame_dict = nodes_from_multi_segmentation(segmentations) # TODO: other args + cand_graph = nx.DiGraph() + node_frame_dict = {} + num_hypotheses = segmentation.shape[1] + for hypo_id in range(num_hypotheses): + hypothesis = segmentation[:,hypo_id] + node_graph, frame_dict = nodes_from_segmentation(hypothesis, hypo_id=hypo_id) + cand_graph.update(node_graph) + node_frame_dict.update(frame_dict) # Compute conflict sets between segmentations # can use same method as IOU (without the U) to compute conflict sets conflicts = [] - for time, segs in enumerate(segmentations): - conflicts.append(compute_conflict_sets(segs, time)) + for time, segs in enumerate(segmentation): + conflicts.extend(compute_conflict_sets(segs, time)) # add edges with same method as before, with slightly different implementation - add_cand_edges(cand_graph) # TODO: other args - if EdgeAttr.IOU in edge_attributes: - # TODO: cross product when calling (need to re-organize add_iou to not assume stuff) - add_iou(cand_graph, segmentation) - - return cand_graph - - - + add_cand_edges(cand_graph, max_edge_distance, node_frame_dict) + if iou: + add_multihypo_iou(cand_graph, segmentation, node_frame_dict) + return cand_graph, conflicts -def nodes_from_multi_segmentation( - segmentations: list[np.ndarray], - attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), - position_keys: tuple[str, ...] | list[str] = ("y", "x"), - frame_key: str = "t", -) -> tuple[nx.DiGraph, dict[int, list[Any]]]: - multi_hypo_node_graph = nx.DiGraph() - multi_frame_dict = {} - for layer_id, segmentation in enumerate(segmentations): - node_graph, frame_dict = nodes_from_segmentation(segmentation, layer_id) - # TODO: pass attributes, etc. - # TODO: add multi segmentation attribute to nodes_from_segmentation - # (use in node id and add to attributes) - multi_hypo_node_graph.update(node_graph) - multi_frame_dict.update(frame_dict) - # TODO: Make sure there is no node-id collision - - return multi_hypo_node_graph, multi_frame_dict - +def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]: + """Segmentation in one frame only. Return + Args: + segmentation_frame (np.ndarray): One frame of the multiple hypothesis + segmentation. Dimensions are (h, [z], y, x), where h is the number of + hypotheses. + time (int): Time frame, for computing node_ids. -def compute_conflict_sets(segmenations: np.ndarray, time: int) -> list[set]: - """Segmentations in one frame only. Return list of sets of node ids that conflict.""" - # This will look a lot like the IOU code - pass + Returns: + list[set]: list of sets of node ids that overlap + """ + flattened_segs = [seg.flatten() for seg in segmentation_frame] + + # get locations where at least two hypotheses have labels + # This approach may be inefficient, but likely doesn't matter compared to np.unique + conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool) + for seg1, seg2 in combinations(flattened_segs, 2): + non_zero_indices = np.logical_and(seg1, seg2) + conflict_indices = np.logical_or(conflict_indices, non_zero_indices) + + flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs]) + values = np.unique(flattened_stacked, axis=1) + + conflict_sets = [] + for conflicting_labels in values: + id_set = set() + for hypo_id, label in enumerate(conflicting_labels): + if label != 0: + id_set.add(_get_node_id(time, label, hypo_id)) + conflict_sets.append(id_set) + return conflict_sets From 7e3f71c5bab2ad738388102a68cdadab7d25fb1a Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 14:32:51 -0400 Subject: [PATCH 08/20] Reorganize candidate graph module files --- .../candidate_graph/__init__.py | 4 +- .../candidate_graph/compute_graph.py | 93 +++++++++++++++++++ .../candidate_graph/conflict_sets.py | 41 ++++++++ src/motile_toolbox/candidate_graph/iou.py | 14 +-- .../candidate_graph/multi_seg_graph.py | 87 ----------------- .../{graph_from_segmentation.py => utils.py} | 48 +--------- tests/conftest.py | 25 ++--- .../test_graph_from_segmentation.py | 4 +- tests/test_candidate_graph/test_iou.py | 10 +- 9 files changed, 167 insertions(+), 159 deletions(-) create mode 100644 src/motile_toolbox/candidate_graph/compute_graph.py create mode 100644 src/motile_toolbox/candidate_graph/conflict_sets.py delete mode 100644 src/motile_toolbox/candidate_graph/multi_seg_graph.py rename src/motile_toolbox/candidate_graph/{graph_from_segmentation.py => utils.py} (69%) diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index d06fd4e..eab1b89 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -1,3 +1,5 @@ +from .compute_graph import compute_multi_seg_graph, graph_from_segmentation from .graph_attributes import EdgeAttr, NodeAttr -from .graph_from_segmentation import graph_from_segmentation from .graph_to_nx import graph_to_nx +from .iou import add_iou, add_multihypo_iou +from .utils import add_cand_edges, get_node_id, nodes_from_segmentation diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py new file mode 100644 index 0000000..be5fe5c --- /dev/null +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -0,0 +1,93 @@ +import logging + +import networkx as nx +import numpy as np + +from .conflict_sets import compute_conflict_sets +from .iou import add_iou, add_multihypo_iou +from .utils import add_cand_edges, nodes_from_segmentation + +logger = logging.getLogger(__name__) + + +def graph_from_segmentation( + segmentation: np.ndarray, + max_edge_distance: float, + iou: bool = False, +) -> nx.DiGraph: + """Construct a candidate graph from a segmentation array. Nodes are placed at the + centroid of each segmentation and edges are added for all nodes in adjacent frames + within max_edge_distance. The specified attributes are computed during construction. + Node ids are strings with format "{time}_{label id}". + + Args: + segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels + (0 is background, all pixels with value 1 belong to one cell, etc.). The + time dimension is first, followed by two or three position dimensions. If + the position dims are not (y, x), use `position_keys` to specify the names + of the dimensions. + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes within this distance in adjacent frames will by connected + with a candidate edge. + iou (bool, optional): Whether to include IOU on the candidate graph. + Defaults to False. + + Returns: + nx.DiGraph: A candidate graph that can be passed to the motile solver. + """ + # add nodes + cand_graph, node_frame_dict = nodes_from_segmentation(segmentation) + logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") + + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + node_frame_dict=node_frame_dict, + ) + if iou: + add_iou(cand_graph, segmentation, node_frame_dict) + + logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") + return cand_graph + + +def compute_multi_seg_graph( + segmentation: np.ndarray, + max_edge_distance: float, + iou: bool = False, +) -> tuple[nx.DiGraph, list[set]]: + """Create a candidate graph from multi hypothesis segmentation. This is not + tailored for agglomeration approaches with hierarchical merge graphs, it simply + creates a conflict set for any nodes that overlap in the same time frame. + + Args: + segmentations (np.ndarray): Multiple hypothesis segmentation. Dimensions + are (t, h, [z], y, x), where h is the number of hypotheses. + + Returns: + nx.DiGraph: _description_ + """ + # for each segmentation, get nodes using same method as graph_from_segmentation + # add them all to one big graph + cand_graph = nx.DiGraph() + node_frame_dict = {} + num_hypotheses = segmentation.shape[1] + for hypo_id in range(num_hypotheses): + hypothesis = segmentation[:, hypo_id] + node_graph, frame_dict = nodes_from_segmentation(hypothesis, hypo_id=hypo_id) + cand_graph.update(node_graph) + node_frame_dict.update(frame_dict) + + # Compute conflict sets between segmentations + # can use same method as IOU (without the U) to compute conflict sets + conflicts = [] + for time, segs in enumerate(segmentation): + conflicts.extend(compute_conflict_sets(segs, time)) + + # add edges with same method as before, with slightly different implementation + add_cand_edges(cand_graph, max_edge_distance, node_frame_dict) + if iou: + add_multihypo_iou(cand_graph, segmentation, node_frame_dict) + + return cand_graph, conflicts diff --git a/src/motile_toolbox/candidate_graph/conflict_sets.py b/src/motile_toolbox/candidate_graph/conflict_sets.py new file mode 100644 index 0000000..16c8f59 --- /dev/null +++ b/src/motile_toolbox/candidate_graph/conflict_sets.py @@ -0,0 +1,41 @@ +from itertools import combinations + +import numpy as np + +from .utils import ( + get_node_id, +) + + +def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]: + """Segmentation in one frame only. Return + + Args: + segmentation_frame (np.ndarray): One frame of the multiple hypothesis + segmentation. Dimensions are (h, [z], y, x), where h is the number of + hypotheses. + time (int): Time frame, for computing node_ids. + + Returns: + list[set]: list of sets of node ids that overlap + """ + flattened_segs = [seg.flatten() for seg in segmentation_frame] + + # get locations where at least two hypotheses have labels + # This approach may be inefficient, but likely doesn't matter compared to np.unique + conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool) + for seg1, seg2 in combinations(flattened_segs, 2): + non_zero_indices = np.logical_and(seg1, seg2) + conflict_indices = np.logical_or(conflict_indices, non_zero_indices) + + flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs]) + values = np.unique(flattened_stacked, axis=1) + + conflict_sets = [] + for conflicting_labels in values: + id_set = set() + for hypo_id, label in enumerate(conflicting_labels): + if label != 0: + id_set.add(get_node_id(time, label, hypo_id)) + conflict_sets.append(id_set) + return conflict_sets diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index ceb6431..f86dfea 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -5,10 +5,12 @@ from tqdm import tqdm from .graph_attributes import EdgeAttr, NodeAttr -from .graph_from_segmentation import _get_node_id +from .utils import get_node_id -def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]: +def _compute_ious( + frame1: np.ndarray, frame2: np.ndarray +) -> dict[int, dict[int, float]]: """Compute label IOUs between two label arrays of the same shape. Ignores background (label 0). @@ -55,7 +57,7 @@ def add_iou(cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict) - for frame in tqdm(frames): if frame + 1 not in node_frame_dict: continue - ious = compute_ious(segmentation[frame], segmentation[frame + 1]) + ious = _compute_ious(segmentation[frame], segmentation[frame + 1]) next_nodes = node_frame_dict[frame + 1] for node_id in node_frame_dict[frame]: node_seg_id = cand_graph.nodes[node_id][NodeAttr.SEG_ID.value] @@ -83,14 +85,14 @@ def add_multihypo_iou( # construct dictionary of ious between node_ids in frame 1 and frame 2 ious: dict[str, dict[str, float]] = {} for hypo1, hypo2 in combinations(range(num_hypotheses), 2): - hypo_ious = compute_ious( + hypo_ious = _compute_ious( segmentation[frame][hypo1], segmentation[frame + 1][hypo2] ) for segid, intersecting_labels in hypo_ious.items(): - node_id = _get_node_id(frame, segid, hypo1) + node_id = get_node_id(frame, segid, hypo1) ious[node_id] = {} for segid2, iou in intersecting_labels.items(): - next_id = _get_node_id(frame + 1, segid2, hypo2) + next_id = get_node_id(frame + 1, segid2, hypo2) ious[node_id][next_id] = iou next_nodes = node_frame_dict[frame + 1] for node_id in node_frame_dict[frame]: diff --git a/src/motile_toolbox/candidate_graph/multi_seg_graph.py b/src/motile_toolbox/candidate_graph/multi_seg_graph.py deleted file mode 100644 index 4635044..0000000 --- a/src/motile_toolbox/candidate_graph/multi_seg_graph.py +++ /dev/null @@ -1,87 +0,0 @@ - -from itertools import combinations - -import networkx as nx -import numpy as np - -from .graph_from_segmentation import ( - _get_node_id, - add_cand_edges, - nodes_from_segmentation, -) -from .iou import add_multihypo_iou - - -def compute_multi_seg_graph( - segmentation: np.ndarray, - max_edge_distance: float, - iou: bool = False, -) -> tuple[nx.DiGraph, list[set]]: - """Create a candidate graph from multi hypothesis segmentation. This is not - tailored for agglomeration approaches with hierarchical merge graphs, it simply - creates a conflict set for any nodes that overlap in the same time frame. - - Args: - segmentations (np.ndarray): Multiple hypothesis segmentation. Dimensions - are (t, h, [z], y, x), where h is the number of hypotheses. - - Returns: - nx.DiGraph: _description_ - """ - # for each segmentation, get nodes using same method as graph_from_segmentation - # add them all to one big graph - cand_graph = nx.DiGraph() - node_frame_dict = {} - num_hypotheses = segmentation.shape[1] - for hypo_id in range(num_hypotheses): - hypothesis = segmentation[:,hypo_id] - node_graph, frame_dict = nodes_from_segmentation(hypothesis, hypo_id=hypo_id) - cand_graph.update(node_graph) - node_frame_dict.update(frame_dict) - - # Compute conflict sets between segmentations - # can use same method as IOU (without the U) to compute conflict sets - conflicts = [] - for time, segs in enumerate(segmentation): - conflicts.extend(compute_conflict_sets(segs, time)) - - # add edges with same method as before, with slightly different implementation - add_cand_edges(cand_graph, max_edge_distance, node_frame_dict) - if iou: - add_multihypo_iou(cand_graph, segmentation, node_frame_dict) - - return cand_graph, conflicts - - -def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]: - """Segmentation in one frame only. Return - - Args: - segmentation_frame (np.ndarray): One frame of the multiple hypothesis - segmentation. Dimensions are (h, [z], y, x), where h is the number of - hypotheses. - time (int): Time frame, for computing node_ids. - - Returns: - list[set]: list of sets of node ids that overlap - """ - flattened_segs = [seg.flatten() for seg in segmentation_frame] - - # get locations where at least two hypotheses have labels - # This approach may be inefficient, but likely doesn't matter compared to np.unique - conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool) - for seg1, seg2 in combinations(flattened_segs, 2): - non_zero_indices = np.logical_and(seg1, seg2) - conflict_indices = np.logical_or(conflict_indices, non_zero_indices) - - flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs]) - values = np.unique(flattened_stacked, axis=1) - - conflict_sets = [] - for conflicting_labels in values: - id_set = set() - for hypo_id, label in enumerate(conflicting_labels): - if label != 0: - id_set.add(_get_node_id(time, label, hypo_id)) - conflict_sets.append(id_set) - return conflict_sets diff --git a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py b/src/motile_toolbox/candidate_graph/utils.py similarity index 69% rename from src/motile_toolbox/candidate_graph/graph_from_segmentation.py rename to src/motile_toolbox/candidate_graph/utils.py index 10100ec..817a808 100644 --- a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -8,14 +8,12 @@ from tqdm import tqdm from .graph_attributes import EdgeAttr, NodeAttr -from .iou import add_iou logger = logging.getLogger(__name__) -def _get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> str: +def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> str: if hypothesis_id is not None: - print(hypothesis_id) return f"{time}_{hypothesis_id}_{label_id}" else: return f"{time}_{label_id}" @@ -48,7 +46,7 @@ def nodes_from_segmentation( nodes_in_frame = [] props = regionprops(segmentation[t]) for regionprop in props: - node_id = _get_node_id(t, regionprop.label, hypothesis_id=hypo_id) + node_id = get_node_id(t, regionprop.label, hypothesis_id=hypo_id) attrs = { NodeAttr.TIME.value: t, } @@ -117,45 +115,3 @@ def add_cand_edges( if dist <= max_edge_distance: attrs = {EdgeAttr.DISTANCE.value: dist} cand_graph.add_edge(node, next_id, **attrs) - - -def graph_from_segmentation( - segmentation: np.ndarray, - max_edge_distance: float, - iou: bool = False, -) -> nx.DiGraph: - """Construct a candidate graph from a segmentation array. Nodes are placed at the - centroid of each segmentation and edges are added for all nodes in adjacent frames - within max_edge_distance. The specified attributes are computed during construction. - Node ids are strings with format "{time}_{label id}". - - Args: - segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels - (0 is background, all pixels with value 1 belong to one cell, etc.). The - time dimension is first, followed by two or three position dimensions. If - the position dims are not (y, x), use `position_keys` to specify the names - of the dimensions. - max_edge_distance (float): Maximum distance that objects can travel between - frames. All nodes within this distance in adjacent frames will by connected - with a candidate edge. - iou (bool, optional): Whether to include IOU on the candidate graph. - Defaults to False. - - Returns: - nx.DiGraph: A candidate graph that can be passed to the motile solver. - """ - # add nodes - cand_graph, node_frame_dict = nodes_from_segmentation(segmentation) - logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") - - # add edges - add_cand_edges( - cand_graph, - max_edge_distance=max_edge_distance, - node_frame_dict=node_frame_dict, - ) - if iou: - add_iou(cand_graph, segmentation, node_frame_dict) - - logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") - return cand_graph diff --git a/tests/conftest.py b/tests/conftest.py index 9c7f55c..972d80d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import networkx as nx import numpy as np import pytest -from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr +from motile_toolbox.candidate_graph.graph_attributes import EdgeAttr, NodeAttr from skimage.draw import disk @@ -34,8 +34,9 @@ def multi_hypothesis_segmentation_2d(): frame_shape = (100, 100) total_shape = (2, 2, *frame_shape) # 2 time points, 2 hypotheses layers, H, W segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 + # make frame with one cell in center with label 1 (hypo 1) rr0, cc0 = disk(center=(50, 50), radius=20, shape=frame_shape) + # make frame with one cell at (45, 45) with label 1 (hypo 2) rr1, cc1 = disk(center=(45, 45), radius=15, shape=frame_shape) segmentation[0, 0][rr0, cc0] = 1 @@ -106,7 +107,7 @@ def multi_hypothesis_graph_2d(): { NodeAttr.POS.value: (50, 50), NodeAttr.TIME.value: 0, - NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, }, ), @@ -115,7 +116,7 @@ def multi_hypothesis_graph_2d(): { NodeAttr.POS.value: (45, 45), NodeAttr.TIME.value: 0, - NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, }, ), @@ -124,7 +125,7 @@ def multi_hypothesis_graph_2d(): { NodeAttr.POS.value: (20, 80), NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, }, ), @@ -142,7 +143,7 @@ def multi_hypothesis_graph_2d(): { NodeAttr.POS.value: (60, 45), NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 2, }, ), @@ -151,7 +152,7 @@ def multi_hypothesis_graph_2d(): { NodeAttr.POS.value: (55, 40), NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 2, }, ), @@ -293,7 +294,7 @@ def multi_hypothesis_graph_3d(): { NodeAttr.POS.value: (50, 50, 50), NodeAttr.TIME.value: 0, - NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, }, ), @@ -302,7 +303,7 @@ def multi_hypothesis_graph_3d(): { NodeAttr.POS.value: (45, 50, 55), NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, }, ), @@ -311,7 +312,7 @@ def multi_hypothesis_graph_3d(): { NodeAttr.POS.value: (20, 50, 80), NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, }, ), @@ -320,7 +321,7 @@ def multi_hypothesis_graph_3d(): { NodeAttr.POS.value: (60, 50, 45), NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPOTHESIS.value: 0, + NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 2, }, ), @@ -329,7 +330,7 @@ def multi_hypothesis_graph_3d(): { NodeAttr.POS.value: (15, 50, 70), NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, }, ), diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_graph_from_segmentation.py index 96923e2..c047dfa 100644 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ b/tests/test_candidate_graph/test_graph_from_segmentation.py @@ -3,8 +3,8 @@ import networkx as nx import numpy as np import pytest -from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr -from motile_toolbox.candidate_graph.graph_from_segmentation import ( +from motile_toolbox.candidate_graph import ( + EdgeAttr, NodeAttr, add_cand_edges, graph_from_segmentation, nodes_from_segmentation, diff --git a/tests/test_candidate_graph/test_iou.py b/tests/test_candidate_graph/test_iou.py index a9dcb42..e797889 100644 --- a/tests/test_candidate_graph/test_iou.py +++ b/tests/test_candidate_graph/test_iou.py @@ -1,15 +1,15 @@ import pytest -from motile_toolbox.candidate_graph.iou import compute_ious +from motile_toolbox.candidate_graph.iou import _compute_ious def test_compute_ious_2d(segmentation_2d): - ious = compute_ious(segmentation_2d[0], segmentation_2d[1]) + ious = _compute_ious(segmentation_2d[0], segmentation_2d[1]) expected = {1: {2: 555.46 / 1408.0}} assert ious.keys() == expected.keys() assert ious[1].keys() == expected[1].keys() assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) - ious = compute_ious(segmentation_2d[1], segmentation_2d[1]) + ious = _compute_ious(segmentation_2d[1], segmentation_2d[1]) expected = {1: {1: 1.0}, 2: {2: 1.0}} assert ious.keys() == expected.keys() assert ious[1].keys() == expected[1].keys() @@ -19,13 +19,13 @@ def test_compute_ious_2d(segmentation_2d): def test_compute_ious_3d(segmentation_3d): - ious = compute_ious(segmentation_3d[0], segmentation_3d[1]) + ious = _compute_ious(segmentation_3d[0], segmentation_3d[1]) expected = {1: {2: 0.30}} assert ious.keys() == expected.keys() assert ious[1].keys() == expected[1].keys() assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) - ious = compute_ious(segmentation_3d[1], segmentation_3d[1]) + ious = _compute_ious(segmentation_3d[1], segmentation_3d[1]) expected = {1: {1: 1.0}, 2: {2: 1.0}} assert ious.keys() == expected.keys() assert ious[1].keys() == expected[1].keys() From b224f1ccd16a23fefd39c55d2764679d50d907f6 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 14:39:23 -0400 Subject: [PATCH 09/20] Reorganize candidate graph tests --- .../test_compute_graph.py | 55 +++++++++++++++++++ .../test_conflict_sets.py | 1 + ...aph_from_segmentation.py => test_utils.py} | 52 +----------------- 3 files changed, 58 insertions(+), 50 deletions(-) create mode 100644 tests/test_candidate_graph/test_compute_graph.py create mode 100644 tests/test_candidate_graph/test_conflict_sets.py rename tests/test_candidate_graph/{test_graph_from_segmentation.py => test_utils.py} (61%) diff --git a/tests/test_candidate_graph/test_compute_graph.py b/tests/test_candidate_graph/test_compute_graph.py new file mode 100644 index 0000000..9d055cf --- /dev/null +++ b/tests/test_candidate_graph/test_compute_graph.py @@ -0,0 +1,55 @@ +from collections import Counter + +import pytest +from motile_toolbox.candidate_graph import ( + EdgeAttr, + graph_from_segmentation, +) + + +def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): + # test with 2D segmentation + cand_graph = graph_from_segmentation( + segmentation=segmentation_2d, + max_edge_distance=100, + iou=True, + ) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) + for edge in cand_graph.edges: + print(cand_graph.edges[edge]) + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) + == graph_2d.edges[edge][EdgeAttr.DISTANCE.value] + ) + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) + == graph_2d.edges[edge][EdgeAttr.IOU.value] + ) + + # lower edge distance + cand_graph = graph_from_segmentation( + segmentation=segmentation_2d, + max_edge_distance=15, + ) + assert Counter(list(cand_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert Counter(list(cand_graph.edges)) == Counter([("0_1", "1_2")]) + assert cand_graph.edges[("0_1", "1_2")][EdgeAttr.DISTANCE.value] == pytest.approx( + 11.18, abs=0.01 + ) + + +def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): + # test with 3D segmentation + cand_graph = graph_from_segmentation( + segmentation=segmentation_3d, + max_edge_distance=100, + ) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) + for edge in cand_graph.edges: + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] diff --git a/tests/test_candidate_graph/test_conflict_sets.py b/tests/test_candidate_graph/test_conflict_sets.py new file mode 100644 index 0000000..4640904 --- /dev/null +++ b/tests/test_candidate_graph/test_conflict_sets.py @@ -0,0 +1 @@ +# TODO diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_utils.py similarity index 61% rename from tests/test_candidate_graph/test_graph_from_segmentation.py rename to tests/test_candidate_graph/test_utils.py index c047dfa..1374643 100644 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ b/tests/test_candidate_graph/test_utils.py @@ -4,9 +4,9 @@ import numpy as np import pytest from motile_toolbox.candidate_graph import ( - EdgeAttr, NodeAttr, + EdgeAttr, + NodeAttr, add_cand_edges, - graph_from_segmentation, nodes_from_segmentation, ) @@ -83,51 +83,3 @@ def test_add_cand_edges_3d(graph_3d): assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) for edge in cand_graph.edges: assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] - - -def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): - # test with 2D segmentation - cand_graph = graph_from_segmentation( - segmentation=segmentation_2d, - max_edge_distance=100, - iou=True, - ) - assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes)) - assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) - for node in cand_graph.nodes: - assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) - for edge in cand_graph.edges: - print(cand_graph.edges[edge]) - assert ( - pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) - == graph_2d.edges[edge][EdgeAttr.DISTANCE.value] - ) - assert ( - pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) - == graph_2d.edges[edge][EdgeAttr.IOU.value] - ) - - # lower edge distance - cand_graph = graph_from_segmentation( - segmentation=segmentation_2d, - max_edge_distance=15, - ) - assert Counter(list(cand_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert Counter(list(cand_graph.edges)) == Counter([("0_1", "1_2")]) - assert cand_graph.edges[("0_1", "1_2")][EdgeAttr.DISTANCE.value] == pytest.approx( - 11.18, abs=0.01 - ) - - -def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): - # test with 3D segmentation - cand_graph = graph_from_segmentation( - segmentation=segmentation_3d, - max_edge_distance=100, - ) - assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes)) - assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) - for node in cand_graph.nodes: - assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) - for edge in cand_graph.edges: - assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] From 9edde1b3a6adac3faa77fc76d2e6f7a6f2b810a6 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 15:24:52 -0400 Subject: [PATCH 10/20] Test multi hypothesis iou --- src/motile_toolbox/candidate_graph/iou.py | 35 ++++++++++++++++++----- tests/conftest.py | 2 +- tests/test_candidate_graph/test_iou.py | 27 +++++++++++++++++ 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index f86dfea..e076720 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -1,11 +1,12 @@ -from itertools import combinations +from itertools import product +from typing import Any import networkx as nx import numpy as np from tqdm import tqdm from .graph_attributes import EdgeAttr, NodeAttr -from .utils import get_node_id +from .utils import _compute_node_frame_dict, get_node_id def _compute_ious( @@ -46,13 +47,23 @@ def _compute_ious( return iou_dict -def add_iou(cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict) -> None: +def add_iou( + cand_graph: nx.DiGraph, + segmentation: np.ndarray, + node_frame_dict: dict[int, list[Any]] | None = None, +) -> None: """Add IOU to the candidate graph. Args: cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated segmentation (np.ndarray): segmentation that was used to create cand_graph + node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from + time frames to nodes in that frame. Will be computed if not provided, + but can be provided for efficiency (e.g. after running + nodes_from_segmentation). Defaults to None. """ + if node_frame_dict is None: + node_frame_dict = _compute_node_frame_dict(cand_graph) frames = sorted(node_frame_dict.keys()) for frame in tqdm(frames): if frame + 1 not in node_frame_dict: @@ -68,7 +79,9 @@ def add_iou(cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict) - def add_multihypo_iou( - cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict + cand_graph: nx.DiGraph, + segmentation: np.ndarray, + node_frame_dict: dict[int, list[Any]] | None = None, ) -> None: """Add IOU to the candidate graph for multi-hypothesis segmentations. @@ -76,7 +89,13 @@ def add_multihypo_iou( cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated segmentation (np.ndarray): Multiple hypothesis segmentation. Dimensions are (t, h, [z], y, x), where h is the number of hypotheses. + node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from + time frames to nodes in that frame. Will be computed if not provided, + but can be provided for efficiency (e.g. after running + nodes_from_segmentation). Defaults to None. """ + if node_frame_dict is None: + node_frame_dict = _compute_node_frame_dict(cand_graph) frames = sorted(node_frame_dict.keys()) num_hypotheses = segmentation.shape[1] for frame in tqdm(frames): @@ -84,13 +103,14 @@ def add_multihypo_iou( continue # construct dictionary of ious between node_ids in frame 1 and frame 2 ious: dict[str, dict[str, float]] = {} - for hypo1, hypo2 in combinations(range(num_hypotheses), 2): + for hypo1, hypo2 in product(range(num_hypotheses), repeat=2): hypo_ious = _compute_ious( segmentation[frame][hypo1], segmentation[frame + 1][hypo2] ) for segid, intersecting_labels in hypo_ious.items(): node_id = get_node_id(frame, segid, hypo1) - ious[node_id] = {} + if node_id not in ious: + ious[node_id] = {} for segid2, iou in intersecting_labels.items(): next_id = get_node_id(frame + 1, segid2, hypo2) ious[node_id][next_id] = iou @@ -98,4 +118,5 @@ def add_multihypo_iou( for node_id in node_frame_dict[frame]: for next_id in next_nodes: iou = ious.get(node_id, {}).get(next_id, 0) - cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou + if (node_id, next_id) in cand_graph.edges: + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou diff --git a/tests/conftest.py b/tests/conftest.py index 972d80d..a4643a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -134,7 +134,7 @@ def multi_hypothesis_graph_2d(): { NodeAttr.POS.value: (15, 75), NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPOTHESIS.value: 1, + NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, }, ), diff --git a/tests/test_candidate_graph/test_iou.py b/tests/test_candidate_graph/test_iou.py index e797889..17d0916 100644 --- a/tests/test_candidate_graph/test_iou.py +++ b/tests/test_candidate_graph/test_iou.py @@ -1,4 +1,6 @@ +import networkx as nx import pytest +from motile_toolbox.candidate_graph import EdgeAttr, add_iou, add_multihypo_iou from motile_toolbox.candidate_graph.iou import _compute_ious @@ -32,3 +34,28 @@ def test_compute_ious_3d(segmentation_3d): assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) assert ious[2].keys() == expected[2].keys() assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) + + +def test_add_iou_2d(segmentation_2d, graph_2d): + expected = graph_2d + input_graph = graph_2d.copy() + nx.set_edge_attributes(input_graph, -1, name=EdgeAttr.IOU.value) + add_iou(input_graph, segmentation_2d) + for s, t, attrs in expected.edges(data=True): + assert ( + pytest.approx(attrs[EdgeAttr.IOU.value], abs=0.01) + == input_graph.edges[(s, t)][EdgeAttr.IOU.value] + ) + + +def test_multi_hypo_iou_2d(multi_hypothesis_segmentation_2d, multi_hypothesis_graph_2d): + expected = multi_hypothesis_graph_2d + input_graph = multi_hypothesis_graph_2d.copy() + nx.set_edge_attributes(input_graph, -1, name=EdgeAttr.IOU.value) + add_multihypo_iou(input_graph, multi_hypothesis_segmentation_2d) + for s, t, attrs in expected.edges(data=True): + print(s, t) + assert ( + pytest.approx(attrs[EdgeAttr.IOU.value], abs=0.01) + == input_graph.edges[(s, t)][EdgeAttr.IOU.value] + ) From d3fe5dba9693ad0eb8d01cc9f81bd2740d380897 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 15:45:27 -0400 Subject: [PATCH 11/20] Test helper functions in utils --- tests/test_candidate_graph/test_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_candidate_graph/test_utils.py b/tests/test_candidate_graph/test_utils.py index 1374643..46f1832 100644 --- a/tests/test_candidate_graph/test_utils.py +++ b/tests/test_candidate_graph/test_utils.py @@ -7,8 +7,10 @@ EdgeAttr, NodeAttr, add_cand_edges, + get_node_id, nodes_from_segmentation, ) +from motile_toolbox.candidate_graph.utils import _compute_node_frame_dict # nodes_from_segmentation @@ -83,3 +85,19 @@ def test_add_cand_edges_3d(graph_3d): assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) for edge in cand_graph.edges: assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] + + +def test_get_node_id(): + assert get_node_id(0, 2) == "0_2" + assert get_node_id(2, 10, 3) == "2_3_10" + + +def test_compute_node_frame_dict(graph_2d): + node_frame_dict = _compute_node_frame_dict(graph_2d) + expected = { + 0: [ + "0_1", + ], + 1: ["1_1", "1_2"], + } + assert node_frame_dict == expected From 148fc992c28c762f3bd6ebb548a2fc1b63b601d6 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 15:47:58 -0400 Subject: [PATCH 12/20] Update docstring for conflict sets --- src/motile_toolbox/candidate_graph/conflict_sets.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/conflict_sets.py b/src/motile_toolbox/candidate_graph/conflict_sets.py index 16c8f59..e15e6f4 100644 --- a/src/motile_toolbox/candidate_graph/conflict_sets.py +++ b/src/motile_toolbox/candidate_graph/conflict_sets.py @@ -8,7 +8,9 @@ def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]: - """Segmentation in one frame only. Return + """Compute all sets of node ids that conflict with each other. + Note: Results might include redundant sets, for example {a, b, c} and {a, b} + might both appear in the results. Args: segmentation_frame (np.ndarray): One frame of the multiple hypothesis @@ -17,7 +19,8 @@ def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set time (int): Time frame, for computing node_ids. Returns: - list[set]: list of sets of node ids that overlap + list[set]: list of sets of node ids that overlap. Might include some sets + that are subsets of others. """ flattened_segs = [seg.flatten() for seg in segmentation_frame] From 63c221987b8d83fb34ef789321e80bb4c9cccb05 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 16:37:55 -0400 Subject: [PATCH 13/20] Unite single and multi hypothesis graph creation into one function --- .../candidate_graph/__init__.py | 2 +- .../candidate_graph/compute_graph.py | 92 ++++++++----------- src/motile_toolbox/candidate_graph/utils.py | 17 +++- .../test_compute_graph.py | 56 +++++++++-- 4 files changed, 106 insertions(+), 61 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index eab1b89..51261a0 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -1,4 +1,4 @@ -from .compute_graph import compute_multi_seg_graph, graph_from_segmentation +from .compute_graph import get_candidate_graph from .graph_attributes import EdgeAttr, NodeAttr from .graph_to_nx import graph_to_nx from .iou import add_iou, add_multihypo_iou diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py index be5fe5c..aca1b44 100644 --- a/src/motile_toolbox/candidate_graph/compute_graph.py +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -1,4 +1,5 @@ import logging +from typing import Any import networkx as nx import numpy as np @@ -10,33 +11,49 @@ logger = logging.getLogger(__name__) -def graph_from_segmentation( +def get_candidate_graph( segmentation: np.ndarray, max_edge_distance: float, iou: bool = False, -) -> nx.DiGraph: + multihypo: bool = False, +) -> tuple[nx.DiGraph, list[set[Any]] | None]: """Construct a candidate graph from a segmentation array. Nodes are placed at the centroid of each segmentation and edges are added for all nodes in adjacent frames - within max_edge_distance. The specified attributes are computed during construction. - Node ids are strings with format "{time}_{label id}". + within max_edge_distance. If segmentation contains multiple hypotheses, will also + return a list of conflicting node ids that cannot be selected together. Args: - segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels - (0 is background, all pixels with value 1 belong to one cell, etc.). The - time dimension is first, followed by two or three position dimensions. If - the position dims are not (y, x), use `position_keys` to specify the names - of the dimensions. + segmentation (np.ndarray): A numpy array with integer labels and dimensions + (t, [h], [z], y, x), where h is the number of hypotheses. max_edge_distance (float): Maximum distance that objects can travel between - frames. All nodes within this distance in adjacent frames will by connected - with a candidate edge. + frames. All nodes with centroids within this distance in adjacent frames + will by connected with a candidate edge. iou (bool, optional): Whether to include IOU on the candidate graph. Defaults to False. + multihypo (bool, optional): Whether the segmentation contains multiple + hypotheses. Defaults to False. Returns: - nx.DiGraph: A candidate graph that can be passed to the motile solver. + tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed + to the motile solver, and a list of conflicting node ids. """ # add nodes - cand_graph, node_frame_dict = nodes_from_segmentation(segmentation) + if multihypo: + cand_graph = nx.DiGraph() + num_frames = segmentation.shape[0] + node_frame_dict = {t: [] for t in range(num_frames)} + num_hypotheses = segmentation.shape[1] + for hypo_id in range(num_hypotheses): + hypothesis = segmentation[:, hypo_id] + node_graph, frame_dict = nodes_from_segmentation( + hypothesis, hypo_id=hypo_id + ) + cand_graph.update(node_graph) + for t in range(num_frames): + if t in frame_dict: + node_frame_dict[t].extend(frame_dict[t]) + else: + cand_graph, node_frame_dict = nodes_from_segmentation(segmentation) logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") # add edges @@ -46,48 +63,19 @@ def graph_from_segmentation( node_frame_dict=node_frame_dict, ) if iou: - add_iou(cand_graph, segmentation, node_frame_dict) + if multihypo: + add_multihypo_iou(cand_graph, segmentation, node_frame_dict) + else: + add_iou(cand_graph, segmentation, node_frame_dict) logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") - return cand_graph - - -def compute_multi_seg_graph( - segmentation: np.ndarray, - max_edge_distance: float, - iou: bool = False, -) -> tuple[nx.DiGraph, list[set]]: - """Create a candidate graph from multi hypothesis segmentation. This is not - tailored for agglomeration approaches with hierarchical merge graphs, it simply - creates a conflict set for any nodes that overlap in the same time frame. - - Args: - segmentations (np.ndarray): Multiple hypothesis segmentation. Dimensions - are (t, h, [z], y, x), where h is the number of hypotheses. - - Returns: - nx.DiGraph: _description_ - """ - # for each segmentation, get nodes using same method as graph_from_segmentation - # add them all to one big graph - cand_graph = nx.DiGraph() - node_frame_dict = {} - num_hypotheses = segmentation.shape[1] - for hypo_id in range(num_hypotheses): - hypothesis = segmentation[:, hypo_id] - node_graph, frame_dict = nodes_from_segmentation(hypothesis, hypo_id=hypo_id) - cand_graph.update(node_graph) - node_frame_dict.update(frame_dict) # Compute conflict sets between segmentations - # can use same method as IOU (without the U) to compute conflict sets - conflicts = [] - for time, segs in enumerate(segmentation): - conflicts.extend(compute_conflict_sets(segs, time)) - - # add edges with same method as before, with slightly different implementation - add_cand_edges(cand_graph, max_edge_distance, node_frame_dict) - if iou: - add_multihypo_iou(cand_graph, segmentation, node_frame_dict) + if multihypo: + conflicts = [] + for time, segs in enumerate(segmentation): + conflicts.extend(compute_conflict_sets(segs, time)) + else: + conflicts = None return cand_graph, conflicts diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index 817a808..f75870b 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -13,6 +13,21 @@ def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> str: + """Construct a node id given the time frame, segmentation label id, and + optionally the hypothesis id. This function is not designed for candidate graphs + that do not come from segmentations, but could be used if there is a similar + "detection id" that is unique for all cells detected in a given frame. + + Args: + time (int): The time frame the node is in + label_id (int): The label the node has in the segmentation. + hypothesis_id (int | None, optional): An integer representing which hypothesis + the segmentation came from, if applicable. Defaults to None. + + Returns: + str: A string to use as the node id in the candidate graph. Assuming that label + ids are not repeated in the same time frame and hypothesis, it is unique. + """ if hypothesis_id is not None: return f"{time}_{hypothesis_id}_{label_id}" else: @@ -41,7 +56,7 @@ def nodes_from_segmentation( cand_graph = nx.DiGraph() # also construct a dictionary from time frame to node_id for efficiency node_frame_dict = {} - print("Extracting nodes from segmentaiton") + print("Extracting nodes from segmentation") for t in tqdm(range(len(segmentation))): nodes_in_frame = [] props = regionprops(segmentation[t]) diff --git a/tests/test_candidate_graph/test_compute_graph.py b/tests/test_candidate_graph/test_compute_graph.py index 9d055cf..2d4a8a1 100644 --- a/tests/test_candidate_graph/test_compute_graph.py +++ b/tests/test_candidate_graph/test_compute_graph.py @@ -1,15 +1,12 @@ from collections import Counter import pytest -from motile_toolbox.candidate_graph import ( - EdgeAttr, - graph_from_segmentation, -) +from motile_toolbox.candidate_graph import EdgeAttr, get_candidate_graph def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): # test with 2D segmentation - cand_graph = graph_from_segmentation( + cand_graph, _ = get_candidate_graph( segmentation=segmentation_2d, max_edge_distance=100, iou=True, @@ -30,7 +27,7 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): ) # lower edge distance - cand_graph = graph_from_segmentation( + cand_graph, _ = get_candidate_graph( segmentation=segmentation_2d, max_edge_distance=15, ) @@ -43,7 +40,7 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): # test with 3D segmentation - cand_graph = graph_from_segmentation( + cand_graph, _ = get_candidate_graph( segmentation=segmentation_3d, max_edge_distance=100, ) @@ -53,3 +50,48 @@ def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) for edge in cand_graph.edges: assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] + + +def test_graph_from_multi_segmentation_2d( + multi_hypothesis_segmentation_2d, multi_hypothesis_graph_2d +): + # test with 2D segmentation + cand_graph, conflict_set = get_candidate_graph( + segmentation=multi_hypothesis_segmentation_2d, + max_edge_distance=100, + iou=True, + multihypo=True, + ) + assert Counter(list(cand_graph.nodes)) == Counter( + list(multi_hypothesis_graph_2d.nodes) + ) + assert Counter(list(cand_graph.edges)) == Counter( + list(multi_hypothesis_graph_2d.edges) + ) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter( + multi_hypothesis_graph_2d.nodes[node] + ) + for edge in cand_graph.edges: + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) + == multi_hypothesis_graph_2d.edges[edge][EdgeAttr.DISTANCE.value] + ) + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) + == multi_hypothesis_graph_2d.edges[edge][EdgeAttr.IOU.value] + ) + # TODO: Test conflict set + + # lower edge distance + cand_graph, _ = get_candidate_graph( + segmentation=multi_hypothesis_segmentation_2d, + max_edge_distance=14, + multihypo=True, + ) + assert Counter(list(cand_graph.nodes)) == Counter( + list(multi_hypothesis_graph_2d.nodes) + ) + assert Counter(list(cand_graph.edges)) == Counter( + [("0_0_1", "1_0_2"), ("0_0_1", "1_1_2"), ("0_1_1", "1_1_2")] + ) From b1c738e433fdae9e5ecd2663fc3ed84572b61e60 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 16:40:45 -0400 Subject: [PATCH 14/20] Add type annotation for mypy --- src/motile_toolbox/candidate_graph/compute_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py index aca1b44..daa930c 100644 --- a/src/motile_toolbox/candidate_graph/compute_graph.py +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -41,7 +41,7 @@ def get_candidate_graph( if multihypo: cand_graph = nx.DiGraph() num_frames = segmentation.shape[0] - node_frame_dict = {t: [] for t in range(num_frames)} + node_frame_dict: dict[int, list[Any]] = {t: [] for t in range(num_frames)} num_hypotheses = segmentation.shape[1] for hypo_id in range(num_hypotheses): hypothesis = segmentation[:, hypo_id] From 6d2af9b9e49f8acbd9ad8f0ac79b0b3e03be6ceb Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 1 Apr 2024 17:00:39 -0400 Subject: [PATCH 15/20] Add `pytest-unordered` dependency. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ac3310c..8706b10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dev = [ 'pdoc', 'pre-commit', 'types-tqdm', + 'pytest-unordered' ] [project.urls] From 8ea2d29066a5a5c8433f62083a01053ebf2a0f43 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 1 Apr 2024 17:01:02 -0400 Subject: [PATCH 16/20] Minor changes to `compute_conflict_sets`. --- src/motile_toolbox/candidate_graph/conflict_sets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/conflict_sets.py b/src/motile_toolbox/candidate_graph/conflict_sets.py index e15e6f4..4747c29 100644 --- a/src/motile_toolbox/candidate_graph/conflict_sets.py +++ b/src/motile_toolbox/candidate_graph/conflict_sets.py @@ -33,12 +33,12 @@ def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs]) values = np.unique(flattened_stacked, axis=1) - + values = np.transpose(values) conflict_sets = [] for conflicting_labels in values: id_set = set() for hypo_id, label in enumerate(conflicting_labels): if label != 0: id_set.add(get_node_id(time, label, hypo_id)) - conflict_sets.append(id_set) + conflict_sets.append(id_set) return conflict_sets From fc4985236aebc723effe5762ed795c0e007d84e9 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 1 Apr 2024 17:01:25 -0400 Subject: [PATCH 17/20] Add `test_conflict_sets`. --- .../test_conflict_sets.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/test_candidate_graph/test_conflict_sets.py b/tests/test_candidate_graph/test_conflict_sets.py index 4640904..c20f810 100644 --- a/tests/test_candidate_graph/test_conflict_sets.py +++ b/tests/test_candidate_graph/test_conflict_sets.py @@ -1 +1,36 @@ -# TODO + +import numpy as np +from motile_toolbox.candidate_graph.conflict_sets import compute_conflict_sets +from pytest_unordered import unordered + + +def test_conflict_sets_2d(multi_hypothesis_segmentation_2d): + for t in range(multi_hypothesis_segmentation_2d.shape[0]): + conflict_set =compute_conflict_sets(multi_hypothesis_segmentation_2d[t], t) + if t==0: + expected= [{'0_1_1', '0_0_1'}] + assert len(conflict_set) == 1 + assert conflict_set==unordered(expected) + elif t==1: + expected= [{'1_0_2', '1_1_2'}, {'1_0_1', '1_1_1'}] + assert len(conflict_set) == 2 + assert conflict_set == unordered(expected) + + +def test_conflict_sets_2d_reshaped(multi_hypothesis_segmentation_2d): + """Reshape segmentation array just to provide a slightly difficult example. + """ + + + reshaped = np.asarray([multi_hypothesis_segmentation_2d[0, 0], # hypothesis 0 + multi_hypothesis_segmentation_2d[1, 0], # hypothesis 1 + multi_hypothesis_segmentation_2d[1, 1]]) # hypothesis 2 + conflict_set = compute_conflict_sets(reshaped, 0) + # note the expected ids are not really there since the + # reshaped array is artifically constructed + expected = [{'0_0_1', '0_1_2', '0_2_2'}, + {'0_1_1', '0_2_1'}, + {'0_0_1', '0_1_2'}, + {'0_1_2', '0_2_2'}, + {'0_0_1', '0_2_2'}] + assert conflict_set == unordered(expected) From 61eabb7fa61c4b9749ffa24bbbbefc0e2d4bf26a Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 1 Apr 2024 17:29:18 -0400 Subject: [PATCH 18/20] Refactor IOU code to work for single and multi hypothesis --- .../candidate_graph/__init__.py | 2 +- .../candidate_graph/compute_graph.py | 7 +- src/motile_toolbox/candidate_graph/iou.py | 103 ++++++++---------- tests/test_candidate_graph/test_iou.py | 38 +++---- 4 files changed, 67 insertions(+), 83 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index 51261a0..8b67fe2 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -1,5 +1,5 @@ from .compute_graph import get_candidate_graph from .graph_attributes import EdgeAttr, NodeAttr from .graph_to_nx import graph_to_nx -from .iou import add_iou, add_multihypo_iou +from .iou import add_iou from .utils import add_cand_edges, get_node_id, nodes_from_segmentation diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py index daa930c..7649c62 100644 --- a/src/motile_toolbox/candidate_graph/compute_graph.py +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -5,7 +5,7 @@ import numpy as np from .conflict_sets import compute_conflict_sets -from .iou import add_iou, add_multihypo_iou +from .iou import add_iou from .utils import add_cand_edges, nodes_from_segmentation logger = logging.getLogger(__name__) @@ -63,10 +63,7 @@ def get_candidate_graph( node_frame_dict=node_frame_dict, ) if iou: - if multihypo: - add_multihypo_iou(cand_graph, segmentation, node_frame_dict) - else: - add_iou(cand_graph, segmentation, node_frame_dict) + add_iou(cand_graph, segmentation, node_frame_dict, multihypo=multihypo) logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index e076720..0f3d4f7 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -5,13 +5,13 @@ import numpy as np from tqdm import tqdm -from .graph_attributes import EdgeAttr, NodeAttr +from .graph_attributes import EdgeAttr from .utils import _compute_node_frame_dict, get_node_id def _compute_ious( frame1: np.ndarray, frame2: np.ndarray -) -> dict[int, dict[int, float]]: +) -> list[tuple[int, int, float]]: """Compute label IOUs between two label arrays of the same shape. Ignores background (label 0). @@ -20,8 +20,8 @@ def _compute_ious( frame2 (np.ndarray): Array with integer labels Returns: - dict[int, dict[int, float]]: Dictionary from labels in frame 1 to labels in - frame 2 to iou values. Nodes that have no overlap are not included. + list[tuple[int, int, float]]: List of tuples of label in frame 1, label in + frame 2, and iou values. Labels that have no overlap are not included. """ frame1 = frame1.flatten() frame2 = frame2.flatten() @@ -35,88 +35,81 @@ def _compute_ious( frame1_label_sizes = dict(zip(frame1_values, frame1_counts)) frame2_values, frame2_counts = np.unique(frame2, return_counts=True) frame2_label_sizes = dict(zip(frame2_values, frame2_counts)) - iou_dict: dict[int, dict[int, float]] = {} + ious: list[tuple[int, int, float]] = [] for index in range(values.shape[1]): pair = values[:, index] intersection = counts[index] id1, id2 = pair union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection - if id1 not in iou_dict: - iou_dict[id1] = {} - iou_dict[id1][id2] = intersection / union - return iou_dict + ious.append((id1, id2, intersection / union)) + return ious -def add_iou( - cand_graph: nx.DiGraph, - segmentation: np.ndarray, - node_frame_dict: dict[int, list[Any]] | None = None, -) -> None: - """Add IOU to the candidate graph. +def _get_iou_dict(segmentation, multihypo=False) -> dict[str, dict[str, float]]: + """Get all ious values for the provided segmentation (all frames). + Will return as map from node_id -> dict[node_id] -> iou for easy + navigation when adding to candidate graph. Args: - cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated - segmentation (np.ndarray): segmentation that was used to create cand_graph - node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from - time frames to nodes in that frame. Will be computed if not provided, - but can be provided for efficiency (e.g. after running - nodes_from_segmentation). Defaults to None. + segmentation (np.ndarray): Segmentation that was used to create cand_graph. + Has shape (t, [h], [z], y, x), where h is the number of hypotheses. + multihypo (bool, optional): Whether or not the segmentation is multi hypothesis. + Defaults to False. + + Returns: + dict[str, dict[str, float]]: A map from node id to another dictionary, which + contains node_ids to iou values. """ - if node_frame_dict is None: - node_frame_dict = _compute_node_frame_dict(cand_graph) - frames = sorted(node_frame_dict.keys()) - for frame in tqdm(frames): - if frame + 1 not in node_frame_dict: - continue - ious = _compute_ious(segmentation[frame], segmentation[frame + 1]) - next_nodes = node_frame_dict[frame + 1] - for node_id in node_frame_dict[frame]: - node_seg_id = cand_graph.nodes[node_id][NodeAttr.SEG_ID.value] - for next_id in next_nodes: - next_seg_id = cand_graph.nodes[next_id][NodeAttr.SEG_ID.value] - iou = ious.get(node_seg_id, {}).get(next_seg_id, 0) - cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou + iou_dict: dict[str, dict[str, float]] = {} + hypo_pairs: list[tuple[int | None, ...]] + if multihypo: + num_hypotheses = segmentation.shape[1] + hypo_pairs = list(product(range(num_hypotheses), repeat=2)) + else: + hypo_pairs = [(None, None)] + + for frame in range(len(segmentation) - 1): + for hypo1, hypo2 in hypo_pairs: + seg1 = segmentation[frame][hypo1] + seg2 = segmentation[frame + 1][hypo2] + ious = _compute_ious(seg1, seg2) + for label1, label2, iou in ious: + node_id1 = get_node_id(frame, label1, hypo1) + if node_id1 not in iou_dict: + iou_dict[node_id1] = {} + node_id2 = get_node_id(frame + 1, label2, hypo2) + iou_dict[node_id1][node_id2] = iou + return iou_dict -def add_multihypo_iou( +def add_iou( cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict: dict[int, list[Any]] | None = None, + multihypo: bool = False, ) -> None: - """Add IOU to the candidate graph for multi-hypothesis segmentations. + """Add IOU to the candidate graph. Args: cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated - segmentation (np.ndarray): Multiple hypothesis segmentation. Dimensions - are (t, h, [z], y, x), where h is the number of hypotheses. + segmentation (np.ndarray): segmentation that was used to create cand_graph. + Has shape (t, [h], [z], y, x), where h is the number of hypotheses. node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from time frames to nodes in that frame. Will be computed if not provided, but can be provided for efficiency (e.g. after running nodes_from_segmentation). Defaults to None. + multihypo (bool, optional): Whether the segmentation contains multiple + hypotheses. Defaults to False. """ if node_frame_dict is None: node_frame_dict = _compute_node_frame_dict(cand_graph) frames = sorted(node_frame_dict.keys()) - num_hypotheses = segmentation.shape[1] + ious = _get_iou_dict(segmentation, multihypo=multihypo) for frame in tqdm(frames): if frame + 1 not in node_frame_dict: continue - # construct dictionary of ious between node_ids in frame 1 and frame 2 - ious: dict[str, dict[str, float]] = {} - for hypo1, hypo2 in product(range(num_hypotheses), repeat=2): - hypo_ious = _compute_ious( - segmentation[frame][hypo1], segmentation[frame + 1][hypo2] - ) - for segid, intersecting_labels in hypo_ious.items(): - node_id = get_node_id(frame, segid, hypo1) - if node_id not in ious: - ious[node_id] = {} - for segid2, iou in intersecting_labels.items(): - next_id = get_node_id(frame + 1, segid2, hypo2) - ious[node_id][next_id] = iou next_nodes = node_frame_dict[frame + 1] for node_id in node_frame_dict[frame]: for next_id in next_nodes: iou = ious.get(node_id, {}).get(next_id, 0) - if (node_id, next_id) in cand_graph.edges: - cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou diff --git a/tests/test_candidate_graph/test_iou.py b/tests/test_candidate_graph/test_iou.py index 17d0916..d5d88d4 100644 --- a/tests/test_candidate_graph/test_iou.py +++ b/tests/test_candidate_graph/test_iou.py @@ -1,39 +1,33 @@ import networkx as nx import pytest -from motile_toolbox.candidate_graph import EdgeAttr, add_iou, add_multihypo_iou +from motile_toolbox.candidate_graph import EdgeAttr, add_iou from motile_toolbox.candidate_graph.iou import _compute_ious def test_compute_ious_2d(segmentation_2d): ious = _compute_ious(segmentation_2d[0], segmentation_2d[1]) - expected = {1: {2: 555.46 / 1408.0}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) + expected = [ + (1, 2, 555.46 / 1408.0), + ] + for iou, expected_iou in zip(ious, expected): + assert iou == pytest.approx(expected_iou, abs=0.01) ious = _compute_ious(segmentation_2d[1], segmentation_2d[1]) - expected = {1: {1: 1.0}, 2: {2: 1.0}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) - assert ious[2].keys() == expected[2].keys() - assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) + expected = [(1, 1, 1.0), (2, 2, 1.0)] + for iou, expected_iou in zip(ious, expected): + assert iou == pytest.approx(expected_iou, abs=0.01) def test_compute_ious_3d(segmentation_3d): ious = _compute_ious(segmentation_3d[0], segmentation_3d[1]) - expected = {1: {2: 0.30}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) + expected = [(1, 2, 0.30)] + for iou, expected_iou in zip(ious, expected): + assert iou == pytest.approx(expected_iou, abs=0.01) ious = _compute_ious(segmentation_3d[1], segmentation_3d[1]) - expected = {1: {1: 1.0}, 2: {2: 1.0}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) - assert ious[2].keys() == expected[2].keys() - assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) + expected = [(1, 1, 1.0), (2, 2, 1.0)] + for iou, expected_iou in zip(ious, expected): + assert iou == pytest.approx(expected_iou, abs=0.01) def test_add_iou_2d(segmentation_2d, graph_2d): @@ -52,7 +46,7 @@ def test_multi_hypo_iou_2d(multi_hypothesis_segmentation_2d, multi_hypothesis_gr expected = multi_hypothesis_graph_2d input_graph = multi_hypothesis_graph_2d.copy() nx.set_edge_attributes(input_graph, -1, name=EdgeAttr.IOU.value) - add_multihypo_iou(input_graph, multi_hypothesis_segmentation_2d) + add_iou(input_graph, multi_hypothesis_segmentation_2d, multihypo=True) for s, t, attrs in expected.edges(data=True): print(s, t) assert ( From ef602415ecc3d4b907f6fa839a89851e754e057c Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 3 Apr 2024 11:00:26 -0400 Subject: [PATCH 19/20] Black format tests --- .../test_conflict_sets.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/tests/test_candidate_graph/test_conflict_sets.py b/tests/test_candidate_graph/test_conflict_sets.py index c20f810..c07e4e8 100644 --- a/tests/test_candidate_graph/test_conflict_sets.py +++ b/tests/test_candidate_graph/test_conflict_sets.py @@ -1,4 +1,3 @@ - import numpy as np from motile_toolbox.candidate_graph.conflict_sets import compute_conflict_sets from pytest_unordered import unordered @@ -6,31 +5,35 @@ def test_conflict_sets_2d(multi_hypothesis_segmentation_2d): for t in range(multi_hypothesis_segmentation_2d.shape[0]): - conflict_set =compute_conflict_sets(multi_hypothesis_segmentation_2d[t], t) - if t==0: - expected= [{'0_1_1', '0_0_1'}] + conflict_set = compute_conflict_sets(multi_hypothesis_segmentation_2d[t], t) + if t == 0: + expected = [{"0_1_1", "0_0_1"}] assert len(conflict_set) == 1 - assert conflict_set==unordered(expected) - elif t==1: - expected= [{'1_0_2', '1_1_2'}, {'1_0_1', '1_1_1'}] + assert conflict_set == unordered(expected) + elif t == 1: + expected = [{"1_0_2", "1_1_2"}, {"1_0_1", "1_1_1"}] assert len(conflict_set) == 2 assert conflict_set == unordered(expected) def test_conflict_sets_2d_reshaped(multi_hypothesis_segmentation_2d): - """Reshape segmentation array just to provide a slightly difficult example. - """ - + """Reshape segmentation array just to provide a slightly difficult example.""" - reshaped = np.asarray([multi_hypothesis_segmentation_2d[0, 0], # hypothesis 0 - multi_hypothesis_segmentation_2d[1, 0], # hypothesis 1 - multi_hypothesis_segmentation_2d[1, 1]]) # hypothesis 2 + reshaped = np.asarray( + [ + multi_hypothesis_segmentation_2d[0, 0], # hypothesis 0 + multi_hypothesis_segmentation_2d[1, 0], # hypothesis 1 + multi_hypothesis_segmentation_2d[1, 1], + ] + ) # hypothesis 2 conflict_set = compute_conflict_sets(reshaped, 0) # note the expected ids are not really there since the # reshaped array is artifically constructed - expected = [{'0_0_1', '0_1_2', '0_2_2'}, - {'0_1_1', '0_2_1'}, - {'0_0_1', '0_1_2'}, - {'0_1_2', '0_2_2'}, - {'0_0_1', '0_2_2'}] + expected = [ + {"0_0_1", "0_1_2", "0_2_2"}, + {"0_1_1", "0_2_1"}, + {"0_0_1", "0_1_2"}, + {"0_1_2", "0_2_2"}, + {"0_0_1", "0_2_2"}, + ] assert conflict_set == unordered(expected) From ab75d9ac90ce6f42360cd8f3db799e352ddc1f0f Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 3 Apr 2024 11:09:55 -0400 Subject: [PATCH 20/20] Add black back into pre-commit --- .pre-commit-config.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ce0cab7..e0f6442 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,11 @@ repos: - id: check-yaml - id: check-added-large-files + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.2.2 hooks: