diff --git a/starfish/core/morphology/Binarize/__init__.py b/starfish/core/morphology/Binarize/__init__.py new file mode 100644 index 000000000..3506bc1ca --- /dev/null +++ b/starfish/core/morphology/Binarize/__init__.py @@ -0,0 +1,11 @@ +"""Algorithms in this module binarize an ImageStack into a BinaryMaskCollection.""" +from ._base import BinarizeAlgorithm +from .threshold import ThresholdBinarize + +# autodoc's automodule directive only captures the modules explicitly listed in __all__. +all_filters = { + filter_name: filter_cls + for filter_name, filter_cls in locals().items() + if isinstance(filter_cls, type) and issubclass(filter_cls, BinarizeAlgorithm) +} +__all__ = list(all_filters.keys()) diff --git a/starfish/core/morphology/Binarize/_base.py b/starfish/core/morphology/Binarize/_base.py new file mode 100644 index 000000000..66faf2cc8 --- /dev/null +++ b/starfish/core/morphology/Binarize/_base.py @@ -0,0 +1,13 @@ +from abc import abstractmethod + +from starfish.core.imagestack.imagestack import ImageStack +from starfish.core.morphology.binary_mask import BinaryMaskCollection +from starfish.core.pipeline.algorithmbase import AlgorithmBase + + +class BinarizeAlgorithm(metaclass=AlgorithmBase): + + @abstractmethod + def run(self, image: ImageStack, *args, **kwargs) -> BinaryMaskCollection: + """Performs binarization on the stack provided.""" + raise NotImplementedError() diff --git a/starfish/core/morphology/Binarize/test/__init__.py b/starfish/core/morphology/Binarize/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/starfish/core/morphology/Binarize/test/test_threshold.py b/starfish/core/morphology/Binarize/test/test_threshold.py new file mode 100644 index 000000000..e31bbaa60 --- /dev/null +++ b/starfish/core/morphology/Binarize/test/test_threshold.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest + +from starfish import ImageStack +from starfish.types import Number +from ..threshold import ThresholdBinarize + + +@pytest.mark.parametrize(["threshold"], [[threshold] for threshold in np.linspace(0, 1, 3)]) +def test_binarize(threshold: Number, num_rounds=1, num_chs=1, num_zplanes=4, ysize=5, xsize=6): + data = np.linspace(0, 1, num_rounds * num_chs * num_zplanes * ysize * xsize, dtype=np.float32) + data = data.reshape((num_rounds, num_chs, num_zplanes, ysize, xsize)) + + imagestack = ImageStack.from_numpy(data) + binarizer = ThresholdBinarize(threshold) + binary_mask_collection = binarizer.run(imagestack) + + assert len(binary_mask_collection) == 1 + mask = binary_mask_collection.uncropped_mask(0) + + expected_value = data[0, 0] >= threshold + + assert np.array_equal(mask, expected_value) + + +@pytest.mark.parametrize( + ["num_rounds", "num_chs"], + [ + [1, 2], + [2, 1], + [2, 2], + ]) +def test_binarize_non_3d(num_rounds, num_chs, num_zplanes=4, ysize=5, xsize=6): + data = np.linspace(0, 1, num_rounds * num_chs * num_zplanes * ysize * xsize, dtype=np.float32) + data = data.reshape((num_rounds, num_chs, num_zplanes, ysize, xsize)) + + imagestack = ImageStack.from_numpy(data) + binarizer = ThresholdBinarize(0.0) + + with pytest.raises(ValueError): + binarizer.run(imagestack) diff --git a/starfish/core/morphology/Binarize/threshold.py b/starfish/core/morphology/Binarize/threshold.py new file mode 100644 index 000000000..cbb9a8966 --- /dev/null +++ b/starfish/core/morphology/Binarize/threshold.py @@ -0,0 +1,68 @@ +from typing import Mapping, Union + +import numpy as np +import xarray as xr + +from starfish.core.imagestack.imagestack import ImageStack +from starfish.core.morphology.binary_mask import BinaryMaskCollection +from starfish.core.morphology.util import _get_axes_names +from starfish.core.types import ArrayLike, Axes, Coordinates, Number +from ._base import BinarizeAlgorithm + + +class ThresholdBinarize(BinarizeAlgorithm): + """Binarizes an image using a threshold. Pixels that exceed the threshold are considered True + and all remaining pixels are considered False. + + The image being binarized must be an ImageStack with num_rounds == 1 and num_chs == 1. + """ + def __init__(self, threshold: Number): + self.threshold = threshold + + def _binarize(self, result: np.ndarray, tile_data: Union[np.ndarray, xr.DataArray]) -> None: + result[:] = np.asarray(tile_data) >= self.threshold + + def run(self, image: ImageStack, *args, **kwargs) -> BinaryMaskCollection: + if image.num_rounds != 1: + raise ValueError( + f"{ThresholdBinarize.__name__} given an image with more than one round " + f"{image.num_rounds}") + if image.num_chs != 1: + raise ValueError( + f"{ThresholdBinarize.__name__} given an image with more than one channel " + f"{image.num_chs}") + + result_array = np.empty( + shape=[ + image.shape[axis] + for axis, _ in zip(*_get_axes_names(3)) + ], + dtype=np.bool) + + # TODO: (ttung) This could theoretically be done with ImageStack.transform, but + # ImageStack.transform doesn't provide the selectors to the worker method. In this case, + # we need the selectors to select the correct region of the output array. The alternative + # is for each worker thread to create a new array, and then merge them at the end, but that + # effectively doubles our memory consumption. + # + # For now, we will just do it in-process, because it's not a particularly compute-intensive + # task. + self._binarize(result_array, image.xarray[0, 0]) + + pixel_ticks: Mapping[Axes, ArrayLike[int]] = { + Axes(axis): axis_data + for axis, axis_data in image.xarray.coords.items() + if axis in _get_axes_names(3)[0] + } + physical_ticks: Mapping[Coordinates, ArrayLike[Number]] = { + Coordinates(coord): coord_data + for coord, coord_data in image.xarray.coords.items() + if coord in _get_axes_names(3)[1] + } + + return BinaryMaskCollection.from_binary_arrays_and_ticks( + (result_array,), + pixel_ticks, + physical_ticks, + image.log, + )