diff --git a/pyproject.toml b/pyproject.toml index af5444f..23bd142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ ignore = [ "D105", # Missing docstring in magic method "D107", # Missing docstring in `__init__ "D205", # 1 blank line required between summary and description + "S101", # Use of assert detected ] [tool.ruff.lint.per-file-ignores] diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py index 2c911cc..4fb4d39 100644 --- a/src/motile_toolbox/candidate_graph/compute_graph.py +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -15,6 +15,7 @@ def get_candidate_graph( segmentation: np.ndarray, max_edge_distance: float, iou: bool = False, + scale: list[float] | None = None, ) -> tuple[nx.DiGraph, list[set[Any]] | None]: """Construct a candidate graph from a segmentation array. Nodes are placed at the centroid of each segmentation and edges are added for all nodes in adjacent frames @@ -29,6 +30,9 @@ def get_candidate_graph( will by connected with a candidate edge. iou (bool, optional): Whether to include IOU on the candidate graph. Defaults to False. + scale (list[float] | None, optional): The scale of the segmentation data. + Will be used to rescale the point locations and attribute computations. + Defaults to None, which implies the data is isotropic. Returns: tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed @@ -37,7 +41,7 @@ def get_candidate_graph( num_hypotheses = segmentation.shape[1] # add nodes - cand_graph, node_frame_dict = nodes_from_segmentation(segmentation) + cand_graph, node_frame_dict = nodes_from_segmentation(segmentation, scale=scale) logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") # add edges @@ -47,6 +51,9 @@ def get_candidate_graph( node_frame_dict=node_frame_dict, ) if iou: + # Scale does not matter to IOU, because both numerator and denominator + # are scaled by the anisotropy. It would matter to compare IOUs across + # multiple scales of data, but this is not the current use case. add_iou(cand_graph, segmentation, node_frame_dict) logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index 902f4c0..7775410 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -36,6 +36,7 @@ def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> s def nodes_from_segmentation( segmentation: np.ndarray, + scale: list[float] | None = None, ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: """Extract candidate nodes from a segmentation. Also computes specified attributes. Returns a networkx graph with only nodes, and also a dictionary from frames to @@ -44,6 +45,10 @@ def nodes_from_segmentation( Args: segmentation (np.ndarray): A numpy array with integer labels and dimensions (t, h, [z], y, x), where h is the number of hypotheses. + scale (list[float] | None, optional): The scale of the segmentation data. + Will be used to rescale the point locations and attribute computations. + Defaults to None, which implies the data is isotropic. Should include + time and all spatial dimentsions. Returns: tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, @@ -54,6 +59,14 @@ def nodes_from_segmentation( node_frame_dict: dict[int, list[Any]] = {} print("Extracting nodes from segmentation") num_hypotheses = segmentation.shape[1] + if scale is None: + scale = [ + 1, + ] * (segmentation.ndim - 1) # don't include hypothesis + else: + assert ( + len(scale) == segmentation.ndim - 1 + ), f"Scale {scale} should have {segmentation.ndim - 1} dims" for t in tqdm(range(len(segmentation))): segs = segmentation[t] hypo_id: int | None @@ -61,7 +74,7 @@ def nodes_from_segmentation( if num_hypotheses == 1: hypo_id = None nodes_in_frame = [] - props = regionprops(hypo) + props = regionprops(hypo, spacing=tuple(scale[1:])) for regionprop in props: node_id = get_node_id(t, regionprop.label, hypothesis_id=hypo_id) attrs = { @@ -136,6 +149,17 @@ def _compute_node_frame_dict(cand_graph: nx.DiGraph) -> dict[int, list[Any]]: def create_kdtree(cand_graph: nx.DiGraph, node_ids: Iterable[Any]) -> KDTree: + """Create a kdtree with the given nodes from the candidate graph. + Will fail if provided node ids are not in the candidate graph. + + Args: + cand_graph (nx.DiGraph): A candidate graph + node_ids (Iterable[Any]): The nodes within the candidate graph to + include in the KDTree. Useful for limiting to one time frame. + + Returns: + KDTree: A KDTree containing the positions of the given nodes. + """ positions = [cand_graph.nodes[node][NodeAttr.POS.value] for node in node_ids] return KDTree(positions) diff --git a/tests/test_candidate_graph/test_utils.py b/tests/test_candidate_graph/test_utils.py index 3bd96ac..f03c404 100644 --- a/tests/test_candidate_graph/test_utils.py +++ b/tests/test_candidate_graph/test_utils.py @@ -37,6 +37,18 @@ def test_nodes_from_segmentation_2d(segmentation_2d): assert node_frame_dict[0] == ["0_1"] assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + # test with scaling + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_2d, scale=[1, 1, 2] + ) + assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 160) + + assert node_frame_dict[0] == ["0_1"] + assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + def test_nodes_from_segmentation_2d_hypo( multi_hypothesis_segmentation_2d, multi_hypothesis_graph_2d @@ -70,6 +82,18 @@ def test_nodes_from_segmentation_3d(segmentation_3d): assert node_frame_dict[0] == ["0_1"] assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + # test with scaling + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_3d, scale=[1, 1, 4.5, 1] + ) + assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20.0, 225.0, 80.0) + + assert node_frame_dict[0] == ["0_1"] + assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + # add_cand_edges def test_add_cand_edges_2d(graph_2d):