Skip to content

Commit

Permalink
Consistent DatasetRolling.construct behavior (pydata#7578)
Browse files Browse the repository at this point in the history
* Removed `.isel` for consistent rolling behavior.

`.isel` causes `DatasetRolling.construct` to behavior to be inconsistent with `DataArrayRolling.construct` when `stride` > 1.

* new rolling construct strategy for coords

* add whats-new

* add new tests with different coords

* next try on aligning strided coords

* add peakmem test for rolling.construct

* increase asv benchmark rolling sizes

---------

Co-authored-by: Michael Niklas <[email protected]>
  • Loading branch information
p4perf4ce and headtr1ck authored Sep 20, 2023
1 parent 04550e6 commit 2b784f2
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 14 deletions.
13 changes: 11 additions & 2 deletions asv_bench/benchmarks/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from . import parameterized, randn, requires_dask

nx = 300
nx = 3000
long_nx = 30000
ny = 200
nt = 100
nt = 1000
window = 20

randn_xy = randn((nx, ny), frac_nan=0.1)
Expand Down Expand Up @@ -115,6 +115,11 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck):
roll = self.ds.var3.rolling(t=100)
getattr(roll, func)()

@parameterized(["stride"], ([None, 5, 50]))
def peakmem_1drolling_construct(self, stride):
self.ds.var2.rolling(t=100).construct("w", stride=stride)
self.ds.var3.rolling(t=100).construct("w", stride=stride)


class DatasetRollingMemory(RollingMemory):
@parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False]))
Expand All @@ -128,3 +133,7 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck):
with xr.set_options(use_bottleneck=use_bottleneck):
roll = self.ds.rolling(t=100)
getattr(roll, func)()

@parameterized(["stride"], ([None, 5, 50]))
def peakmem_1drolling_construct(self, stride):
self.ds.rolling(t=100).construct("w", stride=stride)
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,12 @@ Bug fixes
of :py:meth:`DataArray.__setitem__` lose dimension names.
(:issue:`7030`, :pull:`8067`) By `Darsh Ranjan <https://github.com/dranjan>`_.
- Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and
special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar()`
special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar`
(:issue:`7928`, :pull:`8084`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Fix :py:meth:`~core.rolling.DatasetRolling.construct` with stride on Datasets without indexes.
(:issue:`7021`, :pull:`7578`).
By `Amrest Chinkamol <https://github.com/p4perf4ce>`_ and `Michael Niklas <https://github.com/headtr1ck>`_.
- Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments
(:issue:`7552`, :pull:`8174`).
By `Wiktor Kraśnicki <https://github.com/wkrasnicki>`_.
Expand Down
9 changes: 6 additions & 3 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,11 +785,14 @@ def construct(
if not keep_attrs:
dataset[key].attrs = {}

# Need to stride coords as well. TODO: is there a better way?
coords = self.obj.isel(
{d: slice(None, None, s) for d, s in zip(self.dim, strides)}
).coords

attrs = self.obj.attrs if keep_attrs else {}

return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel(
{d: slice(None, None, s) for d, s in zip(self.dim, strides)}
)
return Dataset(dataset, coords=coords, attrs=attrs)


class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):
Expand Down
55 changes: 47 additions & 8 deletions xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:

@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
def test_rolling_construct(self, center, window) -> None:
def test_rolling_construct(self, center: bool, window: int) -> None:
s = pd.Series(np.arange(10))
da = DataArray.from_series(s)

Expand Down Expand Up @@ -610,7 +610,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:

@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
def test_rolling_construct(self, center, window) -> None:
def test_rolling_construct(self, center: bool, window: int) -> None:
df = pd.DataFrame(
{
"x": np.random.randn(20),
Expand All @@ -627,19 +627,58 @@ def test_rolling_construct(self, center, window) -> None:
np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values)
np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"])

# with stride
ds_rolling_mean = ds_rolling.construct("window", stride=2).mean("window")
np.testing.assert_allclose(
df_rolling["x"][::2].values, ds_rolling_mean["x"].values
)
np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean["index"])
# with fill_value
ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean(
"window"
)
assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all()
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0

@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
def test_rolling_construct_stride(self, center: bool, window: int) -> None:
df = pd.DataFrame(
{
"x": np.random.randn(20),
"y": np.random.randn(20),
"time": np.linspace(0, 1, 20),
}
)
ds = Dataset.from_dataframe(df)
df_rolling_mean = df.rolling(window, center=center, min_periods=1).mean()

# With an index (dimension coordinate)
ds_rolling = ds.rolling(index=window, center=center)
ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w")
np.testing.assert_allclose(
df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values
)
np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"])

# Without index (https://github.com/pydata/xarray/issues/7021)
ds2 = ds.drop_vars("index")
ds2_rolling = ds2.rolling(index=window, center=center)
ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w")
np.testing.assert_allclose(
df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values
)

# Mixed coordinates, indexes and 2D coordinates
ds3 = xr.Dataset(
{"x": ("t", range(20)), "x2": ("y", range(5))},
{
"t": range(20),
"y": ("y", range(5)),
"t2": ("t", range(20)),
"y2": ("y", range(5)),
"yt": (["t", "y"], np.ones((20, 5))),
},
)
ds3_rolling = ds3.rolling(t=window, center=center)
ds3_rolling_mean = ds3_rolling.construct("w", stride=2).mean("w")
for coord in ds3.coords:
assert coord in ds3_rolling_mean.coords

@pytest.mark.slow
@pytest.mark.parametrize("ds", (1, 2), indirect=True)
@pytest.mark.parametrize("center", (True, False))
Expand Down

0 comments on commit 2b784f2

Please sign in to comment.