Skip to content

Commit

Permalink
update typed ops
Browse files Browse the repository at this point in the history
  • Loading branch information
headtr1ck committed Jul 12, 2024
1 parent 69b93dd commit ab4dbf5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 37 deletions.
69 changes: 36 additions & 33 deletions xarray/core/_typed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import operator
from typing import Any, Callable, overload
from typing import TYPE_CHECKING, Any, Callable, overload

from xarray.core import nputils, ops
from xarray.core.types import (
Expand All @@ -15,8 +15,11 @@
T_Xarray,
VarCompatible,
)
from xarray.core.types import T_DataArray as T_DA
from xarray.core.types import T_Dataset as T_DS

if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import T_DataArray as T_DA


class DatasetOpsMixin:
Expand Down Expand Up @@ -767,96 +770,96 @@ class DatasetGroupByOpsMixin:
__slots__ = ()

def _binary_op(
self, other: T_DS | T_DA, f: Callable, reflexive: bool = False
) -> T_DS:
self, other: Dataset | DataArray, f: Callable, reflexive: bool = False
) -> Dataset:
raise NotImplementedError

def __add__(self, other: T_DS | T_DA) -> T_DS:
def __add__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.add)

def __sub__(self, other: T_DS | T_DA) -> T_DS:
def __sub__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.sub)

def __mul__(self, other: T_DS | T_DA) -> T_DS:
def __mul__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.mul)

def __pow__(self, other: T_DS | T_DA) -> T_DS:
def __pow__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.pow)

def __truediv__(self, other: T_DS | T_DA) -> T_DS:
def __truediv__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.truediv)

def __floordiv__(self, other: T_DS | T_DA) -> T_DS:
def __floordiv__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.floordiv)

def __mod__(self, other: T_DS | T_DA) -> T_DS:
def __mod__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.mod)

def __and__(self, other: T_DS | T_DA) -> T_DS:
def __and__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.and_)

def __xor__(self, other: T_DS | T_DA) -> T_DS:
def __xor__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.xor)

def __or__(self, other: T_DS | T_DA) -> T_DS:
def __or__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.or_)

def __lshift__(self, other: T_DS | T_DA) -> T_DS:
def __lshift__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.lshift)

def __rshift__(self, other: T_DS | T_DA) -> T_DS:
def __rshift__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.rshift)

def __lt__(self, other: T_DS | T_DA) -> T_DS:
def __lt__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.lt)

def __le__(self, other: T_DS | T_DA) -> T_DS:
def __le__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.le)

def __gt__(self, other: T_DS | T_DA) -> T_DS:
def __gt__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.gt)

def __ge__(self, other: T_DS | T_DA) -> T_DS:
def __ge__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.ge)

def __eq__(self, other: T_DS | T_DA) -> T_DS: # type:ignore[override]
def __eq__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)

def __ne__(self, other: T_DS | T_DA) -> T_DS: # type:ignore[override]
def __ne__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)

# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]

def __radd__(self, other: T_DS | T_DA) -> T_DS:
def __radd__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.add, reflexive=True)

def __rsub__(self, other: T_DS | T_DA) -> T_DS:
def __rsub__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.sub, reflexive=True)

def __rmul__(self, other: T_DS | T_DA) -> T_DS:
def __rmul__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.mul, reflexive=True)

def __rpow__(self, other: T_DS | T_DA) -> T_DS:
def __rpow__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.pow, reflexive=True)

def __rtruediv__(self, other: T_DS | T_DA) -> T_DS:
def __rtruediv__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.truediv, reflexive=True)

def __rfloordiv__(self, other: T_DS | T_DA) -> T_DS:
def __rfloordiv__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.floordiv, reflexive=True)

def __rmod__(self, other: T_DS | T_DA) -> T_DS:
def __rmod__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.mod, reflexive=True)

def __rand__(self, other: T_DS | T_DA) -> T_DS:
def __rand__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.and_, reflexive=True)

def __rxor__(self, other: T_DS | T_DA) -> T_DS:
def __rxor__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.xor, reflexive=True)

def __ror__(self, other: T_DS | T_DA) -> T_DS:
def __ror__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.or_, reflexive=True)

__add__.__doc__ = operator.add.__doc__
Expand Down
11 changes: 7 additions & 4 deletions xarray/util/generate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def unops() -> list[OpsType]:
+ unops()
)
ops_info["DatasetGroupByOpsMixin"] = binops(
other_type="T_DS | T_DA", return_type="T_DS"
other_type="Dataset | DataArray", return_type="Dataset"
)
ops_info["DataArrayGroupByOpsMixin"] = binops(
other_type="T_Xarray", return_type="T_Xarray"
Expand All @@ -245,7 +245,7 @@ def unops() -> list[OpsType]:
from __future__ import annotations
import operator
from typing import Any, Callable, overload
from typing import TYPE_CHECKING, Any, Callable, overload
from xarray.core import nputils, ops
from xarray.core.types import (
Expand All @@ -255,8 +255,11 @@ def unops() -> list[OpsType]:
T_Xarray,
VarCompatible,
)
from xarray.core.types import T_DataArray as T_DA
from xarray.core.types import T_Dataset as T_DS'''
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import T_DataArray as T_DA'''


CLASS_PREAMBLE = """{newline}
Expand Down

0 comments on commit ab4dbf5

Please sign in to comment.