Skip to content

Commit

Permalink
Make regionprops available per mask (#1610)
Browse files Browse the repository at this point in the history
regionprops are useful to characterize and filter segmented cells.  This PR changes what we store for mask data to be a tuple of binary mask and regionprops.  If we construct a mask collection from a labeled image, we retain the region props that are automatically calculated as a part of this conversion process.  If we construct a mask collection from what's stored on disk, we will calculate the regionprops when requested.

Test plan: Save to disk a mask collection generated from a labeled image, and after loading it from disk, verify that the region props of the two mask collections match.

Part of #1497
  • Loading branch information
Tony Tung authored Oct 17, 2019
1 parent 4378e25 commit d2df275
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
67 changes: 59 additions & 8 deletions starfish/core/segmentation_mask/segmentation_mask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import io
import itertools
import tarfile
from dataclasses import dataclass
from typing import (
cast,
Dict,
Expand All @@ -17,6 +19,7 @@
import numpy as np
import xarray as xr
from skimage.measure import regionprops
from skimage.measure._regionprops import _RegionProperties

from starfish.core.types import Axes, Coordinates
from .expand import fill_from_mask
Expand Down Expand Up @@ -50,30 +53,45 @@ def _validate_segmentation_mask(arr: xr.DataArray):
raise TypeError(f"missing coordinates '{dims.union(coords).difference(arr.coords)}'")


@dataclass
class MaskData:
binary_mask: xr.DataArray
region_properties: Optional[_RegionProperties]


class SegmentationMaskCollection:
"""Collection of binary segmentation masks with a dict-like access pattern.
Parameters
----------
masks : Iterable[xr.DataArray]
Segmentation masks.
props : Iterable[_RegionProperties]
Properties for each of the regions in the masks.
Attributes
----------
max_shape : Dict[Axes, Optional[int]]
Maximum index of contained masks.
"""
def __init__(self, masks: Iterable[xr.DataArray]):
self._masks: MutableMapping[int, xr.DataArray] = {}
def __init__(
self,
masks: Iterable[xr.DataArray],
props: Optional[Iterable[Optional[_RegionProperties]]] = None,
):
if props is None:
props = itertools.cycle((None,))
self._masks: MutableMapping[int, MaskData] = {}
self.max_shape: Dict[Axes, int] = {
Axes.X: 0,
Axes.Y: 0,
Axes.ZPLANE: 0
}

for ix, mask in enumerate(masks):
for ix, (mask, mask_props) in enumerate(zip(masks, props)):
_validate_segmentation_mask(mask)
self._masks[ix] = mask

self._masks[ix] = MaskData(mask, mask_props)

for axis in Axes:
if axis.value in mask.coords:
Expand All @@ -82,16 +100,49 @@ def __init__(self, masks: Iterable[xr.DataArray]):
self.max_shape[axis] = max_val + 1

def __getitem__(self, index: int) -> xr.DataArray:
return self._masks[index]
return self._masks[index].binary_mask

def __iter__(self) -> Iterator[Tuple[int, xr.DataArray]]:
return iter(self._masks.items())
for mask_index, mask_data in self._masks.items():
yield mask_index, mask_data.binary_mask

def __len__(self) -> int:
return len(self._masks)

def masks(self) -> Iterator[xr.DataArray]:
return iter(self._masks.values())
for mask_index, mask_data in self._masks.items():
yield mask_data.binary_mask

def mask_regionprops(self, mask_id: int) -> _RegionProperties:
"""
Return the region properties for
Parameters
----------
mask_id
Returns
-------
"""
mask_data = self._masks[mask_id]
if mask_data.region_properties is None:
# recreate the label image (but with just this mask)
image = np.zeros(
shape=tuple(
self.max_shape[axis]
for axis in AXES_ORDER
if self.max_shape[axis] != 0
),
dtype=np.uint32,
)
fill_from_mask(
mask_data.binary_mask,
mask_id + 1,
image,
[axis for axis in AXES_ORDER if self.max_shape[axis] > 0],
)
mask_data.region_properties = regionprops(image)
return mask_data.region_properties

@classmethod
def from_label_image(
Expand Down Expand Up @@ -143,7 +194,7 @@ def from_label_image(
name=name)
masks.append(mask)

return cls(masks)
return cls(masks, props)

def to_label_image(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,9 @@ def test_save_load():
assert np.array_equal(m, m2)
finally:
os.remove(path)

# ensure that the regionprops are equal
for ix in range(len(masks)):
original_props = masks.mask_regionprops(ix)
recalculated_props = masks.mask_regionprops(ix)
assert original_props == recalculated_props

0 comments on commit d2df275

Please sign in to comment.