From 565b23b95beda893e0d66d1e2c6da49984bb0925 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 25 Sep 2023 06:43:54 +0200 Subject: [PATCH] Rewrite typed_ops (#8204) * rewrite typed_ops * improved typing of rolling instance attrs * add typed_ops xr.Variable tests * add typed_ops test * add minor typehint * adjust to numpy 1.24 * add groupby ops type tests * remove wrong types from ops * fix Dataset not being part of SupportsArray Protocol * ignore mypy align complaint * add reasons for type ignores in test * add overloads for variable typed ops * move tests to their own module * add entry to whats-new --- doc/whats-new.rst | 3 + xarray/core/_typed_ops.py | 591 ++++++++++++++++--------- xarray/core/_typed_ops.pyi | 782 --------------------------------- xarray/core/dataarray.py | 15 +- xarray/core/dataset.py | 21 +- xarray/core/rolling.py | 28 +- xarray/core/types.py | 7 +- xarray/core/weighted.py | 1 + xarray/tests/test_groupby.py | 4 +- xarray/tests/test_typed_ops.py | 246 +++++++++++ xarray/util/generate_ops.py | 286 ++++++------ 11 files changed, 827 insertions(+), 1157 deletions(-) delete mode 100644 xarray/core/_typed_ops.pyi create mode 100644 xarray/tests/test_typed_ops.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9a21bcb7ab9..4307c2829ca 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -92,6 +92,9 @@ Bug fixes By `Maximilian Roos `_. - In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`). By `Kai Mühlbauer `_. +- Static typing of dunder ops methods (like :py:meth:`DataArray.__eq__`) has been fixed. + Remaining issues are upstream problems (:issue:`7780`, :pull:`8204`). + By `Michael Niklas `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py index d3a783be45d..330d13bb217 100644 --- a/xarray/core/_typed_ops.py +++ b/xarray/core/_typed_ops.py @@ -1,165 +1,182 @@ """Mixin classes with arithmetic operators.""" # This file was generated using xarray.util.generate_ops. Do not edit manually. +from __future__ import annotations + import operator +from typing import TYPE_CHECKING, Any, Callable, NoReturn, overload from xarray.core import nputils, ops +from xarray.core.types import ( + DaCompatible, + DsCompatible, + GroupByCompatible, + Self, + T_DataArray, + T_Xarray, + VarCompatible, +) + +if TYPE_CHECKING: + from xarray.core.dataset import Dataset class DatasetOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: DsCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DsCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -215,157 +232,159 @@ def conjugate(self, *args, **kwargs): class DataArrayOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -421,157 +440,303 @@ def conjugate(self, *args, **kwargs): class VariableOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: VarCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + @overload + def __add__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __add__(self, other: VarCompatible) -> Self: + ... + + def __add__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + @overload + def __sub__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __sub__(self, other: VarCompatible) -> Self: + ... + + def __sub__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + @overload + def __mul__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __mul__(self, other: VarCompatible) -> Self: + ... + + def __mul__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + @overload + def __pow__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __pow__(self, other: VarCompatible) -> Self: + ... + + def __pow__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + @overload + def __truediv__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __truediv__(self, other: VarCompatible) -> Self: + ... + + def __truediv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + @overload + def __floordiv__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __floordiv__(self, other: VarCompatible) -> Self: + ... + + def __floordiv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + @overload + def __mod__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __mod__(self, other: VarCompatible) -> Self: + ... + + def __mod__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + @overload + def __and__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __and__(self, other: VarCompatible) -> Self: + ... + + def __and__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + @overload + def __xor__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __xor__(self, other: VarCompatible) -> Self: + ... + + def __xor__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + @overload + def __or__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __or__(self, other: VarCompatible) -> Self: + ... + + def __or__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + @overload + def __lshift__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __lshift__(self, other: VarCompatible) -> Self: + ... + + def __lshift__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + @overload + def __rshift__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __rshift__(self, other: VarCompatible) -> Self: + ... + + def __rshift__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + @overload + def __lt__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __lt__(self, other: VarCompatible) -> Self: + ... + + def __lt__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + @overload + def __le__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __le__(self, other: VarCompatible) -> Self: + ... + + def __le__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + @overload + def __gt__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __gt__(self, other: VarCompatible) -> Self: + ... + + def __gt__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + @overload + def __ge__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __ge__(self, other: VarCompatible) -> Self: + ... + + def __ge__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + @overload # type:ignore[override] + def __eq__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __eq__(self, other: VarCompatible) -> Self: + ... + + def __eq__(self, other: VarCompatible) -> Self: return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + @overload # type:ignore[override] + def __ne__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __ne__(self, other: VarCompatible) -> Self: + ... + + def __ne__(self, other: VarCompatible) -> Self: return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: VarCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -627,91 +792,93 @@ def conjugate(self, *args, **kwargs): class DatasetGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: GroupByCompatible, f: Callable, reflexive: bool = False + ) -> Dataset: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ @@ -747,91 +914,93 @@ def __ror__(self, other): class DataArrayGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: T_Xarray, f: Callable, reflexive: bool = False + ) -> T_Xarray: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi deleted file mode 100644 index 9e2ba2d3a06..00000000000 --- a/xarray/core/_typed_ops.pyi +++ /dev/null @@ -1,782 +0,0 @@ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload - -import numpy as np -from numpy.typing import ArrayLike - -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( - DaCompatible, - DsCompatible, - GroupByIncompatible, - ScalarOrArray, - VarCompatible, -) -from .variable import Variable - -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -class DatasetOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - def __add__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __sub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __pow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __truediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __floordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __lshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ge__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __ne__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __radd__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rsub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rpow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rtruediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rfloordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rand__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rxor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ror__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Dataset) -> T_Dataset: ... - def __pos__(self: T_Dataset) -> T_Dataset: ... - def __abs__(self: T_Dataset) -> T_Dataset: ... - def __invert__(self: T_Dataset) -> T_Dataset: ... - def round(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def argsort(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conj(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conjugate(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - -class DataArrayOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __add__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __sub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __pow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __truediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __floordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __and__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __xor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __lshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __lt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __le__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __gt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ge__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __eq__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ne__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __radd__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rsub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rpow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rtruediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rfloordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rand__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rxor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ror__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_DataArray) -> T_DataArray: ... - def __pos__(self: T_DataArray) -> T_DataArray: ... - def __abs__(self: T_DataArray) -> T_DataArray: ... - def __invert__(self: T_DataArray) -> T_DataArray: ... - def round(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def argsort(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conj(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conjugate(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - -class VariableOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Variable) -> T_Variable: ... - def __pos__(self: T_Variable) -> T_Variable: ... - def __abs__(self: T_Variable) -> T_Variable: ... - def __invert__(self: T_Variable) -> T_Variable: ... - def round(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def argsort(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conj(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conjugate(self: T_Variable, *args, **kwargs) -> T_Variable: ... - -class DatasetGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DataArray") -> "Dataset": ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DataArray") -> "Dataset": ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DataArray") -> "Dataset": ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: "DataArray") -> "Dataset": ... - @overload - def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DataArray") -> "Dataset": ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DataArray") -> "Dataset": ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DataArray") -> "Dataset": ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... - -class DataArrayGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 724a5fc2580..0b9786dc2b7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4618,25 +4618,22 @@ def _unary_op(self, f: Callable, *args, **kwargs) -> Self: return da def _binary_op( - self, - other: T_Xarray, - f: Callable, - reflexive: bool = False, - ) -> T_Xarray: + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: from xarray.core.groupby import GroupBy if isinstance(other, (Dataset, GroupBy)): return NotImplemented if isinstance(other, DataArray): align_type = OPTIONS["arithmetic_join"] - self, other = align(self, other, join=align_type, copy=False) - other_variable = getattr(other, "variable", other) + self, other = align(self, other, join=align_type, copy=False) # type: ignore[type-var,assignment] + other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other) other_coords = getattr(other, "coords", None) variable = ( - f(self.variable, other_variable) + f(self.variable, other_variable_or_arraylike) if not reflexive - else f(other_variable, self.variable) + else f(other_variable_or_arraylike, self.variable) ) coords, indexes = self.coords._merge_raw(other_coords, reflexive) name = self._result_name(other) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 44016e87306..d24a62414ea 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1475,13 +1475,20 @@ def __bool__(self) -> bool: def __iter__(self) -> Iterator[Hashable]: return iter(self.data_vars) - def __array__(self, dtype=None): - raise TypeError( - "cannot directly convert an xarray.Dataset into a " - "numpy array. Instead, create an xarray.DataArray " - "first, either with indexing on the Dataset or by " - "invoking the `to_array()` method." - ) + if TYPE_CHECKING: + # needed because __getattr__ is returning Any and otherwise + # this class counts as part of the SupportsArray Protocol + __array__ = None + + else: + + def __array__(self, dtype=None): + raise TypeError( + "cannot directly convert an xarray.Dataset into a " + "numpy array. Instead, create an xarray.DataArray " + "first, either with indexing on the Dataset or by " + "invoking the `to_array()` method." + ) @property def nbytes(self) -> int: diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index c6911cbe65b..b85092982e3 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -61,6 +61,11 @@ class Rolling(Generic[T_Xarray]): __slots__ = ("obj", "window", "min_periods", "center", "dim") _attributes = ("window", "min_periods", "center", "dim") + dim: list[Hashable] + window: list[int] + center: list[bool] + obj: T_Xarray + min_periods: int def __init__( self, @@ -91,8 +96,8 @@ def __init__( ------- rolling : type of input argument """ - self.dim: list[Hashable] = [] - self.window: list[int] = [] + self.dim = [] + self.window = [] for d, w in windows.items(): self.dim.append(d) if w <= 0: @@ -100,7 +105,7 @@ def __init__( self.window.append(w) self.center = self._mapping_to_list(center, default=False) - self.obj: T_Xarray = obj + self.obj = obj missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims) if missing_dims: @@ -814,6 +819,10 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): ) _attributes = ("windows", "side", "trim_excess") obj: T_Xarray + windows: Mapping[Hashable, int] + side: SideOptions | Mapping[Hashable, SideOptions] + boundary: CoarsenBoundaryOptions + coord_func: Mapping[Hashable, str | Callable] def __init__( self, @@ -855,12 +864,15 @@ def __init__( f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " f"dimensions {tuple(self.obj.dims)}" ) - if not utils.is_dict_like(coord_func): - coord_func = {d: coord_func for d in self.obj.dims} # type: ignore[misc] + + if utils.is_dict_like(coord_func): + coord_func_map = coord_func + else: + coord_func_map = {d: coord_func for d in self.obj.dims} for c in self.obj.coords: - if c not in coord_func: - coord_func[c] = duck_array_ops.mean # type: ignore[index] - self.coord_func: Mapping[Hashable, str | Callable] = coord_func + if c not in coord_func_map: + coord_func_map[c] = duck_array_ops.mean # type: ignore[index] + self.coord_func = coord_func_map def _get_keep_attrs(self, keep_attrs): if keep_attrs is None: diff --git a/xarray/core/types.py b/xarray/core/types.py index 6b6f9300631..073121b13b1 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -38,7 +38,6 @@ from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.groupby import DataArrayGroupBy, GroupBy from xarray.core.indexes import Index, Indexes from xarray.core.utils import Frozen from xarray.core.variable import Variable @@ -176,10 +175,10 @@ def copy( T_DuckArray = TypeVar("T_DuckArray", bound=Any) ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] -DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"] -DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] VarCompatible = Union["Variable", "ScalarOrArray"] -GroupByIncompatible = Union["Variable", "GroupBy"] +DaCompatible = Union["DataArray", "VarCompatible"] +DsCompatible = Union["Dataset", "DaCompatible"] +GroupByCompatible = Union["Dataset", "DataArray"] Dims = Union[str, Iterable[Hashable], "ellipsis", None] OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None] diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 82ffe684ec7..b1ea1ee625c 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -324,6 +324,7 @@ def _weighted_quantile( def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray: """Return the interpolation parameter.""" # Note that options are not yet exposed in the public API. + h: np.ndarray if method == "linear": h = (n - 1) * q + 1 elif method == "interpolated_inverted_cdf": diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e143e2b8e03..320ba999318 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -810,9 +810,9 @@ def test_groupby_math_more() -> None: with pytest.raises(TypeError, match=r"only support binary ops"): grouped + 1 # type: ignore[operator] with pytest.raises(TypeError, match=r"only support binary ops"): - grouped + grouped + grouped + grouped # type: ignore[operator] with pytest.raises(TypeError, match=r"in-place operations"): - ds += grouped + ds += grouped # type: ignore[arg-type] ds = Dataset( { diff --git a/xarray/tests/test_typed_ops.py b/xarray/tests/test_typed_ops.py new file mode 100644 index 00000000000..1d4ef89ae29 --- /dev/null +++ b/xarray/tests/test_typed_ops.py @@ -0,0 +1,246 @@ +import numpy as np + +from xarray import DataArray, Dataset, Variable + + +def test_variable_typed_ops() -> None: + """Tests for type checking of typed_ops on Variable""" + + var = Variable(dims=["t"], data=[1, 2, 3]) + + def _test(var: Variable) -> None: + # mypy checks the input type + assert isinstance(var, Variable) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + + # __add__ as an example of binary ops + _test(var + _int) + _test(var + _list) + _test(var + _ndarray) + _test(var + var) + + # __radd__ as an example of reflexive binary ops + _test(_int + var) + _test(_list + var) + _test(_ndarray + var) # type: ignore[arg-type] # numpy problem + + # __eq__ as an example of cmp ops + _test(var == _int) + _test(var == _list) + _test(var == _ndarray) + _test(_int == var) # type: ignore[arg-type] # typeshed problem + _test(_list == var) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == var) + + # __lt__ as another example of cmp ops + _test(var < _int) + _test(var < _list) + _test(var < _ndarray) + _test(_int > var) + _test(_list > var) + _test(_ndarray > var) # type: ignore[arg-type] # numpy problem + + # __iadd__ as an example of inplace binary ops + var += _int + var += _list + var += _ndarray + + # __neg__ as an example of unary ops + _test(-var) + + +def test_dataarray_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArray""" + + da = DataArray([1, 2, 3], dims=["t"]) + + def _test(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + + # __add__ as an example of binary ops + _test(da + _int) + _test(da + _list) + _test(da + _ndarray) + _test(da + _var) + _test(da + da) + + # __radd__ as an example of reflexive binary ops + _test(_int + da) + _test(_list + da) + _test(_ndarray + da) # type: ignore[arg-type] # numpy problem + _test(_var + da) + + # __eq__ as an example of cmp ops + _test(da == _int) + _test(da == _list) + _test(da == _ndarray) + _test(da == _var) + _test(_int == da) # type: ignore[arg-type] # typeshed problem + _test(_list == da) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == da) + _test(_var == da) + + # __lt__ as another example of cmp ops + _test(da < _int) + _test(da < _list) + _test(da < _ndarray) + _test(da < _var) + _test(_int > da) + _test(_list > da) + _test(_ndarray > da) # type: ignore[arg-type] # numpy problem + _test(_var > da) + + # __iadd__ as an example of inplace binary ops + da += _int + da += _list + da += _ndarray + da += _var + + # __neg__ as an example of unary ops + _test(-da) + + +def test_dataset_typed_ops() -> None: + """Tests for type checking of typed_ops on Dataset""" + + ds = Dataset({"a": ("t", [1, 2, 3])}) + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + _da = DataArray([1, 2, 3], dims=["t"]) + + # __add__ as an example of binary ops + _test(ds + _int) + _test(ds + _list) + _test(ds + _ndarray) + _test(ds + _var) + _test(ds + _da) + _test(ds + ds) + + # __radd__ as an example of reflexive binary ops + _test(_int + ds) + _test(_list + ds) + _test(_ndarray + ds) + _test(_var + ds) + _test(_da + ds) + + # __eq__ as an example of cmp ops + _test(ds == _int) + _test(ds == _list) + _test(ds == _ndarray) + _test(ds == _var) + _test(ds == _da) + _test(_int == ds) # type: ignore[arg-type] # typeshed problem + _test(_list == ds) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == ds) + _test(_var == ds) + _test(_da == ds) + + # __lt__ as another example of cmp ops + _test(ds < _int) + _test(ds < _list) + _test(ds < _ndarray) + _test(ds < _var) + _test(ds < _da) + _test(_int > ds) + _test(_list > ds) + _test(_ndarray > ds) # type: ignore[arg-type] # numpy problem + _test(_var > ds) + _test(_da > ds) + + # __iadd__ as an example of inplace binary ops + ds += _int + ds += _list + ds += _ndarray + ds += _var + ds += _da + + # __neg__ as an example of unary ops + _test(-ds) + + +def test_dataarray_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArrayGroupBy""" + + da = DataArray([1, 2, 3], coords={"x": ("t", [1, 2, 2])}, dims=["t"]) + grp = da.groupby("x") + + def _testda(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + def _testds(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _testda(grp + _da) + _testds(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _testda(_da + grp) + _testds(_ds + grp) + + # __eq__ as an example of cmp ops + _testda(grp == _da) + _testda(_da == grp) + _testds(grp == _ds) + _testds(_ds == grp) + + # __lt__ as another example of cmp ops + _testda(grp < _da) + _testda(_da > grp) + _testds(grp < _ds) + _testds(_ds > grp) + + +def test_dataset_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DatasetGroupBy""" + + ds = Dataset({"a": ("t", [1, 2, 3])}, coords={"x": ("t", [1, 2, 2])}) + grp = ds.groupby("x") + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _test(grp + _da) + _test(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _test(_da + grp) + _test(_ds + grp) + + # __eq__ as an example of cmp ops + _test(grp == _da) + _test(_da == grp) + _test(grp == _ds) + _test(_ds == grp) + + # __lt__ as another example of cmp ops + _test(grp < _da) + _test(_da > grp) + _test(grp < _ds) + _test(_ds > grp) diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index cf0673e7cca..632ca06d295 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -3,14 +3,16 @@ For internal xarray development use only. Usage: - python xarray/util/generate_ops.py --module > xarray/core/_typed_ops.py - python xarray/util/generate_ops.py --stubs > xarray/core/_typed_ops.pyi + python xarray/util/generate_ops.py > xarray/core/_typed_ops.py """ # Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some # background to some of the design choices made here. -import sys +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from typing import Optional BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne")) BINOPS_CMP = ( @@ -74,155 +76,178 @@ ("conjugate", "ops.conjugate"), ) + +required_method_binary = """ + def _binary_op( + self, other: {other_type}, f: Callable, reflexive: bool = False + ) -> {return_type}: + raise NotImplementedError""" template_binop = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} return self._binary_op(other, {func})""" +template_binop_overload = """ + @overload{overload_type_ignore} + def {method}(self, other: {overload_type}) -> NoReturn: + ... + + @overload + def {method}(self, other: {other_type}) -> {return_type}: + ... +""" template_reflexive = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}: return self._binary_op(other, {func}, reflexive=True)""" + +required_method_inplace = """ + def _inplace_binary_op(self, other: {other_type}, f: Callable) -> Self: + raise NotImplementedError""" template_inplace = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> Self:{type_ignore} return self._inplace_binary_op(other, {func})""" + +required_method_unary = """ + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError""" template_unary = """ - def {method}(self): + def {method}(self) -> Self: return self._unary_op({func})""" template_other_unary = """ - def {method}(self, *args, **kwargs): + def {method}(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op({func}, *args, **kwargs)""" -required_method_unary = """ - def _unary_op(self, f, *args, **kwargs): - raise NotImplementedError""" -required_method_binary = """ - def _binary_op(self, other, f, reflexive=False): - raise NotImplementedError""" -required_method_inplace = """ - def _inplace_binary_op(self, other, f): - raise NotImplementedError""" # For some methods we override return type `bool` defined by base class `object`. -OVERRIDE_TYPESHED = {"override": " # type: ignore[override]"} -NO_OVERRIDE = {"override": ""} - -# Note: in some of the overloads below the return value in reality is NotImplemented, -# which cannot accurately be expressed with type hints,e.g. Literal[NotImplemented] -# or type(NotImplemented) are not allowed and NoReturn has a different meaning. -# In such cases we are lending the type checkers a hand by specifying the return type -# of the corresponding reflexive method on `other` which will be called instead. -stub_ds = """\ - def {method}(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...{override}""" -stub_da = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def {method}(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...""" -stub_var = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self: T_Variable, other: VarCompatible) -> T_Variable: ...""" -stub_dsgb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DataArray") -> "Dataset": ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_dagb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_unary = """\ - def {method}(self: {self_type}) -> {self_type}: ...""" -stub_other_unary = """\ - def {method}(self: {self_type}, *args, **kwargs) -> {self_type}: ...""" -stub_required_unary = """\ - def _unary_op(self, f, *args, **kwargs): ...""" -stub_required_binary = """\ - def _binary_op(self, other, f, reflexive=...): ...""" -stub_required_inplace = """\ - def _inplace_binary_op(self, other, f): ...""" - - -def unops(self_type): - extra_context = {"self_type": self_type} +# We need to add "# type: ignore[override]" +# Keep an eye out for: +# https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240 +# The type ignores might not be neccesary anymore at some point. +# +# We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray +# In reality this returns NotImplementes, but this is not a valid type in python 3.9. +# Therefore, we use NoReturn which mypy seems to recognise! +# TODO: change once python 3.10 is the minimum. +# +# Mypy seems to require that __iadd__ and __add__ have the same signature. +# This requires some extra type: ignores[misc] in the inplace methods :/ + + +def _type_ignore(ignore: str) -> str: + return f" # type:ignore[{ignore}]" if ignore else "" + + +FuncType = Sequence[tuple[Optional[str], Optional[str]]] +OpsType = tuple[FuncType, str, dict[str, str]] + + +def binops( + other_type: str, return_type: str = "Self", type_ignore_eq: str = "override" +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} return [ - ([(None, None)], required_method_unary, stub_required_unary, {}), - (UNARY_OPS, template_unary, stub_unary, extra_context), - (OTHER_UNARY_METHODS, template_other_unary, stub_other_unary, extra_context), + ([(None, None)], required_method_binary, extras), + (BINOPS_NUM + BINOPS_CMP, template_binop, extras | {"type_ignore": ""}), + ( + BINOPS_EQNE, + template_binop, + extras | {"type_ignore": _type_ignore(type_ignore_eq)}, + ), + (BINOPS_REFLEXIVE, template_reflexive, extras), ] -def binops(stub=""): +def binops_overload( + other_type: str, + overload_type: str, + return_type: str = "Self", + type_ignore_eq: str = "override", +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} return [ - ([(None, None)], required_method_binary, stub_required_binary, {}), - (BINOPS_NUM + BINOPS_CMP, template_binop, stub, NO_OVERRIDE), - (BINOPS_EQNE, template_binop, stub, OVERRIDE_TYPESHED), - (BINOPS_REFLEXIVE, template_reflexive, stub, NO_OVERRIDE), + ([(None, None)], required_method_binary, extras), + ( + BINOPS_NUM + BINOPS_CMP, + template_binop_overload + template_binop, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": "", + }, + ), + ( + BINOPS_EQNE, + template_binop_overload + template_binop, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": _type_ignore(type_ignore_eq), + }, + ), + (BINOPS_REFLEXIVE, template_reflexive, extras), ] -def inplace(): +def inplace(other_type: str, type_ignore: str = "") -> list[OpsType]: + extras = {"other_type": other_type} return [ - ([(None, None)], required_method_inplace, stub_required_inplace, {}), - (BINOPS_INPLACE, template_inplace, "", {}), + ([(None, None)], required_method_inplace, extras), + ( + BINOPS_INPLACE, + template_inplace, + extras | {"type_ignore": _type_ignore(type_ignore)}, + ), + ] + + +def unops() -> list[OpsType]: + return [ + ([(None, None)], required_method_unary, {}), + (UNARY_OPS, template_unary, {}), + (OTHER_UNARY_METHODS, template_other_unary, {}), ] ops_info = {} -ops_info["DatasetOpsMixin"] = binops(stub_ds) + inplace() + unops("T_Dataset") -ops_info["DataArrayOpsMixin"] = binops(stub_da) + inplace() + unops("T_DataArray") -ops_info["VariableOpsMixin"] = binops(stub_var) + inplace() + unops("T_Variable") -ops_info["DatasetGroupByOpsMixin"] = binops(stub_dsgb) -ops_info["DataArrayGroupByOpsMixin"] = binops(stub_dagb) +ops_info["DatasetOpsMixin"] = ( + binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops() +) +ops_info["DataArrayOpsMixin"] = ( + binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops() +) +ops_info["VariableOpsMixin"] = ( + binops_overload(other_type="VarCompatible", overload_type="T_DataArray") + + inplace(other_type="VarCompatible", type_ignore="misc") + + unops() +) +ops_info["DatasetGroupByOpsMixin"] = binops( + other_type="GroupByCompatible", return_type="Dataset" +) +ops_info["DataArrayGroupByOpsMixin"] = binops( + other_type="T_Xarray", return_type="T_Xarray" +) MODULE_PREAMBLE = '''\ """Mixin classes with arithmetic operators.""" # This file was generated using xarray.util.generate_ops. Do not edit manually. -import operator - -from . import nputils, ops''' - -STUBFILE_PREAMBLE = '''\ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload +from __future__ import annotations -import numpy as np -from numpy.typing import ArrayLike +import operator +from typing import TYPE_CHECKING, Any, Callable, NoReturn, overload -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( +from xarray.core import nputils, ops +from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByIncompatible, - ScalarOrArray, + GroupByCompatible, + Self, + T_DataArray, + T_Xarray, VarCompatible, ) -from .variable import Variable -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")''' +if TYPE_CHECKING: + from xarray.core.dataset import Dataset''' CLASS_PREAMBLE = """{newline} @@ -233,35 +258,28 @@ class {cls_name}: {method}.__doc__ = {func}.__doc__""" -def render(ops_info, is_module): +def render(ops_info: dict[str, list[OpsType]]) -> Iterator[str]: """Render the module or stub file.""" - yield MODULE_PREAMBLE if is_module else STUBFILE_PREAMBLE + yield MODULE_PREAMBLE for cls_name, method_blocks in ops_info.items(): - yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n" * is_module) - yield from _render_classbody(method_blocks, is_module) + yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n") + yield from _render_classbody(method_blocks) -def _render_classbody(method_blocks, is_module): - for method_func_pairs, method_template, stub_template, extra in method_blocks: - template = method_template if is_module else stub_template +def _render_classbody(method_blocks: list[OpsType]) -> Iterator[str]: + for method_func_pairs, template, extra in method_blocks: if template: for method, func in method_func_pairs: yield template.format(method=method, func=func, **extra) - if is_module: - yield "" - for method_func_pairs, *_ in method_blocks: - for method, func in method_func_pairs: - if method and func: - yield COPY_DOCSTRING.format(method=method, func=func) + yield "" + for method_func_pairs, *_ in method_blocks: + for method, func in method_func_pairs: + if method and func: + yield COPY_DOCSTRING.format(method=method, func=func) if __name__ == "__main__": - option = sys.argv[1].lower() if len(sys.argv) == 2 else None - if option not in {"--module", "--stubs"}: - raise SystemExit(f"Usage: {sys.argv[0]} --module | --stubs") - is_module = option == "--module" - - for line in render(ops_info, is_module): + for line in render(ops_info): print(line)