diff --git a/doc/computation.rst b/doc/computation.rst index 3660aed93ed..474c3905981 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -188,9 +188,16 @@ a value when aggregating: r = arr.rolling(y=3, center=True, min_periods=2) r.mean() +From version 0.17, xarray supports multidimensional rolling, + +.. ipython:: python + + r = arr.rolling(x=2, y=3, min_periods=2) + r.mean() + .. tip:: - Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects. + Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects with 1d-rolling. .. _bottleneck: https://github.com/pydata/bottleneck/ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c4cf931be61..9d4261da2dd 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,9 @@ Breaking changes New Features ~~~~~~~~~~~~ +- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling` + now accept more than 1 dimension.(:pull:`4219`) + By `Keisuke Fujii `_. - Build :py:meth:`CFTimeIndex.__repr__` explicitly as :py:class:`pandas.Index`. Add ``calendar`` as a new property for :py:class:`CFTimeIndex` and show ``calendar`` and ``length`` in :py:meth:`CFTimeIndex.__repr__` (:issue:`2416`, :pull:`4092`) diff --git a/xarray/core/common.py b/xarray/core/common.py index c95df77313e..bc5035b682e 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -786,7 +786,7 @@ def rolling( self, dim: Mapping[Hashable, int] = None, min_periods: int = None, - center: bool = False, + center: Union[bool, Mapping[Hashable, bool]] = False, keep_attrs: bool = None, **window_kwargs: int, ): @@ -802,7 +802,7 @@ def rolling( Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. - center : boolean, default False + center : boolean, or a mapping, default False Set the labels at the center of the window. keep_attrs : bool, optional If True, the object's attributes (`attrs`) will be copied from diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 87f646352eb..74474f4321e 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -32,69 +32,80 @@ def rolling_window(a, axis, window, center, fill_value): """ import dask.array as da + if not hasattr(axis, "__len__"): + axis = [axis] + window = [window] + center = [center] + orig_shape = a.shape - if axis < 0: - axis = a.ndim + axis depth = {d: 0 for d in range(a.ndim)} - depth[axis] = int(window / 2) - # For evenly sized window, we need to crop the first point of each block. - offset = 1 if window % 2 == 0 else 0 - - if depth[axis] > min(a.chunks[axis]): - raise ValueError( - "For window size %d, every chunk should be larger than %d, " - "but the smallest chunk size is %d. Rechunk your array\n" - "with a larger chunk size or a chunk size that\n" - "more evenly divides the shape of your array." - % (window, depth[axis], min(a.chunks[axis])) - ) - - # Although da.overlap pads values to boundaries of the array, - # the size of the generated array is smaller than what we want - # if center == False. - if center: - start = int(window / 2) # 10 -> 5, 9 -> 4 - end = window - 1 - start - else: - start, end = window - 1, 0 - pad_size = max(start, end) + offset - depth[axis] - drop_size = 0 - # pad_size becomes more than 0 when the overlapped array is smaller than - # needed. In this case, we need to enlarge the original array by padding - # before overlapping. - if pad_size > 0: - if pad_size < depth[axis]: - # overlapping requires each chunk larger than depth. If pad_size is - # smaller than the depth, we enlarge this and truncate it later. - drop_size = depth[axis] - pad_size - pad_size = depth[axis] - shape = list(a.shape) - shape[axis] = pad_size - chunks = list(a.chunks) - chunks[axis] = (pad_size,) - fill_array = da.full(shape, fill_value, dtype=a.dtype, chunks=chunks) - a = da.concatenate([fill_array, a], axis=axis) - + offset = [0] * a.ndim + drop_size = [0] * a.ndim + pad_size = [0] * a.ndim + for ax, win, cent in zip(axis, window, center): + if ax < 0: + ax = a.ndim + ax + depth[ax] = int(win / 2) + # For evenly sized window, we need to crop the first point of each block. + offset[ax] = 1 if win % 2 == 0 else 0 + + if depth[ax] > min(a.chunks[ax]): + raise ValueError( + "For window size %d, every chunk should be larger than %d, " + "but the smallest chunk size is %d. Rechunk your array\n" + "with a larger chunk size or a chunk size that\n" + "more evenly divides the shape of your array." + % (win, depth[ax], min(a.chunks[ax])) + ) + + # Although da.overlap pads values to boundaries of the array, + # the size of the generated array is smaller than what we want + # if center == False. + if cent: + start = int(win / 2) # 10 -> 5, 9 -> 4 + end = win - 1 - start + else: + start, end = win - 1, 0 + pad_size[ax] = max(start, end) + offset[ax] - depth[ax] + drop_size[ax] = 0 + # pad_size becomes more than 0 when the overlapped array is smaller than + # needed. In this case, we need to enlarge the original array by padding + # before overlapping. + if pad_size[ax] > 0: + if pad_size[ax] < depth[ax]: + # overlapping requires each chunk larger than depth. If pad_size is + # smaller than the depth, we enlarge this and truncate it later. + drop_size[ax] = depth[ax] - pad_size[ax] + pad_size[ax] = depth[ax] + + # TODO maybe following two lines can be summarized. + a = da.pad( + a, [(p, 0) for p in pad_size], mode="constant", constant_values=fill_value + ) boundary = {d: fill_value for d in range(a.ndim)} # create overlap arrays ag = da.overlap.overlap(a, depth=depth, boundary=boundary) - # apply rolling func - def func(x, window, axis=-1): + def func(x, window, axis): x = np.asarray(x) - rolling = nputils._rolling_window(x, window, axis) - return rolling[(slice(None),) * axis + (slice(offset, None),)] - - chunks = list(a.chunks) - chunks.append(window) + index = [slice(None)] * x.ndim + for ax, win in zip(axis, window): + x = nputils._rolling_window(x, win, ax) + index[ax] = slice(offset[ax], None) + return x[tuple(index)] + + chunks = list(a.chunks) + window + new_axis = [a.ndim + i for i in range(len(axis))] out = ag.map_blocks( - func, dtype=a.dtype, new_axis=a.ndim, chunks=chunks, window=window, axis=axis + func, dtype=a.dtype, new_axis=new_axis, chunks=chunks, window=window, axis=axis ) # crop boundary. - index = (slice(None),) * axis + (slice(drop_size, drop_size + orig_shape[axis]),) - return out[index] + index = [slice(None)] * a.ndim + for ax in axis: + index[ax] = slice(drop_size[ax], drop_size[ax] + orig_shape[ax]) + return out[tuple(index)] def least_squares(lhs, rhs, rcond=None, skipna=False): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9fbaf7479db..68f76c7af2c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5968,7 +5968,7 @@ def polyfit( skipna_da = np.any(da.isnull()) dims_to_stack = [dimname for dimname in da.dims if dimname != dim] - stacked_coords = {} + stacked_coords: Dict[Hashable, DataArray] = {} if dims_to_stack: stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked") rhs = da.transpose(dim, *dims_to_stack).stack( diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index fa6df63e0ea..4f592eb3c5c 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -135,14 +135,22 @@ def __setitem__(self, key, value): def rolling_window(a, axis, window, center, fill_value): """ rolling window with padding. """ pads = [(0, 0) for s in a.shape] - if center: - start = int(window / 2) # 10 -> 5, 9 -> 4 - end = window - 1 - start - pads[axis] = (start, end) - else: - pads[axis] = (window - 1, 0) + if not hasattr(axis, "__len__"): + axis = [axis] + window = [window] + center = [center] + + for ax, win, cent in zip(axis, window, center): + if cent: + start = int(win / 2) # 10 -> 5, 9 -> 4 + end = win - 1 - start + pads[ax] = (start, end) + else: + pads[ax] = (win - 1, 0) a = np.pad(a, pads, mode="constant", constant_values=fill_value) - return _rolling_window(a, window, axis) + for ax, win in zip(axis, window): + a = _rolling_window(a, win, ax) + return a def _rolling_window(a, window, axis=-1): diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index ecba5307680..5f996565243 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -75,40 +75,32 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None ------- rolling : type of input argument """ - if len(windows) != 1: - raise ValueError("exactly one dim/window should be provided") - - dim, window = next(iter(windows.items())) - - if window <= 0: - raise ValueError("window must be > 0") - + self.dim, self.window = [], [] + for d, w in windows.items(): + self.dim.append(d) + if w <= 0: + raise ValueError("window must be > 0") + self.window.append(w) + + self.center = self._mapping_to_list(center, default=False) self.obj = obj # attributes - self.window = window if min_periods is not None and min_periods <= 0: raise ValueError("min_periods must be greater than zero or None") - self.min_periods = min_periods - self.center = center - self.dim = dim + self.min_periods = np.prod(self.window) if min_periods is None else min_periods if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) self.keep_attrs = keep_attrs - @property - def _min_periods(self): - return self.min_periods if self.min_periods is not None else self.window - def __repr__(self): """provide a nice str repr of our rolling object""" attrs = [ "{k}->{v}".format(k=k, v=getattr(self, k)) - for k in self._attributes - if getattr(self, k, None) is not None + for k in list(self.dim) + self.window + self.center + [self.min_periods] ] return "{klass} [{attrs}]".format( klass=self.__class__.__name__, attrs=",".join(attrs) @@ -143,11 +135,29 @@ def method(self, **kwargs): def count(self): rolling_count = self._counts() - enough_periods = rolling_count >= self._min_periods + enough_periods = rolling_count >= self.min_periods return rolling_count.where(enough_periods) count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") + def _mapping_to_list( + self, arg, default=None, allow_default=True, allow_allsame=True + ): + if utils.is_dict_like(arg): + if allow_default: + return [arg.get(d, default) for d in self.dim] + else: + for d in self.dim: + if d not in arg: + raise KeyError("argument has no key {}.".format(d)) + return [arg[d] for d in self.dim] + elif allow_allsame: # for single argument + return [arg] * len(self.dim) + elif len(self.dim) == 1: + return [arg] + else: + raise ValueError("Mapping argument is necessary.") + class DataArrayRolling(Rolling): __slots__ = ("window_labels",) @@ -196,33 +206,41 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs ) - self.window_labels = self.obj[self.dim] + # TODO legacy attribute + self.window_labels = self.obj[self.dim[0]] def __iter__(self): + if len(self.dim) > 1: + raise ValueError("__iter__ is only supported for 1d-rolling") stops = np.arange(1, len(self.window_labels) + 1) - starts = stops - int(self.window) - starts[: int(self.window)] = 0 + starts = stops - int(self.window[0]) + starts[: int(self.window[0])] = 0 for (label, start, stop) in zip(self.window_labels, starts, stops): - window = self.obj.isel(**{self.dim: slice(start, stop)}) + window = self.obj.isel(**{self.dim[0]: slice(start, stop)}) - counts = window.count(dim=self.dim) - window = window.where(counts >= self._min_periods) + counts = window.count(dim=self.dim[0]) + window = window.where(counts >= self.min_periods) yield (label, window) - def construct(self, window_dim, stride=1, fill_value=dtypes.NA): + def construct( + self, window_dim=None, stride=1, fill_value=dtypes.NA, **window_dim_kwargs + ): """ Convert this rolling object to xr.DataArray, where the window dimension is stacked as a new dimension Parameters ---------- - window_dim: str - New name of the window dimension. - stride: integer, optional + window_dim: str or a mapping, optional + A mapping from dimension name to the new window dimension names. + Just a string can be used for 1d-rolling. + stride: integer or a mapping, optional Size of stride for the rolling window. fill_value: optional. Default dtypes.NA Filling value to match the dimension size. + **window_dim_kwargs : {dim: new_name, ...}, optional + The keyword arguments form of ``window_dim``. Returns ------- @@ -251,13 +269,27 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): from .dataarray import DataArray + if window_dim is None: + if len(window_dim_kwargs) == 0: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + window_dim = {d: window_dim_kwargs[d] for d in self.dim} + + window_dim = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False + ) + stride = self._mapping_to_list(stride, default=1) + window = self.obj.variable.rolling_window( self.dim, self.window, window_dim, self.center, fill_value=fill_value ) result = DataArray( - window, dims=self.obj.dims + (window_dim,), coords=self.obj.coords + window, dims=self.obj.dims + tuple(window_dim), coords=self.obj.coords + ) + return result.isel( + **{d: slice(None, None, s) for d, s in zip(self.dim, stride)} ) - return result.isel(**{self.dim: slice(None, None, stride)}) def reduce(self, func, **kwargs): """Reduce the items in this group by applying `func` along some @@ -300,27 +332,36 @@ def reduce(self, func, **kwargs): [ 4., 9., 15., 18.]]) """ - rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") + rolling_dim = { + d: utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) + for d in self.dim + } windows = self.construct(rolling_dim) - result = windows.reduce(func, dim=rolling_dim, **kwargs) + result = windows.reduce(func, dim=list(rolling_dim.values()), **kwargs) # Find valid windows based on count. counts = self._counts() - return result.where(counts >= self._min_periods) + return result.where(counts >= self.min_periods) def _counts(self): """ Number of non-nan entries in each rolling window. """ - rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") + rolling_dim = { + d: utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) + for d in self.dim + } # We use False as the fill_value instead of np.nan, since boolean # array is faster to be reduced than object array. # The use of skipna==False is also faster since it does not need to # copy the strided array. counts = ( self.obj.notnull() - .rolling(center=self.center, **{self.dim: self.window}) + .rolling( + center={d: self.center[i] for i, d in enumerate(self.dim)}, + **{d: w for d, w in zip(self.dim, self.window)}, + ) .construct(rolling_dim, fill_value=False) - .sum(dim=rolling_dim, skipna=False) + .sum(dim=list(rolling_dim.values()), skipna=False) ) return counts @@ -329,39 +370,40 @@ def _bottleneck_reduce(self, func, **kwargs): # bottleneck doesn't allow min_count to be 0, although it should # work the same as if min_count = 1 + # Note bottleneck only works with 1d-rolling. if self.min_periods is not None and self.min_periods == 0: min_count = 1 else: min_count = self.min_periods - axis = self.obj.get_axis_num(self.dim) + axis = self.obj.get_axis_num(self.dim[0]) padded = self.obj.variable - if self.center: + if self.center[0]: if isinstance(padded.data, dask_array_type): # Workaround to make the padded chunk size is larger than # self.window-1 - shift = -(self.window + 1) // 2 - offset = (self.window - 1) // 2 + shift = -(self.window[0] + 1) // 2 + offset = (self.window[0] - 1) // 2 valid = (slice(None),) * axis + ( slice(offset, offset + self.obj.shape[axis]), ) else: - shift = (-self.window // 2) + 1 + shift = (-self.window[0] // 2) + 1 valid = (slice(None),) * axis + (slice(-shift, None),) - padded = padded.pad({self.dim: (0, -shift)}, mode="constant") + padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant") if isinstance(padded.data, dask_array_type): raise AssertionError("should not be reachable") values = dask_rolling_wrapper( - func, padded.data, window=self.window, min_count=min_count, axis=axis + func, padded.data, window=self.window[0], min_count=min_count, axis=axis ) else: values = func( - padded.data, window=self.window, min_count=min_count, axis=axis + padded.data, window=self.window[0], min_count=min_count, axis=axis ) - if self.center: + if self.center[0]: values = values[valid] result = DataArray(values, self.obj.coords) @@ -378,8 +420,10 @@ def _numpy_or_bottleneck_reduce( ) del kwargs["dim"] - if bottleneck_move_func is not None and not isinstance( - self.obj.data, dask_array_type + if ( + bottleneck_move_func is not None + and not isinstance(self.obj.data, dask_array_type) + and len(self.dim) == 1 ): # TODO: renable bottleneck with dask after the issues # underlying https://github.com/pydata/xarray/issues/2940 are @@ -412,7 +456,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. - center : boolean, default False + center : boolean, or a mapping from dimension name to boolean, default False Set the labels at the center of the window. keep_attrs : bool, optional If True, the object's attributes (`attrs`) will be copied from @@ -431,15 +475,22 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None DataArray.groupby """ super().__init__(obj, windows, min_periods, center, keep_attrs) - if self.dim not in self.obj.dims: + if any(d not in self.obj.dims for d in self.dim): raise KeyError(self.dim) # Keep each Rolling object as a dictionary self.rollings = {} for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim - if self.dim in da.dims: + dims, center = [], {} + for i, d in enumerate(self.dim): + if d in da.dims: + dims.append(d) + center[d] = self.center[i] + + if len(dims) > 0: + w = {d: windows[d] for d in dims} self.rollings[key] = DataArrayRolling( - da, windows, min_periods, center, keep_attrs + da, w, min_periods, center, keep_attrs ) def _dataset_implementation(self, func, **kwargs): @@ -447,7 +498,7 @@ def _dataset_implementation(self, func, **kwargs): reduced = {} for key, da in self.obj.data_vars.items(): - if self.dim in da.dims: + if any(d in da.dims for d in self.dim): reduced[key] = func(self.rollings[key], **kwargs) else: reduced[key] = self.obj[key] @@ -491,19 +542,29 @@ def _numpy_or_bottleneck_reduce( **kwargs, ) - def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None): + def construct( + self, + window_dim=None, + stride=1, + fill_value=dtypes.NA, + keep_attrs=None, + **window_dim_kwargs, + ): """ Convert this rolling object to xr.Dataset, where the window dimension is stacked as a new dimension Parameters ---------- - window_dim: str - New name of the window dimension. + window_dim: str or a mapping, optional + A mapping from dimension name to the new window dimension names. + Just a string can be used for 1d-rolling. stride: integer, optional size of stride for the rolling window. fill_value: optional. Default dtypes.NA Filling value to match the dimension size. + **window_dim_kwargs : {dim: new_name, ...}, optional + The keyword arguments form of ``window_dim``. Returns ------- @@ -512,19 +573,35 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) from .dataset import Dataset + if window_dim is None: + if len(window_dim_kwargs) == 0: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + window_dim = {d: window_dim_kwargs[d] for d in self.dim} + + window_dim = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False + ) + stride = self._mapping_to_list(stride, default=1) + if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) dataset = {} for key, da in self.obj.data_vars.items(): - if self.dim in da.dims: + # keeps rollings only for the dataset depending on slf.dim + dims = [d for d in self.dim if d in da.dims] + if len(dims) > 0: + wi = {d: window_dim[i] for i, d in enumerate(self.dim) if d in da.dims} + st = {d: stride[i] for i, d in enumerate(self.dim) if d in da.dims} dataset[key] = self.rollings[key].construct( - window_dim, fill_value=fill_value + window_dim=wi, fill_value=fill_value, stride=st ) else: dataset[key] = da return Dataset(dataset, coords=self.obj.coords).isel( - **{self.dim: slice(None, None, stride)} + **{d: slice(None, None, s) for d, s in zip(self.dim, stride)} ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d13de439a69..1f86a40348c 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1883,11 +1883,14 @@ def rolling_window( Parameters ---------- dim: str - Dimension over which to compute rolling_window + Dimension over which to compute rolling_window. + For nd-rolling, should be list of dimensions. window: int Window size of the rolling + For nd-rolling, should be list of integers. window_dim: str New name of the window dimension. + For nd-rolling, should be list of integers. center: boolean. default False. If True, pad fill_value for both ends. Otherwise, pad in the head of the axis. @@ -1921,15 +1924,21 @@ def rolling_window( dtype = self.dtype array = self.data - new_dims = self.dims + (window_dim,) + if isinstance(dim, list): + assert len(dim) == len(window) + assert len(dim) == len(window_dim) + assert len(dim) == len(center) + else: + dim = [dim] + window = [window] + window_dim = [window_dim] + center = [center] + axis = [self.get_axis_num(d) for d in dim] + new_dims = self.dims + tuple(window_dim) return Variable( new_dims, duck_array_ops.rolling_window( - array, - axis=self.get_axis_num(dim), - window=window, - center=center, - fill_value=fill_value, + array, axis=axis, window=window, center=center, fill_value=fill_value ), ) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e0da3f1527f..7ccf1eb14bc 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6193,8 +6193,6 @@ def test_rolling_properties(da): assert rolling_obj.obj.get_axis_num("time") == 1 # catching invalid args - with pytest.raises(ValueError, match="exactly one dim/window should"): - da.rolling(time=7, x=2) with pytest.raises(ValueError, match="window must be > 0"): da.rolling(time=-2) with pytest.raises(ValueError, match="min_periods must be greater than zero"): @@ -6399,6 +6397,47 @@ def test_rolling_count_correct(): assert_equal(result, expected) +@pytest.mark.parametrize("da", (1,), indirect=True) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1)) +@pytest.mark.parametrize("name", ("sum", "mean", "max")) +def test_ndrolling_reduce(da, center, min_periods, name): + rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods) + + actual = getattr(rolling_obj, name)() + expected = getattr( + getattr( + da.rolling(time=3, center=center, min_periods=min_periods), name + )().rolling(x=2, center=center, min_periods=min_periods), + name, + )() + + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + +@pytest.mark.parametrize("center", (True, False, (True, False))) +@pytest.mark.parametrize("fill_value", (np.nan, 0.0)) +def test_ndrolling_construct(center, fill_value): + da = DataArray( + np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float), + dims=["x", "y", "z"], + coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)}, + ) + actual = da.rolling(x=3, z=2, center=center).construct( + x="x1", z="z1", fill_value=fill_value + ) + if not isinstance(center, tuple): + center = (center, center) + expected = ( + da.rolling(x=3, center=center[0]) + .construct(x="x1", fill_value=fill_value) + .rolling(z=2, center=center[1]) + .construct(z="z1", fill_value=fill_value) + ) + assert_allclose(actual, expected) + + def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: xr.DataArray([1, 2, np.NaN]) > 0 diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9037013cc79..da7621dceb8 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5881,8 +5881,6 @@ def test_rolling_keep_attrs(): def test_rolling_properties(ds): # catching invalid args - with pytest.raises(ValueError, match="exactly one dim/window should"): - ds.rolling(time=7, x=2) with pytest.raises(ValueError, match="window must be > 0"): ds.rolling(time=-2) with pytest.raises(ValueError, match="min_periods must be greater than zero"): @@ -6007,6 +6005,66 @@ def test_rolling_reduce(ds, center, min_periods, window, name): assert src_var.dims == actual[key].dims +@pytest.mark.parametrize("ds", (2,), indirect=True) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1)) +@pytest.mark.parametrize("name", ("sum", "max")) +@pytest.mark.parametrize("dask", (True, False)) +def test_ndrolling_reduce(ds, center, min_periods, name, dask): + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + rolling_obj = ds.rolling(time=4, x=3, center=center, min_periods=min_periods) + + actual = getattr(rolling_obj, name)() + expected = getattr( + getattr( + ds.rolling(time=4, center=center, min_periods=min_periods), name + )().rolling(x=3, center=center, min_periods=min_periods), + name, + )() + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + # Do it in the opposite order + expected = getattr( + getattr( + ds.rolling(x=3, center=center, min_periods=min_periods), name + )().rolling(time=4, center=center, min_periods=min_periods), + name, + )() + + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + +@pytest.mark.parametrize("center", (True, False, (True, False))) +@pytest.mark.parametrize("fill_value", (np.nan, 0.0)) +@pytest.mark.parametrize("dask", (True, False)) +def test_ndrolling_construct(center, fill_value, dask): + da = DataArray( + np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float), + dims=["x", "y", "z"], + coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)}, + ) + ds = xr.Dataset({"da": da}) + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + actual = ds.rolling(x=3, z=2, center=center).construct( + x="x1", z="z1", fill_value=fill_value + ) + if not isinstance(center, tuple): + center = (center, center) + expected = ( + ds.rolling(x=3, center=center[0]) + .construct(x="x1", fill_value=fill_value) + .rolling(z=2, center=center[1]) + .construct(z="z1", fill_value=fill_value) + ) + assert_allclose(actual, expected) + + def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: Dataset(data_vars={"x": ("y", [1, 2, np.NaN])}) > 0 diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py index 1002a9dd9e3..ccb825dc7e9 100644 --- a/xarray/tests/test_nputils.py +++ b/xarray/tests/test_nputils.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from numpy.testing import assert_array_equal from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous, rolling_window @@ -47,3 +48,19 @@ def test_rolling(): actual = rolling_window(x, axis=-1, window=3, center=False, fill_value=0.0) expected = np.stack([expected, expected * 1.1], axis=0) assert_array_equal(actual, expected) + + +@pytest.mark.parametrize("center", [[True, True], [False, False]]) +@pytest.mark.parametrize("axis", [(0, 1), (1, 2), (2, 0)]) +def test_nd_rolling(center, axis): + x = np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float) + window = [3, 3] + actual = rolling_window( + x, axis=axis, window=window, center=center, fill_value=np.nan + ) + expected = x + for ax, win, cent in zip(axis, window, center): + expected = rolling_window( + expected, axis=ax, window=win, center=cent, fill_value=np.nan + ) + assert_array_equal(actual, expected) diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index 39ad250246b..adc29a3cc92 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -37,7 +37,7 @@ def test_allclose_regression(): "obj1,obj2", ( pytest.param( - xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable", + xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable" ), pytest.param( xr.DataArray([1e-17, 2], dims="x"),