Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify API between single and multi hypothesis segmentations #9

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 7 additions & 25 deletions src/motile_toolbox/candidate_graph/compute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def get_candidate_graph(
segmentation: np.ndarray,
max_edge_distance: float,
iou: bool = False,
multihypo: bool = False,
) -> tuple[nx.DiGraph, list[set[Any]] | None]:
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
centroid of each segmentation and edges are added for all nodes in adjacent frames
Expand All @@ -24,36 +23,21 @@ def get_candidate_graph(

Args:
segmentation (np.ndarray): A numpy array with integer labels and dimensions
(t, [h], [z], y, x), where h is the number of hypotheses.
(t, h, [z], y, x), where h is the number of hypotheses.
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes with centroids within this distance in adjacent frames
will by connected with a candidate edge.
iou (bool, optional): Whether to include IOU on the candidate graph.
Defaults to False.
multihypo (bool, optional): Whether the segmentation contains multiple
hypotheses. Defaults to False.

Returns:
tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed
to the motile solver, and a list of conflicting node ids.
"""
num_hypotheses = segmentation.shape[1]

# add nodes
if multihypo:
cand_graph = nx.DiGraph()
num_frames = segmentation.shape[0]
node_frame_dict: dict[int, list[Any]] = {t: [] for t in range(num_frames)}
num_hypotheses = segmentation.shape[1]
for hypo_id in range(num_hypotheses):
hypothesis = segmentation[:, hypo_id]
node_graph, frame_dict = nodes_from_segmentation(
hypothesis, hypo_id=hypo_id
)
cand_graph.update(node_graph)
for t in range(num_frames):
if t in frame_dict:
node_frame_dict[t].extend(frame_dict[t])
else:
cand_graph, node_frame_dict = nodes_from_segmentation(segmentation)
cand_graph, node_frame_dict = nodes_from_segmentation(segmentation)
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")

# add edges
Expand All @@ -63,16 +47,14 @@ def get_candidate_graph(
node_frame_dict=node_frame_dict,
)
if iou:
add_iou(cand_graph, segmentation, node_frame_dict, multihypo=multihypo)
add_iou(cand_graph, segmentation, node_frame_dict)

logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")

# Compute conflict sets between segmentations
if multihypo:
conflicts = []
conflicts = []
if num_hypotheses > 1:
for time, segs in enumerate(segmentation):
conflicts.extend(compute_conflict_sets(segs, time))
else:
conflicts = None

return cand_graph, conflicts
17 changes: 6 additions & 11 deletions src/motile_toolbox/candidate_graph/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,23 @@ def _compute_ious(
return ious


def _get_iou_dict(segmentation, multihypo=False) -> dict[str, dict[str, float]]:
def _get_iou_dict(segmentation) -> dict[str, dict[str, float]]:
"""Get all ious values for the provided segmentation (all frames).
Will return as map from node_id -> dict[node_id] -> iou for easy
navigation when adding to candidate graph.

Args:
segmentation (np.ndarray): Segmentation that was used to create cand_graph.
Has shape (t, [h], [z], y, x), where h is the number of hypotheses.
multihypo (bool, optional): Whether or not the segmentation is multi hypothesis.
Defaults to False.
Has shape (t, h, [z], y, x), where h is the number of hypotheses.

Returns:
dict[str, dict[str, float]]: A map from node id to another dictionary, which
contains node_ids to iou values.
"""
iou_dict: dict[str, dict[str, float]] = {}
hypo_pairs: list[tuple[int | None, ...]]
if multihypo:
num_hypotheses = segmentation.shape[1]
num_hypotheses = segmentation.shape[1]
if num_hypotheses > 1:
hypo_pairs = list(product(range(num_hypotheses), repeat=2))
else:
hypo_pairs = [(None, None)]
Expand All @@ -86,25 +84,22 @@ def add_iou(
cand_graph: nx.DiGraph,
segmentation: np.ndarray,
node_frame_dict: dict[int, list[Any]] | None = None,
multihypo: bool = False,
) -> None:
"""Add IOU to the candidate graph.

Args:
cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated
segmentation (np.ndarray): segmentation that was used to create cand_graph.
Has shape (t, [h], [z], y, x), where h is the number of hypotheses.
Has shape (t, h, [z], y, x), where h is the number of hypotheses.
node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from
time frames to nodes in that frame. Will be computed if not provided,
but can be provided for efficiency (e.g. after running
nodes_from_segmentation). Defaults to None.
multihypo (bool, optional): Whether the segmentation contains multiple
hypotheses. Defaults to False.
"""
if node_frame_dict is None:
node_frame_dict = _compute_node_frame_dict(cand_graph)
frames = sorted(node_frame_dict.keys())
ious = _get_iou_dict(segmentation, multihypo=multihypo)
ious = _get_iou_dict(segmentation)
for frame in tqdm(frames):
if frame + 1 not in node_frame_dict.keys():
continue
Expand Down
52 changes: 28 additions & 24 deletions src/motile_toolbox/candidate_graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,45 +35,49 @@ def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> s


def nodes_from_segmentation(
segmentation: np.ndarray, hypo_id: int | None = None
segmentation: np.ndarray,
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
"""Extract candidate nodes from a segmentation. Also computes specified attributes.
Returns a networkx graph with only nodes, and also a dictionary from frames to
node_ids for efficient edge adding.

Args:
segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels
(0 is background, all pixels with value 1 belong to one cell, etc.). The
time dimension is first, followed by two or three position dimensions.
hypo_id (int | None, optional): An id to identify which layer of the multi-
hypothesis segmentation this is. Used to create node id, and is added
to each node if not None. Defaults to None.
segmentation (np.ndarray): A numpy array with integer labels and dimensions
(t, h, [z], y, x), where h is the number of hypotheses.

Returns:
tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,
and a mapping from time frames to node ids.
"""
cand_graph = nx.DiGraph()
# also construct a dictionary from time frame to node_id for efficiency
node_frame_dict = {}
node_frame_dict: dict[int, list[Any]] = {}
print("Extracting nodes from segmentation")
num_hypotheses = segmentation.shape[1]
for t in tqdm(range(len(segmentation))):
nodes_in_frame = []
props = regionprops(segmentation[t])
for regionprop in props:
node_id = get_node_id(t, regionprop.label, hypothesis_id=hypo_id)
attrs = {
NodeAttr.TIME.value: t,
}
attrs[NodeAttr.SEG_ID.value] = regionprop.label
if hypo_id is not None:
attrs[NodeAttr.SEG_HYPO.value] = hypo_id
centroid = regionprop.centroid # [z,] y, x
attrs[NodeAttr.POS.value] = centroid
cand_graph.add_node(node_id, **attrs)
nodes_in_frame.append(node_id)
if nodes_in_frame:
node_frame_dict[t] = nodes_in_frame
segs = segmentation[t]
hypo_id: int | None
for hypo_id, hypo in enumerate(segs):
if num_hypotheses == 1:
hypo_id = None
nodes_in_frame = []
props = regionprops(hypo)
for regionprop in props:
node_id = get_node_id(t, regionprop.label, hypothesis_id=hypo_id)
attrs = {
NodeAttr.TIME.value: t,
}
attrs[NodeAttr.SEG_ID.value] = regionprop.label
if hypo_id is not None:
attrs[NodeAttr.SEG_HYPO.value] = hypo_id
centroid = regionprop.centroid # [z,] y, x
attrs[NodeAttr.POS.value] = centroid
cand_graph.add_node(node_id, **attrs)
nodes_in_frame.append(node_id)
if nodes_in_frame:
if t not in node_frame_dict:
node_frame_dict[t] = []
node_frame_dict[t].extend(nodes_in_frame)
return cand_graph, node_frame_dict


Expand Down
2 changes: 1 addition & 1 deletion src/motile_toolbox/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .saving_utils import relabel_segmentation
from .relabel_segmentation import relabel_segmentation
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ def relabel_segmentation(
Args:
solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use
for relabeling. Nodes not in graph will be removed from seg. Original
segmentation ids have to be stored in the graph so we can map them back.
segmentation (np.ndarray): Original segmentation with labels ids that correspond
to segmentation id in graph.
frame_key (str, optional): Time frame key in networkx graph. Defaults to "t".
segmentation ids and hypothesis ids have to be stored in the graph so we
can map them back.
segmentation (np.ndarray): Original (potentially multi-hypothesis)
segmentation with dimensions (t,h,[z],y,x), where h is 1 for single
input segmentation.

Returns:
np.ndarray: Relabeled segmentation array where nodes in same track share same
id.
id with shape (t,1,[z],y,x)
"""
tracked_masks = np.zeros_like(segmentation)
output_shape = (segmentation.shape[0], 1, *segmentation.shape[2:])
tracked_masks = np.zeros_like(segmentation, shape=output_shape)
id_counter = 1
parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1]
soln_copy = solution_nx_graph.copy()
Expand All @@ -34,7 +36,13 @@ def relabel_segmentation(
for node in node_set:
time_frame = solution_nx_graph.nodes[node][NodeAttr.TIME.value]
previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value]
previous_seg_mask = segmentation[time_frame] == previous_seg_id
tracked_masks[time_frame][previous_seg_mask] = id_counter
if NodeAttr.SEG_HYPO.value in solution_nx_graph.nodes[node]:
hypothesis_id = solution_nx_graph.nodes[node][NodeAttr.SEG_HYPO.value]
else:
hypothesis_id = 0
previous_seg_mask = (
segmentation[time_frame, hypothesis_id] == previous_seg_id
)
tracked_masks[time_frame, 0][previous_seg_mask] = id_counter
id_counter += 1
return tracked_masks
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def segmentation_2d():
rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape)
segmentation[1][rr, cc] = 2

return segmentation
return np.expand_dims(segmentation, 1)


@pytest.fixture
Expand Down Expand Up @@ -210,7 +210,7 @@ def segmentation_3d():
mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape)
segmentation[1][mask] = 2

return segmentation
return np.expand_dims(segmentation, 1)


@pytest.fixture
Expand Down
2 changes: 0 additions & 2 deletions tests/test_candidate_graph/test_compute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def test_graph_from_multi_segmentation_2d(
segmentation=multi_hypothesis_segmentation_2d,
max_edge_distance=100,
iou=True,
multihypo=True,
)
assert Counter(list(cand_graph.nodes)) == Counter(
list(multi_hypothesis_graph_2d.nodes)
Expand All @@ -87,7 +86,6 @@ def test_graph_from_multi_segmentation_2d(
cand_graph, _ = get_candidate_graph(
segmentation=multi_hypothesis_segmentation_2d,
max_edge_distance=14,
multihypo=True,
)
assert Counter(list(cand_graph.nodes)) == Counter(
list(multi_hypothesis_graph_2d.nodes)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_candidate_graph/test_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_multi_hypo_iou_2d(multi_hypothesis_segmentation_2d, multi_hypothesis_gr
expected = multi_hypothesis_graph_2d
input_graph = multi_hypothesis_graph_2d.copy()
nx.set_edge_attributes(input_graph, -1, name=EdgeAttr.IOU.value)
add_iou(input_graph, multi_hypothesis_segmentation_2d, multihypo=True)
add_iou(input_graph, multi_hypothesis_segmentation_2d)
for s, t, attrs in expected.edges(data=True):
print(s, t)
assert (
Expand Down
16 changes: 10 additions & 6 deletions tests/test_candidate_graph/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def test_nodes_from_segmentation_empty():
# test with empty segmentation
empty_graph, node_frame_dict = nodes_from_segmentation(
np.zeros((3, 10, 10), dtype="int32")
np.zeros((3, 1, 10, 10), dtype="int32")
)
assert Counter(empty_graph.nodes) == Counter([])
assert node_frame_dict == {}
Expand All @@ -37,19 +37,23 @@ def test_nodes_from_segmentation_2d(segmentation_2d):
assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"])


def test_nodes_from_segmentation_2d_hypo(segmentation_2d):
def test_nodes_from_segmentation_2d_hypo(
multi_hypothesis_segmentation_2d, multi_hypothesis_graph_2d
):
# test with 2D segmentation
node_graph, node_frame_dict = nodes_from_segmentation(
segmentation=segmentation_2d, hypo_id=0
segmentation=multi_hypothesis_segmentation_2d
)
assert Counter(list(node_graph.nodes)) == Counter(
list(multi_hypothesis_graph_2d.nodes)
)
assert Counter(list(node_graph.nodes)) == Counter(["0_0_1", "1_0_1", "1_0_2"])
assert node_graph.nodes["1_0_1"][NodeAttr.SEG_ID.value] == 1
assert node_graph.nodes["1_0_1"][NodeAttr.SEG_HYPO.value] == 0
assert node_graph.nodes["1_0_1"][NodeAttr.TIME.value] == 1
assert node_graph.nodes["1_0_1"][NodeAttr.POS.value] == (20, 80)

assert node_frame_dict[0] == ["0_0_1"]
assert Counter(node_frame_dict[1]) == Counter(["1_0_1", "1_0_2"])
assert Counter(node_frame_dict[0]) == Counter(["0_0_1", "0_1_1"])
assert Counter(node_frame_dict[1]) == Counter(["1_0_1", "1_0_2", "1_1_1", "1_1_2"])


def test_nodes_from_segmentation_3d(segmentation_3d):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ def test_relabel_segmentation(segmentation_2d, graph_2d):
expected = np.zeros(segmentation_2d.shape, dtype="int32")
# make frame with one cell in center with label 1
rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100))
expected[0][rr, cc] = 1
expected[0, 0][rr, cc] = 1

# make frame with cell centered at (20, 80) with label 1
rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape)
expected[1][rr, cc] = 1
expected[1, 0][rr, cc] = 1

graph_2d.remove_node("1_2")
relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d)
Expand Down
Loading