diff --git a/starfish/core/segmentation_mask/segmentation_mask.py b/starfish/core/segmentation_mask/segmentation_mask.py index 1329fdb1d..a54095a5d 100644 --- a/starfish/core/segmentation_mask/segmentation_mask.py +++ b/starfish/core/segmentation_mask/segmentation_mask.py @@ -1,5 +1,7 @@ import io +import itertools import tarfile +from dataclasses import dataclass from typing import ( cast, Dict, @@ -17,6 +19,7 @@ import numpy as np import xarray as xr from skimage.measure import regionprops +from skimage.measure._regionprops import _RegionProperties from starfish.core.types import Axes, Coordinates from .expand import fill_from_mask @@ -50,6 +53,12 @@ def _validate_segmentation_mask(arr: xr.DataArray): raise TypeError(f"missing coordinates '{dims.union(coords).difference(arr.coords)}'") +@dataclass +class MaskData: + binary_mask: xr.DataArray + region_properties: Optional[_RegionProperties] + + class SegmentationMaskCollection: """Collection of binary segmentation masks with a dict-like access pattern. @@ -57,23 +66,32 @@ class SegmentationMaskCollection: ---------- masks : Iterable[xr.DataArray] Segmentation masks. + props : Iterable[_RegionProperties] + Properties for each of the regions in the masks. Attributes ---------- max_shape : Dict[Axes, Optional[int]] Maximum index of contained masks. """ - def __init__(self, masks: Iterable[xr.DataArray]): - self._masks: MutableMapping[int, xr.DataArray] = {} + def __init__( + self, + masks: Iterable[xr.DataArray], + props: Optional[Iterable[Optional[_RegionProperties]]] = None, + ): + if props is None: + props = itertools.cycle((None,)) + self._masks: MutableMapping[int, MaskData] = {} self.max_shape: Dict[Axes, int] = { Axes.X: 0, Axes.Y: 0, Axes.ZPLANE: 0 } - for ix, mask in enumerate(masks): + for ix, (mask, mask_props) in enumerate(zip(masks, props)): _validate_segmentation_mask(mask) - self._masks[ix] = mask + + self._masks[ix] = MaskData(mask, mask_props) for axis in Axes: if axis.value in mask.coords: @@ -82,16 +100,49 @@ def __init__(self, masks: Iterable[xr.DataArray]): self.max_shape[axis] = max_val + 1 def __getitem__(self, index: int) -> xr.DataArray: - return self._masks[index] + return self._masks[index].binary_mask def __iter__(self) -> Iterator[Tuple[int, xr.DataArray]]: - return iter(self._masks.items()) + for mask_index, mask_data in self._masks.items(): + yield mask_index, mask_data.binary_mask def __len__(self) -> int: return len(self._masks) def masks(self) -> Iterator[xr.DataArray]: - return iter(self._masks.values()) + for mask_index, mask_data in self._masks.items(): + yield mask_data.binary_mask + + def mask_regionprops(self, mask_id: int) -> _RegionProperties: + """ + Return the region properties for + Parameters + ---------- + mask_id + + Returns + ------- + + """ + mask_data = self._masks[mask_id] + if mask_data.region_properties is None: + # recreate the label image (but with just this mask) + image = np.zeros( + shape=tuple( + self.max_shape[axis] + for axis in AXES_ORDER + if self.max_shape[axis] != 0 + ), + dtype=np.uint32, + ) + fill_from_mask( + mask_data.binary_mask, + mask_id + 1, + image, + [axis for axis in AXES_ORDER if self.max_shape[axis] > 0], + ) + mask_data.region_properties = regionprops(image) + return mask_data.region_properties @classmethod def from_label_image( @@ -143,7 +194,7 @@ def from_label_image( name=name) masks.append(mask) - return cls(masks) + return cls(masks, props) def to_label_image( self, diff --git a/starfish/core/segmentation_mask/test/test_segmentation_mask.py b/starfish/core/segmentation_mask/test/test_segmentation_mask.py index 26137adf1..bfa4e5a63 100644 --- a/starfish/core/segmentation_mask/test/test_segmentation_mask.py +++ b/starfish/core/segmentation_mask/test/test_segmentation_mask.py @@ -138,3 +138,9 @@ def test_save_load(): assert np.array_equal(m, m2) finally: os.remove(path) + + # ensure that the regionprops are equal + for ix in range(len(masks)): + original_props = masks.mask_regionprops(ix) + recalculated_props = masks.mask_regionprops(ix) + assert original_props == recalculated_props