Skip to content

Commit

Permalink
Merge pull request #1 from funkelab/split_graph_creation
Browse files Browse the repository at this point in the history
Split graph creation into node and edge functions
  • Loading branch information
cmalinmayor authored Feb 28, 2024
2 parents c340f3b + 7668474 commit 6be58a6
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 116 deletions.
11 changes: 5 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@ repos:
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.252
rev: v0.2.2
hooks:
- id: ruff
args: [--fix]

- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
Expand Down
17 changes: 15 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ omit = ["src/motile_toolbox/visualization/*"]
[tool.ruff]
line-length = 88
target-version = "py38"
extend-select = [

[tool.ruff.lint]
select = [
"E", # style errors
"F", # flakes
"I001", # isort
Expand All @@ -58,8 +60,19 @@ extend-select = [
"B", # flake8-bugbear
"A001", # flake8-builtins
"RUF", # ruff-specific rules
"D", # documentation
]
ignore = [
"D100", # Missing docstring in public mod
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D107", # Missing docstring in `__init__
"D205", # 1 blank line required between summary and description
]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D", "S"]
"*/__init__.py" = ["F401"]

[tool.ruff.lint.pydocstyle]
convention = "google"
227 changes: 173 additions & 54 deletions src/motile_toolbox/candidate_graph/graph_from_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,225 @@
import logging
import math
from typing import Any

import networkx as nx
from skimage.measure import regionprops
import numpy as np
from typing import Iterable
from skimage.measure import regionprops
from tqdm import tqdm
import logging
import math

logger = logging.getLogger(__name__)


def get_location(node_data, loc_keys=("z", "y", "x")):
return [node_data[k] for k in loc_keys]
def _get_location(
node_data: dict[str, Any], position_keys: tuple[str, ...] | list[str]
) -> list[Any]:
"""Convenience function to get the location of a networkx node when each dimension
is stored in a different attribute.
Args:
node_data (dict[str, Any]): Dictionary of attributes of a networkx node.
Assumes the provided position keys are in the dictionary.
position_keys (tuple[str, ...] | list[str], optional): Keys to use to get
location information from node_data (assumes they are present in node_data).
Defaults to ("z", "y", "x").
def graph_from_segmentation(
Returns:
list: _description_
Raises:
KeyError if position keys not in node_data
"""
return [node_data[k] for k in position_keys]


def nodes_from_segmentation(
segmentation: np.ndarray,
max_edge_distance: float,
attributes: tuple[str, ...] | list[str] = ("distance",),
attributes: tuple[str, ...] | list[str] = ("segmentation_id",),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
):
"""Construct a candidate graph from a segmentation array. Nodes are placed at the centroid
of each segmentation and edges are added for all nodes in adjacent frames within
max_edge_distance. The specified attributes are computed during construction.
Node ids are strings with format "{time}_{label id}".
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
"""Extract candidate nodes from a segmentation. Also computes specified attributes.
Returns a networkx graph with only nodes, and also a dictionary from frames to
node_ids for efficient edge adding.
Args:
segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels
(0 is background, all pixels with value 1 belong to one cell, etc.).
The time dimension is first, followed by two or three position dimensions.
If the position dims are not (y, x), use `position_keys` to specify the names of
the dimensions.
max_edge_distance (float): Maximum distance that objects can travel between frames. All
nodes within this distance in adjacent frames will by connected with a candidate edge.
attributes (tuple[str, ...], optional): Set of attributes to compute and add to graph.
Valid attributes are: "distance". Defaults to ("distance",).
position_keys (tuple[str, ...], optional): What to label the position dimensions in the
candidate graph. The order of the names corresponds to the order of the dimensions
in `segmentation`. Defaults to ("y", "x").
frame_key (str, optional): What to label the time dimension in the candidate graph.
Defaults to 't'.
(0 is background, all pixels with value 1 belong to one cell, etc.). The
time dimension is first, followed by two or three position dimensions. If
the position dims are not (y, x), use `position_keys` to specify the names
of the dimensions.
attributes (tuple[str, ...] | list[str] , optional): Set of attributes to
compute and add to graph nodes. Valid attributes are: "segmentation_id".
Defaults to ("segmentation_id",).
position_keys (tuple[str, ...]| list[str] , optional): What to label the
position dimensions in the candidate graph. The order of the names
corresponds to the order of the dimensions in `segmentation`. Defaults to
("y", "x").
frame_key (str, optional): What to label the time dimension in the candidate
graph. Defaults to 't'.
Returns:
nx.DiGraph: A candidate graph that can be passed to the motile solver.
Raises:
ValueError: if unsupported attribute strings are passed in to the attributes argument,
or if the number of position keys provided does not match the number of position dimensions.
tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,
and a mapping from time frames to node ids.
"""
valid_attributes = ["distance"]
for attr in attributes:
if attr not in valid_attributes:
raise ValueError(
f"Invalid attribute {attr} (supported attributes: {valid_attributes})"
)
if len(position_keys) != segmentation.ndim - 1:
raise ValueError(
f"Position labels {position_keys} does not match number of spatial dims ({segmentation.ndim - 1})"
)
# add nodes
node_frame_dict = (
{}
) # construct a dictionary from time frame to node_id for efficiency
cand_graph = nx.DiGraph()
# also construct a dictionary from time frame to node_id for efficiency
node_frame_dict = {}

for t in range(len(segmentation)):
nodes_in_frame = []
props = regionprops(segmentation[t])
for i, regionprop in enumerate(props):
for regionprop in props:
node_id = f"{t}_{regionprop.label}"
attrs = {
frame_key: t,
"segmentation_id": regionprop.label,
}
if "segmentation_id" in attributes:
attrs["segmentation_id"] = regionprop.label
centroid = regionprop.centroid # [z,] y, x
print(f"centroid: {centroid}")
for label, value in zip(position_keys, centroid):
attrs[label] = value
cand_graph.add_node(node_id, **attrs)
nodes_in_frame.append(node_id)
node_frame_dict[t] = nodes_in_frame
if nodes_in_frame:
node_frame_dict[t] = nodes_in_frame
return cand_graph, node_frame_dict

logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")

# add edges
def add_cand_edges(
cand_graph: nx.DiGraph,
max_edge_distance: float,
attributes: tuple[str, ...] | list[str] = ("distance",),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
node_frame_dict: None | dict[int, list[Any]] = None,
) -> None:
"""Add candidate edges to a candidate graph by connecting all nodes in adjacent
frames that are closer than max_edge_distance. Also adds attributes to the edges.
Args:
cand_graph (nx.DiGraph): Candidate graph with only nodes populated. Will
be modified in-place to add edges.
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes within this distance in adjacent frames will by connected
with a candidate edge.
attributes (tuple[str, ...], optional): Set of attributes to compute and add to
graph.Valid attributes are: "distance". Defaults to ("distance",).
position_keys (tuple[str, ...], optional): What the position dimensions of nodes
in the candidate graph are labeled. Defaults to ("y", "x").
frame_key (str, optional): The label of the time dimension in the candidate
graph. Defaults to "t".
node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames
to node ids. If not provided, it will be computed from cand_graph. Defaults
to None.
"""
if not node_frame_dict:
node_frame_dict = {}
for node, data in cand_graph.nodes(data=True):
print(data)
t = data[frame_key]
if t not in node_frame_dict:
node_frame_dict[t] = []
node_frame_dict[t].append(node)
print(node_frame_dict)
frames = sorted(node_frame_dict.keys())
for frame in tqdm(frames):
print(frame)
if frame + 1 not in node_frame_dict:
continue
next_nodes = node_frame_dict[frame + 1]
next_locs = [
get_location(cand_graph.nodes[n], loc_keys=position_keys)
_get_location(cand_graph.nodes[n], position_keys=position_keys)
for n in next_nodes
]
for node in node_frame_dict[frame]:
loc = get_location(cand_graph.nodes[node], loc_keys=position_keys)
loc = _get_location(cand_graph.nodes[node], position_keys=position_keys)
for next_id, next_loc in zip(next_nodes, next_locs):
dist = math.dist(next_loc, loc)
attrs = {}
if "distance" in attributes:
attrs["distance"] = dist
if dist < max_edge_distance:
if dist <= max_edge_distance:
cand_graph.add_edge(node, next_id, **attrs)


def graph_from_segmentation(
segmentation: np.ndarray,
max_edge_distance: float,
node_attributes: tuple[str, ...] | list[str] = ("segmentation_id",),
edge_attributes: tuple[str, ...] | list[str] = ("distance",),
position_keys: tuple[str, ...] | list[str] = ("y", "x"),
frame_key: str = "t",
):
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
centroid of each segmentation and edges are added for all nodes in adjacent frames
within max_edge_distance. The specified attributes are computed during construction.
Node ids are strings with format "{time}_{label id}".
Args:
segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels
(0 is background, all pixels with value 1 belong to one cell, etc.). The
time dimension is first, followed by two or three position dimensions. If
the position dims are not (y, x), use `position_keys` to specify the names
of the dimensions.
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes within this distance in adjacent frames will by connected
with a candidate edge.
node_attributes (tuple[str, ...] | list[str], optional): Set of attributes to
compute and add to nodes in graph. Valid attributes are: "segmentation_id".
Defaults to ("segmentation_id",).
edge_attributes (tuple[str, ...] | list[str], optional): Set of attributes to
compute and add to edges in graph. Valid attributes are: "distance".
Defaults to ("distance",).
position_keys (tuple[str, ...], optional): What to label the position dimensions
in the candidate graph. The order of the names corresponds to the order of
the dimensions in `segmentation`. Defaults to ("y", "x").
frame_key (str, optional): What to label the time dimension in the candidate
graph. Defaults to 't'.
Returns:
nx.DiGraph: A candidate graph that can be passed to the motile solver.
Raises:
ValueError: if unsupported attribute strings are passed in to the attributes
arguments, or if the number of position keys provided does not match the
number of position dimensions.
"""
valid_edge_attributes = [
"distance",
]
for attr in edge_attributes:
if attr not in valid_edge_attributes:
raise ValueError(
f"Invalid attribute {attr} (supported attrs: {valid_edge_attributes})"
)
valid_node_attributes = [
"segmentation_id",
]
for attr in node_attributes:
if attr not in valid_node_attributes:
raise ValueError(
f"Invalid attribute {attr} (supported attrs: {valid_node_attributes})"
)
if len(position_keys) != segmentation.ndim - 1:
raise ValueError(
f"Position labels {position_keys} does not match number of spatial dims "
f"({segmentation.ndim - 1})"
)
# add nodes
cand_graph, node_frame_dict = nodes_from_segmentation(
segmentation, node_attributes, position_keys=position_keys, frame_key=frame_key
)
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")

# add edges
add_cand_edges(
cand_graph,
max_edge_distance=max_edge_distance,
attributes=edge_attributes,
position_keys=position_keys,
node_frame_dict=node_frame_dict,
)

logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")
return cand_graph
Loading

0 comments on commit 6be58a6

Please sign in to comment.