Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing of positional arguments in apply for Groupby objects #2413

Merged
merged 11 commits into from
Dec 24, 2018
8 changes: 8 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ Enhancements
- :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` now supports the
``loffset`` kwarg just like Pandas.
By `Deepak Cherian <https://github.com/dcherian>`_
- The `apply` methods for `DatasetGroupBy`, `DataArrayGroupBy`,
`DatasetResample` and `DataArrayResample` can now pass positional arguments to
the applied function.
By `Matti Eskelinen <https://github.com/maaleske>`_.
- 0d slices of ndarrays are now obtained directly through indexing, rather than
extracting and wrapping a scalar, avoiding unnecessary copying. By `Daniel
Wennberg <https://github.com/danielwe>`_.
Expand Down Expand Up @@ -260,13 +264,17 @@ Announcements of note:
for more details.
- We have a new :doc:`roadmap` that outlines our future development plans.

- `Dataset.apply` now properly documents the way `func` is called.
By `Matti Eskelinen <https://github.com/maaleske>`_.

Enhancements
~~~~~~~~~~~~

- :py:meth:`~xarray.DataArray.differentiate` and
:py:meth:`~xarray.Dataset.differentiate` are newly added.
(:issue:`1332`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- Default colormap for sequential and divergent data can now be set via
:py:func:`~xarray.set_options()`
(:issue:`2394`)
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2953,8 +2953,8 @@ def apply(self, func, keep_attrs=None, args=(), **kwargs):
Parameters
----------
func : function
Function which can be called in the form `f(x, **kwargs)` to
transform each DataArray `x` in this dataset into another
Function which can be called in the form `func(x, *args, **kwargs)`
to transform each DataArray `x` in this dataset into another
DataArray.
keep_attrs : bool, optional
If True, the dataset's attributes (`attrs`) will be copied from
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def lookup_order(dimension):
new_order = sorted(stacked.dims, key=lookup_order)
return stacked.transpose(*new_order)

def apply(self, func, shortcut=False, **kwargs):
def apply(self, func, shortcut=False, args=(), **kwargs):
"""Apply a function over each array in the group and concatenate them
together into a new array.

Expand Down Expand Up @@ -532,6 +532,8 @@ def apply(self, func, shortcut=False, **kwargs):
If these conditions are satisfied `shortcut` provides significant
speedup. This should be the case for many common groupby operations
(e.g., applying numpy ufuncs).
args : tuple, optional
Positional arguments passed to `func`.
**kwargs
Used to call `func(ar, **kwargs)` for each array `ar`.

Expand All @@ -544,7 +546,7 @@ def apply(self, func, shortcut=False, **kwargs):
grouped = self._iter_grouped_shortcut()
else:
grouped = self._iter_grouped()
applied = (maybe_wrap_array(arr, func(arr, **kwargs))
applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs))
for arr in grouped)
return self._combine(applied, shortcut=shortcut)

Expand Down Expand Up @@ -642,7 +644,7 @@ def wrapped_func(self, dim=DEFAULT_DIMS, axis=None,


class DatasetGroupBy(GroupBy, ImplementsDatasetReduce):
def apply(self, func, **kwargs):
def apply(self, func, args=(), **kwargs):
"""Apply a function over each Dataset in the group and concatenate them
together into a new Dataset.

Expand All @@ -661,6 +663,8 @@ def apply(self, func, **kwargs):
----------
func : function
Callable to apply to each sub-dataset.
args : tuple, optional
Positional arguments to pass to `func`.
**kwargs
Used to call `func(ds, **kwargs)` for each sub-dataset `ar`.

Expand All @@ -670,7 +674,7 @@ def apply(self, func, **kwargs):
The result of splitting, applying and combining this dataset.
"""
kwargs.pop('shortcut', None) # ignore shortcut if set (for now)
applied = (func(ds, **kwargs) for ds in self._iter_grouped())
applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
return self._combine(applied)

def _combine(self, applied):
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, *args, **kwargs):
"('{}')! ".format(self._resample_dim, self._dim))
super(DataArrayResample, self).__init__(*args, **kwargs)

def apply(self, func, shortcut=False, **kwargs):
def apply(self, func, shortcut=False, args=(), **kwargs):
"""Apply a function over each array in the group and concatenate them
together into a new array.

Expand Down Expand Up @@ -158,6 +158,8 @@ def apply(self, func, shortcut=False, **kwargs):
If these conditions are satisfied `shortcut` provides significant
speedup. This should be the case for many common groupby operations
(e.g., applying numpy ufuncs).
args : tuple, optional
Positional arguments passed on to `func`.
**kwargs
Used to call `func(ar, **kwargs)` for each array `ar`.

Expand All @@ -167,7 +169,7 @@ def apply(self, func, shortcut=False, **kwargs):
The result of splitting, applying and combining this array.
"""
combined = super(DataArrayResample, self).apply(
func, shortcut=shortcut, **kwargs)
func, shortcut=shortcut, args=args, **kwargs)

# If the aggregation function didn't drop the original resampling
# dimension, then we need to do so before we can rename the proxy
Expand Down Expand Up @@ -240,7 +242,7 @@ def __init__(self, *args, **kwargs):
"('{}')! ".format(self._resample_dim, self._dim))
super(DatasetResample, self).__init__(*args, **kwargs)

def apply(self, func, **kwargs):
def apply(self, func, args=(), **kwargs):
"""Apply a function over each Dataset in the groups generated for
resampling and concatenate them together into a new Dataset.

Expand All @@ -259,6 +261,8 @@ def apply(self, func, **kwargs):
----------
func : function
Callable to apply to each sub-dataset.
args : tuple, optional
Positional arguments passed on to `func`.
**kwargs
Used to call `func(ds, **kwargs)` for each sub-dataset `ar`.

Expand All @@ -268,7 +272,7 @@ def apply(self, func, **kwargs):
The result of splitting, applying and combining this dataset.
"""
kwargs.pop('shortcut', None) # ignore shortcut if set (for now)
applied = (func(ds, **kwargs) for ds in self._iter_grouped())
applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
combined = self._combine(applied)

return combined.rename({self._resample_dim: self._dim})
Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,6 +2295,17 @@ def test_resample(self):
with raises_regex(ValueError, 'index must be monotonic'):
array[[2, 0, 1]].resample(time='1D')

def test_da_resample_func_args(self):

def func(arg1, arg2, arg3=0.):
return arg1.mean('time') + arg2 + arg3

times = pd.date_range('2000', periods=3, freq='D')
da = xr.DataArray([1., 1., 1.], coords=[times], dims=['time'])
expected = xr.DataArray([3., 3., 3.], coords=[times], dims=['time'])
actual = da.resample(time='D').apply(func, args=(1.,), arg3=1.)
assert_identical(actual, expected)

@requires_cftime
def test_resample_cftimeindex(self):
cftime = _import_cftime()
Expand Down
13 changes: 13 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,6 +2886,19 @@ def test_resample_old_api(self):
with raises_regex(TypeError, r'resample\(\) no longer supports'):
ds.resample('1D', dim='time')

def test_ds_resample_apply_func_args(self):

def func(arg1, arg2, arg3=0.):
return arg1.mean('time') + arg2 + arg3

times = pd.date_range('2000', freq='D', periods=3)
ds = xr.Dataset({'foo': ('time', [1., 1., 1.]),
'time': times})
expected = xr.Dataset({'foo': ('time', [3., 3., 3.]),
'time': times})
actual = ds.resample(time='D').apply(func, args=(1.,), arg3=1.)
assert_identical(expected, actual)

def test_to_array(self):
ds = Dataset(OrderedDict([('a', 1), ('b', ('x', [1, 2, 3]))]),
coords={'c': 42}, attrs={'Conventions': 'None'})
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,26 @@ def test_groupby_input_mutation():
assert_identical(array, array_copy) # should not modify inputs


def test_da_groupby_apply_func_args():

def func(arg1, arg2, arg3=0):
return arg1 + arg2 + arg3

array = xr.DataArray([1, 1, 1], [('x', [1, 2, 3])])
expected = xr.DataArray([3, 3, 3], [('x', [1, 2, 3])])
actual = array.groupby('x').apply(func, args=(1,), arg3=1)
assert_identical(expected, actual)


def test_ds_groupby_apply_func_args():

def func(arg1, arg2, arg3=0):
return arg1 + arg2 + arg3

dataset = xr.Dataset({'foo': ('x', [1, 1, 1])}, {'x': [1, 2, 3]})
expected = xr.Dataset({'foo': ('x', [3, 3, 3])}, {'x': [1, 2, 3]})
actual = dataset.groupby('x').apply(func, args=(1,), arg3=1)
assert_identical(expected, actual)


# TODO: move other groupby tests from test_dataset and test_dataarray over here