Skip to content
forked from pydata/xarray

Commit

Permalink
Add typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Apr 3, 2023
1 parent e07ae31 commit d557418
Showing 1 changed file with 75 additions and 40 deletions.
115 changes: 75 additions & 40 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Callable,
Generic,
Literal,
overload,
TypeVar,
Union,
cast,
Expand All @@ -36,7 +37,7 @@
)
from xarray.core.options import _get_keep_attrs
from xarray.core.pycompat import integer_types
from xarray.core.types import Dims, QuantileMethods, T_Xarray
from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray
from xarray.core.utils import (
either_dict_or_kwargs,
hashable,
Expand All @@ -56,9 +57,7 @@

GroupKey = Any
GroupIndex = Union[int, slice, list[int]]

T_GroupIndicesListInt = list[list[int]]
T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray]
T_GroupIndices = list[GroupIndex]


def check_reduce_dims(reduce_dims, dimensions):
Expand Down Expand Up @@ -99,8 +98,8 @@ def unique_value_groups(
return values, groups, inverse


def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndicesListInt:
groups: T_GroupIndicesListInt = [[] for _ in range(N)]
def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices:
groups: T_GroupIndices = [[] for _ in range(N)]
for n, g in enumerate(inverse):
if g >= 0:
groups[g].append(n)
Expand Down Expand Up @@ -147,7 +146,7 @@ def _is_one_or_none(obj) -> bool:

def _consolidate_slices(slices: list[slice]) -> list[slice]:
"""Consolidate adjacent slices in a list of slices."""
result = []
result: list[slice] = []
last_slice = slice(None)
for slice_ in slices:
if not isinstance(slice_, slice):
Expand Down Expand Up @@ -191,7 +190,7 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray
return newpositions[newpositions != -1]


class _DummyGroup:
class _DummyGroup(Generic[T_Xarray]):
"""Class for keeping track of grouped dimensions without coordinates.
Should not be user visible.
Expand Down Expand Up @@ -247,18 +246,19 @@ def to_dataarray(self) -> DataArray:
)


T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup])
# T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup])
T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup]


def _ensure_1d(
group: T_Group, obj: T_Xarray
) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable]]:
) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable],]:
# 1D cases: do nothing
from xarray.core.dataarray import DataArray

if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1:
return group, obj, None, []

from xarray.core.dataarray import DataArray

if isinstance(group, DataArray):
# try to stack the dims of the group into a single dim
orig_dims = group.dims
Expand All @@ -267,7 +267,7 @@ def _ensure_1d(
inserted_dims = [dim for dim in group.dims if dim not in group.coords]
newgroup = group.stack({stacked_dim: orig_dims})
newobj = obj.stack({stacked_dim: orig_dims})
return cast(T_Group, newgroup), newobj, stacked_dim, inserted_dims
return newgroup, newobj, stacked_dim, inserted_dims

raise TypeError(
f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}."
Expand Down Expand Up @@ -311,25 +311,36 @@ def _apply_loffset(
result.index = result.index + loffset


class ResolvedGrouper(ABC):
def __init__(self, grouper: Grouper, group, obj):
self.labels = None
self._group_as_index: pd.Index | None = None
@dataclass
class ResolvedGrouper(ABC, Generic[T_Xarray]):
grouper: Grouper
group: T_Group
obj: T_Xarray

_group_as_index: pd.Index | None = field(default=None, init=False)

# Not used here:?
labels: Any | None = field(default=None, init=False) # TODO: Typing?
codes: DataArray = field(init=False)
group_indices: T_GroupIndices = field(init=False)
unique_coord: IndexVariable | _DummyGroup = field(init=False)
full_index: pd.Index = field(init=False)

self.codes: DataArray
self.group_indices: list[int] | list[slice] | list[list[int]]
self.unique_coord: IndexVariable | _DummyGroup
self.full_index: pd.Index
# _ensure_1d:
group1d: T_Group = field(init=False)
stacked_obj: T_Xarray = field(init=False)
stacked_dim: Hashable | None = field(init=False)
inserted_dims: list[Hashable] = field(init=False)

self.grouper = grouper
self.group = _resolve_group(obj, group)
def __post_init__(self) -> None:
self.group: T_Group = _resolve_group(self.obj, self.group)

(
self.group1d,
self.stacked_obj,
self.stacked_dim,
self.inserted_dims,
) = _ensure_1d(self.group, obj)
) = _ensure_1d(group=self.group, obj=self.obj)

@property
def name(self) -> Hashable:
Expand All @@ -340,7 +351,7 @@ def size(self) -> int:
return len(self)

def __len__(self) -> int:
return len(self.full_index)
return len(self.full_index) # TODO: full_index not def, abstractmethod?

@property
def dims(self):
Expand All @@ -364,7 +375,10 @@ def group_as_index(self) -> pd.Index:
return self._group_as_index


@dataclass
class ResolvedUniqueGrouper(ResolvedGrouper):
grouper: UniqueGrouper

def factorize(self, squeeze) -> None:
is_dimension = self.group.dims == (self.group.name,)
if is_dimension and self.is_unique_and_monotonic:
Expand Down Expand Up @@ -407,7 +421,10 @@ def _factorize_dummy(self, squeeze) -> None:
self.full_index = IndexVariable(self.name, self.group.values, self.group.attrs)


@dataclass
class ResolvedBinGrouper(ResolvedGrouper):
grouper: BinGrouper

def factorize(self, squeeze: bool) -> None:
from xarray.core.dataarray import DataArray

Expand Down Expand Up @@ -438,21 +455,26 @@ def factorize(self, squeeze: bool) -> None:
self.group_indices = group_indices


@dataclass
class ResolvedTimeResampleGrouper(ResolvedGrouper):
def __init__(self, grouper, group, obj):
from xarray import CFTimeIndex
from xarray.core.resample_cftime import CFTimeGrouper
grouper: TimeResampleGrouper

def __post_init__(self) -> None:
super().__post_init__()

super().__init__(grouper, group, obj)
from xarray import CFTimeIndex

self._group_as_index = safe_cast_to_index(group)
group_as_index = self._group_as_index
group_as_index = safe_cast_to_index(self.group)
self._group_as_index = group_as_index

if not group_as_index.is_monotonic_increasing:
# TODO: sort instead of raising an error
raise ValueError("index must be monotonic for resampling")

grouper = self.grouper
if isinstance(group_as_index, CFTimeIndex):
from xarray.core.resample_cftime import CFTimeGrouper

index_grouper = CFTimeGrouper(
freq=grouper.freq,
closed=grouper.closed,
Expand Down Expand Up @@ -501,9 +523,9 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
def factorize(self, squeeze: bool) -> None:
self.full_index, first_items, codes = self._get_index_and_items()
sbins = first_items.values.astype(np.int64)
self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [
slice(sbins[-1], None)
]
self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])]
self.group_indices += [slice(sbins[-1], None)]

self.unique_coord = IndexVariable(
self.group.name, first_items.index, self.group.attrs
)
Expand Down Expand Up @@ -550,7 +572,7 @@ def _validate_groupby_squeeze(squeeze):
raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied")


def _resolve_group(obj, group: T_Group | Hashable) -> T_Group:
def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group:
from xarray.core.dataarray import DataArray

if isinstance(group, (DataArray, IndexVariable)):
Expand Down Expand Up @@ -625,6 +647,19 @@ class GroupBy(Generic[T_Xarray]):
"_codes",
)
_obj: T_Xarray
groupers: tuple[ResolvedGrouper]
_squeeze: bool
_restore_coord_dims: bool

_original_obj: T_Xarray
_original_group: T_Group
_group_indices: T_GroupIndices
_codes: DataArray
_group_dim: Hashable

_groups: dict[GroupKey, GroupIndex] | None
_dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None
_sizes: Frozen[Hashable, int] | None

def __init__(
self,
Expand All @@ -647,7 +682,7 @@ def __init__(
"""
self.groupers = groupers

self._original_obj: T_Xarray = obj
self._original_obj = obj

for grouper_ in self.groupers:
grouper_.factorize(squeeze)
Expand All @@ -656,7 +691,7 @@ def __init__(
self._original_group = grouper.group

# specification for the groupby operation
self._obj: T_Xarray = grouper.stacked_obj
self._obj = grouper.stacked_obj
self._restore_coord_dims = restore_coord_dims
self._squeeze = squeeze

Expand All @@ -666,9 +701,9 @@ def __init__(

(self._group_dim,) = grouper.group1d.dims
# cached attributes
self._groups: dict[GroupKey, slice | int | list[int]] | None = None
self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None
self._sizes: Frozen[Hashable, int] | None = None
self._groups = None
self._dims = None
self._sizes = None

@property
def sizes(self) -> Frozen[Hashable, int]:
Expand Down

0 comments on commit d557418

Please sign in to comment.