From 6951e6a0d0a2c543d23139bedabe649030b6277b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Nov 2024 13:37:07 -0700 Subject: [PATCH] fix --- xarray/core/groupby.py | 8 +++++--- xarray/tests/test_groupby.py | 9 +++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a9221eeeedd..b4a60d8a778 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -740,9 +740,11 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: shuffled = as_dataset._shuffle( dim=self._group_dim, indices=self.encoded.group_indices, chunks=chunks ) - shuffled = self._maybe_unstack(shuffled) - new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled - return new_obj + unstacked: Dataset = self._maybe_unstack(shuffled) + if was_array: + return self._obj._from_temp_dataset(unstacked) + else: + return unstacked # type: ignore[return-value] def map( self, diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 5a218ed46e5..6c44dd49752 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1891,10 +1891,10 @@ def resample_as_pandas(array, *args, **kwargs): rs = array.resample(time="24h", closed="right") actual = rs.mean() - shuffled = rs.distributed_shuffle().resample(time="24h", closed="right").mean() + shuffled = rs.distributed_shuffle().resample(time="24h", closed="right") expected = resample_as_pandas(array, "24h", closed="right") assert_identical(expected, actual) - assert_identical(expected, shuffled) + assert_identical(expected, shuffled.mean()) with pytest.raises(ValueError, match=r"Index must be monotonic"): array[[2, 0, 1]].resample(time=resample_freq) @@ -2883,9 +2883,10 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None: coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]], {"foo": "bar"})}, dims=["x", "y", "z"], ) - gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) + groupers = dict(x=UniqueGrouper(), y=UniqueGrouper()) + gb = b.groupby(groupers) if shuffle: - gb = gb.distributed_shuffle() + gb = gb.distributed_shuffle().groupby(groupers) repr(gb) with xr.set_options(use_flox=use_flox): assert_identical(gb.mean("z"), b.mean("z"))