Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nd-rolling #4219

Merged
merged 27 commits into from
Aug 8, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def rolling_window(a, axis, window, center, fill_value):
"""
import dask.array as da

# for nd-rolling.
# TODO It can be more efficient. Currently, the chunks at the boundaries
# will be copied, but it might be OK for many-chunked-arrays.
if hasattr(axis, "__len__"):
for ax, win, cen in zip(axis, window, center):
a = rolling_window(a, ax, win, cen, fill_value)
return a
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope it is true, though I didn't check it.
A suspicious part is da.concatenate, which I expect does not copy the original (strided-)array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that ghosting breaks my expectation. I need to update the algo here.


orig_shape = a.shape
if axis < 0:
axis = a.ndim + axis
Expand Down
22 changes: 15 additions & 7 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,22 @@ def __setitem__(self, key, value):
def rolling_window(a, axis, window, center, fill_value):
""" rolling window with padding. """
pads = [(0, 0) for s in a.shape]
if center:
start = int(window / 2) # 10 -> 5, 9 -> 4
end = window - 1 - start
pads[axis] = (start, end)
else:
pads[axis] = (window - 1, 0)
if not hasattr(axis, "__len__"):
axis = [axis]
window = [window]
center = [center]

for ax, win, cent in zip(axis, window, center):
if cent:
start = int(win / 2) # 10 -> 5, 9 -> 4
end = win - 1 - start
pads[ax] = (start, end)
else:
pads[ax] = (win - 1, 0)
a = np.pad(a, pads, mode="constant", constant_values=fill_value)
return _rolling_window(a, window, axis)
for ax, win in zip(axis, window):
a = _rolling_window(a, win, ax)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very clever! I spent a while trying to figure out how it works...

return a


def _rolling_window(a, window, axis=-1):
Expand Down
118 changes: 77 additions & 41 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,23 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
-------
rolling : type of input argument
"""
if len(windows) != 1:
raise ValueError("exactly one dim/window should be provided")
dim = list(windows.keys())
window = list(windows.values())

dim, window = next(iter(windows.items()))

if window <= 0:
if any([w <= 0 for w in window]):
raise ValueError("window must be > 0")

if center is None or isinstance(center, bool):
center = [center] * len(dim)

self.obj = obj

# attributes
self.window = window
if min_periods is not None and min_periods <= 0:
raise ValueError("min_periods must be greater than zero or None")
self.min_periods = min_periods

self.min_periods = np.prod(window) if min_periods is None else min_periods

self.center = center
self.dim = dim
Expand All @@ -98,17 +100,15 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
keep_attrs = _get_keep_attrs(default=False)
self.keep_attrs = keep_attrs

@property
def _min_periods(self):
return self.min_periods if self.min_periods is not None else self.window

def __repr__(self):
"""provide a nice str repr of our rolling object"""

attrs = [
"{k}->{v}".format(k=k, v=getattr(self, k))
for k in self._attributes
if getattr(self, k, None) is not None
for k in list(self.dim)
+ list(self.window)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a list of ints? Is it going to do getattr(da, 3)?

+ list(self.center)
+ [self.min_periods]
]
return "{klass} [{attrs}]".format(
klass=self.__class__.__name__, attrs=",".join(attrs)
Expand Down Expand Up @@ -143,7 +143,7 @@ def method(self, **kwargs):

def count(self):
rolling_count = self._counts()
enough_periods = rolling_count >= self._min_periods
enough_periods = rolling_count >= self.min_periods
return rolling_count.where(enough_periods)

count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count")
Expand Down Expand Up @@ -196,17 +196,20 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs
)

self.window_labels = self.obj[self.dim]
# TODO legacy attribute
self.window_labels = self.obj[self.dim[0]]

def __iter__(self):
if len(self.dim) > 1:
raise ValueError("__iter__ is only supported for 1d-rolling")
stops = np.arange(1, len(self.window_labels) + 1)
starts = stops - int(self.window)
starts[: int(self.window)] = 0
starts = stops - int(self.window[0])
starts[: int(self.window[0])] = 0
for (label, start, stop) in zip(self.window_labels, starts, stops):
window = self.obj.isel(**{self.dim: slice(start, stop)})
window = self.obj.isel(**{self.dim[0]: slice(start, stop)})

counts = window.count(dim=self.dim)
window = window.where(counts >= self._min_periods)
counts = window.count(dim=self.dim[0])
window = window.where(counts >= self.min_periods)

yield (label, window)

Expand Down Expand Up @@ -251,13 +254,19 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):

from .dataarray import DataArray

if len(self.dim) == 1 and not isinstance(window_dim, list):
window_dim = [window_dim]
if isinstance(stride, int):
stride = [stride] * len(self.dim)
window = self.obj.variable.rolling_window(
self.dim, self.window, window_dim, self.center, fill_value=fill_value
)
result = DataArray(
window, dims=self.obj.dims + (window_dim,), coords=self.obj.coords
window, dims=self.obj.dims + tuple(window_dim), coords=self.obj.coords
)
return result.isel(
**{d: slice(None, None, s) for d, s in zip(self.dim, stride)}
)
return result.isel(**{self.dim: slice(None, None, stride)})

def reduce(self, func, **kwargs):
"""Reduce the items in this group by applying `func` along some
Expand Down Expand Up @@ -300,25 +309,33 @@ def reduce(self, func, **kwargs):
[ 4., 9., 15., 18.]])

"""
rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim")
rolling_dim = [
utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d))
for d in self.dim
]
windows = self.construct(rolling_dim)
result = windows.reduce(func, dim=rolling_dim, **kwargs)

# Find valid windows based on count.
counts = self._counts()
return result.where(counts >= self._min_periods)
return result.where(counts >= self.min_periods)

def _counts(self):
""" Number of non-nan entries in each rolling window. """

rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim")
rolling_dim = [
utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d))
for d in self.dim
]
# We use False as the fill_value instead of np.nan, since boolean
# array is faster to be reduced than object array.
# The use of skipna==False is also faster since it does not need to
# copy the strided array.
counts = (
self.obj.notnull()
.rolling(center=self.center, **{self.dim: self.window})
.rolling(
center=self.center, **{d: w for d, w in zip(self.dim, self.window)}
)
.construct(rolling_dim, fill_value=False)
.sum(dim=rolling_dim, skipna=False)
)
Expand All @@ -329,39 +346,40 @@ def _bottleneck_reduce(self, func, **kwargs):

# bottleneck doesn't allow min_count to be 0, although it should
# work the same as if min_count = 1
# Note bottleneck only works with 1d-rolling.
if self.min_periods is not None and self.min_periods == 0:
min_count = 1
else:
min_count = self.min_periods

axis = self.obj.get_axis_num(self.dim)
axis = self.obj.get_axis_num(self.dim[0])

padded = self.obj.variable
if self.center:
if self.center[0]:
if isinstance(padded.data, dask_array_type):
# Workaround to make the padded chunk size is larger than
# self.window-1
shift = -(self.window + 1) // 2
offset = (self.window - 1) // 2
shift = -(self.window[0] + 1) // 2
offset = (self.window[0] - 1) // 2
valid = (slice(None),) * axis + (
slice(offset, offset + self.obj.shape[axis]),
)
else:
shift = (-self.window // 2) + 1
shift = (-self.window[0] // 2) + 1
valid = (slice(None),) * axis + (slice(-shift, None),)
padded = padded.pad({self.dim: (0, -shift)}, mode="constant")
padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant")

if isinstance(padded.data, dask_array_type):
raise AssertionError("should not be reachable")
values = dask_rolling_wrapper(
func, padded.data, window=self.window, min_count=min_count, axis=axis
func, padded.data, window=self.window[0], min_count=min_count, axis=axis
)
else:
values = func(
padded.data, window=self.window, min_count=min_count, axis=axis
padded.data, window=self.window[0], min_count=min_count, axis=axis
)

if self.center:
if self.center[0]:
values = values[valid]
result = DataArray(values, self.obj.coords)

Expand All @@ -378,8 +396,10 @@ def _numpy_or_bottleneck_reduce(
)
del kwargs["dim"]

if bottleneck_move_func is not None and not isinstance(
self.obj.data, dask_array_type
if (
bottleneck_move_func is not None
and not isinstance(self.obj.data, dask_array_type)
and len(self.dim) == 1
):
# TODO: renable bottleneck with dask after the issues
# underlying https://github.com/pydata/xarray/issues/2940 are
Expand Down Expand Up @@ -431,13 +451,19 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
DataArray.groupby
"""
super().__init__(obj, windows, min_periods, center, keep_attrs)
if self.dim not in self.obj.dims:
if any(d not in self.obj.dims for d in self.dim):
raise KeyError(self.dim)
# Keep each Rolling object as a dictionary
self.rollings = {}
for key, da in self.obj.data_vars.items():
# keeps rollings only for the dataset depending on slf.dim
if self.dim in da.dims:
dims, center = [], []
for i, d in enumerate(self.dim):
if d in da.dims:
dims.append(d)
center.append(self.center[i])

if len(dims) > 0:
self.rollings[key] = DataArrayRolling(
da, windows, min_periods, center, keep_attrs
)
Expand All @@ -447,7 +473,7 @@ def _dataset_implementation(self, func, **kwargs):

reduced = {}
for key, da in self.obj.data_vars.items():
if self.dim in da.dims:
if any(d in da.dims for d in self.dim):
reduced[key] = func(self.rollings[key], **kwargs)
else:
reduced[key] = self.obj[key]
Expand Down Expand Up @@ -512,19 +538,29 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None)

from .dataset import Dataset

if isinstance(stride, int):
stride = [stride] * len(self.dim)

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

dataset = {}
for key, da in self.obj.data_vars.items():
if self.dim in da.dims:
# keeps rollings only for the dataset depending on slf.dim
dims, center = [], []
for i, d in enumerate(self.dim):
if d in da.dims:
dims.append(d)
center.append(self.center[i])

if len(dims) > 0:
dataset[key] = self.rollings[key].construct(
window_dim, fill_value=fill_value
)
else:
dataset[key] = da
return Dataset(dataset, coords=self.obj.coords).isel(
**{self.dim: slice(None, None, stride)}
**{d: slice(None, None, s) for d, s in zip(self.dim, stride)}
)


Expand Down
23 changes: 16 additions & 7 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,11 +1881,14 @@ def rolling_window(
Parameters
----------
dim: str
Dimension over which to compute rolling_window
Dimension over which to compute rolling_window.
For nd-rolling, should be list of dimensions.
window: int
Window size of the rolling
For nd-rolling, should be list of integers.
window_dim: str
New name of the window dimension.
For nd-rolling, should be list of integers.
center: boolean. default False.
If True, pad fill_value for both ends. Otherwise, pad in the head
of the axis.
Expand Down Expand Up @@ -1919,15 +1922,21 @@ def rolling_window(
dtype = self.dtype
array = self.data

new_dims = self.dims + (window_dim,)
if isinstance(dim, list):
assert len(dim) == len(window)
assert len(dim) == len(window_dim)
assert len(dim) == len(center)
else:
dim = [dim]
window = [window]
window_dim = [window_dim]
center = [center]
axis = [self.get_axis_num(d) for d in dim]
new_dims = self.dims + tuple(window_dim)
return Variable(
new_dims,
duck_array_ops.rolling_window(
array,
axis=self.get_axis_num(dim),
window=window,
center=center,
fill_value=fill_value,
array, axis=axis, window=window, center=center, fill_value=fill_value
),
)

Expand Down
17 changes: 15 additions & 2 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6193,8 +6193,6 @@ def test_rolling_properties(da):
assert rolling_obj.obj.get_axis_num("time") == 1

# catching invalid args
with pytest.raises(ValueError, match="exactly one dim/window should"):
da.rolling(time=7, x=2)
with pytest.raises(ValueError, match="window must be > 0"):
da.rolling(time=-2)
with pytest.raises(ValueError, match="min_periods must be greater than zero"):
Expand Down Expand Up @@ -6399,6 +6397,21 @@ def test_rolling_count_correct():
assert_equal(result, expected)


@pytest.mark.parametrize("da", (1,), indirect=True)
@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("min_periods", (None, 1))
def test_ndrolling_reduce(da, center, min_periods):
rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods)

actual = rolling_obj.sum()
expected = (
da.rolling(time=3, center=center, min_periods=min_periods).sum()
.rolling(x=2, center=center, min_periods=min_periods).sum())

assert_allclose(actual, expected)
assert actual.dims == expected.dims


def test_raise_no_warning_for_nan_in_binary_ops():
with pytest.warns(None) as record:
xr.DataArray([1, 2, np.NaN]) > 0
Expand Down
Loading