diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 25bd86177df..1a632fb6fe9 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -835,13 +835,9 @@ def chunked_nanlast(darray, axis): def shuffle_array(array, indices: list[list[int]], axis: int): # TODO: do chunk manager dance here. - if is_duck_dask_array(array): - if not module_available("dask", minversion="2024.08.0"): - raise ValueError( - "This method is very inefficient on dask<2024.08.0. Please upgrade." - ) - # TODO: handle dimensions - return array.shuffle(indexer=indices, axis=axis) + if is_chunked_array(array): + chunkmanager = get_chunked_array_type(array) + return chunkmanager.shuffle(array, indexer=indices, axis=axis) else: indexer = np.concatenate(indices) # TODO: Do the array API thing here. diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9fbf6778aea..43150b055db 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -538,6 +538,8 @@ def shuffle(self) -> None: if all(isinstance(idx, slice) for idx in self._group_indices): return + indices: tuple[list[int]] = self._group_indices # type: ignore[assignment] + was_array = isinstance(self._obj, DataArray) as_dataset = self._obj._to_temp_dataset() if was_array else self._obj @@ -547,7 +549,7 @@ def shuffle(self) -> None: shuffled[name] = var continue shuffled_data = shuffle_array( - var._data, list(self._group_indices), axis=var.get_axis_num(dim) + var._data, list(indices), axis=var.get_axis_num(dim) ) shuffled[name] = var._replace(data=shuffled_data) @@ -555,12 +557,16 @@ def shuffle(self) -> None: slices = [] start = 0 for idxr in self._group_indices: + if TYPE_CHECKING: + assert not isinstance(idxr, slice) slices.append(slice(start, start + len(idxr))) start += len(idxr) # TODO: we have now broken the invariant # self._group_indices ≠ self.groupers[0].group_indices self._group_indices = tuple(slices) if was_array: + if TYPE_CHECKING: + assert isinstance(self._obj, DataArray) self._obj = self._obj._from_temp_dataset(shuffled) else: self._obj = shuffled diff --git a/xarray/core/types.py b/xarray/core/types.py index 591320d26da..96e75e18b51 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -297,7 +297,7 @@ def copy( ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] GroupKey = Any -GroupIndex = Union[int, slice, list[int]] +GroupIndex = Union[slice, list[int]] GroupIndices = tuple[GroupIndex, ...] Bins = Union[ int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 963d12fd865..aa4ced9f37a 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -251,3 +251,12 @@ def store( targets=targets, **kwargs, ) + + def shuffle(self, x: DaskArray, indexer: list[list[int]], axis: int) -> DaskArray: + import dask.array + + if not module_available("dask", minversion="2024.08.0"): + raise ValueError( + "This method is very inefficient on dask<2024.08.0. Please upgrade." + ) + return dask.array.shuffle(x, indexer, axis) diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index dd555fe200a..f3c73027a8a 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -364,6 +364,11 @@ def compute( """ raise NotImplementedError() + def shuffle( + self, x: T_ChunkedArray, indexer: list[list[int]], axis: int + ) -> T_ChunkedArray: + raise NotImplementedError() + @property def array_api(self) -> Any: """