Skip to content

Commit

Permalink
GroupBy(multiple strings) (pydata#9414)
Browse files Browse the repository at this point in the history
* Group by multiple strings

Closes pydata#9396

* Fix typing

* some more

* fix

* cleanup

* Update xarray/core/dataarray.py

* Update docs

* Revert "Update xarray/core/dataarray.py"

This reverts commit fafd960.

* update docstring

* Add docstring examples

* Update xarray/core/dataarray.py

Co-authored-by: Maximilian Roos <[email protected]>

* Update xarray/core/dataset.py

* fix assert warning / error

* fix assert warning / error

* Silence RTD warnings

---------

Co-authored-by: Maximilian Roos <[email protected]>
Co-authored-by: Maximilian Roos <[email protected]>
  • Loading branch information
3 people authored and hollymandel committed Sep 23, 2024
1 parent 516aca9 commit 0253dc7
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 59 deletions.
6 changes: 6 additions & 0 deletions doc/user-guide/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ Use grouper objects to group by multiple dimensions:
from xarray.groupers import UniqueGrouper
da.groupby(["lat", "lon"]).sum()
The above is sugar for using ``UniqueGrouper`` objects directly:

.. ipython:: python
da.groupby(lat=UniqueGrouper(), lon=UniqueGrouper()).sum()
Expand Down
77 changes: 51 additions & 26 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
GroupInput,
InterpOptions,
PadModeOptions,
PadReflectOptions,
Expand Down Expand Up @@ -6707,9 +6708,7 @@ def interp_calendar(
@_deprecate_positional_args("v2024.07.0")
def groupby(
self,
group: (
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
) = None,
group: GroupInput = None,
*,
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
Expand All @@ -6719,7 +6718,7 @@ def groupby(
Parameters
----------
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper
Array whose unique values should be used to group this array. If a
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
must map an existing variable name to a :py:class:`Grouper` instance.
Expand Down Expand Up @@ -6770,6 +6769,52 @@ def groupby(
Coordinates:
* dayofyear (dayofyear) int64 3kB 1 2 3 4 5 6 7 ... 361 362 363 364 365 366
>>> da = xr.DataArray(
... data=np.arange(12).reshape((4, 3)),
... dims=("x", "y"),
... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))},
... )
Grouping by a single variable is easy
>>> da.groupby("letters")
<DataArrayGroupBy, grouped over 1 grouper(s), 2 groups in total:
'letters': 2 groups with labels 'a', 'b'>
Execute a reduction
>>> da.groupby("letters").sum()
<xarray.DataArray (letters: 2, y: 3)> Size: 48B
array([[ 9., 11., 13.],
[ 9., 11., 13.]])
Coordinates:
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y
Grouping by multiple variables
>>> da.groupby(["letters", "x"])
<DataArrayGroupBy, grouped over 2 grouper(s), 8 groups in total:
'letters': 2 groups with labels 'a', 'b'
'x': 4 groups with labels 10, 20, 30, 40>
Use Grouper objects to express more complicated GroupBy operations
>>> from xarray.groupers import BinGrouper, UniqueGrouper
>>>
>>> da.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()
<xarray.DataArray (x_bins: 2, letters: 2, y: 3)> Size: 96B
array([[[ 0., 1., 2.],
[nan, nan, nan]],
<BLANKLINE>
[[nan, nan, nan],
[ 3., 4., 5.]]])
Coordinates:
* x_bins (x_bins) object 16B (5, 15] (15, 25]
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y
See Also
--------
:ref:`groupby`
Expand All @@ -6791,32 +6836,12 @@ def groupby(
"""
from xarray.core.groupby import (
DataArrayGroupBy,
ResolvedGrouper,
_parse_group_and_groupers,
_validate_groupby_squeeze,
)
from xarray.groupers import UniqueGrouper

_validate_groupby_squeeze(squeeze)

if isinstance(group, Mapping):
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if group is not None:
if groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
)

rgroupers = _parse_group_and_groupers(self, group, groupers)
return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)

@_deprecate_positional_args("v2024.07.0")
Expand Down
75 changes: 50 additions & 25 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
DsCompatible,
ErrorOptions,
ErrorOptionsWithWarn,
GroupInput,
InterpOptions,
JoinOptions,
PadModeOptions,
Expand Down Expand Up @@ -10332,9 +10333,7 @@ def interp_calendar(
@_deprecate_positional_args("v2024.07.0")
def groupby(
self,
group: (
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
) = None,
group: GroupInput = None,
*,
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
Expand All @@ -10344,7 +10343,7 @@ def groupby(
Parameters
----------
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper
Array whose unique values should be used to group this array. If a
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
must map an existing variable name to a :py:class:`Grouper` instance.
Expand All @@ -10366,6 +10365,51 @@ def groupby(
A `DatasetGroupBy` object patterned after `pandas.GroupBy` that can be
iterated over in the form of `(unique_value, grouped_array)` pairs.
Examples
--------
>>> ds = xr.Dataset(
... {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))},
... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))},
... )
Grouping by a single variable is easy
>>> ds.groupby("letters")
<DatasetGroupBy, grouped over 1 grouper(s), 2 groups in total:
'letters': 2 groups with labels 'a', 'b'>
Execute a reduction
>>> ds.groupby("letters").sum()
<xarray.Dataset> Size: 64B
Dimensions: (letters: 2, y: 3)
Coordinates:
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y
Data variables:
foo (letters, y) float64 48B 9.0 11.0 13.0 9.0 11.0 13.0
Grouping by multiple variables
>>> ds.groupby(["letters", "x"])
<DatasetGroupBy, grouped over 2 grouper(s), 8 groups in total:
'letters': 2 groups with labels 'a', 'b'
'x': 4 groups with labels 10, 20, 30, 40>
Use Grouper objects to express more complicated GroupBy operations
>>> from xarray.groupers import BinGrouper, UniqueGrouper
>>>
>>> ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()
<xarray.Dataset> Size: 128B
Dimensions: (y: 3, x_bins: 2, letters: 2)
Coordinates:
* x_bins (x_bins) object 16B (5, 15] (15, 25]
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y
Data variables:
foo (y, x_bins, letters) float64 96B 0.0 nan nan 3.0 ... nan nan 5.0
See Also
--------
:ref:`groupby`
Expand All @@ -10387,31 +10431,12 @@ def groupby(
"""
from xarray.core.groupby import (
DatasetGroupBy,
ResolvedGrouper,
_parse_group_and_groupers,
_validate_groupby_squeeze,
)
from xarray.groupers import UniqueGrouper

_validate_groupby_squeeze(squeeze)

if isinstance(group, Mapping):
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if group is not None:
if groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
)
rgroupers = _parse_group_and_groupers(self, group, groupers)

return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)

Expand Down
53 changes: 49 additions & 4 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Generic, Literal, Union
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -54,7 +54,7 @@

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import GroupIndex, GroupIndices, GroupKey
from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey
from xarray.core.utils import Frozen
from xarray.groupers import EncodedGroups, Grouper

Expand Down Expand Up @@ -319,6 +319,51 @@ def __len__(self) -> int:
return len(self.encoded.full_index)


def _parse_group_and_groupers(
obj: T_Xarray, group: GroupInput, groupers: dict[str, Grouper]
) -> tuple[ResolvedGrouper, ...]:
from xarray.core.dataarray import DataArray
from xarray.core.variable import Variable
from xarray.groupers import UniqueGrouper

if group is not None and groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)

if group is None and not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")

if isinstance(group, np.ndarray | pd.Index):
raise TypeError(
f"`group` must be a DataArray. Received {type(group).__name__!r} instead"
)

if isinstance(group, Mapping):
grouper_mapping = either_dict_or_kwargs(group, groupers, "groupby")
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if isinstance(group, DataArray | Variable):
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, obj),)
else:
if group is not None:
if TYPE_CHECKING:
assert isinstance(group, str | Sequence)
group_iter: Sequence[Hashable] = (
(group,) if isinstance(group, str) else group
)
grouper_mapping = {g: UniqueGrouper() for g in group_iter}
elif groupers:
grouper_mapping = cast("Mapping[Hashable, Grouper]", groupers)

rgroupers = tuple(
ResolvedGrouper(grouper, group, obj)
for group, grouper in grouper_mapping.items()
)
return rgroupers


def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
# While we don't generally check the type of every arg, passing
# multiple dimensions as multiple arguments is common enough, and the
Expand All @@ -327,7 +372,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
# A future version could make squeeze kwarg only, but would face
# backward-compat issues.
if squeeze is not False:
raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.")
raise TypeError(f"`squeeze` must be False, but {squeeze!r} was supplied.")


def _resolve_group(
Expand Down Expand Up @@ -626,7 +671,7 @@ def __repr__(self) -> str:
for grouper in self.groupers:
coord = grouper.unique_coord
labels = ", ".join(format_array_flat(coord, 30).split())
text += f"\n\t{grouper.name!r}: {coord.size} groups with labels {labels}"
text += f"\n {grouper.name!r}: {coord.size} groups with labels {labels}"
return text + ">"

def _iter_grouped(self) -> Iterator[T_Xarray]:
Expand Down
13 changes: 11 additions & 2 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,17 @@
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index, Indexes
from xarray.core.utils import Frozen
from xarray.core.variable import Variable
from xarray.groupers import TimeResampler
from xarray.core.variable import IndexVariable, Variable
from xarray.groupers import Grouper, TimeResampler

GroupInput: TypeAlias = (
str
| DataArray
| IndexVariable
| Sequence[Hashable]
| Mapping[Any, Grouper]
| None
)

try:
from dask.array import Array as DaskArray
Expand Down
Loading

0 comments on commit 0253dc7

Please sign in to comment.