Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ellipsis typehint to reductions #7048

Merged
merged 12 commits into from
Sep 28, 2022
431 changes: 238 additions & 193 deletions xarray/core/_reductions.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .coordinates import Coordinates
from .dataarray import DataArray
from .dataset import Dataset
from .types import CombineAttrsOptions, Ellipsis, JoinOptions
from .types import CombineAttrsOptions, JoinOptions

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -1624,7 +1624,7 @@ def cross(

def dot(
*arrays,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
dims: str | Iterable[Hashable] | ellipsis | None = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like np.einsum, but
Expand Down
23 changes: 12 additions & 11 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
from .types import (
CoarsenBoundaryOptions,
DatetimeUnitOptions,
Ellipsis,
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -2888,9 +2888,9 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray:
def reduce(
self: T_DataArray,
func: Callable[..., Any],
dim: None | Hashable | Iterable[Hashable] = None,
dim: str | Iterable[Hashable] | ellipsis | None = None,
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
*,
axis: None | int | Sequence[int] = None,
axis: int | Sequence[int] | None = None,
keep_attrs: bool | None = None,
keepdims: bool = False,
**kwargs: Any,
Expand All @@ -2903,8 +2903,9 @@ def reduce(
Function which can be called in the form
`f(x, axis=axis, **kwargs)` to return the result of reducing an
np.ndarray over an integer valued axis.
dim : Hashable or Iterable of Hashable, optional
Dimension(s) over which to apply `func`.
dim : "...", str, Iterable of Hashable or None, optional
Dimension(s) over which to apply `func`. By default `func` is
applied over all dimensions.
axis : int or sequence of int, optional
Axis(es) over which to repeatedly apply `func`. Only one of the
'dim' and 'axis' arguments can be supplied. If neither are
Expand Down Expand Up @@ -3770,7 +3771,7 @@ def imag(self: T_DataArray) -> T_DataArray:
def dot(
self: T_DataArray,
other: T_DataArray,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
dims: str | Iterable[Hashable] | ellipsis | None = None,
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
) -> T_DataArray:
"""Perform dot product of two DataArrays along their shared dims.

Expand Down Expand Up @@ -3890,7 +3891,7 @@ def sortby(
def quantile(
self: T_DataArray,
q: ArrayLike,
dim: str | Iterable[Hashable] | None = None,
dim: Dims = None,
method: QUANTILE_METHODS = "linear",
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand Down Expand Up @@ -4774,7 +4775,7 @@ def idxmax(
# https://github.com/python/mypy/issues/12846 is resolved
def argmin(
self,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
dim: str | Iterable[Hashable] | ellipsis | None = None,
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand All @@ -4790,7 +4791,7 @@ def argmin(

Parameters
----------
dim : Hashable, sequence of Hashable, None or ..., optional
dim : "...", str, Iterable of Hashable or None, optional
The dimensions over which to find the minimum. By default, finds minimum over
all dimensions - for now returning an int for backward compatibility, but
this is deprecated, in future will return a dict with indices for all
Expand Down Expand Up @@ -4879,7 +4880,7 @@ def argmin(
# https://github.com/python/mypy/issues/12846 is resolved
def argmax(
self,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
dim: str | Iterable[Hashable] | ellipsis | None = None,
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand All @@ -4895,7 +4896,7 @@ def argmax(

Parameters
----------
dim : Hashable, sequence of Hashable, None or ..., optional
dim : "...", str, Iterable of Hashable or None, optional
The dimensions over which to find the maximum. By default, finds maximum over
all dimensions - for now returning an int for backward compatibility, but
this is deprecated, in future will return a dict with indices for all
Expand Down
44 changes: 22 additions & 22 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
CombineAttrsOptions,
CompatOptions,
DatetimeUnitOptions,
Ellipsis,
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -4256,7 +4256,7 @@ def _get_stack_index(

def _stack_once(
self: T_Dataset,
dims: Sequence[Hashable | Ellipsis],
dims: Sequence[Hashable | ellipsis],
new_dim: Hashable,
index_cls: type[Index],
create_index: bool | None = True,
Expand Down Expand Up @@ -4315,10 +4315,10 @@ def _stack_once(

def stack(
self: T_Dataset,
dimensions: Mapping[Any, Sequence[Hashable | Ellipsis]] | None = None,
dimensions: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None,
create_index: bool | None = True,
index_cls: type[Index] = PandasMultiIndex,
**dimensions_kwargs: Sequence[Hashable | Ellipsis],
**dimensions_kwargs: Sequence[Hashable | ellipsis],
) -> T_Dataset:
"""
Stack any number of existing dimensions into a single new dimension.
Expand Down Expand Up @@ -5504,7 +5504,7 @@ def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset:
def reduce(
self: T_Dataset,
func: Callable,
dim: Hashable | Iterable[Hashable] = None,
dim: str | Iterable[Hashable] | ellipsis | None = None,
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
*,
keep_attrs: bool | None = None,
keepdims: bool = False,
Expand All @@ -5519,8 +5519,8 @@ def reduce(
Function which can be called in the form
`f(x, axis=axis, **kwargs)` to return the result of reducing an
np.ndarray over an integer valued axis.
dim : str or sequence of str, optional
Dimension(s) over which to apply `func`. By default `func` is
dim : str, Iterable of Hashable or None, optional
Dimension(s) over which to apply `func`. By default `func` is
applied over all dimensions.
keep_attrs : bool or None, optional
If True, the dataset's attributes (`attrs`) will be copied from
Expand Down Expand Up @@ -5578,18 +5578,14 @@ def reduce(
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
):
reduce_maybe_single: Hashable | None | list[Hashable]
if len(reduce_dims) == 1:
# unpack dimensions for the benefit of functions
# like np.argmin which can't handle tuple arguments
(reduce_maybe_single,) = reduce_dims
elif len(reduce_dims) == var.ndim:
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
# the former is often more efficient
reduce_maybe_single = None
else:
reduce_maybe_single = reduce_dims
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
# the former is often more efficient
reduce_maybe_single = (
None
if len(reduce_dims) == var.ndim and var.ndim != 1
else reduce_dims
)
variables[name] = var.reduce(
func,
dim=reduce_maybe_single,
Expand Down Expand Up @@ -6698,7 +6694,7 @@ def sortby(
def quantile(
self: T_Dataset,
q: ArrayLike,
dim: str | Iterable[Hashable] | None = None,
dim: Dims = None,
method: QUANTILE_METHODS = "linear",
numeric_only: bool = False,
keep_attrs: bool = None,
Expand Down Expand Up @@ -8030,7 +8026,9 @@ def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
# Return int index if single dimension is passed, and is not part of a
# sequence
argmin_func = getattr(duck_array_ops, "argmin")
return self.reduce(argmin_func, dim=dim, **kwargs)
return self.reduce(
argmin_func, dim=None if dim is None else [dim], **kwargs
)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
Expand Down Expand Up @@ -8088,7 +8086,9 @@ def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
# Return int index if single dimension is passed, and is not part of a
# sequence
argmax_func = getattr(duck_array_ops, "argmax")
return self.reduce(argmax_func, dim=dim, **kwargs)
return self.reduce(
argmax_func, dim=None if dim is None else [dim], **kwargs
)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
Expand Down
Loading