Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute iou #4

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/motile_toolbox/candidate_graph/graph_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ class EdgeAttr(Enum):
"""

DISTANCE = "distance"
IOU = "iou"
96 changes: 83 additions & 13 deletions src/motile_toolbox/candidate_graph/graph_from_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def nodes_from_segmentation(
cand_graph = nx.DiGraph()
# also construct a dictionary from time frame to node_id for efficiency
node_frame_dict = {}

for t in range(len(segmentation)):
print("Extracting nodes from segmentaiton")
for t in tqdm(range(len(segmentation))):
nodes_in_frame = []
props = regionprops(segmentation[t])
for regionprop in props:
Expand All @@ -87,13 +87,36 @@ def nodes_from_segmentation(
return cand_graph, node_frame_dict


def _compute_node_frame_dict(
cand_graph: nx.DiGraph, frame_key: str = "t"
) -> dict[int, list[Any]]:
"""Compute dictionary from time frames to node ids for candidate graph.

Args:
cand_graph (nx.DiGraph): A networkx graph
frame_key (str, optional): Attribute key that holds the time frame of each
node in cand_graph. Defaults to "t".

Returns:
dict[int, list[Any]]: A mapping from time frames to lists of node ids.
"""
node_frame_dict: dict[int, list[Any]] = {}
for node, data in cand_graph.nodes(data=True):
t = data[frame_key]
if t not in node_frame_dict:
node_frame_dict[t] = []
node_frame_dict[t].append(node)
return node_frame_dict


def add_cand_edges(
cand_graph: nx.DiGraph,
max_edge_distance: float,
attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
node_frame_dict: None | dict[int, list[Any]] = None,
segmentation: None | np.ndarray = None,
) -> None:
"""Add candidate edges to a candidate graph by connecting all nodes in adjacent
frames that are closer than max_edge_distance. Also adds attributes to the edges.
Expand All @@ -104,23 +127,22 @@ def add_cand_edges(
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes within this distance in adjacent frames will by connected
with a candidate edge.
attributes (tuple[str, ...], optional): Set of attributes to compute and add to
graph.Valid attributes are: "distance". Defaults to ("distance",).
attributes (tuple[EdgeAttr, ...], optional): Set of attributes to compute and
add to graph. Defaults to (EdgeAttr.DISTANCE,).
position_keys (tuple[str, ...], optional): What the position dimensions of nodes
in the candidate graph are labeled. Defaults to ("y", "x").
frame_key (str, optional): The label of the time dimension in the candidate
graph. Defaults to "t".
node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames
to node ids. If not provided, it will be computed from cand_graph. Defaults
to None.
segmentation (np.ndarray, optional): The segmentation array for optionally
computing attributes such as IOU. Defaults to None.
"""
print("Extracting candidate edges")
if not node_frame_dict:
node_frame_dict = {}
for node, data in cand_graph.nodes(data=True):
t = data[frame_key]
if t not in node_frame_dict:
node_frame_dict[t] = []
node_frame_dict[t].append(node)
node_frame_dict = _compute_node_frame_dict(cand_graph, frame_key=frame_key)

frames = sorted(node_frame_dict.keys())
for frame in tqdm(frames):
if frame + 1 not in node_frame_dict:
Expand All @@ -130,17 +152,63 @@ def add_cand_edges(
_get_location(cand_graph.nodes[n], position_keys=position_keys)
for n in next_nodes
]
if EdgeAttr.IOU in attributes:
if segmentation is None:
raise ValueError("Can't compute IOU without segmentation.")
ious = compute_ious(segmentation[frame], segmentation[frame + 1])
for node in node_frame_dict[frame]:
loc = _get_location(cand_graph.nodes[node], position_keys=position_keys)
for next_id, next_loc in zip(next_nodes, next_locs):
dist = math.dist(next_loc, loc)
attrs = {}
if EdgeAttr.DISTANCE in attributes:
attrs[EdgeAttr.DISTANCE.value] = dist
if dist <= max_edge_distance:
attrs = {}
if EdgeAttr.DISTANCE in attributes:
attrs[EdgeAttr.DISTANCE.value] = dist
if EdgeAttr.IOU in attributes:
node_seg_id = cand_graph.nodes[node][NodeAttr.SEG_ID.value]
next_seg_id = cand_graph.nodes[next_id][NodeAttr.SEG_ID.value]
attrs[EdgeAttr.IOU.value] = ious.get(node_seg_id, {}).get(
next_seg_id, 0
)
cand_graph.add_edge(node, next_id, **attrs)


def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]:
"""Compute label IOUs between two label arrays of the same shape. Ignores background
(label 0).

Args:
frame1 (np.ndarray): Array with integer labels
frame2 (np.ndarray): Array with integer labels

Returns:
dict[int, dict[int, float]]: Dictionary from labels in frame 1 to labels in
frame 2 to iou values. Nodes that have no overlap are not included.
"""
frame1 = frame1.flatten()
frame2 = frame2.flatten()
# get indices where both are not zero (ignore background)
# this speeds up computation significantly
non_zero_indices = np.logical_and(frame1, frame2)
flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]])

values, counts = np.unique(flattened_stacked, axis=1, return_counts=True)
frame1_values, frame1_counts = np.unique(frame1, return_counts=True)
frame1_label_sizes = dict(zip(frame1_values, frame1_counts))
frame2_values, frame2_counts = np.unique(frame2, return_counts=True)
frame2_label_sizes = dict(zip(frame2_values, frame2_counts))
iou_dict: dict[int, dict[int, float]] = {}
for index in range(values.shape[1]):
pair = values[:, index]
intersection = counts[index]
id1, id2 = pair
union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection
if id1 not in iou_dict:
iou_dict[id1] = {}
iou_dict[id1][id2] = intersection / union
return iou_dict


def graph_from_segmentation(
segmentation: np.ndarray,
max_edge_distance: float,
Expand Down Expand Up @@ -189,6 +257,7 @@ def graph_from_segmentation(
f"({segmentation.ndim - 1})"
)
# add nodes

cand_graph, node_frame_dict = nodes_from_segmentation(
segmentation, node_attributes, position_keys=position_keys, frame_key=frame_key
)
Expand All @@ -201,6 +270,7 @@ def graph_from_segmentation(
attributes=edge_attributes,
position_keys=position_keys,
node_frame_dict=node_frame_dict,
segmentation=segmentation,
)

logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")
Expand Down
54 changes: 49 additions & 5 deletions tests/test_candidate_graph/test_graph_from_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import networkx as nx
import numpy as np
import pytest
from motile_toolbox.candidate_graph import NodeAttr
from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr
from motile_toolbox.candidate_graph.graph_from_segmentation import (
add_cand_edges,
compute_ious,
graph_from_segmentation,
nodes_from_segmentation,
)
Expand Down Expand Up @@ -41,8 +42,8 @@ def graph_2d():
("1_2", {"y": 60, "x": 45, "t": 1, "segmentation_id": 2}),
]
edges = [
("0_1", "1_1", {"distance": 42.43}),
("0_1", "1_2", {"distance": 11.18}),
("0_1", "1_1", {"distance": 42.43, "iou": 0.0}),
("0_1", "1_2", {"distance": 11.18, "iou": 0.395}),
]
graph.add_nodes_from(nodes)
graph.add_edges_from(edges)
Expand Down Expand Up @@ -153,7 +154,10 @@ def test_add_cand_edges_2d(graph_2d):
add_cand_edges(cand_graph, max_edge_distance=50)
assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges))
for edge in cand_graph.edges:
assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_2d.edges[edge]
assert (
pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01)
== graph_2d.edges[edge][EdgeAttr.DISTANCE.value]
)


def test_add_cand_edges_3d(graph_3d):
Expand Down Expand Up @@ -192,13 +196,21 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d):
cand_graph = graph_from_segmentation(
segmentation=segmentation_2d,
max_edge_distance=100,
edge_attributes=[EdgeAttr.DISTANCE, EdgeAttr.IOU],
)
assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes))
assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges))
for node in cand_graph.nodes:
assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node])
for edge in cand_graph.edges:
assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_2d.edges[edge]
assert (
pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01)
== graph_2d.edges[edge][EdgeAttr.DISTANCE.value]
)
assert (
pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01)
== graph_2d.edges[edge][EdgeAttr.IOU.value]
)

# lower edge distance
cand_graph = graph_from_segmentation(
Expand All @@ -225,3 +237,35 @@ def test_graph_from_segmentation_3d(segmentation_3d, graph_3d):
assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node])
for edge in cand_graph.edges:
assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge]


def test_compute_ious_2d(segmentation_2d):
ious = compute_ious(segmentation_2d[0], segmentation_2d[1])
expected = {1: {2: 555.46 / 1408.0}}
assert ious.keys() == expected.keys()
assert ious[1].keys() == expected[1].keys()
assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1)

ious = compute_ious(segmentation_2d[1], segmentation_2d[1])
expected = {1: {1: 1.0}, 2: {2: 1.0}}
assert ious.keys() == expected.keys()
assert ious[1].keys() == expected[1].keys()
assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1)
assert ious[2].keys() == expected[2].keys()
assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1)


def test_compute_ious_3d(segmentation_3d):
ious = compute_ious(segmentation_3d[0], segmentation_3d[1])
expected = {1: {2: 0.30}}
assert ious.keys() == expected.keys()
assert ious[1].keys() == expected[1].keys()
assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1)

ious = compute_ious(segmentation_3d[1], segmentation_3d[1])
expected = {1: {1: 1.0}, 2: {2: 1.0}}
assert ious.keys() == expected.keys()
assert ious[1].keys() == expected[1].keys()
assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1)
assert ious[2].keys() == expected[2].keys()
assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1)
Loading