diff --git a/starfish/core/segmentation_mask/segmentation_mask.py b/starfish/core/segmentation_mask/segmentation_mask.py index 0fb698e36..dbe5a34f8 100644 --- a/starfish/core/segmentation_mask/segmentation_mask.py +++ b/starfish/core/segmentation_mask/segmentation_mask.py @@ -1,6 +1,18 @@ import io import tarfile -from typing import cast, Dict, Hashable, List, Optional, Sequence, Tuple, Union +from typing import ( + cast, + Dict, + Hashable, + Iterable, + Iterator, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import xarray as xr @@ -38,11 +50,11 @@ def _validate_segmentation_mask(arr: xr.DataArray): class SegmentationMaskCollection: - """Collection of binary segmentation masks with a list-like access pattern. + """Collection of binary segmentation masks with a dict-like access pattern. Parameters ---------- - masks : List[xr.DataArray] + masks : Iterable[xr.DataArray] Segmentation masks. Attributes @@ -50,42 +62,35 @@ class SegmentationMaskCollection: max_shape : Dict[Axes, Optional[int]] Maximum index of contained masks. """ - def __init__(self, masks: List[xr.DataArray]): - self._masks: List[xr.DataArray] = [] + def __init__(self, masks: Iterable[xr.DataArray]): + self._masks: MutableMapping[int, xr.DataArray] = {} self.max_shape: Dict[Axes, int] = { Axes.X: 0, Axes.Y: 0, Axes.ZPLANE: 0 } - for mask in masks: - self.append(mask) + for ix, mask in enumerate(masks): + _validate_segmentation_mask(mask) + self._masks[ix] = mask - def __getitem__(self, index): + for axis in Axes: + if axis.value in mask.coords: + max_val = mask.coords[axis.value].values[-1] + if max_val >= self.max_shape[axis]: + self.max_shape[axis] = max_val + 1 + + def __getitem__(self, index: int) -> xr.DataArray: return self._masks[index] - def __iter__(self): - return iter(self._masks) + def __iter__(self) -> Iterator[Tuple[int, xr.DataArray]]: + return iter(self._masks.items()) - def __len__(self): + def __len__(self) -> int: return len(self._masks) - def append(self, mask: xr.DataArray): - """Add an existing segmentation mask. - - Parameters - ---------- - arr : xr.DataArray - Segmentation mask. - """ - _validate_segmentation_mask(mask) - self._masks.append(mask) - - for axis in Axes: - if axis.value in mask.coords: - max_val = mask.coords[axis.value].values[-1] - if max_val >= self.max_shape[axis]: - self.max_shape[axis] = max_val + 1 + def masks(self) -> Iterator[xr.DataArray]: + return iter(self._masks.values()) @classmethod def from_label_image( @@ -111,8 +116,7 @@ def from_label_image( dims, _ = _get_axes_names(label_image.ndim) - masks: List[xr.DataArray] = [] - + masks: MutableSequence[xr.DataArray] = [] coords: Dict[Hashable, Union[list, Tuple[str, Sequence]]] # for each region (and its properties): @@ -172,7 +176,7 @@ def to_label_image( label_image = np.zeros(shape, dtype=np.uint16) - for i, mask in enumerate(self): + for i, mask in iter(self): mask = mask.transpose(*[o.value for o in ordering if o in mask.coords]) coords = mask.values.nonzero() j = 0 @@ -218,7 +222,7 @@ def save(self, path: str): Path of the tar file to write to. """ with tarfile.open(path, 'w:gz') as t: - for i, mask in enumerate(self._masks): + for i, mask in iter(self): data = cast(bytes, mask.to_netcdf()) with io.BytesIO(data) as buff: info = tarfile.TarInfo(name=str(i) + '.nc') diff --git a/starfish/core/segmentation_mask/test/test_segmentation_mask.py b/starfish/core/segmentation_mask/test/test_segmentation_mask.py index 0b18ce294..26137adf1 100644 --- a/starfish/core/segmentation_mask/test/test_segmentation_mask.py +++ b/starfish/core/segmentation_mask/test/test_segmentation_mask.py @@ -70,8 +70,8 @@ def test_from_label_image(): physical_ticks = {Coordinates.Y: [1.2, 2.4, 3.6, 4.8, 6.0], Coordinates.X: [7.2, 8.4, 9.6, 10.8, 12]} - masks = SegmentationMaskCollection.from_label_image(label_image, - physical_ticks) + masks = list(SegmentationMaskCollection.from_label_image( + label_image, physical_ticks).masks()) assert len(masks) == 2 @@ -134,7 +134,7 @@ def test_save_load(): try: masks.save(path) masks2 = SegmentationMaskCollection.from_disk(path) - for m, m2 in zip(masks, masks2): + for m, m2 in zip(masks.masks(), masks2.masks()): assert np.array_equal(m, m2) finally: os.remove(path) diff --git a/starfish/core/spots/AssignTargets/label.py b/starfish/core/spots/AssignTargets/label.py index a1ddbadfd..49c897ca1 100644 --- a/starfish/core/spots/AssignTargets/label.py +++ b/starfish/core/spots/AssignTargets/label.py @@ -32,7 +32,7 @@ def _assign( decoded_intensities[Features.CELL_ID] = cell_ids - for mask in masks: + for _, mask in masks: y_min, y_max = float(mask.y.min()), float(mask.y.max()) x_min, x_max = float(mask.x.min()), float(mask.x.max())