Skip to content

Commit

Permalink
Add non-vectorised code path to pixel count map generation (#24)
Browse files Browse the repository at this point in the history
* add compatibility for pytorch 2.0 and pytorch-lightning 2.0

* add a non-vectorised path for generating pixel count maps
  • Loading branch information
alisterburt authored May 12, 2023
1 parent 66604e3 commit ac14536
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/fidder/_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch

from fidder.utils import connected_component_transform_2d
from fidder.utils import pixel_count_map_2d


def test_connected_component_transform_2d():
mask = torch.zeros((10, 10))
mask[::2, ::2] = 1
mask[1::2, 1::2] = 1
connected_component_image = connected_component_transform_2d(mask).long()
connected_component_image = pixel_count_map_2d(mask).long()
assert torch.allclose(
connected_component_image[::2, ::2], torch.tensor(1).long())
assert torch.allclose(
Expand Down
4 changes: 2 additions & 2 deletions src/fidder/predict/probabilities_to_mask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from ..utils import connected_component_transform_2d
from ..utils import pixel_count_map_2d


def probabilities_to_mask(
Expand Down Expand Up @@ -27,6 +27,6 @@ def probabilities_to_mask(
`(h, w)` boolean array.
"""
mask = probabilities > threshold
pixel_count_map = connected_component_transform_2d(mask)
pixel_count_map = pixel_count_map_2d(mask)
mask[pixel_count_map < connected_pixel_count_threshold] = 0
return mask
25 changes: 16 additions & 9 deletions src/fidder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,10 @@ def get_pixel_spacing_from_header(image: Path) -> float:
return float(mrc.voxel_size.x)


def connected_component_transform_2d(mask: torch.Tensor):
"""Perform a connected component transform on a binary 2D image.
def pixel_count_map_2d(mask: torch.Tensor):
"""Calculate a pixel count map from a binary 2D image.
A connected component transform replaces every pixel in a binary image with
the number of connected components in the region around that pixel.
https://haesleinhuepf.github.io/BioImageAnalysisNotebooks/60_data_visualization/parametric_maps.html
Parameters
----------
Expand All @@ -142,8 +141,16 @@ def connected_component_transform_2d(mask: torch.Tensor):
"""
labels, n = ndi.label(mask.cpu().numpy())
labels = torch.tensor(labels, dtype=torch.long)
labels_one_hot = F.one_hot(labels, num_classes=(n + 1))
counts = reduce(labels_one_hot, "h w c -> c", reduction="sum")
counts = {label_index: count.item() for label_index, count in enumerate(counts)}
connected_component_image = np.vectorize(counts.__getitem__)(labels)
return torch.tensor(connected_component_image)
if n < 100: # vectorised, not memory efficient
labels_one_hot = F.one_hot(labels, num_classes=(n + 1))
counts = reduce(labels_one_hot, "h w c -> c", reduction="sum")
counts = {label_index: count.item() for label_index, count in enumerate(counts)}
connected_component_image = np.vectorize(counts.__getitem__)(labels)
connected_component_image = torch.tensor(connected_component_image)
else:
connected_component_image = torch.zeros_like(labels)
for label_id in range(n + 1):
conected_component_mask = labels == label_id
n_connected_components = torch.sum(conected_component_mask)
connected_component_image[conected_component_mask] = n_connected_components
return connected_component_image

0 comments on commit ac14536

Please sign in to comment.