Skip to content

Commit

Permalink
move typed ops
Browse files Browse the repository at this point in the history
  • Loading branch information
headtr1ck committed Oct 14, 2023
1 parent dafd726 commit 657e7af
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 366 deletions.
356 changes: 1 addition & 355 deletions xarray/core/_typed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@
from __future__ import annotations

import operator
from typing import TYPE_CHECKING, Any, Callable, overload
from typing import TYPE_CHECKING, Any, Callable

from xarray.core import nputils, ops
from xarray.core.types import (
DaCompatible,
DsCompatible,
GroupByCompatible,
Self,
T_DataArray,
T_Xarray,
VarCompatible,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -437,358 +435,6 @@ def conjugate(self, *args: Any, **kwargs: Any) -> Self:
conjugate.__doc__ = ops.conjugate.__doc__


class VariableOpsMixin:
__slots__ = ()

def _binary_op(
self, other: VarCompatible, f: Callable, reflexive: bool = False
) -> Self:
raise NotImplementedError

@overload
def __add__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __add__(self, other: VarCompatible) -> Self:
...

def __add__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.add)

@overload
def __sub__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __sub__(self, other: VarCompatible) -> Self:
...

def __sub__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.sub)

@overload
def __mul__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __mul__(self, other: VarCompatible) -> Self:
...

def __mul__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.mul)

@overload
def __pow__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __pow__(self, other: VarCompatible) -> Self:
...

def __pow__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.pow)

@overload
def __truediv__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __truediv__(self, other: VarCompatible) -> Self:
...

def __truediv__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.truediv)

@overload
def __floordiv__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __floordiv__(self, other: VarCompatible) -> Self:
...

def __floordiv__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.floordiv)

@overload
def __mod__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __mod__(self, other: VarCompatible) -> Self:
...

def __mod__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.mod)

@overload
def __and__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __and__(self, other: VarCompatible) -> Self:
...

def __and__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.and_)

@overload
def __xor__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __xor__(self, other: VarCompatible) -> Self:
...

def __xor__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.xor)

@overload
def __or__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __or__(self, other: VarCompatible) -> Self:
...

def __or__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.or_)

@overload
def __lshift__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __lshift__(self, other: VarCompatible) -> Self:
...

def __lshift__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.lshift)

@overload
def __rshift__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __rshift__(self, other: VarCompatible) -> Self:
...

def __rshift__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.rshift)

@overload
def __lt__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __lt__(self, other: VarCompatible) -> Self:
...

def __lt__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.lt)

@overload
def __le__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __le__(self, other: VarCompatible) -> Self:
...

def __le__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.le)

@overload
def __gt__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __gt__(self, other: VarCompatible) -> Self:
...

def __gt__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.gt)

@overload
def __ge__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __ge__(self, other: VarCompatible) -> Self:
...

def __ge__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.ge)

@overload # type:ignore[override]
def __eq__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __eq__(self, other: VarCompatible) -> Self:
...

def __eq__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, nputils.array_eq)

@overload # type:ignore[override]
def __ne__(self, other: T_DataArray) -> T_DataArray:
...

@overload
def __ne__(self, other: VarCompatible) -> Self:
...

def __ne__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, nputils.array_ne)

def __radd__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)

def __rsub__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.sub, reflexive=True)

def __rmul__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)

def __rpow__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.pow, reflexive=True)

def __rtruediv__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.truediv, reflexive=True)

def __rfloordiv__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.floordiv, reflexive=True)

def __rmod__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)

def __rand__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.and_, reflexive=True)

def __rxor__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.xor, reflexive=True)

def __ror__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.or_, reflexive=True)

def _inplace_binary_op(self, other: VarCompatible, f: Callable) -> Self:
raise NotImplementedError

def __iadd__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.iadd)

def __isub__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.isub)

def __imul__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.imul)

def __ipow__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ipow)

def __itruediv__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.itruediv)

def __ifloordiv__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ifloordiv)

def __imod__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.imod)

def __iand__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.iand)

def __ixor__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ixor)

def __ior__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ior)

def __ilshift__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ilshift)

def __irshift__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.irshift)

def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError

def __neg__(self) -> Self:
return self._unary_op(operator.neg)

def __pos__(self) -> Self:
return self._unary_op(operator.pos)

def __abs__(self) -> Self:
return self._unary_op(operator.abs)

def __invert__(self) -> Self:
return self._unary_op(operator.invert)

def round(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.round_, *args, **kwargs)

def argsort(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.argsort, *args, **kwargs)

def conj(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conj, *args, **kwargs)

def conjugate(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conjugate, *args, **kwargs)

__add__.__doc__ = operator.add.__doc__
__sub__.__doc__ = operator.sub.__doc__
__mul__.__doc__ = operator.mul.__doc__
__pow__.__doc__ = operator.pow.__doc__
__truediv__.__doc__ = operator.truediv.__doc__
__floordiv__.__doc__ = operator.floordiv.__doc__
__mod__.__doc__ = operator.mod.__doc__
__and__.__doc__ = operator.and_.__doc__
__xor__.__doc__ = operator.xor.__doc__
__or__.__doc__ = operator.or_.__doc__
__lshift__.__doc__ = operator.lshift.__doc__
__rshift__.__doc__ = operator.rshift.__doc__
__lt__.__doc__ = operator.lt.__doc__
__le__.__doc__ = operator.le.__doc__
__gt__.__doc__ = operator.gt.__doc__
__ge__.__doc__ = operator.ge.__doc__
__eq__.__doc__ = nputils.array_eq.__doc__
__ne__.__doc__ = nputils.array_ne.__doc__
__radd__.__doc__ = operator.add.__doc__
__rsub__.__doc__ = operator.sub.__doc__
__rmul__.__doc__ = operator.mul.__doc__
__rpow__.__doc__ = operator.pow.__doc__
__rtruediv__.__doc__ = operator.truediv.__doc__
__rfloordiv__.__doc__ = operator.floordiv.__doc__
__rmod__.__doc__ = operator.mod.__doc__
__rand__.__doc__ = operator.and_.__doc__
__rxor__.__doc__ = operator.xor.__doc__
__ror__.__doc__ = operator.or_.__doc__
__iadd__.__doc__ = operator.iadd.__doc__
__isub__.__doc__ = operator.isub.__doc__
__imul__.__doc__ = operator.imul.__doc__
__ipow__.__doc__ = operator.ipow.__doc__
__itruediv__.__doc__ = operator.itruediv.__doc__
__ifloordiv__.__doc__ = operator.ifloordiv.__doc__
__imod__.__doc__ = operator.imod.__doc__
__iand__.__doc__ = operator.iand.__doc__
__ixor__.__doc__ = operator.ixor.__doc__
__ior__.__doc__ = operator.ior.__doc__
__ilshift__.__doc__ = operator.ilshift.__doc__
__irshift__.__doc__ = operator.irshift.__doc__
__neg__.__doc__ = operator.neg.__doc__
__pos__.__doc__ = operator.pos.__doc__
__abs__.__doc__ = operator.abs.__doc__
__invert__.__doc__ = operator.invert.__doc__
round.__doc__ = ops.round_.__doc__
argsort.__doc__ = ops.argsort.__doc__
conj.__doc__ = ops.conj.__doc__
conjugate.__doc__ = ops.conjugate.__doc__


class DatasetGroupByOpsMixin:
__slots__ = ()

Expand Down
Loading

0 comments on commit 657e7af

Please sign in to comment.