Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Variable aggregations to NamedArray #8304

Merged
merged 17 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,85 @@
IndexVariable.sizes
IndexVariable.values


NamedArray.all
NamedArray.any
..
NamedArray.argmax
NamedArray.argmin
NamedArray.argsort
NamedArray.astype
..
NamedArray.broadcast_equals
NamedArray.chunk
NamedArray.clip
NamedArray.coarsen
NamedArray.compute
NamedArray.concat
NamedArray.conj
NamedArray.conjugate
NamedArray.copy
NamedArray.count
NamedArray.cumprod
NamedArray.cumsum
..
NamedArray.equals
NamedArray.fillna
NamedArray.get_axis_num
..
NamedArray.identical
..
NamedArray.isel
..
NamedArray.isnull
NamedArray.item
NamedArray.load
NamedArray.max
NamedArray.mean
NamedArray.median
NamedArray.min
..
NamedArray.no_conflicts
NamedArray.notnull
NamedArray.pad
NamedArray.prod
NamedArray.quantile
..
NamedArray.rank
NamedArray.reduce
..
NamedArray.roll
NamedArray.rolling_window
NamedArray.round
NamedArray.searchsorted
NamedArray.set_dims
NamedArray.shift
NamedArray.squeeze
NamedArray.stack
NamedArray.std
NamedArray.sum
..
NamedArray.to_dict
NamedArray.transpose
NamedArray.unstack
NamedArray.var
..
NamedArray.where
NamedArray.T
NamedArray.attrs
NamedArray.chunks
NamedArray.data
NamedArray.dims
NamedArray.dtype
NamedArray.imag
NamedArray.nbytes
NamedArray.ndim
NamedArray.real
NamedArray.shape
NamedArray.size
NamedArray.sizes
NamedArray.values

plot.plot
plot.line
plot.step
Expand Down
3 changes: 0 additions & 3 deletions xarray/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.ops import (
IncludeCumMethods,
IncludeNumpySameMethods,
IncludeReduceMethods,
)
Expand Down Expand Up @@ -99,8 +98,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

class VariableArithmetic(
ImplementsArrayReduce,
IncludeReduceMethods,
IncludeCumMethods,
IncludeNumpySameMethods,
SupportsArithmetic,
VariableOpsMixin,
Expand Down
25 changes: 0 additions & 25 deletions xarray/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
"var",
"median",
]
NAN_CUM_METHODS = ["cumsum", "cumprod"]
# TODO: wrap take, dot, sort


Expand Down Expand Up @@ -263,20 +262,6 @@ def inject_reduce_methods(cls):
setattr(cls, name, func)


def inject_cum_methods(cls):
methods = [(name, getattr(duck_array_ops, name), True) for name in NAN_CUM_METHODS]
for name, f, include_skipna in methods:
numeric_only = getattr(f, "numeric_only", False)
func = cls._reduce_method(f, include_skipna, numeric_only)
func.__name__ = name
func.__doc__ = _CUM_DOCSTRING_TEMPLATE.format(
name=name,
cls=cls.__name__,
extra_args=cls._cum_extra_args_docstring.format(name=name),
)
setattr(cls, name, func)


def op_str(name):
return f"__{name}__"

Expand Down Expand Up @@ -316,16 +301,6 @@ def __init_subclass__(cls, **kwargs):
inject_reduce_methods(cls)


class IncludeCumMethods:
__slots__ = ()

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)

if getattr(cls, "_reduce_method", None):
inject_cum_methods(cls)


class IncludeNumpySameMethods:
__slots__ = ()

Expand Down
63 changes: 13 additions & 50 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import numbers
import warnings
from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Hashable, Mapping, Sequence
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast
Expand Down Expand Up @@ -1743,59 +1743,22 @@ def reduce(
Array with summarized data and the indicated dimension(s)
removed.
"""
if dim == ...:
dim = None
if dim is not None and axis is not None:
raise ValueError("cannot supply both 'axis' and 'dim' arguments")

if dim is not None:
axis = self.get_axis_num(dim)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", r"Mean of empty slice", category=RuntimeWarning
)
if axis is not None:
if isinstance(axis, tuple) and len(axis) == 1:
# unpack axis for the benefit of functions
# like np.argmin which can't handle tuple arguments
axis = axis[0]
data = func(self.data, axis=axis, **kwargs)
else:
data = func(self.data, **kwargs)

if getattr(data, "shape", ()) == self.shape:
dims = self.dims
else:
removed_axes: Iterable[int]
if axis is None:
removed_axes = range(self.ndim)
else:
removed_axes = np.atleast_1d(axis) % self.ndim
if keepdims:
# Insert np.newaxis for removed dims
slices = tuple(
np.newaxis if i in removed_axes else slice(None, None)
for i in range(self.ndim)
)
if getattr(data, "shape", None) is None:
# Reduce has produced a scalar value, not an array-like
data = np.asanyarray(data)[slices]
else:
data = data[slices]
dims = self.dims
else:
dims = tuple(
adim for n, adim in enumerate(self.dims) if n not in removed_axes
)
keep_attrs_ = (
_get_keep_attrs(default=False) if keep_attrs is None else keep_attrs
)

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
attrs = self._attrs if keep_attrs else None
result = super().reduce(
func=func,
dim=dim,
axis=axis,
keep_attrs=keep_attrs_,
keepdims=keepdims,
**kwargs,
)

# We need to return `Variable` rather than the type of `self` at the moment, ref
# #8216
return Variable(dims, data, attrs=attrs)
return Variable(result.dims, result._data, attrs=result._attrs)

@classmethod
def concat(
Expand Down
Loading
Loading