Skip to content

Commit

Permalink
Offer a fixture for unifying DataArray & Dataset tests
Browse files Browse the repository at this point in the history
(stacked on #8512, worth reviewing after that's merged)

Some tests are literally copy & pasted between DataArray & Dataset tests. This change allows them to use a single test. Not everything will work — sometimes we want to check specifics — but sometimes they will...
  • Loading branch information
max-sixty committed Dec 8, 2023
1 parent 500d11f commit 4c8abcb
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 42 deletions.
43 changes: 43 additions & 0 deletions xarray/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -77,3 +79,44 @@ def da(request, backend):
return da
else:
raise ValueError


@pytest.fixture(params=[Dataset, DataArray])
def type(request):
return request.param


@pytest.fixture(params=[1])
def d(request, backend, type) -> DataArray | Dataset:
"""
For tests which can test either a DataArray or a Dataset.
"""
result: DataArray | Dataset
if request.param == 1:
ds = Dataset(
dict(
a=(["x", "y"], np.arange(16).reshape(8, 2)),
b=(["y", "z"], np.arange(12, 32).reshape(2, 10).astype(np.float64)),
),
dict(
x=("x", np.linspace(0, 1.0, 8)),
y=range(2),
z=("z", np.linspace(0, 1.0, 10)),
w=("y", ["a", "b"]),
),
)
if type == DataArray:
result = ds["a"].assign_coords(w=ds.coords["w"])
elif type == Dataset:
result = ds
else:
raise ValueError
else:
raise ValueError

if backend == "dask":
return result.chunk()
elif backend == "numpy":
return result
else:
raise ValueError
67 changes: 25 additions & 42 deletions xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,31 @@ def compute_backend(request):
yield request.param


@pytest.mark.parametrize("func", ["mean", "sum"])
@pytest.mark.parametrize("min_periods", [1, 20])
def test_cumulative(d, func, min_periods) -> None:
# One dim
result = getattr(d.cumulative("x", min_periods=min_periods), func)()
expected = getattr(d.rolling(x=d["x"].size, min_periods=min_periods), func)()
assert_identical(result, expected)

# Multiple dim
result = getattr(d.cumulative(["x", "y"], min_periods=min_periods), func)()
expected = getattr(
d.rolling(x=d["x"].size, y=d["y"].size, min_periods=min_periods),
func,
)()
assert_identical(result, expected)


def test_cumulative_vs_cum(d) -> None:
result = d.cumulative("x").sum()
expected = d.cumsum("x")
# cumsum drops the coord of the dimension; cumulative doesn't
expected = expected.assign_coords(x=result["x"])
assert_identical(result, expected)


class TestDataArrayRolling:
@pytest.mark.parametrize("da", (1, 2), indirect=True)
@pytest.mark.parametrize("center", [True, False])
Expand Down Expand Up @@ -485,29 +510,6 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None:
):
da.rolling_exp(time=10, keep_attrs=True)

@pytest.mark.parametrize("func", ["mean", "sum"])
@pytest.mark.parametrize("min_periods", [1, 20])
def test_cumulative(self, da, func, min_periods) -> None:
# One dim
result = getattr(da.cumulative("time", min_periods=min_periods), func)()
expected = getattr(
da.rolling(time=da.time.size, min_periods=min_periods), func
)()
assert_identical(result, expected)

# Multiple dim
result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)()
expected = getattr(
da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods),
func,
)()
assert_identical(result, expected)

def test_cumulative_vs_cum(self, da) -> None:
result = da.cumulative("time").sum()
expected = da.cumsum("time")
assert_identical(result, expected)


class TestDatasetRolling:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -832,25 +834,6 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)()
assert_allclose(actual, expected)

@pytest.mark.parametrize("func", ["mean", "sum"])
@pytest.mark.parametrize("ds", (2,), indirect=True)
@pytest.mark.parametrize("min_periods", [1, 10])
def test_cumulative(self, ds, func, min_periods) -> None:
# One dim
result = getattr(ds.cumulative("time", min_periods=min_periods), func)()
expected = getattr(
ds.rolling(time=ds.time.size, min_periods=min_periods), func
)()
assert_identical(result, expected)

# Multiple dim
result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)()
expected = getattr(
ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods),
func,
)()
assert_identical(result, expected)


@requires_numbagg
class TestDatasetRollingExp:
Expand Down

0 comments on commit 4c8abcb

Please sign in to comment.