Skip to content

Commit

Permalink
Refactor reduce to take an optional module and only a function name. (#…
Browse files Browse the repository at this point in the history
…1386)

Module defaults to np (== numpy).  Class no longer accepts a callable.  Alias list is a dict that relatively easy to extend.

Test plan: travis
  • Loading branch information
Tony Tung authored Jun 19, 2019
1 parent 89dc61e commit fcded6d
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 49 deletions.
169 changes: 134 additions & 35 deletions starfish/core/image/_filter/reduce.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import importlib
from enum import Enum
from typing import (
Callable,
cast,
Iterable,
Mapping,
MutableMapping,
Optional,
Sequence,
Union
)
Expand All @@ -17,53 +22,141 @@

class Reduce(FilterAlgorithmBase):
"""
Reduces the dimensions of the ImageStack by applying a function
along one or more axes.
Reduces the cardinality of one or more axes to 1 by applying a function across those axes.
Parameters
----------
dims : Axes
one or more Axes to project over
func : Union[str, Callable]
function to apply across the dimension(s) specified by dims.
If a function is provided, it should follow the form specified by
DataArray.reduce():
http://xarray.pydata.org/en/stable/generated/xarray.DataArray.reduce.html
If a string is provided, it should correspond to a numpy function that
matches the form specified above
(i.e., function is resolved: func = getattr(np, func)).
Some common examples below:
amax: maximum intensity projection (applies numpy.amax)
max: maximum intensity projection (this is an alias for amax and applies numpy.amax)
mean: take the mean across the dim(s) (applies numpy.mean)
sum: sum across the dim(s) (applies numpy.sum)
one or more Axes to reduce over
func : str
Name of a function in the module specified by the ``module`` parameter to apply across the
dimension(s) specified by dims. The function is resolved by ``getattr(<module>, func)``,
except in the cases of predefined aliases. See :py:class:`FunctionSource` for more
information about aliases.
Some common examples for the np FunctionSource:
- amax: maximum intensity projection (applies np.amax)
- max: maximum intensity projection (this is an alias for amax and applies np.amax)
- mean: take the mean across the dim(s) (applies np.mean)
- sum: sum across the dim(s) (applies np.sum)
module : FunctionSource
Python module that serves as the source of the function. It must be listed as one of the
members of :py:class:`FunctionSource`.
Currently, the supported FunctionSources are:
- ``np``: the top-level package of numpy
- ``scipy``: the top-level package of scipy
clip_method : Clip
(Default Clip.CLIP) Controls the way that data are scaled to retain skimage dtype
requirements that float data fall in [0, 1].
Clip.CLIP: data above 1 are set to 1, and below 0 are set to 0
Clip.SCALE_BY_IMAGE: data above 1 are scaled by the maximum value, with the maximum
value calculated over the entire ImageStack
Examples
--------
Reducing via max projection.
>>> from starfish.core.imagestack.test.factories.synthetic_stack import synthetic_stack
>>> from starfish.image import Filter
>>> from starfish.types import Axes
>>> stack = synthetic_stack()
>>> reducer = Filter.Reduce({Axes.ROUND}, func="max")
>>> max_proj = reducer.run(stack)
Reducing via linalg.norm
>>> from starfish.core.imagestack.test.factories.synthetic_stack import synthetic_stack
>>> from starfish.image import Filter
>>> from starfish.types import Axes
>>> stack = synthetic_stack()
>>> reducer = Filter.Reduce(
{Axes.ROUND},
func="linalg.norm",
module=Filter.Reduce.FunctionSource.scipy,
ord=2,
)
>>> norm = reducer.run(stack)
See Also
--------
starfish.types.Axes
starfish.core.types.Axes
"""

class FunctionSource(Enum):
"""Each FunctionSource declares a package from which reduction methods can be obtained.
Generally, the packages should be those that are included as starfish's dependencies for
reproducibility.
Many packages are broken into subpackages which are not necessarily implicitly imported when
importing the top-level package. For example, ``scipy.linalg`` is not implicitly imported
when one imports ``scipy``. To avoid the complexity of enumerating each scipy subpackage in
FunctionSource, we assemble the fully-qualified method name, and then try all the
permutations of how one could import that method.
In the example of ``scipy.linalg.norm``, we try the following:
1. import ``scipy``, attempt to resolve ``linalg.norm``.
2. import ``scipy.linalg``, attempt to resolve ``norm``.
"""

def __init__(self, top_level_package: str, aliases: Optional[Mapping[str, str]] = None):
self.top_level_package = top_level_package
self.aliases = aliases or {}

def _resolve_method(self, method: str) -> Callable:
"""Resolve a method. The method itself might be enclosed in a package, such as
subpackage.actual_method. In that case, we will need to attempt to resolve it in the
following sequence:
1. import top_level_package, then try to resolve subpackage.actual_method recursively
through ``getattr`` calls.
2. import top_level_package.subpackage, then try to resolve actual_method through
``gettatr`` calls.
This is done instead of just creating a bunch of FunctionSource for libraries that have
a lot of packages that are not implicitly imported by importing the top-level package.
"""
# first resolve the aliases.
actual_method = self.aliases.get(method, method)

method_splitted = actual_method.split(".")
splitted = [self.top_level_package]
splitted.extend(method_splitted)

for divider in range(1, len(splitted)):
import_section = splitted[:divider]
getattr_section = splitted[divider:]

imported = importlib.import_module(".".join(import_section))

try:
for getattr_name in getattr_section:
imported = getattr(imported, getattr_name)
return cast(Callable, imported)
except AttributeError:
pass

raise AttributeError(
f"Unable to resolve the method {actual_method} from package "
f"{self.top_level_package}")

np = ("numpy", {'max': 'amax'})
"""Function source for the numpy libraries"""
scipy = ("scipy",)

def __init__(
self, dims: Iterable[Union[Axes, str]], func: Union[str, Callable] = 'max',
clip_method: Clip = Clip.CLIP
self,
dims: Iterable[Union[Axes, str]],
func: str = "max",
module: FunctionSource = FunctionSource.np,
clip_method: Clip = Clip.CLIP,
**kwargs
) -> None:

self.dims = dims
self.func = module._resolve_method(func)
self.clip_method = clip_method

# If the user provided a string, convert to callable
if isinstance(func, str):
if func == 'max':
func = 'amax'
func = getattr(np, func)
self.func = func
self.kwargs = kwargs

_DEFAULT_TESTING_PARAMETERS = {"dims": ['r'], "func": 'max'}

Expand All @@ -88,7 +181,8 @@ def run(
"""

# Apply the reducing function
reduced = stack._data.reduce(self.func, dim=[Axes(dim).value for dim in self.dims])
reduced = stack._data.reduce(
self.func, dim=[Axes(dim).value for dim in self.dims], **self.kwargs)

# Add the reduced dims back and align with the original stack
reduced = reduced.expand_dims(tuple(Axes(dim).value for dim in self.dims))
Expand Down Expand Up @@ -128,15 +222,20 @@ def run(
"--dims r --dims c")
@click.option(
"--func",
type=click.Choice(["max", "mean", "sum"]),
type=str,
help="The function to apply across dims."
)
@click.option(
"--module",
type=click.Choice([member.name for member in list(FunctionSource)]),
multiple=False,
help="The function to apply across dims"
"Valid function names: max, mean, sum."
help="Module to source the function from.",
default=FunctionSource.np.name,
)
@click.option(
"--clip-method", default=Clip.CLIP, type=Clip,
help="method to constrain data to [0,1]. options: 'clip', 'scale_by_image', "
"'scale_by_chunk'")
help="method to constrain data to [0,1]. options: 'clip', 'scale_by_image'")
@click.pass_context
def _cli(ctx, dims, func, clip_method):
ctx.obj["component"]._cli_run(ctx, Reduce(dims, func, clip_method))
def _cli(ctx, dims, func, module, clip_method):
ctx.obj["component"]._cli_run(
ctx, Reduce(dims, func, Reduce.FunctionSource[module], clip_method))
54 changes: 40 additions & 14 deletions starfish/core/image/_filter/test/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import pytest
import xarray as xr

from starfish import data
from starfish import ImageStack
Expand All @@ -18,10 +17,7 @@


def make_image_stack():
'''
Make a test ImageStack
'''
"""Make a test ImageStack."""

# Make the test image
test = np.ones((2, 4, 1, 2, 2), dtype='float32') * 0.1
Expand All @@ -40,9 +36,7 @@ def make_image_stack():


def make_expected_image_stack(func):
'''
Make the expected image stack result
'''
"""Make the expected image stack result"""

if func == 'max':
reduced = np.array(
Expand All @@ -69,32 +63,64 @@ def make_expected_image_stack(func):
elif func == 'sum':
reduced = np.array(
[[[[[0.85, 0.2],
[0.2, 0.2]]],
[0.2, 0.2]]],
[[[0.2, 1],
[0.2, 0.2]]],
[[[0.2, 0.2],
[1, 0.2]]],
[[[0.2, 0.2],
[0.2, 1]]]]], dtype='float32'
)
elif func == 'norm':
reduced = np.array(
[[[[[0.75663730, 0.14142136],
[0.14142136, 0.14142136]]],
[[[0.14142136, 1.00000000],
[0.14142136, 0.14142136]]],
[[[0.14142136, 0.14142136],
[1.00000000, 0.14142136]]],
[[[0.14142136, 0.14142136],
[0.14142136, 1.00000000]]]]],
dtype=np.float32
)
else:
raise ValueError("Unsupported func")

expected_stack = ImageStack.from_numpy(reduced)

return expected_stack


@pytest.mark.parametrize("func", ['max', 'mean', 'sum'])
def test_image_stack_reduce(func):
@pytest.mark.parametrize(
"expected_result,func,module,kwargs",
[
(make_expected_image_stack('max'), 'max', None, None),
(make_expected_image_stack('mean'), 'mean', None, None),
(make_expected_image_stack('sum'), 'sum', None, None),
(
make_expected_image_stack('norm'),
'linalg.norm',
Reduce.FunctionSource.scipy,
{'ord': 2},
),
]
)
def test_image_stack_reduce(expected_result, func, module, kwargs):

# Get the test stack and expected result
test_stack = make_image_stack()
expected_result = make_expected_image_stack(func=func)

# Filter
red = Reduce(dims=[Axes.ROUND], func=func)
if kwargs is None:
actual_kwargs = {}
else:
actual_kwargs = kwargs.copy()
if module is not None:
actual_kwargs['module'] = module
red = Reduce(dims=[Axes.ROUND], func=func, **actual_kwargs)
reduced = red.run(test_stack)

xr.testing.assert_equal(reduced.xarray, expected_result.xarray)
assert np.allclose(reduced.xarray, expected_result.xarray)


def test_max_projection_preserves_coordinates():
Expand Down

0 comments on commit fcded6d

Please sign in to comment.