Skip to content

Commit

Permalink
Enable .rolling_exp to work on dask arrays (#8284)
Browse files Browse the repository at this point in the history
  • Loading branch information
max-sixty authored Oct 10, 2023
1 parent 46643bb commit 75af56c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ New Features
- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for
the ``variables`` parameter, passing the object as the only argument.
By `Maximilian Roos <https://github.com/max-sixty>`_.
- ``.rolling_exp`` functions can now operate on dask-backed arrays, assuming the
core dim has exactly one chunk. (:pull:`8284`).
By `Maximilian Roos <https://github.com/max-sixty>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
input_core_dims=[[self.dim]],
kwargs=dict(alpha=self.alpha, axis=-1),
output_core_dims=[[self.dim]],
exclude_dims={self.dim},
keep_attrs=keep_attrs,
on_missing_core_dim="copy",
dask="parallelized",
).transpose(*dim_order)

def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
Expand Down Expand Up @@ -183,7 +183,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
input_core_dims=[[self.dim]],
kwargs=dict(alpha=self.alpha, axis=-1),
output_core_dims=[[self.dim]],
exclude_dims={self.dim},
keep_attrs=keep_attrs,
on_missing_core_dim="copy",
dask="parallelized",
).transpose(*dim_order)
4 changes: 3 additions & 1 deletion xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,9 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:

@requires_numbagg
class TestDatasetRollingExp:
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
@pytest.mark.parametrize(
"backend", ["numpy", pytest.param("dask", marks=requires_dask)], indirect=True
)
def test_rolling_exp(self, ds) -> None:
result = ds.rolling_exp(time=10, window_type="span").mean()
assert isinstance(result, Dataset)
Expand Down

0 comments on commit 75af56c

Please sign in to comment.