diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index fd4201f..5ff945d 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -6,4 +6,4 @@ from .graph_attributes import EdgeAttr, NodeAttr from .graph_to_nx import graph_to_nx from .iou import add_iou -from .utils import add_cand_edges, get_node_id, nodes_from_segmentation +from .utils import add_cand_edges, nodes_from_segmentation diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py index e1e40e3..7593523 100644 --- a/src/motile_toolbox/candidate_graph/compute_graph.py +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -114,7 +114,7 @@ def compute_graph_from_multiseg( conflicts = [] for time in range(segmentations.shape[1]): segs = segmentations[:, time] - conflicts.extend(compute_conflict_sets(segs, time)) + conflicts.extend(compute_conflict_sets(segs)) return cand_graph, conflicts diff --git a/src/motile_toolbox/candidate_graph/conflict_sets.py b/src/motile_toolbox/candidate_graph/conflict_sets.py index 4747c29..cb94075 100644 --- a/src/motile_toolbox/candidate_graph/conflict_sets.py +++ b/src/motile_toolbox/candidate_graph/conflict_sets.py @@ -2,12 +2,8 @@ import numpy as np -from .utils import ( - get_node_id, -) - -def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]: +def compute_conflict_sets(segmentation_frame: np.ndarray) -> list[set]: """Compute all sets of node ids that conflict with each other. Note: Results might include redundant sets, for example {a, b, c} and {a, b} might both appear in the results. @@ -36,9 +32,6 @@ def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set values = np.transpose(values) conflict_sets = [] for conflicting_labels in values: - id_set = set() - for hypo_id, label in enumerate(conflicting_labels): - if label != 0: - id_set.add(get_node_id(time, label, hypo_id)) + id_set = {label for label in conflicting_labels if label != 0} conflict_sets.append(id_set) return conflict_sets diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index d134ebc..5de9a8c 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -1,12 +1,11 @@ from itertools import product -from typing import Any import networkx as nx import numpy as np from tqdm import tqdm from .graph_attributes import EdgeAttr -from .utils import _compute_node_frame_dict, get_node_id +from .utils import _compute_node_frame_dict def _compute_ious( @@ -45,7 +44,7 @@ def _compute_ious( return ious -def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]: +def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]: """Get all ious values for the provided segmentations (all frames). Will return as map from node_id -> dict[node_id] -> iou for easy navigation when adding to candidate graph. @@ -58,10 +57,10 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]: multiple hypothesis segmentations. Defaults to False. Returns: - dict[str, dict[str, float]]: A map from node id to another dictionary, which + dict[int, dict[int, float]]: A map from node id to another dictionary, which contains node_ids to iou values. """ - iou_dict: dict[str, dict[str, float]] = {} + iou_dict: dict[int, dict[int, float]] = {} hypo_pairs: list[tuple[int, ...]] = [(0, 0)] if multiseg: num_hypotheses = segmentation.shape[0] @@ -76,23 +75,16 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]: seg2 = segmentation[hypo2][frame + 1] ious = _compute_ious(seg1, seg2) for label1, label2, iou in ious: - if multiseg: - node_id1 = get_node_id(frame, label1, hypo1) - node_id2 = get_node_id(frame + 1, label2, hypo2) - else: - node_id1 = get_node_id(frame, label1) - node_id2 = get_node_id(frame + 1, label2) - - if node_id1 not in iou_dict: - iou_dict[node_id1] = {} - iou_dict[node_id1][node_id2] = iou + if label1 not in iou_dict: + iou_dict[label1] = {} + iou_dict[label1][label2] = iou return iou_dict def add_iou( cand_graph: nx.DiGraph, segmentation: np.ndarray, - node_frame_dict: dict[int, list[Any]] | None = None, + node_frame_dict: dict[int, list[int]] | None = None, multiseg=False, ) -> None: """Add IOU to the candidate graph. diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index 9a430d6..970f9e5 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -13,28 +13,6 @@ logger = logging.getLogger(__name__) -def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> str: - """Construct a node id given the time frame, segmentation label id, and - optionally the hypothesis id. This function is not designed for candidate graphs - that do not come from segmentations, but could be used if there is a similar - "detection id" that is unique for all cells detected in a given frame. - - Args: - time (int): The time frame the node is in - label_id (int): The label the node has in the segmentation. - hypothesis_id (int | None, optional): An integer representing which hypothesis - the segmentation came from, if applicable. Defaults to None. - - Returns: - str: A string to use as the node id in the candidate graph. Assuming that label - ids are not repeated in the same time frame and hypothesis, it is unique. - """ - if hypothesis_id is not None: - return f"{time}_{hypothesis_id}_{label_id}" - else: - return f"{time}_{label_id}" - - def nodes_from_segmentation( segmentation: np.ndarray, scale: list[float] | None = None, @@ -52,7 +30,9 @@ def nodes_from_segmentation( Args: segmentation (np.ndarray): A numpy array with integer labels and dimensions - (t, [z], y, x). + (t, [z], y, x). Labels must be unique across time, and the label + will be used as the node id. If the labels are not unique, preprocess + with motile_toolbox.utils.ensure_unqiue_ids before calling this function. scale (list[float] | None, optional): The scale of the segmentation data in all dimensions (including time, which should have a dummy 1 value). Will be used to rescale the point locations and attribute computations. @@ -82,7 +62,7 @@ def nodes_from_segmentation( nodes_in_frame = [] props = regionprops(segs, spacing=tuple(scale[1:])) for regionprop in props: - node_id = get_node_id(t, regionprop.label, hypothesis_id=seg_hypo) + node_id = regionprop.label attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area} attrs[NodeAttr.SEG_ID.value] = regionprop.label if seg_hypo: diff --git a/src/motile_toolbox/utils/__init__.py b/src/motile_toolbox/utils/__init__.py index 4d3f858..6a1d4b4 100644 --- a/src/motile_toolbox/utils/__init__.py +++ b/src/motile_toolbox/utils/__init__.py @@ -1 +1,4 @@ -from .relabel_segmentation import relabel_segmentation +from .relabel_segmentation import ( + ensure_unique_labels, + relabel_segmentation_with_track_id, +) diff --git a/src/motile_toolbox/utils/relabel_segmentation.py b/src/motile_toolbox/utils/relabel_segmentation.py index 5fd5a71..885c26a 100644 --- a/src/motile_toolbox/utils/relabel_segmentation.py +++ b/src/motile_toolbox/utils/relabel_segmentation.py @@ -4,7 +4,7 @@ from motile_toolbox.candidate_graph import NodeAttr -def relabel_segmentation( +def relabel_segmentation_with_track_id( solution_nx_graph: nx.DiGraph, segmentation: np.ndarray, ) -> np.ndarray: @@ -37,3 +37,32 @@ def relabel_segmentation( tracked_masks[time_frame][previous_seg_mask] = id_counter id_counter += 1 return tracked_masks + + +def ensure_unique_labels( + segmentation: np.ndarray, + multiseg: bool = False, +) -> np.ndarray: + """Relabels the segmentation in place to ensure that label ids are unique across + time. This means that every detection will have a unique label id. + Useful for combining predictions made in each frame independently, or multiple + segmentation outputs that repeat label IDs. + + Args: + segmentation (np.ndarray): Segmentation with dimensions ([h], t, [z], y, x). + multiseg (bool, optional): Flag indicating if the segmentation contains + multiple hypotheses in the first dimension. Defaults to False. + """ + segmentation = segmentation.astype(np.uint64) + orig_shape = segmentation.shape + if multiseg: + segmentation = segmentation.reshape((-1, *orig_shape[2:])) + curr_max = 0 + for idx in range(segmentation.shape[0]): + frame = segmentation[idx] + frame[frame != 0] += curr_max + curr_max = int(np.max(frame)) + segmentation[idx] = frame + if multiseg: + segmentation = segmentation.reshape(orig_shape) + return segmentation diff --git a/tests/conftest.py b/tests/conftest.py index f2e9958..9c7713c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,12 +16,12 @@ def segmentation_2d(): segmentation[0][rr, cc] = 1 # make frame with two cells - # first cell centered at (20, 80) with label 1 - # second cell centered at (60, 45) with label 2 + # first cell centered at (20, 80) with label 2 + # second cell centered at (60, 45) with label 3 rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) - segmentation[1][rr, cc] = 1 - rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) segmentation[1][rr, cc] = 2 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 3 return segmentation @@ -35,28 +35,28 @@ def multi_hypothesis_segmentation_2d(): frame_shape = (100, 100) total_shape = (2, 2, *frame_shape) # 2 hypotheses, 2 time points, H, W segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 (hypo 1) + # make frame with one cell in center with label 1 (hypo 0) rr0, cc0 = disk(center=(50, 50), radius=20, shape=frame_shape) - # make frame with one cell at (45, 45) with label 1 (hypo 2) + # make frame with one cell at (45, 45) with label 2 (hypo 1) rr1, cc1 = disk(center=(45, 45), radius=15, shape=frame_shape) segmentation[0, 0][rr0, cc0] = 1 - segmentation[1, 0][rr1, cc1] = 1 + segmentation[1, 0][rr1, cc1] = 2 # make frame with two cells - # first cell centered at (20, 80) with label 1 + # first cell centered at (20, 80) with label 3 (hypo0) and 4 (hypo1) rr0, cc0 = disk(center=(20, 80), radius=10, shape=frame_shape) rr1, cc1 = disk(center=(15, 75), radius=15, shape=frame_shape) - segmentation[0, 1][rr0, cc0] = 1 - segmentation[1, 1][rr1, cc1] = 1 + segmentation[0, 1][rr0, cc0] = 3 + segmentation[1, 1][rr1, cc1] = 4 - # second cell centered at (60, 45) with label 2 + # second cell centered at (60, 45) with label 5(hypo0) and 6 (hypo1) rr0, cc0 = disk(center=(60, 45), radius=15, shape=frame_shape) rr1, cc1 = disk(center=(55, 40), radius=20, shape=frame_shape) - segmentation[0, 1][rr0, cc0] = 2 - segmentation[1, 1][rr1, cc1] = 2 + segmentation[0, 1][rr0, cc0] = 5 + segmentation[1, 1][rr1, cc1] = 6 return segmentation @@ -66,7 +66,7 @@ def graph_2d(): graph = nx.DiGraph() nodes = [ ( - "0_1", + 1, { NodeAttr.POS.value: (50, 50), NodeAttr.TIME.value: 0, @@ -75,27 +75,27 @@ def graph_2d(): }, ), ( - "1_1", + 2, { NodeAttr.POS.value: (20, 80), NodeAttr.TIME.value: 1, - NodeAttr.SEG_ID.value: 1, + NodeAttr.SEG_ID.value: 2, NodeAttr.AREA.value: 305, }, ), ( - "1_2", + 3, { NodeAttr.POS.value: (60, 45), NodeAttr.TIME.value: 1, - NodeAttr.SEG_ID.value: 2, + NodeAttr.SEG_ID.value: 3, NodeAttr.AREA.value: 697, }, ), ] edges = [ - ("0_1", "1_1", {EdgeAttr.IOU.value: 0.0}), - ("0_1", "1_2", {EdgeAttr.IOU.value: 0.395}), + (1, 2, {EdgeAttr.IOU.value: 0.0}), + (1, 3, {EdgeAttr.IOU.value: 0.395}), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) @@ -107,7 +107,7 @@ def multi_hypothesis_graph_2d(): graph = nx.DiGraph() nodes = [ ( - "0_0_1", + 1, { NodeAttr.POS.value: (50, 50), NodeAttr.TIME.value: 0, @@ -117,76 +117,76 @@ def multi_hypothesis_graph_2d(): }, ), ( - "0_1_1", + 2, { NodeAttr.POS.value: (45, 45), NodeAttr.TIME.value: 0, NodeAttr.SEG_HYPO.value: 1, - NodeAttr.SEG_ID.value: 1, + NodeAttr.SEG_ID.value: 2, NodeAttr.AREA.value: 697, }, ), ( - "1_0_1", + 3, { NodeAttr.POS.value: (20, 80), NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, - NodeAttr.SEG_ID.value: 1, + NodeAttr.SEG_ID.value: 3, NodeAttr.AREA.value: 305, }, ), ( - "1_1_1", + 4, { NodeAttr.POS.value: (15, 75), NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, - NodeAttr.SEG_ID.value: 1, + NodeAttr.SEG_ID.value: 4, NodeAttr.AREA.value: 697, }, ), ( - "1_0_2", + 5, { NodeAttr.POS.value: (60, 45), NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, - NodeAttr.SEG_ID.value: 2, + NodeAttr.SEG_ID.value: 5, NodeAttr.AREA.value: 697, }, ), ( - "1_1_2", + 6, { NodeAttr.POS.value: (55, 40), NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, - NodeAttr.SEG_ID.value: 2, + NodeAttr.SEG_ID.value: 6, NodeAttr.AREA.value: 1245, }, ), ] edges = [ - ("0_0_1", "1_0_1", {EdgeAttr.IOU.value: 0.0}), - ("0_0_1", "1_1_1", {EdgeAttr.IOU.value: 0.0}), + (1, 3, {EdgeAttr.IOU.value: 0.0}), + (1, 4, {EdgeAttr.IOU.value: 0.0}), ( - "0_0_1", - "1_0_2", + 1, + 5, {EdgeAttr.IOU.value: 0.3931}, ), ( - "0_0_1", - "1_1_2", + 1, + 6, {EdgeAttr.IOU.value: 0.4768}, ), - ("0_1_1", "1_0_1", {EdgeAttr.IOU.value: 0.0}), - ("0_1_1", "1_1_1", {EdgeAttr.IOU.value: 0.0}), - ("0_1_1", "1_0_2", {EdgeAttr.IOU.value: 0.2402}), + (2, 3, {EdgeAttr.IOU.value: 0.0}), + (2, 4, {EdgeAttr.IOU.value: 0.0}), + (2, 5, {EdgeAttr.IOU.value: 0.2402}), ( - "0_1_1", - "1_1_2", + 2, + 6, {EdgeAttr.IOU.value: 0.3931}, ), ] @@ -213,12 +213,12 @@ def segmentation_3d(): segmentation[0][mask] = 1 # make frame with two cells - # first cell centered at (20, 50, 80) with label 1 - # second cell centered at (60, 50, 45) with label 2 + # first cell centered at (20, 50, 80) with label 2 + # second cell centered at (60, 50, 45) with label 3 mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) - segmentation[1][mask] = 1 - mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) segmentation[1][mask] = 2 + mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) + segmentation[1][mask] = 3 return segmentation @@ -261,7 +261,7 @@ def graph_3d(): graph = nx.DiGraph() nodes = [ ( - "0_1", + 1, { NodeAttr.POS.value: (50, 50, 50), NodeAttr.TIME.value: 0, @@ -270,97 +270,29 @@ def graph_3d(): }, ), ( - "1_1", + 2, { NodeAttr.POS.value: (20, 50, 80), NodeAttr.TIME.value: 1, - NodeAttr.SEG_ID.value: 1, + NodeAttr.SEG_ID.value: 2, NodeAttr.AREA.value: 4169, }, ), ( - "1_2", + 3, { NodeAttr.POS.value: (60, 50, 45), NodeAttr.TIME.value: 1, - NodeAttr.SEG_ID.value: 2, + NodeAttr.SEG_ID.value: 3, NodeAttr.AREA.value: 14147, }, ), ] edges = [ # math.dist([50, 50], [20, 80]) - ("0_1", "1_1"), + (1, 2), # math.dist([50, 50], [60, 45]) - ("0_1", "1_2"), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - -@pytest.fixture -def multi_hypothesis_graph_3d(): - graph = nx.DiGraph() - nodes = [ - ( - "0_0_1", - { - NodeAttr.POS.value: (50, 50, 50), - NodeAttr.TIME.value: 0, - NodeAttr.SEG_HYPO.value: 0, - NodeAttr.SEG_ID.value: 1, - NodeAttr.AREA.value: 305, - }, - ), - ( - "0_1_1", - { - NodeAttr.POS.value: (45, 50, 55), - NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPO.value: 1, - NodeAttr.SEG_ID.value: 1, - NodeAttr.AREA.value: 305, - }, - ), - ( - "1_0_1", - { - NodeAttr.POS.value: (20, 50, 80), - NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPO.value: 0, - NodeAttr.SEG_ID.value: 1, - NodeAttr.AREA.value: 305, - }, - ), - ( - "1_0_2", - { - NodeAttr.POS.value: (60, 50, 45), - NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPO.value: 0, - NodeAttr.SEG_ID.value: 2, - NodeAttr.AREA.value: 305, - }, - ), - ( - "1_1_1", - { - NodeAttr.POS.value: (15, 50, 70), - NodeAttr.TIME.value: 1, - NodeAttr.SEG_HYPO.value: 1, - NodeAttr.SEG_ID.value: 1, - NodeAttr.AREA.value: 305, - }, - ), - ] - edges = [ - ("0_0_1", "1_0_1"), - ("0_0_1", "1_0_2"), - ("0_1_1", "1_0_1"), - ("0_1_1", "1_0_2"), - ("0_0_1", "1_1_1"), - ("0_1_1", "1_1_1"), + (1, 3), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) diff --git a/tests/test_candidate_graph/test_compute_graph.py b/tests/test_candidate_graph/test_compute_graph.py index 0b0a13b..07a7526 100644 --- a/tests/test_candidate_graph/test_compute_graph.py +++ b/tests/test_candidate_graph/test_compute_graph.py @@ -36,8 +36,8 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): segmentation=segmentation_2d, max_edge_distance=15, ) - assert Counter(list(cand_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert Counter(list(cand_graph.edges)) == Counter([("0_1", "1_2")]) + assert Counter(list(cand_graph.nodes)) == Counter([1, 2, 3]) + assert Counter(list(cand_graph.edges)) == Counter([(1, 3)]) def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): @@ -88,9 +88,7 @@ def test_graph_from_multi_segmentation_2d( assert Counter(list(cand_graph.nodes)) == Counter( list(multi_hypothesis_graph_2d.nodes) ) - assert Counter(list(cand_graph.edges)) == Counter( - [("0_0_1", "1_0_2"), ("0_0_1", "1_1_2"), ("0_1_1", "1_1_2")] - ) + assert Counter(list(cand_graph.edges)) == Counter([(1, 5), (1, 6), (2, 6)]) def test_graph_from_points_list(): diff --git a/tests/test_candidate_graph/test_conflict_sets.py b/tests/test_candidate_graph/test_conflict_sets.py index a820609..669faef 100644 --- a/tests/test_candidate_graph/test_conflict_sets.py +++ b/tests/test_candidate_graph/test_conflict_sets.py @@ -5,13 +5,13 @@ def test_conflict_sets_2d(multi_hypothesis_segmentation_2d): for t in range(multi_hypothesis_segmentation_2d.shape[0]): - conflict_set = compute_conflict_sets(multi_hypothesis_segmentation_2d[:, t], t) + conflict_set = compute_conflict_sets(multi_hypothesis_segmentation_2d[:, t]) if t == 0: - expected = [{"0_1_1", "0_0_1"}] + expected = [{2, 1}] assert len(conflict_set) == 1 assert conflict_set == unordered(expected) elif t == 1: - expected = [{"1_0_2", "1_1_2"}, {"1_0_1", "1_1_1"}] + expected = [{3, 4}, {5, 6}] assert len(conflict_set) == 2 assert conflict_set == unordered(expected) @@ -23,17 +23,20 @@ def test_conflict_sets_2d_reshaped(multi_hypothesis_segmentation_2d): [ multi_hypothesis_segmentation_2d[0, 0], # hypothesis 0 multi_hypothesis_segmentation_2d[0, 1], # hypothesis 1 - multi_hypothesis_segmentation_2d[1, 1], + multi_hypothesis_segmentation_2d[ + 1, 1 + ], # hypothesis 2 (time 1 hypothesis 1) ] - ) # hypothesis 2 - conflict_set = compute_conflict_sets(reshaped, 0) + ) # this is simulating one frame of multi hypothesis data + conflict_set = compute_conflict_sets(reshaped) # note the expected ids are not really there since the # reshaped array is artifically constructed + expected = [ - {"0_0_1", "0_1_2", "0_2_2"}, - {"0_1_1", "0_2_1"}, - {"0_0_1", "0_1_2"}, - {"0_1_2", "0_2_2"}, - {"0_0_1", "0_2_2"}, + {1, 5, 6}, + {3, 4}, + {1, 5}, + {5, 6}, + {1, 6}, ] assert conflict_set == unordered(expected) diff --git a/tests/test_candidate_graph/test_iou.py b/tests/test_candidate_graph/test_iou.py index cba2906..2300353 100644 --- a/tests/test_candidate_graph/test_iou.py +++ b/tests/test_candidate_graph/test_iou.py @@ -7,25 +7,25 @@ def test_compute_ious_2d(segmentation_2d): ious = _compute_ious(segmentation_2d[0], segmentation_2d[1]) expected = [ - (1, 2, 555.46 / 1408.0), + (1, 3, 555.46 / 1408.0), ] for iou, expected_iou in zip(ious, expected, strict=False): assert iou == pytest.approx(expected_iou, abs=0.01) ious = _compute_ious(segmentation_2d[1], segmentation_2d[1]) - expected = [(1, 1, 1.0), (2, 2, 1.0)] + expected = [(2, 2, 1.0), (3, 3, 1.0)] for iou, expected_iou in zip(ious, expected, strict=False): assert iou == pytest.approx(expected_iou, abs=0.01) def test_compute_ious_3d(segmentation_3d): ious = _compute_ious(segmentation_3d[0], segmentation_3d[1]) - expected = [(1, 2, 0.30)] + expected = [(1, 3, 0.30)] for iou, expected_iou in zip(ious, expected, strict=False): assert iou == pytest.approx(expected_iou, abs=0.01) ious = _compute_ious(segmentation_3d[1], segmentation_3d[1]) - expected = [(1, 1, 1.0), (2, 2, 1.0)] + expected = [(2, 2, 1.0), (3, 3, 1.0)] for iou, expected_iou in zip(ious, expected, strict=False): assert iou == pytest.approx(expected_iou, abs=0.01) diff --git a/tests/test_candidate_graph/test_utils.py b/tests/test_candidate_graph/test_utils.py index 3513a82..d1be758 100644 --- a/tests/test_candidate_graph/test_utils.py +++ b/tests/test_candidate_graph/test_utils.py @@ -5,7 +5,6 @@ from motile_toolbox.candidate_graph import ( NodeAttr, add_cand_edges, - get_node_id, nodes_from_segmentation, ) from motile_toolbox.candidate_graph.utils import ( @@ -29,27 +28,27 @@ def test_nodes_from_segmentation_2d(segmentation_2d): node_graph, node_frame_dict = nodes_from_segmentation( segmentation=segmentation_2d, ) - assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 - assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 - assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 305 - assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 80) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 305 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 80) - assert node_frame_dict[0] == ["0_1"] - assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) # test with scaling node_graph, node_frame_dict = nodes_from_segmentation( segmentation=segmentation_2d, scale=[1, 1, 2] ) - assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 - assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 - assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 610 - assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 160) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 610 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 160) - assert node_frame_dict[0] == ["0_1"] - assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) def test_nodes_from_segmentation_3d(segmentation_3d): @@ -57,27 +56,27 @@ def test_nodes_from_segmentation_3d(segmentation_3d): node_graph, node_frame_dict = nodes_from_segmentation( segmentation=segmentation_3d, ) - assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 - assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 - assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 4169 - assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 50, 80) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 4169 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 50, 80) - assert node_frame_dict[0] == ["0_1"] - assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) # test with scaling node_graph, node_frame_dict = nodes_from_segmentation( segmentation=segmentation_3d, scale=[1, 1, 4.5, 1] ) - assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 - assert node_graph.nodes["1_1"][NodeAttr.AREA.value] == 4169 * 4.5 - assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 - assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20.0, 225.0, 80.0) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 4169 * 4.5 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20.0, 225.0, 80.0) - assert node_frame_dict[0] == ["0_1"] - assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) # add_cand_edges @@ -90,21 +89,17 @@ def test_add_cand_edges_2d(graph_2d): def test_add_cand_edges_3d(graph_3d): cand_graph = nx.create_empty_copy(graph_3d) add_cand_edges(cand_graph, max_edge_distance=15) - graph_3d.remove_edge("0_1", "1_1") + graph_3d.remove_edge(1, 2) assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) -def test_get_node_id(): - assert get_node_id(0, 2) == "0_2" - - def test_compute_node_frame_dict(graph_2d): node_frame_dict = _compute_node_frame_dict(graph_2d) expected = { 0: [ - "0_1", + 1, ], - 1: ["1_1", "1_2"], + 1: [2, 3], } assert node_frame_dict == expected diff --git a/tests/test_utils/test_relabel_segmentation.py b/tests/test_utils/test_relabel_segmentation.py index 57d796f..7c9cdf5 100644 --- a/tests/test_utils/test_relabel_segmentation.py +++ b/tests/test_utils/test_relabel_segmentation.py @@ -1,9 +1,67 @@ import numpy as np -from motile_toolbox.utils import relabel_segmentation +import pytest +from motile_toolbox.utils import ( + ensure_unique_labels, + relabel_segmentation_with_track_id, +) from numpy.testing import assert_array_equal from skimage.draw import disk +@pytest.fixture +def segmentation_2d_repeat_labels(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 1 + # second cell centered at (60, 45) with label 2 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 1 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 2 + return segmentation + + +@pytest.fixture +def multi_hypothesis_segmentation_2d_repeat_labels(): + """ + Creates a multi-hypothesis version of the `segmentation_2d` fixture defined above. + + """ + frame_shape = (100, 100) + total_shape = (2, 2, *frame_shape) # 2 hypotheses, 2 time points, H, W + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 (hypo 1) + rr0, cc0 = disk(center=(50, 50), radius=20, shape=frame_shape) + # make frame with one cell at (45, 45) with label 1 (hypo 2) + rr1, cc1 = disk(center=(45, 45), radius=15, shape=frame_shape) + + segmentation[0, 0][rr0, cc0] = 1 + segmentation[1, 0][rr1, cc1] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 1 + rr0, cc0 = disk(center=(20, 80), radius=10, shape=frame_shape) + rr1, cc1 = disk(center=(15, 75), radius=15, shape=frame_shape) + + segmentation[0, 1][rr0, cc0] = 1 + segmentation[1, 1][rr1, cc1] = 1 + + # second cell centered at (60, 45) with label 2 + rr0, cc0 = disk(center=(60, 45), radius=15, shape=frame_shape) + rr1, cc1 = disk(center=(55, 40), radius=20, shape=frame_shape) + + segmentation[0, 1][rr0, cc0] = 2 + segmentation[1, 1][rr1, cc1] = 2 + + return segmentation + + def test_relabel_segmentation(segmentation_2d, graph_2d): frame_shape = (100, 100) expected = np.zeros(segmentation_2d.shape, dtype="int32") @@ -15,11 +73,48 @@ def test_relabel_segmentation(segmentation_2d, graph_2d): rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) expected[1][rr, cc] = 1 - graph_2d.remove_node("1_2") - relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d) + graph_2d.remove_node(3) + relabeled_seg = relabel_segmentation_with_track_id(graph_2d, segmentation_2d) print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}") print(f"Nonzero expected: {np.count_nonzero(expected)}") print(f"Max relabeled: {np.max(relabeled_seg)}") print(f"Max expected: {np.max(expected)}") assert_array_equal(relabeled_seg, expected) + + +def test_ensure_unique_labels_2d(segmentation_2d_repeat_labels): + expected = segmentation_2d_repeat_labels.copy().astype(np.uint64) + frame = expected[1] + frame[frame == 2] = 3 + frame[frame == 1] = 2 + expected[1] = frame + + print(np.unique(expected[1], return_counts=True)) + result = ensure_unique_labels(segmentation_2d_repeat_labels) + assert_array_equal(expected, result) + + +def test_ensure_unique_labels_2d_multiseg( + multi_hypothesis_segmentation_2d_repeat_labels, +): + expected = multi_hypothesis_segmentation_2d_repeat_labels.copy().astype(np.uint64) + + # add 1 to the first hypothesis second frame + h0f1 = expected[0, 1] + h0f1[h0f1 == 2] = 3 + h0f1[h0f1 == 1] = 2 + expected[0, 1] = h0f1 + # add 3 to the second hypothesis first frame + h1f0 = expected[1, 0] + h1f0[h1f0 == 1] = 4 + expected[1, 0] = h1f0 + # add 4 to the second hypothesis second frame + h1f1 = expected[1, 1] + h1f1[h1f1 == 1] = 5 + h1f1[h1f1 == 2] = 6 + expected[1, 1] = h1f1 + result = ensure_unique_labels( + multi_hypothesis_segmentation_2d_repeat_labels, multiseg=True + ) + assert_array_equal(expected, result)