-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nd-rolling #4219
Merged
Merged
nd-rolling #4219
Changes from 5 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
1dd6b35
nd-rolling
fujiisoup 6285476
remove unnecessary print
fujiisoup dc113dc
black
fujiisoup 9a45fc7
finding a bug...
fujiisoup f74b2e8
make tests for ndrolling pass
fujiisoup d4990d7
make center and window_dim a dict
fujiisoup 531b0ec
A cleanup.
fujiisoup b6cf250
Revert test_units
fujiisoup 425ed45
make test pass
fujiisoup 54d84e5
More tests.
fujiisoup f7c4911
more docs
fujiisoup 2d82e68
Merge branch 'master' into nd_rolling
fujiisoup 2e8a76f
mypy
fujiisoup 31243b1
improve whatsnew
fujiisoup 22ba8b9
Improve doc
fujiisoup 620aca3
Merge branch 'master' into nd_rolling
fujiisoup 7404266
Merge branch 'master' into nd_rolling
fujiisoup 4b4e64a
Support nd-rolling in dask correctly
fujiisoup 43acf73
Merge branch 'master' into nd_rolling
fujiisoup 398638c
Cleanup according to max's comment
fujiisoup c4fd353
flake8
fujiisoup 28816f6
black
fujiisoup 404e78f
stop using either_dict_or_kwargs
fujiisoup 4cac857
Better tests.
fujiisoup 4bb7804
typo
fujiisoup 9703e11
mypy
fujiisoup f44dd5d
typo2
fujiisoup File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,14 +135,22 @@ def __setitem__(self, key, value): | |
def rolling_window(a, axis, window, center, fill_value): | ||
""" rolling window with padding. """ | ||
pads = [(0, 0) for s in a.shape] | ||
if center: | ||
start = int(window / 2) # 10 -> 5, 9 -> 4 | ||
end = window - 1 - start | ||
pads[axis] = (start, end) | ||
else: | ||
pads[axis] = (window - 1, 0) | ||
if not hasattr(axis, "__len__"): | ||
axis = [axis] | ||
window = [window] | ||
center = [center] | ||
|
||
for ax, win, cent in zip(axis, window, center): | ||
if cent: | ||
start = int(win / 2) # 10 -> 5, 9 -> 4 | ||
end = win - 1 - start | ||
pads[ax] = (start, end) | ||
else: | ||
pads[ax] = (win - 1, 0) | ||
a = np.pad(a, pads, mode="constant", constant_values=fill_value) | ||
return _rolling_window(a, window, axis) | ||
for ax, win in zip(axis, window): | ||
a = _rolling_window(a, win, ax) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very clever! I spent a while trying to figure out how it works... |
||
return a | ||
|
||
|
||
def _rolling_window(a, window, axis=-1): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,21 +75,23 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None | |
------- | ||
rolling : type of input argument | ||
""" | ||
if len(windows) != 1: | ||
raise ValueError("exactly one dim/window should be provided") | ||
dim = list(windows.keys()) | ||
window = list(windows.values()) | ||
|
||
dim, window = next(iter(windows.items())) | ||
|
||
if window <= 0: | ||
if any([w <= 0 for w in window]): | ||
raise ValueError("window must be > 0") | ||
|
||
if center is None or isinstance(center, bool): | ||
center = [center] * len(dim) | ||
|
||
self.obj = obj | ||
|
||
# attributes | ||
self.window = window | ||
if min_periods is not None and min_periods <= 0: | ||
raise ValueError("min_periods must be greater than zero or None") | ||
self.min_periods = min_periods | ||
|
||
self.min_periods = np.prod(window) if min_periods is None else min_periods | ||
|
||
self.center = center | ||
self.dim = dim | ||
|
@@ -98,17 +100,15 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None | |
keep_attrs = _get_keep_attrs(default=False) | ||
self.keep_attrs = keep_attrs | ||
|
||
@property | ||
def _min_periods(self): | ||
return self.min_periods if self.min_periods is not None else self.window | ||
|
||
def __repr__(self): | ||
"""provide a nice str repr of our rolling object""" | ||
|
||
attrs = [ | ||
"{k}->{v}".format(k=k, v=getattr(self, k)) | ||
for k in self._attributes | ||
if getattr(self, k, None) is not None | ||
for k in list(self.dim) | ||
+ list(self.window) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this a list of |
||
+ list(self.center) | ||
+ [self.min_periods] | ||
] | ||
return "{klass} [{attrs}]".format( | ||
klass=self.__class__.__name__, attrs=",".join(attrs) | ||
|
@@ -143,7 +143,7 @@ def method(self, **kwargs): | |
|
||
def count(self): | ||
rolling_count = self._counts() | ||
enough_periods = rolling_count >= self._min_periods | ||
enough_periods = rolling_count >= self.min_periods | ||
return rolling_count.where(enough_periods) | ||
|
||
count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") | ||
|
@@ -196,17 +196,20 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None | |
obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs | ||
) | ||
|
||
self.window_labels = self.obj[self.dim] | ||
# TODO legacy attribute | ||
self.window_labels = self.obj[self.dim[0]] | ||
|
||
def __iter__(self): | ||
if len(self.dim) > 1: | ||
raise ValueError("__iter__ is only supported for 1d-rolling") | ||
stops = np.arange(1, len(self.window_labels) + 1) | ||
starts = stops - int(self.window) | ||
starts[: int(self.window)] = 0 | ||
starts = stops - int(self.window[0]) | ||
starts[: int(self.window[0])] = 0 | ||
for (label, start, stop) in zip(self.window_labels, starts, stops): | ||
window = self.obj.isel(**{self.dim: slice(start, stop)}) | ||
window = self.obj.isel(**{self.dim[0]: slice(start, stop)}) | ||
|
||
counts = window.count(dim=self.dim) | ||
window = window.where(counts >= self._min_periods) | ||
counts = window.count(dim=self.dim[0]) | ||
window = window.where(counts >= self.min_periods) | ||
|
||
yield (label, window) | ||
|
||
|
@@ -251,13 +254,19 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): | |
|
||
from .dataarray import DataArray | ||
|
||
if len(self.dim) == 1 and not isinstance(window_dim, list): | ||
window_dim = [window_dim] | ||
if isinstance(stride, int): | ||
stride = [stride] * len(self.dim) | ||
window = self.obj.variable.rolling_window( | ||
self.dim, self.window, window_dim, self.center, fill_value=fill_value | ||
) | ||
result = DataArray( | ||
window, dims=self.obj.dims + (window_dim,), coords=self.obj.coords | ||
window, dims=self.obj.dims + tuple(window_dim), coords=self.obj.coords | ||
) | ||
return result.isel( | ||
**{d: slice(None, None, s) for d, s in zip(self.dim, stride)} | ||
) | ||
return result.isel(**{self.dim: slice(None, None, stride)}) | ||
|
||
def reduce(self, func, **kwargs): | ||
"""Reduce the items in this group by applying `func` along some | ||
|
@@ -300,25 +309,33 @@ def reduce(self, func, **kwargs): | |
[ 4., 9., 15., 18.]]) | ||
|
||
""" | ||
rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") | ||
rolling_dim = [ | ||
utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) | ||
for d in self.dim | ||
] | ||
windows = self.construct(rolling_dim) | ||
result = windows.reduce(func, dim=rolling_dim, **kwargs) | ||
|
||
# Find valid windows based on count. | ||
counts = self._counts() | ||
return result.where(counts >= self._min_periods) | ||
return result.where(counts >= self.min_periods) | ||
|
||
def _counts(self): | ||
""" Number of non-nan entries in each rolling window. """ | ||
|
||
rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") | ||
rolling_dim = [ | ||
utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) | ||
for d in self.dim | ||
] | ||
# We use False as the fill_value instead of np.nan, since boolean | ||
# array is faster to be reduced than object array. | ||
# The use of skipna==False is also faster since it does not need to | ||
# copy the strided array. | ||
counts = ( | ||
self.obj.notnull() | ||
.rolling(center=self.center, **{self.dim: self.window}) | ||
.rolling( | ||
center=self.center, **{d: w for d, w in zip(self.dim, self.window)} | ||
) | ||
.construct(rolling_dim, fill_value=False) | ||
.sum(dim=rolling_dim, skipna=False) | ||
) | ||
|
@@ -329,39 +346,40 @@ def _bottleneck_reduce(self, func, **kwargs): | |
|
||
# bottleneck doesn't allow min_count to be 0, although it should | ||
# work the same as if min_count = 1 | ||
# Note bottleneck only works with 1d-rolling. | ||
if self.min_periods is not None and self.min_periods == 0: | ||
min_count = 1 | ||
else: | ||
min_count = self.min_periods | ||
|
||
axis = self.obj.get_axis_num(self.dim) | ||
axis = self.obj.get_axis_num(self.dim[0]) | ||
|
||
padded = self.obj.variable | ||
if self.center: | ||
if self.center[0]: | ||
if isinstance(padded.data, dask_array_type): | ||
# Workaround to make the padded chunk size is larger than | ||
# self.window-1 | ||
shift = -(self.window + 1) // 2 | ||
offset = (self.window - 1) // 2 | ||
shift = -(self.window[0] + 1) // 2 | ||
offset = (self.window[0] - 1) // 2 | ||
valid = (slice(None),) * axis + ( | ||
slice(offset, offset + self.obj.shape[axis]), | ||
) | ||
else: | ||
shift = (-self.window // 2) + 1 | ||
shift = (-self.window[0] // 2) + 1 | ||
valid = (slice(None),) * axis + (slice(-shift, None),) | ||
padded = padded.pad({self.dim: (0, -shift)}, mode="constant") | ||
padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant") | ||
|
||
if isinstance(padded.data, dask_array_type): | ||
raise AssertionError("should not be reachable") | ||
values = dask_rolling_wrapper( | ||
func, padded.data, window=self.window, min_count=min_count, axis=axis | ||
func, padded.data, window=self.window[0], min_count=min_count, axis=axis | ||
) | ||
else: | ||
values = func( | ||
padded.data, window=self.window, min_count=min_count, axis=axis | ||
padded.data, window=self.window[0], min_count=min_count, axis=axis | ||
) | ||
|
||
if self.center: | ||
if self.center[0]: | ||
values = values[valid] | ||
result = DataArray(values, self.obj.coords) | ||
|
||
|
@@ -378,8 +396,10 @@ def _numpy_or_bottleneck_reduce( | |
) | ||
del kwargs["dim"] | ||
|
||
if bottleneck_move_func is not None and not isinstance( | ||
self.obj.data, dask_array_type | ||
if ( | ||
bottleneck_move_func is not None | ||
and not isinstance(self.obj.data, dask_array_type) | ||
and len(self.dim) == 1 | ||
): | ||
# TODO: renable bottleneck with dask after the issues | ||
# underlying https://github.com/pydata/xarray/issues/2940 are | ||
|
@@ -431,13 +451,19 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None | |
DataArray.groupby | ||
""" | ||
super().__init__(obj, windows, min_periods, center, keep_attrs) | ||
if self.dim not in self.obj.dims: | ||
if any(d not in self.obj.dims for d in self.dim): | ||
raise KeyError(self.dim) | ||
# Keep each Rolling object as a dictionary | ||
self.rollings = {} | ||
for key, da in self.obj.data_vars.items(): | ||
# keeps rollings only for the dataset depending on slf.dim | ||
if self.dim in da.dims: | ||
dims, center = [], [] | ||
for i, d in enumerate(self.dim): | ||
if d in da.dims: | ||
dims.append(d) | ||
center.append(self.center[i]) | ||
|
||
if len(dims) > 0: | ||
self.rollings[key] = DataArrayRolling( | ||
da, windows, min_periods, center, keep_attrs | ||
) | ||
|
@@ -447,7 +473,7 @@ def _dataset_implementation(self, func, **kwargs): | |
|
||
reduced = {} | ||
for key, da in self.obj.data_vars.items(): | ||
if self.dim in da.dims: | ||
if any(d in da.dims for d in self.dim): | ||
reduced[key] = func(self.rollings[key], **kwargs) | ||
else: | ||
reduced[key] = self.obj[key] | ||
|
@@ -512,19 +538,29 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) | |
|
||
from .dataset import Dataset | ||
|
||
if isinstance(stride, int): | ||
stride = [stride] * len(self.dim) | ||
|
||
if keep_attrs is None: | ||
keep_attrs = _get_keep_attrs(default=True) | ||
|
||
dataset = {} | ||
for key, da in self.obj.data_vars.items(): | ||
if self.dim in da.dims: | ||
# keeps rollings only for the dataset depending on slf.dim | ||
dims, center = [], [] | ||
for i, d in enumerate(self.dim): | ||
if d in da.dims: | ||
dims.append(d) | ||
center.append(self.center[i]) | ||
|
||
if len(dims) > 0: | ||
dataset[key] = self.rollings[key].construct( | ||
window_dim, fill_value=fill_value | ||
) | ||
else: | ||
dataset[key] = da | ||
return Dataset(dataset, coords=self.obj.coords).isel( | ||
**{self.dim: slice(None, None, stride)} | ||
**{d: slice(None, None, s) for d, s in zip(self.dim, stride)} | ||
) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope it is true, though I didn't check it.
A suspicious part is
da.concatenate
, which I expect does not copy the original (strided-)array.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that ghosting breaks my expectation. I need to update the algo here.