diff --git a/starfish/core/morphology/binary_mask/binary_mask.py b/starfish/core/morphology/binary_mask/binary_mask.py index 21aece936..a32f71023 100644 --- a/starfish/core/morphology/binary_mask/binary_mask.py +++ b/starfish/core/morphology/binary_mask/binary_mask.py @@ -1,9 +1,12 @@ import os +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from dataclasses import dataclass +from functools import partial from pathlib import Path from typing import ( Any, + Callable, Hashable, Iterator, Mapping, @@ -508,6 +511,71 @@ def to_targz(self, path: Union[str, Path]): with open(os.fspath(path), "wb") as fh: _io.BinaryMaskIO.write_versioned_binary_mask(fh, self) + def _apply( + self, + function: Callable[[np.ndarray], np.ndarray], + *args, + n_processes: Optional[int] = None, + **kwargs + ) -> "BinaryMaskCollection": + """Given a function that takes an ndarray and outputs another, apply that method to all the + masks in this collection to form a new collection. All the masks are uncropped before being + passed to the function. + + Parameters + ---------- + function : Callable[[np.ndarray], np.ndarray] + A function that should produce a mask array when given a mask array. + n_processes : Optional[int] + The number of processes to use for apply. If None, uses the output of os.cpu_count() + (default = None). + """ + if n_processes is None: + n_processes = os.cpu_count() + + applied_func = partial( + BinaryMaskCollection._apply_single_mask, + function=function, + mask_collection=self, + args=args, + kwargs=kwargs, + ) + selectors = range(len(self._masks)) + + with ThreadPoolExecutor(max_workers=n_processes) as tpe: + results: Iterator[MaskData] = tpe.map(applied_func, selectors) + + return BinaryMaskCollection( + self._pixel_ticks, + self._physical_ticks, + list(results), + self._log, + ) + + @staticmethod + def _apply_single_mask( + mask_index: int, + mask_collection: "BinaryMaskCollection", + args: Sequence, + kwargs: Mapping, + function: Callable[[np.ndarray], np.ndarray], + ) -> MaskData: + """Given a mask collection, and an index, retrieve a mask and apply a function to that mask. + Return the output along with the offsets of the original mask. If the original mask is + uncropped, then the offsets should all be 0. If the original mask is not uncropped, it is + propagated from the input masks's offsets. + """ + input_mask = mask_collection.uncropped_mask(mask_index) + output_mask = function(input_mask.values, *args, **kwargs) # type: ignore + + selection_range: Sequence[slice] = BinaryMaskCollection._crop_mask(output_mask) + + return MaskData( + output_mask[selection_range], + tuple(selection.start for selection in selection_range), + None + ) + # these need to be at the end to avoid recursive imports from . import _io # noqa diff --git a/starfish/core/morphology/binary_mask/test/test_binary_mask.py b/starfish/core/morphology/binary_mask/test/test_binary_mask.py index 5182e4260..a61c0986e 100644 --- a/starfish/core/morphology/binary_mask/test/test_binary_mask.py +++ b/starfish/core/morphology/binary_mask/test/test_binary_mask.py @@ -1,4 +1,5 @@ import numpy as np +from skimage.morphology import binary_dilation from starfish.core.morphology.label_image import LabelImage from starfish.core.types import Axes, Coordinates @@ -157,3 +158,30 @@ def test_from_empty_label_image(tmp_path): original_props = binary_mask_collection.mask_regionprops(ix) recalculated_props = binary_mask_collection.mask_regionprops(ix) assert original_props == recalculated_props + + +def test_apply(): + input_mask_collection = binary_mask_collection_2d() + output_mask_collection = input_mask_collection._apply(binary_dilation) + + assert input_mask_collection._pixel_ticks == output_mask_collection._pixel_ticks + assert input_mask_collection._physical_ticks == output_mask_collection._physical_ticks + assert input_mask_collection._log == output_mask_collection._log + assert len(input_mask_collection) == len(output_mask_collection) + + region_0, region_1 = output_mask_collection.masks() + + assert region_0.name == '0' + assert region_1.name == '1' + + temp = np.ones((2, 6), dtype=np.bool) + assert np.array_equal(region_0, temp) + temp = np.ones((3, 4), dtype=np.bool) + temp[0, 0] = 0 + assert np.array_equal(region_1, temp) + + assert np.array_equal(region_0[Axes.Y.value], [0, 1]) + assert np.array_equal(region_0[Axes.X.value], [0, 1, 2, 3, 4, 5]) + + assert np.array_equal(region_1[Axes.Y.value], [2, 3, 4]) + assert np.array_equal(region_1[Axes.X.value], [2, 3, 4, 5])