Skip to content

Commit

Permalink
Convert SegmentationMaskCollection to a dict-like object (#1579)
Browse files Browse the repository at this point in the history
Maps from cell id to mask.

Part of #1497
  • Loading branch information
Tony Tung authored Oct 10, 2019
1 parent 58eb7ac commit a9a3d13
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 35 deletions.
66 changes: 35 additions & 31 deletions starfish/core/segmentation_mask/segmentation_mask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import io
import tarfile
from typing import cast, Dict, Hashable, List, Optional, Sequence, Tuple, Union
from typing import (
cast,
Dict,
Hashable,
Iterable,
Iterator,
MutableMapping,
MutableSequence,
Optional,
Sequence,
Tuple,
Union,
)

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -38,54 +50,47 @@ def _validate_segmentation_mask(arr: xr.DataArray):


class SegmentationMaskCollection:
"""Collection of binary segmentation masks with a list-like access pattern.
"""Collection of binary segmentation masks with a dict-like access pattern.
Parameters
----------
masks : List[xr.DataArray]
masks : Iterable[xr.DataArray]
Segmentation masks.
Attributes
----------
max_shape : Dict[Axes, Optional[int]]
Maximum index of contained masks.
"""
def __init__(self, masks: List[xr.DataArray]):
self._masks: List[xr.DataArray] = []
def __init__(self, masks: Iterable[xr.DataArray]):
self._masks: MutableMapping[int, xr.DataArray] = {}
self.max_shape: Dict[Axes, int] = {
Axes.X: 0,
Axes.Y: 0,
Axes.ZPLANE: 0
}

for mask in masks:
self.append(mask)
for ix, mask in enumerate(masks):
_validate_segmentation_mask(mask)
self._masks[ix] = mask

def __getitem__(self, index):
for axis in Axes:
if axis.value in mask.coords:
max_val = mask.coords[axis.value].values[-1]
if max_val >= self.max_shape[axis]:
self.max_shape[axis] = max_val + 1

def __getitem__(self, index: int) -> xr.DataArray:
return self._masks[index]

def __iter__(self):
return iter(self._masks)
def __iter__(self) -> Iterator[Tuple[int, xr.DataArray]]:
return iter(self._masks.items())

def __len__(self):
def __len__(self) -> int:
return len(self._masks)

def append(self, mask: xr.DataArray):
"""Add an existing segmentation mask.
Parameters
----------
arr : xr.DataArray
Segmentation mask.
"""
_validate_segmentation_mask(mask)
self._masks.append(mask)

for axis in Axes:
if axis.value in mask.coords:
max_val = mask.coords[axis.value].values[-1]
if max_val >= self.max_shape[axis]:
self.max_shape[axis] = max_val + 1
def masks(self) -> Iterator[xr.DataArray]:
return iter(self._masks.values())

@classmethod
def from_label_image(
Expand All @@ -111,8 +116,7 @@ def from_label_image(

dims, _ = _get_axes_names(label_image.ndim)

masks: List[xr.DataArray] = []

masks: MutableSequence[xr.DataArray] = []
coords: Dict[Hashable, Union[list, Tuple[str, Sequence]]]

# for each region (and its properties):
Expand Down Expand Up @@ -172,7 +176,7 @@ def to_label_image(

label_image = np.zeros(shape, dtype=np.uint16)

for i, mask in enumerate(self):
for i, mask in iter(self):
mask = mask.transpose(*[o.value for o in ordering if o in mask.coords])
coords = mask.values.nonzero()
j = 0
Expand Down Expand Up @@ -218,7 +222,7 @@ def save(self, path: str):
Path of the tar file to write to.
"""
with tarfile.open(path, 'w:gz') as t:
for i, mask in enumerate(self._masks):
for i, mask in iter(self):
data = cast(bytes, mask.to_netcdf())
with io.BytesIO(data) as buff:
info = tarfile.TarInfo(name=str(i) + '.nc')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_from_label_image():
physical_ticks = {Coordinates.Y: [1.2, 2.4, 3.6, 4.8, 6.0],
Coordinates.X: [7.2, 8.4, 9.6, 10.8, 12]}

masks = SegmentationMaskCollection.from_label_image(label_image,
physical_ticks)
masks = list(SegmentationMaskCollection.from_label_image(
label_image, physical_ticks).masks())

assert len(masks) == 2

Expand Down Expand Up @@ -134,7 +134,7 @@ def test_save_load():
try:
masks.save(path)
masks2 = SegmentationMaskCollection.from_disk(path)
for m, m2 in zip(masks, masks2):
for m, m2 in zip(masks.masks(), masks2.masks()):
assert np.array_equal(m, m2)
finally:
os.remove(path)
2 changes: 1 addition & 1 deletion starfish/core/spots/AssignTargets/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _assign(

decoded_intensities[Features.CELL_ID] = cell_ids

for mask in masks:
for _, mask in masks:
y_min, y_max = float(mask.y.min()), float(mask.y.max())
x_min, x_max = float(mask.x.min()), float(mask.x.max())

Expand Down

0 comments on commit a9a3d13

Please sign in to comment.