Skip to content

Commit

Permalink
Merge pull request #25 from funkelab/13-unique-label-ids
Browse files Browse the repository at this point in the history
13 unique label ids
  • Loading branch information
cmalinmayor authored Nov 27, 2024
2 parents 4986526 + 3a8489b commit 6c337e4
Show file tree
Hide file tree
Showing 13 changed files with 253 additions and 233 deletions.
2 changes: 1 addition & 1 deletion src/motile_toolbox/candidate_graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from .graph_attributes import EdgeAttr, NodeAttr
from .graph_to_nx import graph_to_nx
from .iou import add_iou
from .utils import add_cand_edges, get_node_id, nodes_from_segmentation
from .utils import add_cand_edges, nodes_from_segmentation
2 changes: 1 addition & 1 deletion src/motile_toolbox/candidate_graph/compute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def compute_graph_from_multiseg(
conflicts = []
for time in range(segmentations.shape[1]):
segs = segmentations[:, time]
conflicts.extend(compute_conflict_sets(segs, time))
conflicts.extend(compute_conflict_sets(segs))

return cand_graph, conflicts

Expand Down
11 changes: 2 additions & 9 deletions src/motile_toolbox/candidate_graph/conflict_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@

import numpy as np

from .utils import (
get_node_id,
)


def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]:
def compute_conflict_sets(segmentation_frame: np.ndarray) -> list[set]:
"""Compute all sets of node ids that conflict with each other.
Note: Results might include redundant sets, for example {a, b, c} and {a, b}
might both appear in the results.
Expand Down Expand Up @@ -36,9 +32,6 @@ def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set
values = np.transpose(values)
conflict_sets = []
for conflicting_labels in values:
id_set = set()
for hypo_id, label in enumerate(conflicting_labels):
if label != 0:
id_set.add(get_node_id(time, label, hypo_id))
id_set = {label for label in conflicting_labels if label != 0}
conflict_sets.append(id_set)
return conflict_sets
24 changes: 8 additions & 16 deletions src/motile_toolbox/candidate_graph/iou.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from itertools import product
from typing import Any

import networkx as nx
import numpy as np
from tqdm import tqdm

from .graph_attributes import EdgeAttr
from .utils import _compute_node_frame_dict, get_node_id
from .utils import _compute_node_frame_dict


def _compute_ious(
Expand Down Expand Up @@ -45,7 +44,7 @@ def _compute_ious(
return ious


def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]:
def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]:
"""Get all ious values for the provided segmentations (all frames).
Will return as map from node_id -> dict[node_id] -> iou for easy
navigation when adding to candidate graph.
Expand All @@ -58,10 +57,10 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]:
multiple hypothesis segmentations. Defaults to False.
Returns:
dict[str, dict[str, float]]: A map from node id to another dictionary, which
dict[int, dict[int, float]]: A map from node id to another dictionary, which
contains node_ids to iou values.
"""
iou_dict: dict[str, dict[str, float]] = {}
iou_dict: dict[int, dict[int, float]] = {}
hypo_pairs: list[tuple[int, ...]] = [(0, 0)]
if multiseg:
num_hypotheses = segmentation.shape[0]
Expand All @@ -76,23 +75,16 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]:
seg2 = segmentation[hypo2][frame + 1]
ious = _compute_ious(seg1, seg2)
for label1, label2, iou in ious:
if multiseg:
node_id1 = get_node_id(frame, label1, hypo1)
node_id2 = get_node_id(frame + 1, label2, hypo2)
else:
node_id1 = get_node_id(frame, label1)
node_id2 = get_node_id(frame + 1, label2)

if node_id1 not in iou_dict:
iou_dict[node_id1] = {}
iou_dict[node_id1][node_id2] = iou
if label1 not in iou_dict:
iou_dict[label1] = {}
iou_dict[label1][label2] = iou
return iou_dict


def add_iou(
cand_graph: nx.DiGraph,
segmentation: np.ndarray,
node_frame_dict: dict[int, list[Any]] | None = None,
node_frame_dict: dict[int, list[int]] | None = None,
multiseg=False,
) -> None:
"""Add IOU to the candidate graph.
Expand Down
28 changes: 4 additions & 24 deletions src/motile_toolbox/candidate_graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,6 @@
logger = logging.getLogger(__name__)


def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> str:
"""Construct a node id given the time frame, segmentation label id, and
optionally the hypothesis id. This function is not designed for candidate graphs
that do not come from segmentations, but could be used if there is a similar
"detection id" that is unique for all cells detected in a given frame.
Args:
time (int): The time frame the node is in
label_id (int): The label the node has in the segmentation.
hypothesis_id (int | None, optional): An integer representing which hypothesis
the segmentation came from, if applicable. Defaults to None.
Returns:
str: A string to use as the node id in the candidate graph. Assuming that label
ids are not repeated in the same time frame and hypothesis, it is unique.
"""
if hypothesis_id is not None:
return f"{time}_{hypothesis_id}_{label_id}"
else:
return f"{time}_{label_id}"


def nodes_from_segmentation(
segmentation: np.ndarray,
scale: list[float] | None = None,
Expand All @@ -52,7 +30,9 @@ def nodes_from_segmentation(
Args:
segmentation (np.ndarray): A numpy array with integer labels and dimensions
(t, [z], y, x).
(t, [z], y, x). Labels must be unique across time, and the label
will be used as the node id. If the labels are not unique, preprocess
with motile_toolbox.utils.ensure_unqiue_ids before calling this function.
scale (list[float] | None, optional): The scale of the segmentation data in all
dimensions (including time, which should have a dummy 1 value).
Will be used to rescale the point locations and attribute computations.
Expand Down Expand Up @@ -82,7 +62,7 @@ def nodes_from_segmentation(
nodes_in_frame = []
props = regionprops(segs, spacing=tuple(scale[1:]))
for regionprop in props:
node_id = get_node_id(t, regionprop.label, hypothesis_id=seg_hypo)
node_id = regionprop.label
attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area}
attrs[NodeAttr.SEG_ID.value] = regionprop.label
if seg_hypo:
Expand Down
5 changes: 4 additions & 1 deletion src/motile_toolbox/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .relabel_segmentation import relabel_segmentation
from .relabel_segmentation import (
ensure_unique_labels,
relabel_segmentation_with_track_id,
)
31 changes: 30 additions & 1 deletion src/motile_toolbox/utils/relabel_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from motile_toolbox.candidate_graph import NodeAttr


def relabel_segmentation(
def relabel_segmentation_with_track_id(
solution_nx_graph: nx.DiGraph,
segmentation: np.ndarray,
) -> np.ndarray:
Expand Down Expand Up @@ -37,3 +37,32 @@ def relabel_segmentation(
tracked_masks[time_frame][previous_seg_mask] = id_counter
id_counter += 1
return tracked_masks


def ensure_unique_labels(
segmentation: np.ndarray,
multiseg: bool = False,
) -> np.ndarray:
"""Relabels the segmentation in place to ensure that label ids are unique across
time. This means that every detection will have a unique label id.
Useful for combining predictions made in each frame independently, or multiple
segmentation outputs that repeat label IDs.
Args:
segmentation (np.ndarray): Segmentation with dimensions ([h], t, [z], y, x).
multiseg (bool, optional): Flag indicating if the segmentation contains
multiple hypotheses in the first dimension. Defaults to False.
"""
segmentation = segmentation.astype(np.uint64)
orig_shape = segmentation.shape
if multiseg:
segmentation = segmentation.reshape((-1, *orig_shape[2:]))
curr_max = 0
for idx in range(segmentation.shape[0]):
frame = segmentation[idx]
frame[frame != 0] += curr_max
curr_max = int(np.max(frame))
segmentation[idx] = frame
if multiseg:
segmentation = segmentation.reshape(orig_shape)
return segmentation
Loading

0 comments on commit 6c337e4

Please sign in to comment.