diff --git a/src/fidder/_tests/test_utils.py b/src/fidder/_tests/test_utils.py index bd014a6..365dba2 100644 --- a/src/fidder/_tests/test_utils.py +++ b/src/fidder/_tests/test_utils.py @@ -1,13 +1,13 @@ import torch -from fidder.utils import connected_component_transform_2d +from fidder.utils import pixel_count_map_2d def test_connected_component_transform_2d(): mask = torch.zeros((10, 10)) mask[::2, ::2] = 1 mask[1::2, 1::2] = 1 - connected_component_image = connected_component_transform_2d(mask).long() + connected_component_image = pixel_count_map_2d(mask).long() assert torch.allclose( connected_component_image[::2, ::2], torch.tensor(1).long()) assert torch.allclose( diff --git a/src/fidder/predict/probabilities_to_mask.py b/src/fidder/predict/probabilities_to_mask.py index d97c388..b4d3f59 100644 --- a/src/fidder/predict/probabilities_to_mask.py +++ b/src/fidder/predict/probabilities_to_mask.py @@ -1,5 +1,5 @@ import torch -from ..utils import connected_component_transform_2d +from ..utils import pixel_count_map_2d def probabilities_to_mask( @@ -27,6 +27,6 @@ def probabilities_to_mask( `(h, w)` boolean array. """ mask = probabilities > threshold - pixel_count_map = connected_component_transform_2d(mask) + pixel_count_map = pixel_count_map_2d(mask) mask[pixel_count_map < connected_pixel_count_threshold] = 0 return mask diff --git a/src/fidder/utils.py b/src/fidder/utils.py index 543ab29..3dca809 100644 --- a/src/fidder/utils.py +++ b/src/fidder/utils.py @@ -125,11 +125,10 @@ def get_pixel_spacing_from_header(image: Path) -> float: return float(mrc.voxel_size.x) -def connected_component_transform_2d(mask: torch.Tensor): - """Perform a connected component transform on a binary 2D image. +def pixel_count_map_2d(mask: torch.Tensor): + """Calculate a pixel count map from a binary 2D image. - A connected component transform replaces every pixel in a binary image with - the number of connected components in the region around that pixel. + https://haesleinhuepf.github.io/BioImageAnalysisNotebooks/60_data_visualization/parametric_maps.html Parameters ---------- @@ -142,8 +141,16 @@ def connected_component_transform_2d(mask: torch.Tensor): """ labels, n = ndi.label(mask.cpu().numpy()) labels = torch.tensor(labels, dtype=torch.long) - labels_one_hot = F.one_hot(labels, num_classes=(n + 1)) - counts = reduce(labels_one_hot, "h w c -> c", reduction="sum") - counts = {label_index: count.item() for label_index, count in enumerate(counts)} - connected_component_image = np.vectorize(counts.__getitem__)(labels) - return torch.tensor(connected_component_image) + if n < 100: # vectorised, not memory efficient + labels_one_hot = F.one_hot(labels, num_classes=(n + 1)) + counts = reduce(labels_one_hot, "h w c -> c", reduction="sum") + counts = {label_index: count.item() for label_index, count in enumerate(counts)} + connected_component_image = np.vectorize(counts.__getitem__)(labels) + connected_component_image = torch.tensor(connected_component_image) + else: + connected_component_image = torch.zeros_like(labels) + for label_id in range(n + 1): + conected_component_mask = labels == label_id + n_connected_components = torch.sum(conected_component_mask) + connected_component_image[conected_component_mask] = n_connected_components + return connected_component_image