Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi hypothesis #5

Merged
merged 20 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ repos:
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.2.2
hooks:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = [
'pdoc',
'pre-commit',
'types-tqdm',
'pytest-unordered'
]

[project.urls]
Expand Down
4 changes: 3 additions & 1 deletion src/motile_toolbox/candidate_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .compute_graph import get_candidate_graph
from .graph_attributes import EdgeAttr, NodeAttr
from .graph_from_segmentation import graph_from_segmentation
from .graph_to_nx import graph_to_nx
from .iou import add_iou
from .utils import add_cand_edges, get_node_id, nodes_from_segmentation
78 changes: 78 additions & 0 deletions src/motile_toolbox/candidate_graph/compute_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging
from typing import Any

import networkx as nx
import numpy as np

from .conflict_sets import compute_conflict_sets
from .iou import add_iou
from .utils import add_cand_edges, nodes_from_segmentation

logger = logging.getLogger(__name__)


def get_candidate_graph(
segmentation: np.ndarray,
max_edge_distance: float,
iou: bool = False,
multihypo: bool = False,
) -> tuple[nx.DiGraph, list[set[Any]] | None]:
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
centroid of each segmentation and edges are added for all nodes in adjacent frames
within max_edge_distance. If segmentation contains multiple hypotheses, will also
return a list of conflicting node ids that cannot be selected together.

Args:
segmentation (np.ndarray): A numpy array with integer labels and dimensions
(t, [h], [z], y, x), where h is the number of hypotheses.
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes with centroids within this distance in adjacent frames
will by connected with a candidate edge.
iou (bool, optional): Whether to include IOU on the candidate graph.
Defaults to False.
multihypo (bool, optional): Whether the segmentation contains multiple
hypotheses. Defaults to False.

Returns:
tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed
to the motile solver, and a list of conflicting node ids.
"""
# add nodes
if multihypo:
cand_graph = nx.DiGraph()
num_frames = segmentation.shape[0]
node_frame_dict: dict[int, list[Any]] = {t: [] for t in range(num_frames)}
num_hypotheses = segmentation.shape[1]
for hypo_id in range(num_hypotheses):
hypothesis = segmentation[:, hypo_id]
node_graph, frame_dict = nodes_from_segmentation(
hypothesis, hypo_id=hypo_id
)
cand_graph.update(node_graph)
for t in range(num_frames):
if t in frame_dict:
node_frame_dict[t].extend(frame_dict[t])
else:
cand_graph, node_frame_dict = nodes_from_segmentation(segmentation)
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")

# add edges
add_cand_edges(
cand_graph,
max_edge_distance=max_edge_distance,
node_frame_dict=node_frame_dict,
)
if iou:
add_iou(cand_graph, segmentation, node_frame_dict, multihypo=multihypo)

logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")

# Compute conflict sets between segmentations
if multihypo:
conflicts = []
for time, segs in enumerate(segmentation):
conflicts.extend(compute_conflict_sets(segs, time))
else:
conflicts = None

return cand_graph, conflicts
44 changes: 44 additions & 0 deletions src/motile_toolbox/candidate_graph/conflict_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from itertools import combinations

import numpy as np

from .utils import (
get_node_id,
)


def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]:
"""Compute all sets of node ids that conflict with each other.
Note: Results might include redundant sets, for example {a, b, c} and {a, b}
might both appear in the results.

Args:
segmentation_frame (np.ndarray): One frame of the multiple hypothesis
segmentation. Dimensions are (h, [z], y, x), where h is the number of
hypotheses.
time (int): Time frame, for computing node_ids.

Returns:
list[set]: list of sets of node ids that overlap. Might include some sets
that are subsets of others.
"""
flattened_segs = [seg.flatten() for seg in segmentation_frame]

# get locations where at least two hypotheses have labels
# This approach may be inefficient, but likely doesn't matter compared to np.unique
conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool)
for seg1, seg2 in combinations(flattened_segs, 2):
non_zero_indices = np.logical_and(seg1, seg2)
conflict_indices = np.logical_or(conflict_indices, non_zero_indices)

flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs])
values = np.unique(flattened_stacked, axis=1)
values = np.transpose(values)
conflict_sets = []
for conflicting_labels in values:
id_set = set()
for hypo_id, label in enumerate(conflicting_labels):
if label != 0:
id_set.add(get_node_id(time, label, hypo_id))
conflict_sets.append(id_set)
return conflict_sets
5 changes: 4 additions & 1 deletion src/motile_toolbox/candidate_graph/graph_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ class NodeAttr(Enum):
implementations of commonly used ones, listed here.
"""

SEG_ID = "segmentation_id"
POS = "pos"
TIME = "time"
SEG_ID = "seg_id"
SEG_HYPO = "seg_hypo"


class EdgeAttr(Enum):
Expand Down
Loading
Loading