From 11046e3407b63887eb68a94a123ced8828b92f07 Mon Sep 17 00:00:00 2001 From: Tony Tung Date: Mon, 18 Nov 2019 13:42:44 -0800 Subject: [PATCH] Methods for uncropping binary masks. The existing builders for binary masks crop the masks to the smallest box that fits the non-zero data. This adds a method to reverse the crop when retrieving the masks. Documentation is also added to the builders' docblocks to indicate that they crop the masks. Note that the constructor of BinaryMaskCollection does _not_ crop the masks. The builders decide what they want to do. Depends on #1632 Test plan: Add tests to verify the uncropping works as expected. --- .../morphology/binary_mask/binary_mask.py | 46 +++++++++++++- .../binary_mask/test/test_binary_mask.py | 60 +++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/starfish/core/morphology/binary_mask/binary_mask.py b/starfish/core/morphology/binary_mask/binary_mask.py index 9b5fac38c..e94f71799 100644 --- a/starfish/core/morphology/binary_mask/binary_mask.py +++ b/starfish/core/morphology/binary_mask/binary_mask.py @@ -126,6 +126,47 @@ def _format_mask_as_xarray(self, index: int) -> xr.DataArray: name=f"{index:0{max_mask_name_len}d}" ) + def uncropped_mask(self, index: int) -> xr.DataArray: + """Convert a np-based mask into an xarray DataArray.""" + mask_data = self._masks[index] + uncropped_shape = tuple( + len(self._pixel_ticks[axis]) + for axis, _ in zip(*_get_axes_names(len(self._pixel_ticks))) + ) + + if uncropped_shape == mask_data.binary_mask.shape: + return self._format_mask_as_xarray(index) + + max_mask_name_len = len(str(len(self._masks) - 1)) + + xr_dims: MutableSequence[str] = [] + xr_coords: MutableMapping[Hashable, Any] = {} + + for ix, (axis, coord) in enumerate(zip(*_get_axes_names(len(self._pixel_ticks)))): + xr_dims.append(axis.value) + xr_coords[axis.value] = self._pixel_ticks[axis.value] + xr_coords[coord.value] = (axis.value, self._physical_ticks[coord.value]) + + image = np.zeros( + shape=tuple( + len(self._pixel_ticks[axis]) + for axis, _ in zip(*_get_axes_names(len(self._pixel_ticks))) + ), + dtype=np.bool, + ) + fill_from_mask( + mask_data.binary_mask, + mask_data.offsets, + 1, + image, + ) + return xr.DataArray( + image, + dims=xr_dims, + coords=xr_coords, + name=f"{index:0{max_mask_name_len}d}" + ) + def masks(self) -> Iterator[xr.DataArray]: for mask_index in self._masks.keys(): yield self._format_mask_as_xarray(mask_index) @@ -172,7 +213,10 @@ def max_shape(self) -> Mapping[Axes, int]: @classmethod def from_label_image(cls, label_image: LabelImage) -> "BinaryMaskCollection": - """Creates binary masks from a label image. + """Creates binary masks from a label image. Masks are cropped to the smallest size that + contains the non-zero values, but pixel and physical coordinates ticks are retained. Masks + extracted from BinaryMaskCollections will be cropped. To extract masks sized to the + original label image, use :py:meth:`starfish.BinaryMaskCollection.uncropped_mask`. Parameters ---------- diff --git a/starfish/core/morphology/binary_mask/test/test_binary_mask.py b/starfish/core/morphology/binary_mask/test/test_binary_mask.py index 474c80e2e..ff1954fc0 100644 --- a/starfish/core/morphology/binary_mask/test/test_binary_mask.py +++ b/starfish/core/morphology/binary_mask/test/test_binary_mask.py @@ -51,6 +51,66 @@ def test_from_label_image(): physical_ticks[Coordinates.X][3:6]) +def test_uncropped_mask(): + """Test that BinaryMaskCollection.uncropped_mask() works correctly. + """ + label_image_array = np.zeros((5, 5), dtype=np.int32) + label_image_array[0] = 1 + label_image_array[3:5, 3:5] = 2 + label_image_array[-1, -1] = 0 + + physical_ticks = {Coordinates.Y: [1.2, 2.4, 3.6, 4.8, 6.0], + Coordinates.X: [7.2, 8.4, 9.6, 10.8, 12]} + + label_image = LabelImage.from_label_array_and_ticks( + label_image_array, + None, + physical_ticks, + None, + ) + + mask_collection = BinaryMaskCollection.from_label_image(label_image) + assert len(mask_collection) == 2 + + region_0 = mask_collection.uncropped_mask(0) + assert region_0.shape == label_image_array.shape + assert region_0.dtype == np.bool + assert np.all(region_0[0] == 1) + assert np.all(region_0[1:5] == 0) + + region_1 = mask_collection.uncropped_mask(1) + assert region_1.shape == label_image_array.shape + assert region_1.dtype == np.bool + assert np.all(region_1[0:3, :] == 0) + assert np.all(region_1[:, 0:3] == 0) + assert np.all(region_1[3:5, 3:5] == [[1, 1], + [1, 0]]) + + +def test_uncropped_mask_no_uncropping(): + """If the mask doesn't need to be uncropped, it should still work. This is an optimized code + path, so it is separately validated. + """ + label_image_array = np.full((5, 5), fill_value=1, dtype=np.int32) + + physical_ticks = {Coordinates.Y: [1.2, 2.4, 3.6, 4.8, 6.0], + Coordinates.X: [7.2, 8.4, 9.6, 10.8, 12]} + + label_image = LabelImage.from_label_array_and_ticks( + label_image_array, + None, + physical_ticks, + None, + ) + + mask_collection = BinaryMaskCollection.from_label_image(label_image) + assert len(mask_collection) == 1 + + region = mask_collection.uncropped_mask(0) + assert region.shape == label_image_array.shape + assert np.all(region == 1) + + def test_to_label_image(): # test via roundtrip label_image_array = np.zeros((5, 6), dtype=np.int32)