diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5c55ec7b600..d121a9bcd06 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -11,6 +11,7 @@ Callable, Generic, Literal, + overload, TypeVar, Union, cast, @@ -36,7 +37,7 @@ ) from xarray.core.options import _get_keep_attrs from xarray.core.pycompat import integer_types -from xarray.core.types import Dims, QuantileMethods, T_Xarray +from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray from xarray.core.utils import ( either_dict_or_kwargs, hashable, @@ -56,9 +57,7 @@ GroupKey = Any GroupIndex = Union[int, slice, list[int]] - - T_GroupIndicesListInt = list[list[int]] - T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray] + T_GroupIndices = list[GroupIndex] def check_reduce_dims(reduce_dims, dimensions): @@ -99,8 +98,8 @@ def unique_value_groups( return values, groups, inverse -def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndicesListInt: - groups: T_GroupIndicesListInt = [[] for _ in range(N)] +def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices: + groups: T_GroupIndices = [[] for _ in range(N)] for n, g in enumerate(inverse): if g >= 0: groups[g].append(n) @@ -147,7 +146,7 @@ def _is_one_or_none(obj) -> bool: def _consolidate_slices(slices: list[slice]) -> list[slice]: """Consolidate adjacent slices in a list of slices.""" - result = [] + result: list[slice] = [] last_slice = slice(None) for slice_ in slices: if not isinstance(slice_, slice): @@ -191,7 +190,7 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray return newpositions[newpositions != -1] -class _DummyGroup: +class _DummyGroup(Generic[T_Xarray]): """Class for keeping track of grouped dimensions without coordinates. Should not be user visible. @@ -247,18 +246,19 @@ def to_dataarray(self) -> DataArray: ) -T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) +# T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) +T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] def _ensure_1d( group: T_Group, obj: T_Xarray -) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable]]: +) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable],]: # 1D cases: do nothing - from xarray.core.dataarray import DataArray - if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: return group, obj, None, [] + from xarray.core.dataarray import DataArray + if isinstance(group, DataArray): # try to stack the dims of the group into a single dim orig_dims = group.dims @@ -267,7 +267,7 @@ def _ensure_1d( inserted_dims = [dim for dim in group.dims if dim not in group.coords] newgroup = group.stack({stacked_dim: orig_dims}) newobj = obj.stack({stacked_dim: orig_dims}) - return cast(T_Group, newgroup), newobj, stacked_dim, inserted_dims + return newgroup, newobj, stacked_dim, inserted_dims raise TypeError( f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}." @@ -311,25 +311,36 @@ def _apply_loffset( result.index = result.index + loffset -class ResolvedGrouper(ABC): - def __init__(self, grouper: Grouper, group, obj): - self.labels = None - self._group_as_index: pd.Index | None = None +@dataclass +class ResolvedGrouper(ABC, Generic[T_Xarray]): + grouper: Grouper + group: T_Group + obj: T_Xarray + + _group_as_index: pd.Index | None = field(default=None, init=False) + + # Not used here:? + labels: Any | None = field(default=None, init=False) # TODO: Typing? + codes: DataArray = field(init=False) + group_indices: T_GroupIndices = field(init=False) + unique_coord: IndexVariable | _DummyGroup = field(init=False) + full_index: pd.Index = field(init=False) - self.codes: DataArray - self.group_indices: list[int] | list[slice] | list[list[int]] - self.unique_coord: IndexVariable | _DummyGroup - self.full_index: pd.Index + # _ensure_1d: + group1d: T_Group = field(init=False) + stacked_obj: T_Xarray = field(init=False) + stacked_dim: Hashable | None = field(init=False) + inserted_dims: list[Hashable] = field(init=False) - self.grouper = grouper - self.group = _resolve_group(obj, group) + def __post_init__(self) -> None: + self.group: T_Group = _resolve_group(self.obj, self.group) ( self.group1d, self.stacked_obj, self.stacked_dim, self.inserted_dims, - ) = _ensure_1d(self.group, obj) + ) = _ensure_1d(group=self.group, obj=self.obj) @property def name(self) -> Hashable: @@ -340,7 +351,7 @@ def size(self) -> int: return len(self) def __len__(self) -> int: - return len(self.full_index) + return len(self.full_index) # TODO: full_index not def, abstractmethod? @property def dims(self): @@ -364,7 +375,10 @@ def group_as_index(self) -> pd.Index: return self._group_as_index +@dataclass class ResolvedUniqueGrouper(ResolvedGrouper): + grouper: UniqueGrouper + def factorize(self, squeeze) -> None: is_dimension = self.group.dims == (self.group.name,) if is_dimension and self.is_unique_and_monotonic: @@ -407,7 +421,10 @@ def _factorize_dummy(self, squeeze) -> None: self.full_index = IndexVariable(self.name, self.group.values, self.group.attrs) +@dataclass class ResolvedBinGrouper(ResolvedGrouper): + grouper: BinGrouper + def factorize(self, squeeze: bool) -> None: from xarray.core.dataarray import DataArray @@ -438,21 +455,26 @@ def factorize(self, squeeze: bool) -> None: self.group_indices = group_indices +@dataclass class ResolvedTimeResampleGrouper(ResolvedGrouper): - def __init__(self, grouper, group, obj): - from xarray import CFTimeIndex - from xarray.core.resample_cftime import CFTimeGrouper + grouper: TimeResampleGrouper + + def __post_init__(self) -> None: + super().__post_init__() - super().__init__(grouper, group, obj) + from xarray import CFTimeIndex - self._group_as_index = safe_cast_to_index(group) - group_as_index = self._group_as_index + group_as_index = safe_cast_to_index(self.group) + self._group_as_index = group_as_index if not group_as_index.is_monotonic_increasing: # TODO: sort instead of raising an error raise ValueError("index must be monotonic for resampling") + grouper = self.grouper if isinstance(group_as_index, CFTimeIndex): + from xarray.core.resample_cftime import CFTimeGrouper + index_grouper = CFTimeGrouper( freq=grouper.freq, closed=grouper.closed, @@ -501,9 +523,9 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: def factorize(self, squeeze: bool) -> None: self.full_index, first_items, codes = self._get_index_and_items() sbins = first_items.values.astype(np.int64) - self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [ - slice(sbins[-1], None) - ] + self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + self.group_indices += [slice(sbins[-1], None)] + self.unique_coord = IndexVariable( self.group.name, first_items.index, self.group.attrs ) @@ -550,7 +572,7 @@ def _validate_groupby_squeeze(squeeze): raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied") -def _resolve_group(obj, group: T_Group | Hashable) -> T_Group: +def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: from xarray.core.dataarray import DataArray if isinstance(group, (DataArray, IndexVariable)): @@ -625,6 +647,19 @@ class GroupBy(Generic[T_Xarray]): "_codes", ) _obj: T_Xarray + groupers: tuple[ResolvedGrouper] + _squeeze: bool + _restore_coord_dims: bool + + _original_obj: T_Xarray + _original_group: T_Group + _group_indices: T_GroupIndices + _codes: DataArray + _group_dim: Hashable + + _groups: dict[GroupKey, GroupIndex] | None + _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None + _sizes: Frozen[Hashable, int] | None def __init__( self, @@ -647,7 +682,7 @@ def __init__( """ self.groupers = groupers - self._original_obj: T_Xarray = obj + self._original_obj = obj for grouper_ in self.groupers: grouper_.factorize(squeeze) @@ -656,7 +691,7 @@ def __init__( self._original_group = grouper.group # specification for the groupby operation - self._obj: T_Xarray = grouper.stacked_obj + self._obj = grouper.stacked_obj self._restore_coord_dims = restore_coord_dims self._squeeze = squeeze @@ -666,9 +701,9 @@ def __init__( (self._group_dim,) = grouper.group1d.dims # cached attributes - self._groups: dict[GroupKey, slice | int | list[int]] | None = None - self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None - self._sizes: Frozen[Hashable, int] | None = None + self._groups = None + self._dims = None + self._sizes = None @property def sizes(self) -> Frozen[Hashable, int]: