diff --git a/starfish/core/image/_filter/reduce.py b/starfish/core/image/_filter/reduce.py index 21d65f3d3..f440043aa 100644 --- a/starfish/core/image/_filter/reduce.py +++ b/starfish/core/image/_filter/reduce.py @@ -1,7 +1,12 @@ +import importlib +from enum import Enum from typing import ( Callable, + cast, Iterable, + Mapping, MutableMapping, + Optional, Sequence, Union ) @@ -17,26 +22,31 @@ class Reduce(FilterAlgorithmBase): """ - Reduces the dimensions of the ImageStack by applying a function - along one or more axes. + Reduces the cardinality of one or more axes to 1 by applying a function across those axes. Parameters ---------- dims : Axes - one or more Axes to project over - func : Union[str, Callable] - function to apply across the dimension(s) specified by dims. - If a function is provided, it should follow the form specified by - DataArray.reduce(): - http://xarray.pydata.org/en/stable/generated/xarray.DataArray.reduce.html - If a string is provided, it should correspond to a numpy function that - matches the form specified above - (i.e., function is resolved: func = getattr(np, func)). - Some common examples below: - amax: maximum intensity projection (applies numpy.amax) - max: maximum intensity projection (this is an alias for amax and applies numpy.amax) - mean: take the mean across the dim(s) (applies numpy.mean) - sum: sum across the dim(s) (applies numpy.sum) + one or more Axes to reduce over + func : str + Name of a function in the module specified by the ``module`` parameter to apply across the + dimension(s) specified by dims. The function is resolved by ``getattr(, func)``, + except in the cases of predefined aliases. See :py:class:`FunctionSource` for more + information about aliases. + + Some common examples for the np FunctionSource: + + - amax: maximum intensity projection (applies np.amax) + - max: maximum intensity projection (this is an alias for amax and applies np.amax) + - mean: take the mean across the dim(s) (applies np.mean) + - sum: sum across the dim(s) (applies np.sum) + module : FunctionSource + Python module that serves as the source of the function. It must be listed as one of the + members of :py:class:`FunctionSource`. + + Currently, the supported FunctionSources are: + - ``np``: the top-level package of numpy + - ``scipy``: the top-level package of scipy clip_method : Clip (Default Clip.CLIP) Controls the way that data are scaled to retain skimage dtype requirements that float data fall in [0, 1]. @@ -44,26 +54,109 @@ class Reduce(FilterAlgorithmBase): Clip.SCALE_BY_IMAGE: data above 1 are scaled by the maximum value, with the maximum value calculated over the entire ImageStack + Examples + -------- + Reducing via max projection. + >>> from starfish.core.imagestack.test.factories.synthetic_stack import synthetic_stack + >>> from starfish.image import Filter + >>> from starfish.types import Axes + >>> stack = synthetic_stack() + >>> reducer = Filter.Reduce({Axes.ROUND}, func="max") + >>> max_proj = reducer.run(stack) + + Reducing via linalg.norm + >>> from starfish.core.imagestack.test.factories.synthetic_stack import synthetic_stack + >>> from starfish.image import Filter + >>> from starfish.types import Axes + >>> stack = synthetic_stack() + >>> reducer = Filter.Reduce( + {Axes.ROUND}, + func="linalg.norm", + module=Filter.Reduce.FunctionSource.scipy, + ord=2, + ) + >>> norm = reducer.run(stack) + See Also -------- - starfish.types.Axes + starfish.core.types.Axes """ + class FunctionSource(Enum): + """Each FunctionSource declares a package from which reduction methods can be obtained. + Generally, the packages should be those that are included as starfish's dependencies for + reproducibility. + + Many packages are broken into subpackages which are not necessarily implicitly imported when + importing the top-level package. For example, ``scipy.linalg`` is not implicitly imported + when one imports ``scipy``. To avoid the complexity of enumerating each scipy subpackage in + FunctionSource, we assemble the fully-qualified method name, and then try all the + permutations of how one could import that method. + + In the example of ``scipy.linalg.norm``, we try the following: + + 1. import ``scipy``, attempt to resolve ``linalg.norm``. + 2. import ``scipy.linalg``, attempt to resolve ``norm``. + """ + + def __init__(self, top_level_package: str, aliases: Optional[Mapping[str, str]] = None): + self.top_level_package = top_level_package + self.aliases = aliases or {} + + def _resolve_method(self, method: str) -> Callable: + """Resolve a method. The method itself might be enclosed in a package, such as + subpackage.actual_method. In that case, we will need to attempt to resolve it in the + following sequence: + + 1. import top_level_package, then try to resolve subpackage.actual_method recursively + through ``getattr`` calls. + 2. import top_level_package.subpackage, then try to resolve actual_method through + ``gettatr`` calls. + + This is done instead of just creating a bunch of FunctionSource for libraries that have + a lot of packages that are not implicitly imported by importing the top-level package. + """ + # first resolve the aliases. + actual_method = self.aliases.get(method, method) + + method_splitted = actual_method.split(".") + splitted = [self.top_level_package] + splitted.extend(method_splitted) + + for divider in range(1, len(splitted)): + import_section = splitted[:divider] + getattr_section = splitted[divider:] + + imported = importlib.import_module(".".join(import_section)) + + try: + for getattr_name in getattr_section: + imported = getattr(imported, getattr_name) + return cast(Callable, imported) + except AttributeError: + pass + + raise AttributeError( + f"Unable to resolve the method {actual_method} from package " + f"{self.top_level_package}") + + np = ("numpy", {'max': 'amax'}) + """Function source for the numpy libraries""" + scipy = ("scipy",) + def __init__( - self, dims: Iterable[Union[Axes, str]], func: Union[str, Callable] = 'max', - clip_method: Clip = Clip.CLIP + self, + dims: Iterable[Union[Axes, str]], + func: str = "max", + module: FunctionSource = FunctionSource.np, + clip_method: Clip = Clip.CLIP, + **kwargs ) -> None: - self.dims = dims + self.func = module._resolve_method(func) self.clip_method = clip_method - - # If the user provided a string, convert to callable - if isinstance(func, str): - if func == 'max': - func = 'amax' - func = getattr(np, func) - self.func = func + self.kwargs = kwargs _DEFAULT_TESTING_PARAMETERS = {"dims": ['r'], "func": 'max'} @@ -88,7 +181,8 @@ def run( """ # Apply the reducing function - reduced = stack._data.reduce(self.func, dim=[Axes(dim).value for dim in self.dims]) + reduced = stack._data.reduce( + self.func, dim=[Axes(dim).value for dim in self.dims], **self.kwargs) # Add the reduced dims back and align with the original stack reduced = reduced.expand_dims(tuple(Axes(dim).value for dim in self.dims)) @@ -128,15 +222,20 @@ def run( "--dims r --dims c") @click.option( "--func", - type=click.Choice(["max", "mean", "sum"]), + type=str, + help="The function to apply across dims." + ) + @click.option( + "--module", + type=click.Choice([member.name for member in list(FunctionSource)]), multiple=False, - help="The function to apply across dims" - "Valid function names: max, mean, sum." + help="Module to source the function from.", + default=FunctionSource.np.name, ) @click.option( "--clip-method", default=Clip.CLIP, type=Clip, - help="method to constrain data to [0,1]. options: 'clip', 'scale_by_image', " - "'scale_by_chunk'") + help="method to constrain data to [0,1]. options: 'clip', 'scale_by_image'") @click.pass_context - def _cli(ctx, dims, func, clip_method): - ctx.obj["component"]._cli_run(ctx, Reduce(dims, func, clip_method)) + def _cli(ctx, dims, func, module, clip_method): + ctx.obj["component"]._cli_run( + ctx, Reduce(dims, func, Reduce.FunctionSource[module], clip_method)) diff --git a/starfish/core/image/_filter/test/test_reduce.py b/starfish/core/image/_filter/test/test_reduce.py index e51775c5b..128ac5058 100644 --- a/starfish/core/image/_filter/test/test_reduce.py +++ b/starfish/core/image/_filter/test/test_reduce.py @@ -2,7 +2,6 @@ import numpy as np import pytest -import xarray as xr from starfish import data from starfish import ImageStack @@ -18,10 +17,7 @@ def make_image_stack(): - ''' - Make a test ImageStack - - ''' + """Make a test ImageStack.""" # Make the test image test = np.ones((2, 4, 1, 2, 2), dtype='float32') * 0.1 @@ -40,9 +36,7 @@ def make_image_stack(): def make_expected_image_stack(func): - ''' - Make the expected image stack result - ''' + """Make the expected image stack result""" if func == 'max': reduced = np.array( @@ -69,7 +63,7 @@ def make_expected_image_stack(func): elif func == 'sum': reduced = np.array( [[[[[0.85, 0.2], - [0.2, 0.2]]], + [0.2, 0.2]]], [[[0.2, 1], [0.2, 0.2]]], [[[0.2, 0.2], @@ -77,24 +71,56 @@ def make_expected_image_stack(func): [[[0.2, 0.2], [0.2, 1]]]]], dtype='float32' ) + elif func == 'norm': + reduced = np.array( + [[[[[0.75663730, 0.14142136], + [0.14142136, 0.14142136]]], + [[[0.14142136, 1.00000000], + [0.14142136, 0.14142136]]], + [[[0.14142136, 0.14142136], + [1.00000000, 0.14142136]]], + [[[0.14142136, 0.14142136], + [0.14142136, 1.00000000]]]]], + dtype=np.float32 + ) + else: + raise ValueError("Unsupported func") expected_stack = ImageStack.from_numpy(reduced) return expected_stack -@pytest.mark.parametrize("func", ['max', 'mean', 'sum']) -def test_image_stack_reduce(func): +@pytest.mark.parametrize( + "expected_result,func,module,kwargs", + [ + (make_expected_image_stack('max'), 'max', None, None), + (make_expected_image_stack('mean'), 'mean', None, None), + (make_expected_image_stack('sum'), 'sum', None, None), + ( + make_expected_image_stack('norm'), + 'linalg.norm', + Reduce.FunctionSource.scipy, + {'ord': 2}, + ), + ] +) +def test_image_stack_reduce(expected_result, func, module, kwargs): # Get the test stack and expected result test_stack = make_image_stack() - expected_result = make_expected_image_stack(func=func) # Filter - red = Reduce(dims=[Axes.ROUND], func=func) + if kwargs is None: + actual_kwargs = {} + else: + actual_kwargs = kwargs.copy() + if module is not None: + actual_kwargs['module'] = module + red = Reduce(dims=[Axes.ROUND], func=func, **actual_kwargs) reduced = red.run(test_stack) - xr.testing.assert_equal(reduced.xarray, expected_result.xarray) + assert np.allclose(reduced.xarray, expected_result.xarray) def test_max_projection_preserves_coordinates():