From 60d76197388180945434beae2e3cda4287e1254f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 6 Aug 2024 20:36:50 -0600 Subject: [PATCH] Cleanup --- xarray/core/duck_array_ops.py | 15 --------------- xarray/core/groupby.py | 12 +++++++----- xarray/core/types.py | 2 +- xarray/core/variable.py | 18 +++++++++++++++++- xarray/namedarray/daskmanager.py | 9 +++++++++ xarray/namedarray/parallelcompat.py | 5 +++++ 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 25bd86177df..8993c136ba6 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -831,18 +831,3 @@ def chunked_nanfirst(darray, axis): def chunked_nanlast(darray, axis): return _chunked_first_or_last(darray, axis, op=nputils.nanlast) - - -def shuffle_array(array, indices: list[list[int]], axis: int): - # TODO: do chunk manager dance here. - if is_duck_dask_array(array): - if not module_available("dask", minversion="2024.08.0"): - raise ValueError( - "This method is very inefficient on dask<2024.08.0. Please upgrade." - ) - # TODO: handle dimensions - return array.shuffle(indexer=indices, axis=axis) - else: - indexer = np.concatenate(indices) - # TODO: Do the array API thing here. - return np.take(array, indices=indexer, axis=axis) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9fbf6778aea..95a1680e6f0 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -529,7 +529,6 @@ def shuffle(self) -> None: """ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.duck_array_ops import shuffle_array (grouper,) = self.groupers dim = self._group_dim @@ -538,6 +537,8 @@ def shuffle(self) -> None: if all(isinstance(idx, slice) for idx in self._group_indices): return + indices: tuple[list[int]] = self._group_indices # type: ignore[assignment] + was_array = isinstance(self._obj, DataArray) as_dataset = self._obj._to_temp_dataset() if was_array else self._obj @@ -546,21 +547,22 @@ def shuffle(self) -> None: if dim not in var.dims: shuffled[name] = var continue - shuffled_data = shuffle_array( - var._data, list(self._group_indices), axis=var.get_axis_num(dim) - ) - shuffled[name] = var._replace(data=shuffled_data) + shuffled[name] = var._shuffle(indices=list(indices), dim=dim) # Replace self._group_indices with slices slices = [] start = 0 for idxr in self._group_indices: + if TYPE_CHECKING: + assert not isinstance(idxr, slice) slices.append(slice(start, start + len(idxr))) start += len(idxr) # TODO: we have now broken the invariant # self._group_indices ≠ self.groupers[0].group_indices self._group_indices = tuple(slices) if was_array: + if TYPE_CHECKING: + assert isinstance(self._obj, DataArray) self._obj = self._obj._from_temp_dataset(shuffled) else: self._obj = shuffled diff --git a/xarray/core/types.py b/xarray/core/types.py index 591320d26da..96e75e18b51 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -297,7 +297,7 @@ def copy( ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] GroupKey = Any -GroupIndex = Union[int, slice, list[int]] +GroupIndex = Union[slice, list[int]] GroupIndices = tuple[GroupIndex, ...] Bins = Union[ int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 828c53e6187..b37959f2a38 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -44,7 +44,13 @@ maybe_coerce_to_str, ) from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions -from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import ( + integer_types, + is_0d_dask_array, + is_chunked_array, + to_duck_array, +) from xarray.util.deprecation_helpers import deprecate_dims NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( @@ -998,6 +1004,16 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) + def _shuffle(self, indices: list[list[int]], dim: Hashable) -> Self: + array = self._data + if is_chunked_array(array): + chunkmanager = get_chunked_array_type(array) + return chunkmanager.shuffle( + array, indexer=indices, axis=self.get_axis_num(dim) + ) + else: + return self.isel({dim: np.concatenate(indices)}) + def isel( self, indexers: Mapping[Any, Any] | None = None, diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 963d12fd865..aa4ced9f37a 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -251,3 +251,12 @@ def store( targets=targets, **kwargs, ) + + def shuffle(self, x: DaskArray, indexer: list[list[int]], axis: int) -> DaskArray: + import dask.array + + if not module_available("dask", minversion="2024.08.0"): + raise ValueError( + "This method is very inefficient on dask<2024.08.0. Please upgrade." + ) + return dask.array.shuffle(x, indexer, axis) diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index dd555fe200a..f3c73027a8a 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -364,6 +364,11 @@ def compute( """ raise NotImplementedError() + def shuffle( + self, x: T_ChunkedArray, indexer: list[list[int]], axis: int + ) -> T_ChunkedArray: + raise NotImplementedError() + @property def array_api(self) -> Any: """