Skip to content

Commit

Permalink
Add option to scale points when making candidate graph
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Sep 17, 2024
1 parent 6667b81 commit 684f761
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
10 changes: 7 additions & 3 deletions src/motile_toolbox/candidate_graph/compute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def get_candidate_graph(
)
if iou:
# Scale does not matter to IOU, because both numerator and denominator
# are scaled by the anisotropy. It would matter to compare IOUs across
# multiple scales of data, but this is not the current use case.
# are scaled by the anisotropy.
add_iou(cand_graph, segmentation, node_frame_dict)

logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")
Expand All @@ -70,6 +69,7 @@ def get_candidate_graph(
def get_candidate_graph_from_points_list(
points_list: np.ndarray,
max_edge_distance: float,
scale: list[float] | None = None,
) -> nx.DiGraph:
"""Construct a candidate graph from a points list.
Expand All @@ -79,13 +79,17 @@ def get_candidate_graph_from_points_list(
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes with centroids within this distance in adjacent frames
will by connected with a candidate edge.
scale (list[float] | None, optional): Amount to scale the points in each
dimension. Only needed if the provided points are in "voxel" coordinates
instead of world coordinates. Defaults to None, which implies the data is
isotropic.
Returns:
nx.DiGraph: A candidate graph that can be passed to the motile solver.
Multiple hypotheses not supported for points input.
"""
# add nodes
cand_graph, node_frame_dict = nodes_from_points_list(points_list)
cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale)
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")
# add edges
add_cand_edges(
Expand Down
27 changes: 22 additions & 5 deletions src/motile_toolbox/candidate_graph/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Any, Iterable
from collections.abc import Iterable
from typing import Any

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -64,7 +65,7 @@ def nodes_from_segmentation(
cand_graph = nx.DiGraph()
# also construct a dictionary from time frame to node_id for efficiency
node_frame_dict: dict[int, list[Any]] = {}
print("Extracting nodes from segmentation")
logger.info("Extracting nodes from segmentation")
num_hypotheses = segmentation.shape[1]
if scale is None:
scale = [
Expand Down Expand Up @@ -101,6 +102,7 @@ def nodes_from_segmentation(

def nodes_from_points_list(
points_list: np.ndarray,
scale: list[float] | None = None,
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
"""Extract candidate nodes from a list of points. Uses the index of the
point in the list as its unique id.
Expand All @@ -110,6 +112,10 @@ def nodes_from_points_list(
Args:
points_list (np.ndarray): An NxD numpy array with N points and D
(3 or 4) dimensions. Dimensions should be in order (t, [z], y, x).
scale (list[float] | None, optional): Amount to scale the points in each
dimension. Only needed if the provided points are in "voxel" coordinates
instead of world coordinates. Defaults to None, which implies the data is
isotropic.
Returns:
tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,
Expand All @@ -118,7 +124,16 @@ def nodes_from_points_list(
cand_graph = nx.DiGraph()
# also construct a dictionary from time frame to node_id for efficiency
node_frame_dict: dict[int, list[Any]] = {}
print("Extracting nodes from points list")
logger.info("Extracting nodes from points list")

# scale points
if scale is not None:
assert (
len(scale) == points_list.shape[1]
), f"Cannot scale points with {points_list.size[1]} dims by factor {scale}"
points_list = points_list * np.array(scale)

# add points to graph
for i, point in enumerate(points_list):
# assume t, [z], y, x
t = point[0]
Expand Down Expand Up @@ -187,7 +202,7 @@ def add_cand_edges(
to node ids. If not provided, it will be computed from cand_graph. Defaults
to None.
"""
print("Extracting candidate edges")
logger.info("Extracting candidate edges")
if not node_frame_dict:
node_frame_dict = _compute_node_frame_dict(cand_graph)

Expand All @@ -202,7 +217,9 @@ def add_cand_edges(

matched_indices = prev_kdtree.query_ball_tree(next_kdtree, max_edge_distance)

for prev_node_id, next_node_indices in zip(prev_node_ids, matched_indices):
for prev_node_id, next_node_indices in zip(
prev_node_ids, matched_indices, strict=False
):
for next_node_index in next_node_indices:
next_node_id = next_node_ids[next_node_index]
cand_graph.add_edge(prev_node_id, next_node_id)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_candidate_graph/test_compute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from motile_toolbox.candidate_graph.compute_graph import (
get_candidate_graph_from_points_list,
)
from motile_toolbox.candidate_graph.graph_attributes import NodeAttr


def test_graph_from_segmentation_2d(segmentation_2d, graph_2d):
Expand Down Expand Up @@ -91,6 +92,7 @@ def test_graph_from_multi_segmentation_2d(
def test_graph_from_points_list():
points_list = np.array(
[
# t, z, y, x
[0, 1, 1, 1],
[2, 3, 3, 3],
[1, 2, 2, 2],
Expand All @@ -101,3 +103,12 @@ def test_graph_from_points_list():
cand_graph = get_candidate_graph_from_points_list(points_list, max_edge_distance=3)
assert cand_graph.number_of_edges() == 3
assert len(cand_graph.in_edges(3)) == 0

# test scale
cand_graph = get_candidate_graph_from_points_list(
points_list, max_edge_distance=3, scale=[1, 1, 1, 5]
)
assert cand_graph.number_of_edges() == 0
assert len(cand_graph.in_edges(3)) == 0
assert cand_graph.nodes[0][NodeAttr.POS.value] == [1, 1, 5]
assert cand_graph.nodes[0][NodeAttr.TIME.value] == 0

0 comments on commit 684f761

Please sign in to comment.