From 75af56c33a29529269a73bdd00df2d3af17ee0f5 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 9 Oct 2023 23:37:19 -0700 Subject: [PATCH] Enable `.rolling_exp` to work on dask arrays (#8284) --- doc/whats-new.rst | 3 +++ xarray/core/rolling_exp.py | 4 ++-- xarray/tests/test_rolling.py | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 92acc3f90c0..8f576f486dc 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,9 @@ New Features - :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for the ``variables`` parameter, passing the object as the only argument. By `Maximilian Roos `_. +- ``.rolling_exp`` functions can now operate on dask-backed arrays, assuming the + core dim has exactly one chunk. (:pull:`8284`). + By `Maximilian Roos `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index c56bf6a384e..cb77358869c 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -147,9 +147,9 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: input_core_dims=[[self.dim]], kwargs=dict(alpha=self.alpha, axis=-1), output_core_dims=[[self.dim]], - exclude_dims={self.dim}, keep_attrs=keep_attrs, on_missing_core_dim="copy", + dask="parallelized", ).transpose(*dim_order) def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: @@ -183,7 +183,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: input_core_dims=[[self.dim]], kwargs=dict(alpha=self.alpha, axis=-1), output_core_dims=[[self.dim]], - exclude_dims={self.dim}, keep_attrs=keep_attrs, on_missing_core_dim="copy", + dask="parallelized", ).transpose(*dim_order) diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 2dc8ae24438..da834b76124 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -788,7 +788,9 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None: @requires_numbagg class TestDatasetRollingExp: - @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=requires_dask)], indirect=True + ) def test_rolling_exp(self, ds) -> None: result = ds.rolling_exp(time=10, window_type="span").mean() assert isinstance(result, Dataset)