diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py index 4fb4d39..fc05e14 100644 --- a/src/motile_toolbox/candidate_graph/compute_graph.py +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -52,8 +52,7 @@ def get_candidate_graph( ) if iou: # Scale does not matter to IOU, because both numerator and denominator - # are scaled by the anisotropy. It would matter to compare IOUs across - # multiple scales of data, but this is not the current use case. + # are scaled by the anisotropy. add_iou(cand_graph, segmentation, node_frame_dict) logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") @@ -70,6 +69,7 @@ def get_candidate_graph( def get_candidate_graph_from_points_list( points_list: np.ndarray, max_edge_distance: float, + scale: list[float] | None = None, ) -> nx.DiGraph: """Construct a candidate graph from a points list. @@ -79,13 +79,17 @@ def get_candidate_graph_from_points_list( max_edge_distance (float): Maximum distance that objects can travel between frames. All nodes with centroids within this distance in adjacent frames will by connected with a candidate edge. + scale (list[float] | None, optional): Amount to scale the points in each + dimension. Only needed if the provided points are in "voxel" coordinates + instead of world coordinates. Defaults to None, which implies the data is + isotropic. Returns: nx.DiGraph: A candidate graph that can be passed to the motile solver. Multiple hypotheses not supported for points input. """ # add nodes - cand_graph, node_frame_dict = nodes_from_points_list(points_list) + cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale) logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") # add edges add_cand_edges( diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index feebf44..1e5d2e2 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any import networkx as nx import numpy as np @@ -64,7 +65,7 @@ def nodes_from_segmentation( cand_graph = nx.DiGraph() # also construct a dictionary from time frame to node_id for efficiency node_frame_dict: dict[int, list[Any]] = {} - print("Extracting nodes from segmentation") + logger.info("Extracting nodes from segmentation") num_hypotheses = segmentation.shape[1] if scale is None: scale = [ @@ -101,6 +102,7 @@ def nodes_from_segmentation( def nodes_from_points_list( points_list: np.ndarray, + scale: list[float] | None = None, ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: """Extract candidate nodes from a list of points. Uses the index of the point in the list as its unique id. @@ -110,6 +112,10 @@ def nodes_from_points_list( Args: points_list (np.ndarray): An NxD numpy array with N points and D (3 or 4) dimensions. Dimensions should be in order (t, [z], y, x). + scale (list[float] | None, optional): Amount to scale the points in each + dimension. Only needed if the provided points are in "voxel" coordinates + instead of world coordinates. Defaults to None, which implies the data is + isotropic. Returns: tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, @@ -118,7 +124,16 @@ def nodes_from_points_list( cand_graph = nx.DiGraph() # also construct a dictionary from time frame to node_id for efficiency node_frame_dict: dict[int, list[Any]] = {} - print("Extracting nodes from points list") + logger.info("Extracting nodes from points list") + + # scale points + if scale is not None: + assert ( + len(scale) == points_list.shape[1] + ), f"Cannot scale points with {points_list.size[1]} dims by factor {scale}" + points_list = points_list * np.array(scale) + + # add points to graph for i, point in enumerate(points_list): # assume t, [z], y, x t = point[0] @@ -187,7 +202,7 @@ def add_cand_edges( to node ids. If not provided, it will be computed from cand_graph. Defaults to None. """ - print("Extracting candidate edges") + logger.info("Extracting candidate edges") if not node_frame_dict: node_frame_dict = _compute_node_frame_dict(cand_graph) @@ -202,7 +217,9 @@ def add_cand_edges( matched_indices = prev_kdtree.query_ball_tree(next_kdtree, max_edge_distance) - for prev_node_id, next_node_indices in zip(prev_node_ids, matched_indices): + for prev_node_id, next_node_indices in zip( + prev_node_ids, matched_indices, strict=False + ): for next_node_index in next_node_indices: next_node_id = next_node_ids[next_node_index] cand_graph.add_edge(prev_node_id, next_node_id) diff --git a/tests/test_candidate_graph/test_compute_graph.py b/tests/test_candidate_graph/test_compute_graph.py index 19dcf92..755abf2 100644 --- a/tests/test_candidate_graph/test_compute_graph.py +++ b/tests/test_candidate_graph/test_compute_graph.py @@ -6,6 +6,7 @@ from motile_toolbox.candidate_graph.compute_graph import ( get_candidate_graph_from_points_list, ) +from motile_toolbox.candidate_graph.graph_attributes import NodeAttr def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): @@ -91,6 +92,7 @@ def test_graph_from_multi_segmentation_2d( def test_graph_from_points_list(): points_list = np.array( [ + # t, z, y, x [0, 1, 1, 1], [2, 3, 3, 3], [1, 2, 2, 2], @@ -101,3 +103,12 @@ def test_graph_from_points_list(): cand_graph = get_candidate_graph_from_points_list(points_list, max_edge_distance=3) assert cand_graph.number_of_edges() == 3 assert len(cand_graph.in_edges(3)) == 0 + + # test scale + cand_graph = get_candidate_graph_from_points_list( + points_list, max_edge_distance=3, scale=[1, 1, 1, 5] + ) + assert cand_graph.number_of_edges() == 0 + assert len(cand_graph.in_edges(3)) == 0 + assert cand_graph.nodes[0][NodeAttr.POS.value] == [1, 1, 5] + assert cand_graph.nodes[0][NodeAttr.TIME.value] == 0