From d5142b8d6505716fc3294d3e8a3dbf1c4924d0f7 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 28 Feb 2024 15:44:14 -0500 Subject: [PATCH 1/3] Fix sphere generation flipping axes and ruff format --- .pre-commit-config.yaml | 9 ++++----- .../test_graph_from_segmentation.py | 20 ++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0981915..ba39671 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,17 +14,16 @@ repos: - id: check-yaml - id: check-added-large-files - - repo: https://github.com/psf/black - rev: 23.1.0 - hooks: - - id: black - - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.0.252 hooks: - id: ruff args: [--fix] + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.0.1 diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_graph_from_segmentation.py index 06d52c1..83c9d44 100644 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ b/tests/test_candidate_graph/test_graph_from_segmentation.py @@ -1,15 +1,17 @@ -from motile_toolbox.candidate_graph import graph_from_segmentation -import pytest +from collections import Counter + import numpy as np +import pytest +from motile_toolbox.candidate_graph.graph_from_segmentation import ( + graph_from_segmentation, +) from skimage.draw import disk -from collections import Counter -import math @pytest.fixture def segmentation_2d(): frame_shape = (100, 100) - total_shape = (2,) + frame_shape + total_shape = (2, *frame_shape) segmentation = np.zeros(total_shape, dtype="int32") # make frame with one cell in center with label 1 rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) @@ -27,9 +29,9 @@ def segmentation_2d(): def sphere(center, radius, shape): - distance = np.linalg.norm( - np.subtract(np.indices(shape).T, np.asarray(center)), axis=len(center) - ) + assert len(center) == len(shape) + indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index + distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) mask = distance <= radius return mask @@ -37,7 +39,7 @@ def sphere(center, radius, shape): @pytest.fixture def segmentation_3d(): frame_shape = (100, 100, 100) - total_shape = (2,) + frame_shape + total_shape = (2, *frame_shape) segmentation = np.zeros(total_shape, dtype="int32") # make frame with one cell in center with label 1 mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) From 0f6f52991e06f007ff49a06cf3bd388f8060bd53 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 28 Feb 2024 17:10:48 -0500 Subject: [PATCH 2/3] Separate functions for cand nodes and edges --- .pre-commit-config.yaml | 2 +- pyproject.toml | 17 +- .../graph_from_segmentation.py | 227 +++++++++++++----- .../test_graph_from_segmentation.py | 183 ++++++++++---- 4 files changed, 327 insertions(+), 102 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ba39671..505b57f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: - id: check-added-large-files - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.252 + rev: v0.2.2 hooks: - id: ruff args: [--fix] diff --git a/pyproject.toml b/pyproject.toml index b28b134..955425c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,9 @@ omit = ["src/motile_toolbox/visualization/*"] [tool.ruff] line-length = 88 target-version = "py38" -extend-select = [ + +[tool.ruff.lint] +select = [ "E", # style errors "F", # flakes "I001", # isort @@ -58,8 +60,19 @@ extend-select = [ "B", # flake8-bugbear "A001", # flake8-builtins "RUF", # ruff-specific rules + "D", # documentation +] +ignore = [ + "D100", # Missing docstring in public mod + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D107", # Missing docstring in `__init__ + "D205", # 1 blank line required between summary and description ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "tests/*.py" = ["D", "S"] "*/__init__.py" = ["F401"] + +[tool.ruff.lint.pydocstyle] +convention = "google" diff --git a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py index 83717a3..f8f66b5 100644 --- a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py +++ b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py @@ -1,106 +1,225 @@ +import logging +import math +from typing import Any + import networkx as nx -from skimage.measure import regionprops import numpy as np -from typing import Iterable +from skimage.measure import regionprops from tqdm import tqdm -import logging -import math logger = logging.getLogger(__name__) -def get_location(node_data, loc_keys=("z", "y", "x")): - return [node_data[k] for k in loc_keys] +def _get_location( + node_data: dict[str, Any], position_keys: tuple[str, ...] | list[str] +) -> list[Any]: + """Convenience function to get the location of a networkx node when each dimension + is stored in a different attribute. + Args: + node_data (dict[str, Any]): Dictionary of attributes of a networkx node. + Assumes the provided position keys are in the dictionary. + position_keys (tuple[str, ...] | list[str], optional): Keys to use to get + location information from node_data (assumes they are present in node_data). + Defaults to ("z", "y", "x"). -def graph_from_segmentation( + Returns: + list: _description_ + Raises: + KeyError if position keys not in node_data + """ + return [node_data[k] for k in position_keys] + + +def nodes_from_segmentation( segmentation: np.ndarray, - max_edge_distance: float, - attributes: tuple[str, ...] | list[str] = ("distance",), + attributes: tuple[str, ...] | list[str] = ("segmentation_id",), position_keys: tuple[str, ...] | list[str] = ("y", "x"), frame_key: str = "t", -): - """Construct a candidate graph from a segmentation array. Nodes are placed at the centroid - of each segmentation and edges are added for all nodes in adjacent frames within - max_edge_distance. The specified attributes are computed during construction. - Node ids are strings with format "{time}_{label id}". +) -> tuple[nx.DiGraph, dict[int, list[Any]]]: + """Extract candidate nodes from a segmentation. Also computes specified attributes. + Returns a networkx graph with only nodes, and also a dictionary from frames to + node_ids for efficient edge adding. Args: segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels - (0 is background, all pixels with value 1 belong to one cell, etc.). - The time dimension is first, followed by two or three position dimensions. - If the position dims are not (y, x), use `position_keys` to specify the names of - the dimensions. - max_edge_distance (float): Maximum distance that objects can travel between frames. All - nodes within this distance in adjacent frames will by connected with a candidate edge. - attributes (tuple[str, ...], optional): Set of attributes to compute and add to graph. - Valid attributes are: "distance". Defaults to ("distance",). - position_keys (tuple[str, ...], optional): What to label the position dimensions in the - candidate graph. The order of the names corresponds to the order of the dimensions - in `segmentation`. Defaults to ("y", "x"). - frame_key (str, optional): What to label the time dimension in the candidate graph. - Defaults to 't'. + (0 is background, all pixels with value 1 belong to one cell, etc.). The + time dimension is first, followed by two or three position dimensions. If + the position dims are not (y, x), use `position_keys` to specify the names + of the dimensions. + attributes (tuple[str, ...] | list[str] , optional): Set of attributes to + compute and add to graph nodes. Valid attributes are: "segmentation_id". + Defaults to ("segmentation_id",). + position_keys (tuple[str, ...]| list[str] , optional): What to label the + position dimensions in the candidate graph. The order of the names + corresponds to the order of the dimensions in `segmentation`. Defaults to + ("y", "x"). + frame_key (str, optional): What to label the time dimension in the candidate + graph. Defaults to 't'. Returns: - nx.DiGraph: A candidate graph that can be passed to the motile solver. - - Raises: - ValueError: if unsupported attribute strings are passed in to the attributes argument, - or if the number of position keys provided does not match the number of position dimensions. + tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, + and a mapping from time frames to node ids. """ - valid_attributes = ["distance"] - for attr in attributes: - if attr not in valid_attributes: - raise ValueError( - f"Invalid attribute {attr} (supported attributes: {valid_attributes})" - ) - if len(position_keys) != segmentation.ndim - 1: - raise ValueError( - f"Position labels {position_keys} does not match number of spatial dims ({segmentation.ndim - 1})" - ) - # add nodes - node_frame_dict = ( - {} - ) # construct a dictionary from time frame to node_id for efficiency cand_graph = nx.DiGraph() + # also construct a dictionary from time frame to node_id for efficiency + node_frame_dict = {} for t in range(len(segmentation)): nodes_in_frame = [] props = regionprops(segmentation[t]) - for i, regionprop in enumerate(props): + for regionprop in props: node_id = f"{t}_{regionprop.label}" attrs = { frame_key: t, - "segmentation_id": regionprop.label, } + if "segmentation_id" in attributes: + attrs["segmentation_id"] = regionprop.label centroid = regionprop.centroid # [z,] y, x + print(f"centroid: {centroid}") for label, value in zip(position_keys, centroid): attrs[label] = value cand_graph.add_node(node_id, **attrs) nodes_in_frame.append(node_id) - node_frame_dict[t] = nodes_in_frame + if nodes_in_frame: + node_frame_dict[t] = nodes_in_frame + return cand_graph, node_frame_dict - logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") - # add edges +def add_cand_edges( + cand_graph: nx.DiGraph, + max_edge_distance: float, + attributes: tuple[str, ...] | list[str] = ("distance",), + position_keys: tuple[str, ...] | list[str] = ("y", "x"), + frame_key: str = "t", + node_frame_dict: None | dict[int, list[Any]] = None, +) -> None: + """Add candidate edges to a candidate graph by connecting all nodes in adjacent + frames that are closer than max_edge_distance. Also adds attributes to the edges. + + Args: + cand_graph (nx.DiGraph): Candidate graph with only nodes populated. Will + be modified in-place to add edges. + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes within this distance in adjacent frames will by connected + with a candidate edge. + attributes (tuple[str, ...], optional): Set of attributes to compute and add to + graph.Valid attributes are: "distance". Defaults to ("distance",). + position_keys (tuple[str, ...], optional): What the position dimensions of nodes + in the candidate graph are labeled. Defaults to ("y", "x"). + frame_key (str, optional): The label of the time dimension in the candidate + graph. Defaults to "t". + node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames + to node ids. If not provided, it will be computed from cand_graph. Defaults + to None. + """ + if not node_frame_dict: + node_frame_dict = {} + for node, data in cand_graph.nodes(data=True): + print(data) + t = data[frame_key] + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].append(node) + print(node_frame_dict) frames = sorted(node_frame_dict.keys()) for frame in tqdm(frames): + print(frame) if frame + 1 not in node_frame_dict: continue next_nodes = node_frame_dict[frame + 1] next_locs = [ - get_location(cand_graph.nodes[n], loc_keys=position_keys) + _get_location(cand_graph.nodes[n], position_keys=position_keys) for n in next_nodes ] for node in node_frame_dict[frame]: - loc = get_location(cand_graph.nodes[node], loc_keys=position_keys) + loc = _get_location(cand_graph.nodes[node], position_keys=position_keys) for next_id, next_loc in zip(next_nodes, next_locs): dist = math.dist(next_loc, loc) attrs = {} if "distance" in attributes: attrs["distance"] = dist - if dist < max_edge_distance: + if dist <= max_edge_distance: cand_graph.add_edge(node, next_id, **attrs) + +def graph_from_segmentation( + segmentation: np.ndarray, + max_edge_distance: float, + node_attributes: tuple[str, ...] | list[str] = ("segmentation_id",), + edge_attributes: tuple[str, ...] | list[str] = ("distance",), + position_keys: tuple[str, ...] | list[str] = ("y", "x"), + frame_key: str = "t", +): + """Construct a candidate graph from a segmentation array. Nodes are placed at the + centroid of each segmentation and edges are added for all nodes in adjacent frames + within max_edge_distance. The specified attributes are computed during construction. + Node ids are strings with format "{time}_{label id}". + + Args: + segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels + (0 is background, all pixels with value 1 belong to one cell, etc.). The + time dimension is first, followed by two or three position dimensions. If + the position dims are not (y, x), use `position_keys` to specify the names + of the dimensions. + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes within this distance in adjacent frames will by connected + with a candidate edge. + node_attributes (tuple[str, ...] | list[str], optional): Set of attributes to + compute and add to nodes in graph. Valid attributes are: "segmentation_id". + Defaults to ("segmentation_id",). + edge_attributes (tuple[str, ...] | list[str], optional): Set of attributes to + compute and add to edges in graph. Valid attributes are: "distance". + Defaults to ("distance",). + position_keys (tuple[str, ...], optional): What to label the position dimensions + in the candidate graph. The order of the names corresponds to the order of + the dimensions in `segmentation`. Defaults to ("y", "x"). + frame_key (str, optional): What to label the time dimension in the candidate + graph. Defaults to 't'. + + Returns: + nx.DiGraph: A candidate graph that can be passed to the motile solver. + + Raises: + ValueError: if unsupported attribute strings are passed in to the attributes + arguments, or if the number of position keys provided does not match the + number of position dimensions. + """ + valid_edge_attributes = [ + "distance", + ] + for attr in edge_attributes: + if attr not in valid_edge_attributes: + raise ValueError( + f"Invalid attribute {attr} (supported attrs: {valid_edge_attributes})" + ) + valid_node_attributes = [ + "segmentation_id", + ] + for attr in node_attributes: + if attr not in valid_node_attributes: + raise ValueError( + f"Invalid attribute {attr} (supported attrs: {valid_node_attributes})" + ) + if len(position_keys) != segmentation.ndim - 1: + raise ValueError( + f"Position labels {position_keys} does not match number of spatial dims " + f"({segmentation.ndim - 1})" + ) + # add nodes + cand_graph, node_frame_dict = nodes_from_segmentation( + segmentation, node_attributes, position_keys=position_keys, frame_key=frame_key + ) + logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") + + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + attributes=edge_attributes, + position_keys=position_keys, + node_frame_dict=node_frame_dict, + ) + logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") return cand_graph diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_graph_from_segmentation.py index 83c9d44..1b5a314 100644 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ b/tests/test_candidate_graph/test_graph_from_segmentation.py @@ -1,9 +1,12 @@ from collections import Counter +import networkx as nx import numpy as np import pytest from motile_toolbox.candidate_graph.graph_from_segmentation import ( + add_cand_edges, graph_from_segmentation, + nodes_from_segmentation, ) from skimage.draw import disk @@ -28,6 +31,23 @@ def segmentation_2d(): return segmentation +@pytest.fixture +def graph_2d(): + graph = nx.DiGraph() + nodes = [ + ("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}), + ("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 1}), + ("1_2", {"y": 60, "x": 45, "t": 1, "segmentation_id": 2}), + ] + edges = [ + ("0_1", "1_1", {"distance": 42.43}), + ("0_1", "1_2", {"distance": 11.18}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + def sphere(center, radius, shape): assert len(center) == len(shape) indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index @@ -56,13 +76,108 @@ def segmentation_3d(): return segmentation +@pytest.fixture +def graph_3d(): + graph = nx.DiGraph() + nodes = [ + ("0_1", {"z": 50, "y": 50, "x": 50, "t": 0, "segmentation_id": 1}), + ("1_1", {"z": 20, "y": 50, "x": 80, "t": 1, "segmentation_id": 1}), + ("1_2", {"z": 60, "y": 50, "x": 45, "t": 1, "segmentation_id": 2}), + ] + edges = [ + # math.dist([50, 50], [20, 80]) + ("0_1", "1_1", {"distance": 42.43}), + # math.dist([50, 50], [60, 45]) + ("0_1", "1_2", {"distance": 11.18}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +# nodes_from_segmentation +def test_nodes_from_segmentation_empty(): + # test with empty segmentation + empty_graph, node_frame_dict = nodes_from_segmentation( + np.zeros((3, 10, 10), dtype="int32") + ) + assert Counter(empty_graph.nodes) == Counter([]) + assert node_frame_dict == {} + + +def test_nodes_from_segmentation_2d(segmentation_2d): + # test with 2D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_2d, + ) + assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert node_graph.nodes["1_1"]["segmentation_id"] == 1 + assert node_graph.nodes["1_1"]["t"] == 1 + assert node_graph.nodes["1_1"]["y"] == 20 + assert node_graph.nodes["1_1"]["x"] == 80 + + assert node_frame_dict[0] == ["0_1"] + assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + + # remove attrs + node_graph, _ = nodes_from_segmentation( + segmentation=segmentation_2d, + attributes=[], + ) + assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert "segmentation_id" not in node_graph.nodes["0_1"] + + +def test_nodes_from_segmentation_3d(segmentation_3d): + # test with 3D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_3d, + attributes=["segmentation_id"], + position_keys=("pos_z", "pos_y", "pos_x"), + ) + assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert node_graph.nodes["1_1"]["segmentation_id"] == 1 + assert node_graph.nodes["1_1"]["t"] == 1 + assert node_graph.nodes["1_1"]["pos_z"] == 20 + assert node_graph.nodes["1_1"]["pos_y"] == 50 + assert node_graph.nodes["1_1"]["pos_x"] == 80 + + assert node_frame_dict[0] == ["0_1"] + assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + + +# add_cand_edges +def test_add_cand_edges_2d(graph_2d): + cand_graph = nx.create_empty_copy(graph_2d) + add_cand_edges(cand_graph, max_edge_distance=50) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) + for edge in cand_graph.edges: + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_2d.edges[edge] + + +def test_add_cand_edges_3d(graph_3d): + cand_graph = nx.create_empty_copy(graph_3d) + add_cand_edges(cand_graph, max_edge_distance=15) + graph_3d.remove_edge("0_1", "1_1") + assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + for edge in cand_graph.edges: + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] + + +# graph_from_segmentation def test_graph_from_segmentation_invalid(): # test invalid attributes with pytest.raises(ValueError): graph_from_segmentation( np.zeros((3, 10, 10, 10), dtype="int32"), 10, - attributes=["invalid"], + edge_attributes=["invalid"], + ) + with pytest.raises(ValueError): + graph_from_segmentation( + np.zeros((3, 10, 10, 10), dtype="int32"), + 10, + node_attributes=["invalid"], ) with pytest.raises(ValueError): @@ -71,63 +186,41 @@ def test_graph_from_segmentation_invalid(): ) -def test_graph_from_segmentation_empty(): - empty_graph = graph_from_segmentation(np.zeros((3, 10, 10), dtype="int32"), 10) - assert Counter(empty_graph.nodes) == Counter([]) - - -def test_graph_from_segmentation_2d(segmentation_2d): +def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): # test with 2D segmentation - graph_2d = graph_from_segmentation( + cand_graph = graph_from_segmentation( segmentation=segmentation_2d, max_edge_distance=100, ) - assert Counter(list(graph_2d.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert Counter(list(graph_2d.edges)) == Counter([("0_1", "1_1"), ("0_1", "1_2")]) - assert graph_2d.nodes["0_1"]["segmentation_id"] == 1 - assert graph_2d.nodes["0_1"]["t"] == 0 - assert graph_2d.nodes["0_1"]["y"] == 50 - assert graph_2d.nodes["0_1"]["x"] == 50 - assert graph_2d.edges[("0_1", "1_1")]["distance"] == pytest.approx(42.43, abs=0.01) - # math.dist([50, 50], [20, 80]) - assert graph_2d.edges[("0_1", "1_2")]["distance"] == pytest.approx(11.18, abs=0.01) - # math.dist([50, 50], [60, 45]) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) + for edge in cand_graph.edges: + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_2d.edges[edge] # lower edge distance - graph_2d = graph_from_segmentation( + cand_graph = graph_from_segmentation( segmentation=segmentation_2d, max_edge_distance=15, ) - assert Counter(list(graph_2d.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert Counter(list(graph_2d.edges)) == Counter([("0_1", "1_2")]) - assert graph_2d.edges[("0_1", "1_2")]["distance"] == pytest.approx(11.18, abs=0.01) + assert Counter(list(cand_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert Counter(list(cand_graph.edges)) == Counter([("0_1", "1_2")]) + assert cand_graph.edges[("0_1", "1_2")]["distance"] == pytest.approx( + 11.18, abs=0.01 + ) -def test_graph_from_segmentation_3d(segmentation_3d): +def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): # test with 3D segmentation - graph_3d = graph_from_segmentation( + cand_graph = graph_from_segmentation( segmentation=segmentation_3d, max_edge_distance=100, position_keys=("z", "y", "x"), ) - assert Counter(list(graph_3d.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert Counter(list(graph_3d.edges)) == Counter([("0_1", "1_1"), ("0_1", "1_2")]) - assert graph_3d.nodes["0_1"]["segmentation_id"] == 1 - assert graph_3d.nodes["0_1"]["t"] == 0 - assert graph_3d.nodes["0_1"]["y"] == 50 - assert graph_3d.nodes["0_1"]["x"] == 50 - assert graph_3d.edges[("0_1", "1_1")]["distance"] == pytest.approx(42.43, abs=0.01) - # math.dist([50, 50], [20, 80]) - assert graph_3d.edges[("0_1", "1_2")]["distance"] == pytest.approx(11.18, abs=0.01) - # math.dist([50, 50], [60, 45]) - - # lower edge distance - graph_3d = graph_from_segmentation( - segmentation=segmentation_3d, - max_edge_distance=15, - position_keys=("z", "y", "x"), - ) - assert Counter(list(graph_3d.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert Counter(list(graph_3d.edges)) == Counter([("0_1", "1_2")]) - assert graph_3d.edges[("0_1", "1_2")]["distance"] == pytest.approx(11.18, abs=0.01) - # math.dist([50, 50], [60, 45]) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) + for edge in cand_graph.edges: + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] From 766847461fe46bc85169066c13315674b4aedccc Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 28 Feb 2024 17:18:11 -0500 Subject: [PATCH 3/3] bugfix: Use 3D position keys in edge creation test --- tests/test_candidate_graph/test_graph_from_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_graph_from_segmentation.py index 1b5a314..022b5fb 100644 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ b/tests/test_candidate_graph/test_graph_from_segmentation.py @@ -157,7 +157,7 @@ def test_add_cand_edges_2d(graph_2d): def test_add_cand_edges_3d(graph_3d): cand_graph = nx.create_empty_copy(graph_3d) - add_cand_edges(cand_graph, max_edge_distance=15) + add_cand_edges(cand_graph, max_edge_distance=15, position_keys=("z", "y", "x")) graph_3d.remove_edge("0_1", "1_1") assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) for edge in cand_graph.edges: