diff --git a/starfish/core/image/_filter/reduce.py b/starfish/core/image/_filter/reduce.py new file mode 100644 index 000000000..21d65f3d3 --- /dev/null +++ b/starfish/core/image/_filter/reduce.py @@ -0,0 +1,142 @@ +from typing import ( + Callable, + Iterable, + MutableMapping, + Sequence, + Union +) + +import numpy as np + +from starfish.core.imagestack.imagestack import ImageStack +from starfish.core.types import Axes, Clip, Coordinates, Number +from starfish.core.util import click +from starfish.core.util.dtype import preserve_float_range +from ._base import FilterAlgorithmBase + + +class Reduce(FilterAlgorithmBase): + """ + Reduces the dimensions of the ImageStack by applying a function + along one or more axes. + + Parameters + ---------- + dims : Axes + one or more Axes to project over + func : Union[str, Callable] + function to apply across the dimension(s) specified by dims. + If a function is provided, it should follow the form specified by + DataArray.reduce(): + http://xarray.pydata.org/en/stable/generated/xarray.DataArray.reduce.html + If a string is provided, it should correspond to a numpy function that + matches the form specified above + (i.e., function is resolved: func = getattr(np, func)). + Some common examples below: + amax: maximum intensity projection (applies numpy.amax) + max: maximum intensity projection (this is an alias for amax and applies numpy.amax) + mean: take the mean across the dim(s) (applies numpy.mean) + sum: sum across the dim(s) (applies numpy.sum) + clip_method : Clip + (Default Clip.CLIP) Controls the way that data are scaled to retain skimage dtype + requirements that float data fall in [0, 1]. + Clip.CLIP: data above 1 are set to 1, and below 0 are set to 0 + Clip.SCALE_BY_IMAGE: data above 1 are scaled by the maximum value, with the maximum + value calculated over the entire ImageStack + + See Also + -------- + starfish.types.Axes + + """ + + def __init__( + self, dims: Iterable[Union[Axes, str]], func: Union[str, Callable] = 'max', + clip_method: Clip = Clip.CLIP + ) -> None: + + self.dims = dims + self.clip_method = clip_method + + # If the user provided a string, convert to callable + if isinstance(func, str): + if func == 'max': + func = 'amax' + func = getattr(np, func) + self.func = func + + _DEFAULT_TESTING_PARAMETERS = {"dims": ['r'], "func": 'max'} + + def run( + self, + stack: ImageStack, + *args, + ) -> ImageStack: + """Performs the dimension reduction with the specifed function + + Parameters + ---------- + stack : ImageStack + Stack to be filtered. + + Returns + ------- + ImageStack : + If in-place is False, return the results of filter as a new stack. Otherwise return the + original stack. + + """ + + # Apply the reducing function + reduced = stack._data.reduce(self.func, dim=[Axes(dim).value for dim in self.dims]) + + # Add the reduced dims back and align with the original stack + reduced = reduced.expand_dims(tuple(Axes(dim).value for dim in self.dims)) + reduced = reduced.transpose(*stack.xarray.dims) + + if self.clip_method == Clip.CLIP: + reduced = preserve_float_range(reduced, rescale=False) + else: + reduced = preserve_float_range(reduced, rescale=True) + + # Update the physical coordinates + physical_coords: MutableMapping[Coordinates, Sequence[Number]] = {} + for axis, coord in ( + (Axes.X, Coordinates.X), + (Axes.Y, Coordinates.Y), + (Axes.ZPLANE, Coordinates.Z)): + if axis in self.dims: + # this axis was projected out of existence. + assert coord.value not in reduced.coords + physical_coords[coord] = [np.average(stack._data.coords[coord.value])] + else: + physical_coords[coord] = reduced.coords[coord.value] + reduced_stack = ImageStack.from_numpy(reduced.values, coordinates=physical_coords) + + return reduced_stack + + @staticmethod + @click.command("Reduce") + @click.option( + "--dims", + type=click.Choice( + [Axes.ROUND.value, Axes.CH.value, Axes.ZPLANE.value, Axes.X.value, Axes.Y.value] + ), + multiple=True, + help="The dimensions the Imagestack should max project over." + "For multiple dimensions add multiple --dims. Ex." + "--dims r --dims c") + @click.option( + "--func", + type=click.Choice(["max", "mean", "sum"]), + multiple=False, + help="The function to apply across dims" + "Valid function names: max, mean, sum." + ) + @click.option( + "--clip-method", default=Clip.CLIP, type=Clip, + help="method to constrain data to [0,1]. options: 'clip', 'scale_by_image', " + "'scale_by_chunk'") + @click.pass_context + def _cli(ctx, dims, func, clip_method): + ctx.obj["component"]._cli_run(ctx, Reduce(dims, func, clip_method)) diff --git a/starfish/core/image/_filter/test/test_api_contract.py b/starfish/core/image/_filter/test/test_api_contract.py index 3824450e0..792349b9a 100644 --- a/starfish/core/image/_filter/test/test_api_contract.py +++ b/starfish/core/image/_filter/test/test_api_contract.py @@ -25,6 +25,7 @@ from starfish import ImageStack from starfish.core.image import Filter from starfish.core.image._filter.max_proj import MaxProject +from starfish.core.image._filter.reduce import Reduce methods: Mapping[str, Type] = Filter._algorithm_to_class_map() @@ -53,38 +54,42 @@ def test_all_methods_adhere_to_contract(filter_class): # assert isinstance(volume_param, bool), \ # f'{filter_class} is_volume must be a bool, not {type(volume_param)}' - # always emits an Image, even if in_place=True and the resulting filter operates in-place data = generate_default_data() - try: - filtered = instance.run(data, in_place=True) - except TypeError: - raise AssertionError(f'{filter_class} must accept in_place parameter') - assert isinstance(filtered, ImageStack) - if filter_class is not MaxProject: - # Max Proj does not have an in place option, so we need to skip this assertion + + # Max Proj and Reduce don't have an in_place, n_processes, verbose option, + # so we need to skip these tests + if filter_class not in [MaxProject, Reduce]: + # always emits an Image, even if in_place=True and the resulting filter operates in-place + try: + filtered = instance.run(data, in_place=True) + except TypeError: + raise AssertionError(f'{filter_class} must accept in_place parameter') + assert isinstance(filtered, ImageStack) assert data is filtered, \ f'{filter_class} should return a reference to the input ImageStack when run in_place' - # operates out of place - data = generate_default_data() - filtered = instance.run(data, in_place=False) - assert data is not filtered, \ - f'{filter_class} should output a new ImageStack when run out-of-place' - - # accepts n_processes - # TODO shanaxel: verify that this causes more than one process to be generated - data = generate_default_data() - try: - instance.run(data, n_processes=1) - except TypeError: - raise AssertionError(f'{filter_class} must accept n_processes parameter') - - # accepts verbose, and if passed, prints progress - data = generate_default_data() - try: - instance.run(data, verbose=True) - except TypeError: - raise AssertionError(f'{filter_class} must accept verbose parameter') + # operates out of place + data = generate_default_data() + filtered = instance.run(data, in_place=False) + assert data is not filtered, \ + f'{filter_class} should output a new ImageStack when run out-of-place' + + # accepts n_processes + # TODO shanaxel: verify that this causes more than one process to be generated + data = generate_default_data() + try: + instance.run(data, n_processes=1) + except TypeError: + raise AssertionError(f'{filter_class} must accept n_processes parameter') + + # accepts verbose, and if passed, prints progress + data = generate_default_data() + try: + instance.run(data, verbose=True) + except TypeError: + raise AssertionError(f'{filter_class} must accept verbose parameter') + else: + filtered = instance.run(data) # output is dtype float and within the expected interval of [0, 1] assert filtered.xarray.dtype == np.float32, f'{filter_class} must output float32 data' diff --git a/starfish/core/image/_filter/test/test_reduce.py b/starfish/core/image/_filter/test/test_reduce.py new file mode 100644 index 000000000..e51775c5b --- /dev/null +++ b/starfish/core/image/_filter/test/test_reduce.py @@ -0,0 +1,126 @@ +from collections import OrderedDict + +import numpy as np +import pytest +import xarray as xr + +from starfish import data +from starfish import ImageStack +from starfish.core.image._filter.reduce import Reduce +from starfish.core.imagestack.test.factories import imagestack_with_coords_factory +from starfish.core.imagestack.test.imagestack_test_utils import verify_physical_coordinates +from starfish.types import Axes, PhysicalCoordinateTypes + + +X_COORDS = 1, 2 +Y_COORDS = 4, 6 +Z_COORDS = 1, 3 + + +def make_image_stack(): + ''' + Make a test ImageStack + + ''' + + # Make the test image + test = np.ones((2, 4, 1, 2, 2), dtype='float32') * 0.1 + + x = [0, 0, 1, 1] + y = [0, 1, 0, 1] + + for i in range(4): + test[0, i, 0, x[i], y[i]] = 1 + test[0, 0, 0, 0, 0] = 0.75 + + # Make the ImageStack + test_stack = ImageStack.from_numpy(test) + + return test_stack + + +def make_expected_image_stack(func): + ''' + Make the expected image stack result + ''' + + if func == 'max': + reduced = np.array( + [[[[[0.75, 0.1], + [0.1, 0.1]]], + [[[0.1, 1], + [0.1, 0.1]]], + [[[0.1, 0.1], + [1, 0.1]]], + [[[0.1, 0.1], + [0.1, 1]]]]], dtype='float32' + ) + elif func == 'mean': + reduced = np.array( + [[[[[0.425, 0.1], + [0.1, 0.1]]], + [[[0.1, 0.55], + [0.1, 0.1]]], + [[[0.1, 0.1], + [0.55, 0.1]]], + [[[0.1, 0.1], + [0.1, 0.55]]]]], dtype='float32' + ) + elif func == 'sum': + reduced = np.array( + [[[[[0.85, 0.2], + [0.2, 0.2]]], + [[[0.2, 1], + [0.2, 0.2]]], + [[[0.2, 0.2], + [1, 0.2]]], + [[[0.2, 0.2], + [0.2, 1]]]]], dtype='float32' + ) + + expected_stack = ImageStack.from_numpy(reduced) + + return expected_stack + + +@pytest.mark.parametrize("func", ['max', 'mean', 'sum']) +def test_image_stack_reduce(func): + + # Get the test stack and expected result + test_stack = make_image_stack() + expected_result = make_expected_image_stack(func=func) + + # Filter + red = Reduce(dims=[Axes.ROUND], func=func) + reduced = red.run(test_stack) + + xr.testing.assert_equal(reduced.xarray, expected_result.xarray) + + +def test_max_projection_preserves_coordinates(): + e = data.ISS(use_test_data=True) + nuclei = e.fov().get_image('nuclei') + + red = Reduce(dims=[Axes.ROUND, Axes.CH, Axes.ZPLANE], func='max') + nuclei_proj = red.run(nuclei) + + # Since this data already has only 1 round, 1 ch, 1 zplane + # let's just assert that the max_proj operation didn't change anything + assert nuclei.xarray.equals(nuclei_proj.xarray) + + stack_shape = OrderedDict([(Axes.ROUND, 3), (Axes.CH, 2), + (Axes.ZPLANE, 3), (Axes.Y, 10), (Axes.X, 10)]) + + # Create stack with coordinates, verify coords unaffected by max_poj + physical_coords = OrderedDict([(PhysicalCoordinateTypes.X_MIN, X_COORDS[0]), + (PhysicalCoordinateTypes.X_MAX, X_COORDS[1]), + (PhysicalCoordinateTypes.Y_MIN, Y_COORDS[0]), + (PhysicalCoordinateTypes.Y_MAX, Y_COORDS[1]), + (PhysicalCoordinateTypes.Z_MIN, Z_COORDS[0]), + (PhysicalCoordinateTypes.Z_MAX, Z_COORDS[1])]) + + stack = imagestack_with_coords_factory(stack_shape, physical_coords) + + stack_proj = red.run(stack) + expected_z = np.average(Z_COORDS) + verify_physical_coordinates(stack_proj, X_COORDS, Y_COORDS, expected_z)