diff --git a/.github/workflows/pypi-release.yaml b/.github/workflows/pypi-release.yaml index b7571b70d6d..bb75f68aacd 100644 --- a/.github/workflows/pypi-release.yaml +++ b/.github/workflows/pypi-release.yaml @@ -88,7 +88,7 @@ jobs: path: dist - name: Publish package to TestPyPI if: github.event_name == 'push' - uses: pypa/gh-action-pypi-publish@v1.10.3 + uses: pypa/gh-action-pypi-publish@v1.11.0 with: repository_url: https://test.pypi.org/legacy/ verbose: true @@ -111,6 +111,6 @@ jobs: name: releases path: dist - name: Publish package to PyPI - uses: pypa/gh-action-pypi-publish@v1.10.3 + uses: pypa/gh-action-pypi-publish@v1.11.0 with: verbose: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45a667f8aac..2bdb1ecaa69 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.6.9' + rev: 'v0.7.2' hooks: - id: ruff-format - id: ruff @@ -25,7 +25,7 @@ repos: exclude: "generate_aggregations.py" additional_dependencies: ["black==24.8.0"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy # Copied from setup.cfg diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 865c2fd9b19..02c69a41924 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -36,6 +36,14 @@ dependencies: - pre-commit - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap + # start pydap server dependencies, can be removed if pydap-server is available + - gunicorn + - PasteDeploy + - docopt-ng + - Webob + - Jinja2 + - beautifulsoup4 + # end pydap server dependencies - pytest - pytest-cov - pytest-env diff --git a/doc/getting-started-guide/faq.rst b/doc/getting-started-guide/faq.rst index 58d5448cdf5..49d7f5e4873 100644 --- a/doc/getting-started-guide/faq.rst +++ b/doc/getting-started-guide/faq.rst @@ -146,6 +146,9 @@ for conflicts between ``attrs`` when combining arrays and datasets, unless explicitly requested with the option ``compat='identical'``. The guiding principle is that metadata should not be allowed to get in the way. +In general xarray uses the capabilities of the backends for reading and writing +attributes. That has some implications on roundtripping. One example for such inconsistency is that size-1 lists will roundtrip as single element (for netcdf4 backends). + What other netCDF related Python libraries should I know about? --------------------------------------------------------------- diff --git a/doc/user-guide/data-structures.rst b/doc/user-guide/data-structures.rst index e5e89b0fbbd..9a2f26ec7b5 100644 --- a/doc/user-guide/data-structures.rst +++ b/doc/user-guide/data-structures.rst @@ -40,7 +40,8 @@ alignment, building on the functionality of the ``index`` found on a pandas DataArray objects also can have a ``name`` and can hold arbitrary metadata in the form of their ``attrs`` property. Names and attributes are strictly for users and user-written code: xarray makes no attempt to interpret them, and -propagates them only in unambiguous cases +propagates them only in unambiguous cases. For reading and writing attributes +xarray relies on the capabilities of the supported backends. (see FAQ, :ref:`approach to metadata`). .. _creating a dataarray: diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index 98bd7b4833b..069c7e0cb10 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -294,6 +294,15 @@ is identical to ds.resample(time=TimeResampler("ME")) +The :py:class:`groupers.UniqueGrouper` accepts an optional ``labels`` kwarg that is not present +in :py:meth:`DataArray.groupby` or :py:meth:`Dataset.groupby`. +Specifying ``labels`` is required when grouping by a lazy array type (e.g. dask or cubed). +The ``labels`` are used to construct the output coordinate (say for a reduction), and aggregations +will only be run over the specified labels. +You may use ``labels`` to also specify the ordering of groups to be used during iteration. +The order will be preserved in the output. + + .. _groupby.multiple: Grouping by multiple variables diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 659063f2cbf..4659978df8a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,12 @@ New Features ~~~~~~~~~~~~ - Added :py:meth:`DataTree.persist` method (:issue:`9675`, :pull:`9682`). By `Sam Levang `_. +- Added ``write_inherited_coords`` option to :py:meth:`DataTree.to_netcdf` + and :py:meth:`DataTree.to_zarr` (:pull:`9677`). + By `Stephan Hoyer `_. +- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])`` + (:issue:`2852`, :issue:`757`). + By `Deepak Cherian `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -30,12 +36,20 @@ Breaking changes Deprecations ~~~~~~~~~~~~ - +- Grouping by a chunked array (e.g. dask or cubed) currently eagerly loads that variable in to + memory. This behaviour is deprecated. If eager loading was intended, please load such arrays + manually using ``.load()`` or ``.compute()``. Else pass ``eagerly_compute_group=False``, and + provide expected group labels using the ``labels`` kwarg to a grouper object such as + :py:class:`grouper.UniqueGrouper` or :py:class:`grouper.BinGrouper`. Bug fixes ~~~~~~~~~ -- Fix inadvertent deep-copying of child data in DataTree. +- Fix inadvertent deep-copying of child data in DataTree (:issue:`9683`, + :pull:`9684`). + By `Stephan Hoyer `_. +- Avoid including parent groups when writing DataTree subgroups to Zarr or + netCDF (:pull:`9682`). By `Stephan Hoyer `_. - Fix regression in the interoperability of :py:meth:`DataArray.polyfit` and :py:meth:`xr.polyval` for date-time coordinates. (:pull:`9691`). By `Pascal Bourgault `_. @@ -43,6 +57,9 @@ Bug fixes Documentation ~~~~~~~~~~~~~ +- Mention attribute peculiarities in docs/docstrings (:issue:`4798`, :pull:`9700`). + By `Kai Mühlbauer `_. + Internal Changes ~~~~~~~~~~~~~~~~ @@ -91,14 +108,6 @@ New Features (:issue:`9427`, :pull: `9428`). By `Alfonso Ladino `_. -Breaking changes -~~~~~~~~~~~~~~~~ - - -Deprecations -~~~~~~~~~~~~ - - Bug fixes ~~~~~~~~~ diff --git a/xarray/core/common.py b/xarray/core/common.py index 9a6807faad2..6966b1723d3 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1094,7 +1094,7 @@ def _resample( f"Received {type(freq)} instead." ) - rgrouper = ResolvedGrouper(grouper, group, self) + rgrouper = ResolvedGrouper(grouper, group, self, eagerly_compute_group=False) return resample_cls( self, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c6bc082f5ed..bb9360b3175 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -347,6 +347,7 @@ class DataArray( attrs : dict_like or None, optional Attributes to assign to the new instance. By default, an empty attribute dictionary is initialized. + (see FAQ, :ref:`approach to metadata`) indexes : py:class:`~xarray.Indexes` or dict-like, optional For internal use only. For passing indexes objects to the new DataArray, use the ``coords`` argument instead with a @@ -6747,6 +6748,7 @@ def groupby( *, squeeze: Literal[False] = False, restore_coord_dims: bool = False, + eagerly_compute_group: bool = True, **groupers: Grouper, ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. @@ -6762,6 +6764,11 @@ def groupby( restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. + eagerly_compute_group: bool + Whether to eagerly compute ``group`` when it is a chunked array. + This option is to maintain backwards compatibility. Set to False + to opt-in to future behaviour, where ``group`` is not automatically loaded + into memory. **groupers : Mapping of str to Grouper or Resampler Mapping of variable name to group by to :py:class:`Grouper` or :py:class:`Resampler` object. One of ``group`` or ``groupers`` must be provided. @@ -6876,7 +6883,9 @@ def groupby( ) _validate_groupby_squeeze(squeeze) - rgroupers = _parse_group_and_groupers(self, group, groupers) + rgroupers = _parse_group_and_groupers( + self, group, groupers, eagerly_compute_group=eagerly_compute_group + ) return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) @_deprecate_positional_args("v2024.07.0") @@ -6891,6 +6900,7 @@ def groupby_bins( squeeze: Literal[False] = False, restore_coord_dims: bool = False, duplicates: Literal["raise", "drop"] = "raise", + eagerly_compute_group: bool = True, ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. @@ -6927,6 +6937,11 @@ def groupby_bins( coordinates. duplicates : {"raise", "drop"}, default: "raise" If bin edges are not unique, raise ValueError or drop non-uniques. + eagerly_compute_group: bool + Whether to eagerly compute ``group`` when it is a chunked array. + This option is to maintain backwards compatibility. Set to False + to opt-in to future behaviour, where ``group`` is not automatically loaded + into memory. Returns ------- @@ -6964,7 +6979,9 @@ def groupby_bins( precision=precision, include_lowest=include_lowest, ) - rgrouper = ResolvedGrouper(grouper, group, self) + rgrouper = ResolvedGrouper( + grouper, group, self, eagerly_compute_group=eagerly_compute_group + ) return DataArrayGroupBy( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bc9360a809d..4593fd62730 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -596,6 +596,7 @@ class Dataset( attrs : dict-like, optional Global attributes to save on this dataset. + (see FAQ, :ref:`approach to metadata`) Examples -------- @@ -10378,6 +10379,7 @@ def groupby( *, squeeze: Literal[False] = False, restore_coord_dims: bool = False, + eagerly_compute_group: bool = True, **groupers: Grouper, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. @@ -10393,6 +10395,11 @@ def groupby( restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. + eagerly_compute_group: bool + Whether to eagerly compute ``group`` when it is a chunked array. + This option is to maintain backwards compatibility. Set to False + to opt-in to future behaviour, where ``group`` is not automatically loaded + into memory. **groupers : Mapping of str to Grouper or Resampler Mapping of variable name to group by to :py:class:`Grouper` or :py:class:`Resampler` object. One of ``group`` or ``groupers`` must be provided. @@ -10475,7 +10482,9 @@ def groupby( ) _validate_groupby_squeeze(squeeze) - rgroupers = _parse_group_and_groupers(self, group, groupers) + rgroupers = _parse_group_and_groupers( + self, group, groupers, eagerly_compute_group=eagerly_compute_group + ) return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) @@ -10491,6 +10500,7 @@ def groupby_bins( squeeze: Literal[False] = False, restore_coord_dims: bool = False, duplicates: Literal["raise", "drop"] = "raise", + eagerly_compute_group: bool = True, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. @@ -10527,6 +10537,11 @@ def groupby_bins( coordinates. duplicates : {"raise", "drop"}, default: "raise" If bin edges are not unique, raise ValueError or drop non-uniques. + eagerly_compute_group: bool + Whether to eagerly compute ``group`` when it is a chunked array. + This option is to maintain backwards compatibility. Set to False + to opt-in to future behaviour, where ``group`` is not automatically loaded + into memory. Returns ------- @@ -10564,7 +10579,9 @@ def groupby_bins( precision=precision, include_lowest=include_lowest, ) - rgrouper = ResolvedGrouper(grouper, group, self) + rgrouper = ResolvedGrouper( + grouper, group, self, eagerly_compute_group=eagerly_compute_group + ) return DatasetGroupBy( self, diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index eab5d30c7dc..efbdd6bc8eb 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1573,6 +1573,7 @@ def to_netcdf( format: T_DataTreeNetcdfTypes | None = None, engine: T_DataTreeNetcdfEngine | None = None, group: str | None = None, + write_inherited_coords: bool = False, compute: bool = True, **kwargs, ): @@ -1609,6 +1610,11 @@ def to_netcdf( group : str, optional Path to the netCDF4 group in the given file to open as the root group of the ``DataTree``. Currently, specifying a group is not supported. + write_inherited_coords : bool, default: False + If true, replicate inherited coordinates on all descendant nodes. + Otherwise, only write coordinates at the level at which they are + originally defined. This saves disk space, but requires opening the + full tree to load inherited coordinates. compute : bool, default: True If true compute immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed later. @@ -1632,6 +1638,7 @@ def to_netcdf( format=format, engine=engine, group=group, + write_inherited_coords=write_inherited_coords, compute=compute, **kwargs, ) @@ -1643,6 +1650,7 @@ def to_zarr( encoding=None, consolidated: bool = True, group: str | None = None, + write_inherited_coords: bool = False, compute: Literal[True] = True, **kwargs, ): @@ -1668,6 +1676,11 @@ def to_zarr( after writing metadata for all groups. group : str, optional Group path. (a.k.a. `path` in zarr terminology.) + write_inherited_coords : bool, default: False + If true, replicate inherited coordinates on all descendant nodes. + Otherwise, only write coordinates at the level at which they are + originally defined. This saves disk space, but requires opening the + full tree to load inherited coordinates. compute : bool, default: True If true compute immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed later. Metadata @@ -1690,6 +1703,7 @@ def to_zarr( encoding=encoding, consolidated=consolidated, group=group, + write_inherited_coords=write_inherited_coords, compute=compute, **kwargs, ) diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index da1cc12c92a..3d0daa26b90 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -2,54 +2,15 @@ from collections.abc import Mapping, MutableMapping from os import PathLike -from typing import TYPE_CHECKING, Any, Literal, get_args +from typing import Any, Literal, get_args from xarray.core.datatree import DataTree from xarray.core.types import NetcdfWriteModes, ZarrWriteModes -if TYPE_CHECKING: - from h5netcdf.legacyapi import Dataset as h5Dataset - from netCDF4 import Dataset as ncDataset - T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"] T_DataTreeNetcdfTypes = Literal["NETCDF4"] -def _get_nc_dataset_class( - engine: T_DataTreeNetcdfEngine | None, -) -> type[ncDataset] | type[h5Dataset]: - if engine == "netcdf4": - from netCDF4 import Dataset as ncDataset - - return ncDataset - if engine == "h5netcdf": - from h5netcdf.legacyapi import Dataset as h5Dataset - - return h5Dataset - if engine is None: - try: - from netCDF4 import Dataset as ncDataset - - return ncDataset - except ImportError: - from h5netcdf.legacyapi import Dataset as h5Dataset - - return h5Dataset - raise ValueError(f"unsupported engine: {engine}") - - -def _create_empty_netcdf_group( - filename: str | PathLike, - group: str, - mode: NetcdfWriteModes, - engine: T_DataTreeNetcdfEngine | None, -) -> None: - ncDataset = _get_nc_dataset_class(engine) - - with ncDataset(filename, mode=mode) as rootgrp: - rootgrp.createGroup(group) - - def _datatree_to_netcdf( dt: DataTree, filepath: str | PathLike, @@ -59,6 +20,7 @@ def _datatree_to_netcdf( format: T_DataTreeNetcdfTypes | None = None, engine: T_DataTreeNetcdfEngine | None = None, group: str | None = None, + write_inherited_coords: bool = False, compute: bool = True, **kwargs, ) -> None: @@ -97,34 +59,23 @@ def _datatree_to_netcdf( unlimited_dims = {} for node in dt.subtree: - ds = node.to_dataset(inherit=False) - group_path = node.path - if ds is None: - _create_empty_netcdf_group(filepath, group_path, mode, engine) - else: - ds.to_netcdf( - filepath, - group=group_path, - mode=mode, - encoding=encoding.get(node.path), - unlimited_dims=unlimited_dims.get(node.path), - engine=engine, - format=format, - compute=compute, - **kwargs, - ) + at_root = node is dt + ds = node.to_dataset(inherit=write_inherited_coords or at_root) + group_path = None if at_root else "/" + node.relative_to(dt) + ds.to_netcdf( + filepath, + group=group_path, + mode=mode, + encoding=encoding.get(node.path), + unlimited_dims=unlimited_dims.get(node.path), + engine=engine, + format=format, + compute=compute, + **kwargs, + ) mode = "a" -def _create_empty_zarr_group( - store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes -): - import zarr - - root = zarr.open_group(store, mode=mode) - root.create_group(group, overwrite=True) - - def _datatree_to_zarr( dt: DataTree, store: MutableMapping | str | PathLike[str], @@ -132,6 +83,7 @@ def _datatree_to_zarr( encoding: Mapping[str, Any] | None = None, consolidated: bool = True, group: str | None = None, + write_inherited_coords: bool = False, compute: Literal[True] = True, **kwargs, ): @@ -163,19 +115,17 @@ def _datatree_to_zarr( ) for node in dt.subtree: - ds = node.to_dataset(inherit=False) - group_path = node.path - if ds is None: - _create_empty_zarr_group(store, group_path, mode) - else: - ds.to_zarr( - store, - group=group_path, - mode=mode, - encoding=encoding.get(node.path), - consolidated=False, - **kwargs, - ) + at_root = node is dt + ds = node.to_dataset(inherit=write_inherited_coords or at_root) + group_path = None if at_root else "/" + node.relative_to(dt) + ds.to_zarr( + store, + group=group_path, + mode=mode, + encoding=encoding.get(node.path), + consolidated=False, + **kwargs, + ) if "w" in mode: mode = "a" diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5a5a241f6c1..5c4633c1612 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -22,6 +22,7 @@ from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce from xarray.core.concat import concat from xarray.core.coordinates import Coordinates, _coordinates_from_variable +from xarray.core.duck_array_ops import where from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( PandasMultiIndex, @@ -40,6 +41,7 @@ FrozenMappingWarningOnValuesAccess, contains_only_chunked_or_numpy, either_dict_or_kwargs, + emit_user_level_warning, hashable, is_scalar, maybe_wrap_array, @@ -47,6 +49,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable +from xarray.namedarray.pycompat import is_chunked_array from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: @@ -190,8 +193,8 @@ def values(self) -> range: return range(self.size) @property - def data(self) -> range: - return range(self.size) + def data(self) -> np.ndarray: + return np.arange(self.size, dtype=int) def __array__( self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None @@ -255,7 +258,9 @@ def _ensure_1d( stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) # these dimensions get created by the stack operation inserted_dims = [dim for dim in group.dims if dim not in group.coords] - newgroup = group.stack({stacked_dim: orig_dims}) + # `newgroup` construction is optimized so we don't create an index unnecessarily, + # or stack any non-dim coords unnecessarily + newgroup = DataArray(group.variable.stack({stacked_dim: orig_dims})) newobj = obj.stack({stacked_dim: orig_dims}) return newgroup, newobj, stacked_dim, inserted_dims @@ -280,6 +285,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): grouper: Grouper group: T_Group obj: T_DataWithCoords + eagerly_compute_group: bool = field(repr=False) # returned by factorize: encoded: EncodedGroups = field(init=False, repr=False) @@ -302,10 +308,46 @@ def __post_init__(self) -> None: # of pd.cut # We do not want to modify the original object, since the same grouper # might be used multiple times. + from xarray.groupers import BinGrouper, UniqueGrouper + self.grouper = copy.deepcopy(self.grouper) self.group = _resolve_group(self.obj, self.group) + if not isinstance(self.group, _DummyGroup) and is_chunked_array( + self.group.variable._data + ): + if self.eagerly_compute_group is False: + # This requires a pass to discover the groups present + if ( + isinstance(self.grouper, UniqueGrouper) + and self.grouper.labels is None + ): + raise ValueError( + "Please pass `labels` to UniqueGrouper when grouping by a chunked array." + ) + # this requires a pass to compute the bin edges + if isinstance(self.grouper, BinGrouper) and isinstance( + self.grouper.bins, int + ): + raise ValueError( + "Please pass explicit bin edges to BinGrouper using the ``bins`` kwarg" + "when grouping by a chunked array." + ) + + if self.eagerly_compute_group: + emit_user_level_warning( + f""""Eagerly computing the DataArray you're grouping by ({self.group.name!r}) " + is deprecated and will raise an error in v2025.05.0. + Please load this array's data manually using `.compute` or `.load`. + To intentionally avoid eager loading, either (1) specify + `.groupby({self.group.name}=UniqueGrouper(labels=...), eagerly_load_group=False)` + or (2) pass explicit bin edges using or `.groupby({self.group.name}=BinGrouper(bins=...), + eagerly_load_group=False)`; as appropriate.""", + DeprecationWarning, + ) + self.group = self.group.compute() + self.encoded = self.grouper.factorize(self.group) @property @@ -326,7 +368,11 @@ def __len__(self) -> int: def _parse_group_and_groupers( - obj: T_Xarray, group: GroupInput, groupers: dict[str, Grouper] + obj: T_Xarray, + group: GroupInput, + groupers: dict[str, Grouper], + *, + eagerly_compute_group: bool, ) -> tuple[ResolvedGrouper, ...]: from xarray.core.dataarray import DataArray from xarray.core.variable import Variable @@ -351,7 +397,11 @@ def _parse_group_and_groupers( rgroupers: tuple[ResolvedGrouper, ...] if isinstance(group, DataArray | Variable): - rgroupers = (ResolvedGrouper(UniqueGrouper(), group, obj),) + rgroupers = ( + ResolvedGrouper( + UniqueGrouper(), group, obj, eagerly_compute_group=eagerly_compute_group + ), + ) else: if group is not None: if TYPE_CHECKING: @@ -364,7 +414,9 @@ def _parse_group_and_groupers( grouper_mapping = cast("Mapping[Hashable, Grouper]", groupers) rgroupers = tuple( - ResolvedGrouper(grouper, group, obj) + ResolvedGrouper( + grouper, group, obj, eagerly_compute_group=eagerly_compute_group + ) for group, grouper in grouper_mapping.items() ) return rgroupers @@ -467,15 +519,21 @@ def factorize(self) -> EncodedGroups: # Restore these after the raveling broadcasted_masks = broadcast(*masks) mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type] - _flatcodes[mask] = -1 + _flatcodes = where(mask, -1, _flatcodes) full_index = pd.MultiIndex.from_product( (grouper.full_index.values for grouper in groupers), names=tuple(grouper.name for grouper in groupers), ) - # Constructing an index from the product is wrong when there are missing groups - # (e.g. binning, resampling). Account for that now. - midx = full_index[np.sort(pd.unique(_flatcodes[~mask]))] + # This will be unused when grouping by dask arrays, so skip.. + if not is_chunked_array(_flatcodes): + # Constructing an index from the product is wrong when there are missing groups + # (e.g. binning, resampling). Account for that now. + midx = full_index[np.sort(pd.unique(_flatcodes[~mask]))] + group_indices = _codes_to_group_indices(_flatcodes.ravel(), len(full_index)) + else: + midx = full_index + group_indices = None dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers) @@ -485,7 +543,7 @@ def factorize(self) -> EncodedGroups: return EncodedGroups( codes=first_codes.copy(data=_flatcodes), full_index=full_index, - group_indices=_codes_to_group_indices(_flatcodes.ravel(), len(full_index)), + group_indices=group_indices, unique_coord=Variable(dims=(dim_name,), data=midx.values), coords=coords, ) @@ -518,6 +576,7 @@ class GroupBy(Generic[T_Xarray]): "_dims", "_sizes", "_len", + "_by_chunked", # Save unstacked object for flox "_original_obj", "_codes", @@ -535,6 +594,7 @@ class GroupBy(Generic[T_Xarray]): _group_indices: GroupIndices _codes: tuple[DataArray, ...] _group_dim: Hashable + _by_chunked: bool _groups: dict[GroupKey, GroupIndex] | None _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None @@ -587,6 +647,7 @@ def __init__( # specification for the groupby operation # TODO: handle obj having variables that are not present on any of the groupers # simple broadcasting fails for ExtensionArrays. + # FIXME: Skip this stacking when grouping by a dask array, it's useless in that case. (self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = _ensure_1d( group=self.encoded.codes, obj=obj ) @@ -597,6 +658,7 @@ def __init__( self._dims = None self._sizes = None self._len = len(self.encoded.full_index) + self._by_chunked = is_chunked_array(self.encoded.codes.data) @property def sizes(self) -> Mapping[Hashable, int]: @@ -636,6 +698,15 @@ def reduce( ) -> T_Xarray: raise NotImplementedError() + def _raise_if_by_is_chunked(self): + if self._by_chunked: + raise ValueError( + "This method is not supported when lazily grouping by a chunked array. " + "Either load the array in to memory prior to grouping using .load or .compute, " + " or explore another way of applying your function, " + "potentially using the `flox` package." + ) + def _raise_if_not_single_group(self): if len(self.groupers) != 1: raise NotImplementedError( @@ -684,6 +755,7 @@ def __repr__(self) -> str: def _iter_grouped(self) -> Iterator[T_Xarray]: """Iterate over each element in this group""" + self._raise_if_by_is_chunked() for indices in self.encoded.group_indices: if indices: yield self._obj.isel({self._group_dim: indices}) @@ -857,7 +929,7 @@ def _flox_reduce( if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - if Version(flox.__version__) < Version("0.9"): + if Version(flox.__version__) < Version("0.9") and not self._by_chunked: # preserve current strategy (approximately) for dask groupby # on older flox versions to prevent surprises. # flox >=0.9 will choose this on its own. @@ -923,11 +995,15 @@ def _flox_reduce( has_missing_groups = ( self.encoded.unique_coord.size != self.encoded.full_index.size ) - if has_missing_groups or kwargs.get("min_count", 0) > 0: + if self._by_chunked or has_missing_groups or kwargs.get("min_count", 0) > 0: # Xarray *always* returns np.nan when there are no observations in a group, # We can fake that here by forcing min_count=1 when it is not set. # This handles boolean reductions, and count # See GH8090, GH9398 + # Note that `has_missing_groups=False` when `self._by_chunked is True`. + # We *choose* to always do the masking, so that behaviour is predictable + # in some way. The real solution is to expose fill_value as a kwarg, + # and set appopriate defaults :/. kwargs.setdefault("fill_value", np.nan) kwargs.setdefault("min_count", 1) @@ -1266,6 +1342,7 @@ def _iter_grouped_shortcut(self): """Fast version of `_iter_grouped` that yields Variables without metadata """ + self._raise_if_by_is_chunked() var = self._obj.variable for _idx, indices in enumerate(self.encoded.group_indices): if indices: @@ -1426,6 +1503,12 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ + if self._by_chunked: + raise ValueError( + "This method is not supported when lazily grouping by a chunked array. " + "Try installing the `flox` package if you are using one of the standard " + "reductions (e.g. `mean`). " + ) if dim is None: dim = [self._group_dim] @@ -1577,6 +1660,14 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ + + if self._by_chunked: + raise ValueError( + "This method is not supported when lazily grouping by a chunked array. " + "Try installing the `flox` package if you are using one of the standard " + "reductions (e.g. `mean`). " + ) + if dim is None: dim = [self._group_dim] diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 4a735959298..81ea9b5dca5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1031,13 +1031,21 @@ def stack( f"from variable {name!r} that wraps a multi-index" ) - split_labels, levels = zip( - *[lev.factorize() for lev in level_indexes], strict=True - ) - labels_mesh = np.meshgrid(*split_labels, indexing="ij") - labels = [x.ravel() for x in labels_mesh] + # from_product sorts by default, so we can't use that always + # https://github.com/pydata/xarray/issues/980 + # https://github.com/pandas-dev/pandas/issues/14672 + if all(index.is_monotonic_increasing for index in level_indexes): + index = pd.MultiIndex.from_product( + level_indexes, sortorder=0, names=variables.keys() + ) + else: + split_labels, levels = zip( + *[lev.factorize() for lev in level_indexes], strict=True + ) + labels_mesh = np.meshgrid(*split_labels, indexing="ij") + labels = [x.ravel() for x in labels_mesh] - index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys()) + index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys()) level_coords_dtype = {k: var.dtype for k, var in variables.items()} return cls(index, dim, level_coords_dtype=level_coords_dtype) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 0a5bf969260..640708ace45 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -1065,6 +1065,7 @@ def contains_only_chunked_or_numpy(obj) -> bool: Expects obj to be Dataset or DataArray""" from xarray.core.dataarray import DataArray + from xarray.core.indexing import ExplicitlyIndexed from xarray.namedarray.pycompat import is_chunked_array if isinstance(obj, DataArray): @@ -1072,8 +1073,10 @@ def contains_only_chunked_or_numpy(obj) -> bool: return all( [ - isinstance(var.data, np.ndarray) or is_chunked_array(var.data) - for var in obj.variables.values() + isinstance(var._data, ExplicitlyIndexed) + or isinstance(var._data, np.ndarray) + or is_chunked_array(var._data) + for var in obj._variables.values() ] ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 4216e574312..94f95327747 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -408,6 +408,7 @@ def __init__( attrs : dict_like or None, optional Attributes to assign to the new variable. If None (default), an empty attribute dictionary is initialized. + (see FAQ, :ref:`approach to metadata`) encoding : dict_like or None, optional Dictionary specifying how to encode this array's data into a serialized format like netCDF4. Currently used keys (for netCDF) diff --git a/xarray/groupers.py b/xarray/groupers.py index 996f86317b9..c4980e6d810 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -13,11 +13,14 @@ import numpy as np import pandas as pd +from numpy.typing import ArrayLike from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.core import duck_array_ops +from xarray.core.computation import apply_ufunc from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.dataarray import DataArray +from xarray.core.duck_array_ops import isnull from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper @@ -29,6 +32,7 @@ SideOptions, ) from xarray.core.variable import Variable +from xarray.namedarray.pycompat import is_chunked_array __all__ = [ "EncodedGroups", @@ -86,11 +90,17 @@ def __init__( self.full_index = full_index if group_indices is None: - self.group_indices = tuple( - g - for g in _codes_to_group_indices(codes.data.ravel(), len(full_index)) - if g - ) + if not is_chunked_array(codes.data): + self.group_indices = tuple( + g + for g in _codes_to_group_indices( + codes.data.ravel(), len(full_index) + ) + if g + ) + else: + # We will not use this when grouping by a chunked array + self.group_indices = tuple() else: self.group_indices = group_indices @@ -141,9 +151,20 @@ class Resampler(Grouper): @dataclass class UniqueGrouper(Grouper): - """Grouper object for grouping by a categorical variable.""" + """ + Grouper object for grouping by a categorical variable. + + Parameters + ---------- + labels: array-like, optional + Group labels to aggregate on. This is required when grouping by a chunked array type + (e.g. dask or cubed) since it is used to construct the coordinate on the output. + Grouped operations will only be run on the specified group labels. Any group that is not + present in ``labels`` will be ignored. + """ _group_as_index: pd.Index | None = field(default=None, repr=False) + labels: ArrayLike | None = field(default=None) @property def group_as_index(self) -> pd.Index: @@ -158,6 +179,14 @@ def group_as_index(self) -> pd.Index: def factorize(self, group: T_Group) -> EncodedGroups: self.group = group + if is_chunked_array(group.data) and self.labels is None: + raise ValueError( + "When grouping by a dask array, `labels` must be passed using " + "a UniqueGrouper object." + ) + if self.labels is not None: + return self._factorize_given_labels(group) + index = self.group_as_index is_unique_and_monotonic = isinstance(self.group, _DummyGroup) or ( index.is_unique @@ -171,6 +200,25 @@ def factorize(self, group: T_Group) -> EncodedGroups: else: return self._factorize_unique() + def _factorize_given_labels(self, group: T_Group) -> EncodedGroups: + codes = apply_ufunc( + _factorize_given_labels, + group, + kwargs={"labels": self.labels}, + dask="parallelized", + output_dtypes=[np.int64], + keep_attrs=True, + ) + return EncodedGroups( + codes=codes, + full_index=pd.Index(self.labels), # type: ignore[arg-type] + unique_coord=Variable( + dims=codes.name, + data=self.labels, + attrs=self.group.attrs, + ), + ) + def _factorize_unique(self) -> EncodedGroups: # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) @@ -280,13 +328,9 @@ def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") - def factorize(self, group: T_Group) -> EncodedGroups: - from xarray.core.dataarray import DataArray - - data = np.asarray(group.data) # Cast _DummyGroup data to array - - binned, self.bins = pd.cut( # type: ignore [call-overload] - data.ravel(), + def _cut(self, data): + return pd.cut( + np.asarray(data).ravel(), bins=self.bins, right=self.right, labels=self.labels, @@ -296,23 +340,43 @@ def factorize(self, group: T_Group) -> EncodedGroups: retbins=True, ) - binned_codes = binned.codes - if (binned_codes == -1).all(): + def _factorize_lazy(self, group: T_Group) -> DataArray: + def _wrapper(data, **kwargs): + binned, bins = self._cut(data) + if isinstance(self.bins, int): + # we are running eagerly, update self.bins with actual edges instead + self.bins = bins + return binned.codes.reshape(data.shape) + + return apply_ufunc(_wrapper, group, dask="parallelized", keep_attrs=True) + + def factorize(self, group: T_Group) -> EncodedGroups: + if isinstance(group, _DummyGroup): + group = DataArray(group.data, dims=group.dims, name=group.name) + by_is_chunked = is_chunked_array(group.data) + if isinstance(self.bins, int) and by_is_chunked: + raise ValueError( + f"Bin edges must be provided when grouping by chunked arrays. Received {self.bins=!r} instead" + ) + codes = self._factorize_lazy(group) + if not by_is_chunked and (codes == -1).all(): raise ValueError( f"None of the data falls within bins with edges {self.bins!r}" ) new_dim_name = f"{group.name}_bins" + codes.name = new_dim_name + + # This seems silly, but it lets us have Pandas handle the complexity + # of `labels`, `precision`, and `include_lowest`, even when group is a chunked array + dummy, _ = self._cut(np.array([0]).astype(group.dtype)) + full_index = dummy.categories + if not by_is_chunked: + uniques = np.sort(pd.unique(codes.data.ravel())) + unique_values = full_index[uniques[uniques != -1]] + else: + unique_values = full_index - full_index = binned.categories - uniques = np.sort(pd.unique(binned_codes)) - unique_values = full_index[uniques[uniques != -1]] - - codes = DataArray( - binned_codes.reshape(group.shape), - getattr(group, "coords", None), - name=new_dim_name, - ) unique_coord = Variable( dims=new_dim_name, data=unique_values, attrs=group.attrs ) @@ -450,6 +514,21 @@ def factorize(self, group: T_Group) -> EncodedGroups: ) +def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray: + # Copied from flox + sorter = np.argsort(labels) + is_sorted = (sorter == np.arange(sorter.size)).all() + codes = np.searchsorted(labels, data, sorter=sorter) + mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels)) + # codes is the index in to the sorted array. + # if we didn't want sorting, unsort it back + if not is_sorted: + codes[codes == len(labels)] = -1 + codes = sorter[(codes,)] + codes[mask] = -1 + return codes + + def unique_value_groups( ar, sort: bool = True ) -> tuple[np.ndarray | pd.Index, np.ndarray]: diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 6e2d25249fb..608f9645ce8 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -196,6 +196,24 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree): with pytest.raises(ValueError, match="unexpected encoding group.*"): original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine) + def test_write_subgroup(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2, 3]}), + "/child": xr.Dataset({"foo": ("x", [4, 5, 6])}), + } + ).children["child"] + + expected_dt = original_dt.copy() + expected_dt.name = None + + filepath = tmpdir / "test.zarr" + original_dt.to_netcdf(filepath, engine=self.engine) + + with open_datatree(filepath, engine=self.engine) as roundtrip_dt: + assert_equal(original_dt, roundtrip_dt) + assert_identical(expected_dt, roundtrip_dt) + @requires_netCDF4 class TestNetCDF4DatatreeIO(DatatreeIOBase): @@ -556,3 +574,59 @@ def test_open_groups_chunks(self, tmpdir) -> None: for ds in dict_of_datasets.values(): ds.close() + + def test_write_subgroup(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2, 3]}), + "/child": xr.Dataset({"foo": ("x", [4, 5, 6])}), + } + ).children["child"] + + expected_dt = original_dt.copy() + expected_dt.name = None + + filepath = tmpdir / "test.zarr" + original_dt.to_zarr(filepath) + + with open_datatree(filepath, engine="zarr") as roundtrip_dt: + assert_equal(original_dt, roundtrip_dt) + assert_identical(expected_dt, roundtrip_dt) + + def test_write_inherited_coords_false(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2, 3]}), + "/child": xr.Dataset({"foo": ("x", [4, 5, 6])}), + } + ) + + filepath = tmpdir / "test.zarr" + original_dt.to_zarr(filepath, write_inherited_coords=False) + + with open_datatree(filepath, engine="zarr") as roundtrip_dt: + assert_identical(original_dt, roundtrip_dt) + + expected_child = original_dt.children["child"].copy(inherit=False) + expected_child.name = None + with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child: + assert_identical(expected_child, roundtrip_child) + + def test_write_inherited_coords_true(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2, 3]}), + "/child": xr.Dataset({"foo": ("x", [4, 5, 6])}), + } + ) + + filepath = tmpdir / "test.zarr" + original_dt.to_zarr(filepath, write_inherited_coords=True) + + with open_datatree(filepath, engine="zarr") as roundtrip_dt: + assert_identical(original_dt, roundtrip_dt) + + expected_child = original_dt.children["child"].copy(inherit=True) + expected_child.name = None + with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child: + assert_identical(expected_child, roundtrip_child) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index bff5ca8298d..057a467dc7f 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -20,6 +20,7 @@ from numpy import RankWarning # type: ignore[no-redef,attr-defined,unused-ignore] import xarray as xr +import xarray.core.missing from xarray import ( DataArray, Dataset, diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 3d948e7840e..c0eeace71af 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -22,6 +22,7 @@ TimeResampler, UniqueGrouper, ) +from xarray.namedarray.pycompat import is_chunked_array from xarray.tests import ( InaccessibleArray, assert_allclose, @@ -29,8 +30,10 @@ assert_identical, create_test_data, has_cftime, + has_dask, has_flox, has_pandas_ge_2_2, + raise_if_dask_computes, requires_cftime, requires_dask, requires_flox, @@ -2604,7 +2607,9 @@ def test_groupby_math_auto_chunk() -> None: sub = xr.DataArray( InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]} ) - actual = da.chunk(x=1, y=2).groupby("label") - sub + chunked = da.chunk(x=1, y=2) + chunked.label.load() + actual = chunked.groupby("label") - sub assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)} @@ -2814,7 +2819,7 @@ def test_multiple_groupers(use_flox) -> None: b = xr.DataArray( np.random.RandomState(0).randn(2, 3, 4), - coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]])}, + coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]], {"foo": "bar"})}, dims=["x", "y", "z"], ) gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) @@ -2831,10 +2836,40 @@ def test_multiple_groupers(use_flox) -> None: expected.loc[dict(x=1, xy=1)] = expected.sel(x=1, xy=0).data expected.loc[dict(x=1, xy=0)] = np.nan expected.loc[dict(x=1, xy=2)] = newval - expected["xy"] = ("xy", ["a", "b", "c"]) + expected["xy"] = ("xy", ["a", "b", "c"], {"foo": "bar"}) # TODO: is order of dims correct? assert_identical(actual, expected.transpose("z", "x", "xy")) + if has_dask: + b["xy"] = b["xy"].chunk() + for eagerly_compute_group in [True, False]: + kwargs = dict( + x=UniqueGrouper(), + xy=UniqueGrouper(labels=["a", "b", "c"]), + eagerly_compute_group=eagerly_compute_group, + ) + expected = xr.DataArray( + [[[1, 1, 1], [np.nan, 1, 2]]] * 4, + dims=("z", "x", "xy"), + coords={"xy": ("xy", ["a", "b", "c"], {"foo": "bar"})}, + ) + if eagerly_compute_group: + with raise_if_dask_computes(max_computes=1): + with pytest.warns(DeprecationWarning): + gb = b.groupby(**kwargs) # type: ignore[arg-type] + assert_identical(gb.count(), expected) + else: + with raise_if_dask_computes(max_computes=0): + gb = b.groupby(**kwargs) # type: ignore[arg-type] + assert is_chunked_array(gb.encoded.codes.data) + assert not gb.encoded.group_indices + if has_flox: + with raise_if_dask_computes(max_computes=1): + assert_identical(gb.count(), expected) + else: + with pytest.raises(ValueError, match="when lazily grouping"): + gb.count() + @pytest.mark.parametrize("use_flox", [True, False]) def test_multiple_groupers_mixed(use_flox) -> None: @@ -2936,12 +2971,6 @@ def test_gappy_resample_reductions(reduction): assert_identical(expected, actual) -# Possible property tests -# 1. lambda x: x -# 2. grouped-reduce on unique coords is identical to array -# 3. group_over == groupby-reduce along other dimensions - - def test_groupby_transpose(): # GH5361 data = xr.DataArray( @@ -2955,6 +2984,96 @@ def test_groupby_transpose(): assert_identical(first, second.transpose(*first.dims)) +@requires_dask +@pytest.mark.parametrize( + "grouper, expect_index", + [ + [UniqueGrouper(labels=np.arange(1, 5)), pd.Index(np.arange(1, 5))], + [UniqueGrouper(labels=np.arange(1, 5)[::-1]), pd.Index(np.arange(1, 5)[::-1])], + [ + BinGrouper(bins=np.arange(1, 5)), + pd.IntervalIndex.from_breaks(np.arange(1, 5)), + ], + ], +) +def test_lazy_grouping(grouper, expect_index): + import dask.array + + data = DataArray( + dims=("x", "y"), + data=dask.array.arange(20, chunks=3).reshape((4, 5)), + name="zoo", + ) + with raise_if_dask_computes(): + encoded = grouper.factorize(data) + assert encoded.codes.ndim == data.ndim + pd.testing.assert_index_equal(encoded.full_index, expect_index) + np.testing.assert_array_equal(encoded.unique_coord.values, np.array(expect_index)) + + eager = ( + xr.Dataset({"foo": data}, coords={"zoo": data.compute()}) + .groupby(zoo=grouper) + .count() + ) + expected = Dataset( + {"foo": (encoded.codes.name, np.ones(encoded.full_index.size))}, + coords={encoded.codes.name: expect_index}, + ) + assert_identical(eager, expected) + + if has_flox: + lazy = ( + xr.Dataset({"foo": data}, coords={"zoo": data}) + .groupby(zoo=grouper, eagerly_compute_group=False) + .count() + ) + assert_identical(eager, lazy) + + +@requires_dask +def test_lazy_grouping_errors(): + import dask.array + + data = DataArray( + dims=("x",), + data=dask.array.arange(20, chunks=3), + name="foo", + coords={"y": ("x", dask.array.arange(20, chunks=3))}, + ) + + gb = data.groupby( + y=UniqueGrouper(labels=np.arange(5, 10)), eagerly_compute_group=False + ) + message = "not supported when lazily grouping by" + with pytest.raises(ValueError, match=message): + gb.map(lambda x: x) + + with pytest.raises(ValueError, match=message): + gb.reduce(np.mean) + + with pytest.raises(ValueError, match=message): + for _, _ in gb: + pass + + +@requires_dask +def test_lazy_int_bins_error(): + import dask.array + + with pytest.raises(ValueError, match="Bin edges must be provided"): + with raise_if_dask_computes(): + _ = BinGrouper(bins=4).factorize(DataArray(dask.array.arange(3))) + + +def test_time_grouping_seasons_specified(): + time = xr.date_range("2001-01-01", "2002-01-01", freq="D") + ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)}) + labels = ["DJF", "MAM", "JJA", "SON"] + actual = ds.groupby({"time.season": UniqueGrouper(labels=labels)}).sum() + expected = ds.groupby("time.season").sum() + assert_identical(actual, expected.reindex(season=labels)) + + def test_groupby_multiple_bin_grouper_missing_groups(): from numpy import nan @@ -2985,3 +3104,45 @@ def test_groupby_multiple_bin_grouper_missing_groups(): }, ) assert_identical(actual, expected) + + +@requires_dask +def test_groupby_dask_eager_load_warnings(): + ds = xr.Dataset( + {"foo": (("z"), np.arange(12))}, + coords={"x": ("z", np.arange(12)), "y": ("z", np.arange(12))}, + ).chunk(z=6) + + with pytest.warns(DeprecationWarning): + ds.groupby(x=UniqueGrouper()) + + with pytest.warns(DeprecationWarning): + ds.groupby("x") + + with pytest.warns(DeprecationWarning): + ds.groupby(ds.x) + + with pytest.raises(ValueError, match="Please pass"): + ds.groupby("x", eagerly_compute_group=False) + + # This is technically fine but anyone iterating over the groupby object + # will see an error, so let's warn and have them opt-in. + with pytest.warns(DeprecationWarning): + ds.groupby(x=UniqueGrouper(labels=[1, 2, 3])) + + ds.groupby(x=UniqueGrouper(labels=[1, 2, 3]), eagerly_compute_group=False) + + with pytest.warns(DeprecationWarning): + ds.groupby_bins("x", bins=3) + with pytest.raises(ValueError, match="Please pass"): + ds.groupby_bins("x", bins=3, eagerly_compute_group=False) + with pytest.warns(DeprecationWarning): + ds.groupby_bins("x", bins=[1, 2, 3]) + ds.groupby_bins("x", bins=[1, 2, 3], eagerly_compute_group=False) + + +# TODO: Possible property tests to add to this module +# 1. lambda x: x +# 2. grouped-reduce on unique coords is identical to array +# 3. group_over == groupby-reduce along other dimensions +# 4. result is equivalent for transposed input