Skip to content

Commit

Permalink
Provide an apply method to binary mask collections.
Browse files Browse the repository at this point in the history
This allows a method to be applied to the individual masks and constituted into a new binary mask collection.

Depends on #1653
Test plan: Wrote a test that performed binary dilation on the standard set of masks used for testing, and verified the results were somewhat sane.
  • Loading branch information
Tony Tung committed Nov 20, 2019
1 parent 995db89 commit 5bd39db
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
68 changes: 68 additions & 0 deletions starfish/core/morphology/binary_mask/binary_mask.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import (
Any,
Callable,
Hashable,
Iterator,
Mapping,
Expand Down Expand Up @@ -508,6 +511,71 @@ def to_targz(self, path: Union[str, Path]):
with open(os.fspath(path), "wb") as fh:
_io.BinaryMaskIO.write_versioned_binary_mask(fh, self)

def _apply(
self,
function: Callable[[np.ndarray], np.ndarray],
*args,
n_processes: Optional[int] = None,
**kwargs
) -> "BinaryMaskCollection":
"""Given a function that takes an ndarray and outputs another, apply that method to all the
masks in this collection to form a new collection. All the masks are uncropped before being
passed to the function.
Parameters
----------
function : Callable[[np.ndarray], np.ndarray]
A function that should produce a mask array when given a mask array.
n_processes : Optional[int]
The number of processes to use for apply. If None, uses the output of os.cpu_count()
(default = None).
"""
if n_processes is None:
n_processes = os.cpu_count()

applied_func = partial(
BinaryMaskCollection._apply_single_mask,
function=function,
mask_collection=self,
args=args,
kwargs=kwargs,
)
selectors = range(len(self._masks))

with ThreadPoolExecutor(max_workers=n_processes) as tpe:
results: Iterator[MaskData] = tpe.map(applied_func, selectors)

return BinaryMaskCollection(
self._pixel_ticks,
self._physical_ticks,
list(results),
self._log,
)

@staticmethod
def _apply_single_mask(
mask_index: int,
mask_collection: "BinaryMaskCollection",
args: Sequence,
kwargs: Mapping,
function: Callable[[np.ndarray], np.ndarray],
) -> MaskData:
"""Given a mask collection, and an index, retrieve a mask and apply a function to that mask.
Return the output along with the offsets of the original mask. If the original mask is
uncropped, then the offsets should all be 0. If the original mask is not uncropped, it is
propagated from the input masks's offsets.
"""
input_mask = mask_collection.uncropped_mask(mask_index)
output_mask = function(input_mask.values, *args, **kwargs) # type: ignore

selection_range: Sequence[slice] = BinaryMaskCollection._crop_mask(output_mask)

return MaskData(
output_mask[selection_range],
tuple(selection.start for selection in selection_range),
None
)


# these need to be at the end to avoid recursive imports
from . import _io # noqa
28 changes: 28 additions & 0 deletions starfish/core/morphology/binary_mask/test/test_binary_mask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from skimage.morphology import binary_dilation

from starfish.core.morphology.label_image import LabelImage
from starfish.core.types import Axes, Coordinates
Expand Down Expand Up @@ -157,3 +158,30 @@ def test_from_empty_label_image(tmp_path):
original_props = binary_mask_collection.mask_regionprops(ix)
recalculated_props = binary_mask_collection.mask_regionprops(ix)
assert original_props == recalculated_props


def test_apply():
input_mask_collection = binary_mask_collection_2d()
output_mask_collection = input_mask_collection._apply(binary_dilation)

assert input_mask_collection._pixel_ticks == output_mask_collection._pixel_ticks
assert input_mask_collection._physical_ticks == output_mask_collection._physical_ticks
assert input_mask_collection._log == output_mask_collection._log
assert len(input_mask_collection) == len(output_mask_collection)

region_0, region_1 = output_mask_collection.masks()

assert region_0.name == '0'
assert region_1.name == '1'

temp = np.ones((2, 6), dtype=np.bool)
assert np.array_equal(region_0, temp)
temp = np.ones((3, 4), dtype=np.bool)
temp[0, 0] = 0
assert np.array_equal(region_1, temp)

assert np.array_equal(region_0[Axes.Y.value], [0, 1])
assert np.array_equal(region_0[Axes.X.value], [0, 1, 2, 3, 4, 5])

assert np.array_equal(region_1[Axes.Y.value], [2, 3, 4])
assert np.array_equal(region_1[Axes.X.value], [2, 3, 4, 5])

0 comments on commit 5bd39db

Please sign in to comment.