Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Methods for uncropping binary masks. #1647

Merged
merged 1 commit into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion starfish/core/morphology/binary_mask/binary_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,47 @@ def _format_mask_as_xarray(self, index: int) -> xr.DataArray:
name=f"{index:0{max_mask_name_len}d}"
)

def uncropped_mask(self, index: int) -> xr.DataArray:
"""Convert a np-based mask into an xarray DataArray."""
mask_data = self._masks[index]
uncropped_shape = tuple(
len(self._pixel_ticks[axis])
for axis, _ in zip(*_get_axes_names(len(self._pixel_ticks)))
)

if uncropped_shape == mask_data.binary_mask.shape:
return self._format_mask_as_xarray(index)

max_mask_name_len = len(str(len(self._masks) - 1))

xr_dims: MutableSequence[str] = []
xr_coords: MutableMapping[Hashable, Any] = {}

for ix, (axis, coord) in enumerate(zip(*_get_axes_names(len(self._pixel_ticks)))):
xr_dims.append(axis.value)
xr_coords[axis.value] = self._pixel_ticks[axis.value]
xr_coords[coord.value] = (axis.value, self._physical_ticks[coord.value])

image = np.zeros(
shape=tuple(
len(self._pixel_ticks[axis])
for axis, _ in zip(*_get_axes_names(len(self._pixel_ticks)))
),
dtype=np.bool,
)
fill_from_mask(
mask_data.binary_mask,
mask_data.offsets,
1,
image,
)
return xr.DataArray(
image,
dims=xr_dims,
coords=xr_coords,
name=f"{index:0{max_mask_name_len}d}"
)

def masks(self) -> Iterator[xr.DataArray]:
for mask_index in self._masks.keys():
yield self._format_mask_as_xarray(mask_index)
Expand Down Expand Up @@ -172,7 +213,10 @@ def max_shape(self) -> Mapping[Axes, int]:

@classmethod
def from_label_image(cls, label_image: LabelImage) -> "BinaryMaskCollection":
"""Creates binary masks from a label image.
"""Creates binary masks from a label image. Masks are cropped to the smallest size that
contains the non-zero values, but pixel and physical coordinates ticks are retained. Masks
extracted from BinaryMaskCollections will be cropped. To extract masks sized to the
original label image, use :py:meth:`starfish.BinaryMaskCollection.uncropped_mask`.

Parameters
----------
Expand Down
60 changes: 60 additions & 0 deletions starfish/core/morphology/binary_mask/test/test_binary_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,66 @@ def test_from_label_image():
physical_ticks[Coordinates.X][3:6])


def test_uncropped_mask():
"""Test that BinaryMaskCollection.uncropped_mask() works correctly.
"""
label_image_array = np.zeros((5, 5), dtype=np.int32)
label_image_array[0] = 1
label_image_array[3:5, 3:5] = 2
label_image_array[-1, -1] = 0

physical_ticks = {Coordinates.Y: [1.2, 2.4, 3.6, 4.8, 6.0],
Coordinates.X: [7.2, 8.4, 9.6, 10.8, 12]}

label_image = LabelImage.from_label_array_and_ticks(
label_image_array,
None,
physical_ticks,
None,
)

mask_collection = BinaryMaskCollection.from_label_image(label_image)
assert len(mask_collection) == 2

region_0 = mask_collection.uncropped_mask(0)
assert region_0.shape == label_image_array.shape
assert region_0.dtype == np.bool
assert np.all(region_0[0] == 1)
assert np.all(region_0[1:5] == 0)

region_1 = mask_collection.uncropped_mask(1)
assert region_1.shape == label_image_array.shape
assert region_1.dtype == np.bool
assert np.all(region_1[0:3, :] == 0)
assert np.all(region_1[:, 0:3] == 0)
assert np.all(region_1[3:5, 3:5] == [[1, 1],
[1, 0]])


def test_uncropped_mask_no_uncropping():
"""If the mask doesn't need to be uncropped, it should still work. This is an optimized code
path, so it is separately validated.
"""
label_image_array = np.full((5, 5), fill_value=1, dtype=np.int32)

physical_ticks = {Coordinates.Y: [1.2, 2.4, 3.6, 4.8, 6.0],
Coordinates.X: [7.2, 8.4, 9.6, 10.8, 12]}

label_image = LabelImage.from_label_array_and_ticks(
label_image_array,
None,
physical_ticks,
None,
)

mask_collection = BinaryMaskCollection.from_label_image(label_image)
assert len(mask_collection) == 1

region = mask_collection.uncropped_mask(0)
assert region.shape == label_image_array.shape
assert np.all(region == 1)


def test_to_label_image():
# test via roundtrip
label_image_array = np.zeros((5, 6), dtype=np.int32)
Expand Down