Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raise on passing axis to Dataset.reduce methods #4940

Merged
merged 8 commits into from
Mar 4, 2021
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Deprecations

Bug fixes
~~~~~~~~~

- Don't allow passing ``axis`` to :py:meth:`Dataset.reduce` methods (:issue:`3510`, :pull:`4940`).
By `Justus Magin <https://github.com/keewis>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
45 changes: 20 additions & 25 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4665,6 +4665,12 @@ def reduce(
Dataset with this object's DataArrays replaced with new DataArrays
of summarized data and the indicated dimension(s) removed.
"""
if "axis" in kwargs:
raise ValueError(
"passing 'axis' to Dataset reduce methods is ambiguous."
" Please use 'dim' instead."
)

if dim is None or dim is ...:
dims = set(self.dims)
elif isinstance(dim, str) or not isinstance(dim, Iterable):
Expand Down Expand Up @@ -6854,7 +6860,7 @@ def idxmax(
)
)

def argmin(self, dim=None, axis=None, **kwargs):
def argmin(self, dim=None, **kwargs):
"""Indices of the minima of the member variables.

If there are multiple minima, the indices of the first one found will be
Expand All @@ -6868,9 +6874,6 @@ def argmin(self, dim=None, axis=None, **kwargs):
this is deprecated, in future will be an error, since DataArray.argmin will
return a dict with indices for all dimensions, which does not make sense for
a Dataset.
axis : int, optional
Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments
can be supplied.
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
Expand All @@ -6888,36 +6891,33 @@ def argmin(self, dim=None, axis=None, **kwargs):
See Also
--------
DataArray.argmin

"""
if dim is None and axis is None:
if dim is None:
warnings.warn(
"Once the behaviour of DataArray.argmin() and Variable.argmin() with "
"neither dim nor axis argument changes to return a dict of indices of "
"each dimension, for consistency it will be an error to call "
"Dataset.argmin() with no argument, since we don't return a dict of "
"Datasets.",
"Once the behaviour of DataArray.argmin() and Variable.argmin() without "
"dim changes to return a dict of indices of each dimension, for "
"consistency it will be an error to call Dataset.argmin() with no argument,"
"since we don't return a dict of Datasets.",
DeprecationWarning,
stacklevel=2,
)
if (
dim is None
or axis is not None
or (not isinstance(dim, Sequence) and dim is not ...)
or isinstance(dim, str)
):
# Return int index if single dimension is passed, and is not part of a
# sequence
argmin_func = getattr(duck_array_ops, "argmin")
return self.reduce(argmin_func, dim=dim, axis=axis, **kwargs)
return self.reduce(argmin_func, dim=dim, **kwargs)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
"dicts cannot be contained in a Dataset, so cannot call "
"Dataset.argmin() with a sequence or ... for dim"
)

def argmax(self, dim=None, axis=None, **kwargs):
def argmax(self, dim=None, **kwargs):
"""Indices of the maxima of the member variables.

If there are multiple maxima, the indices of the first one found will be
Expand All @@ -6931,9 +6931,6 @@ def argmax(self, dim=None, axis=None, **kwargs):
this is deprecated, in future will be an error, since DataArray.argmax will
return a dict with indices for all dimensions, which does not make sense for
a Dataset.
axis : int, optional
Axis over which to apply `argmax`. Only one of the 'dim' and 'axis' arguments
can be supplied.
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
Expand All @@ -6953,26 +6950,24 @@ def argmax(self, dim=None, axis=None, **kwargs):
DataArray.argmax

"""
if dim is None and axis is None:
if dim is None:
warnings.warn(
"Once the behaviour of DataArray.argmax() and Variable.argmax() with "
"neither dim nor axis argument changes to return a dict of indices of "
"each dimension, for consistency it will be an error to call "
"Dataset.argmax() with no argument, since we don't return a dict of "
"Datasets.",
"Once the behaviour of DataArray.argmin() and Variable.argmin() without "
"dim changes to return a dict of indices of each dimension, for "
"consistency it will be an error to call Dataset.argmin() with no argument,"
"since we don't return a dict of Datasets.",
DeprecationWarning,
stacklevel=2,
)
if (
dim is None
or axis is not None
or (not isinstance(dim, Sequence) and dim is not ...)
or isinstance(dim, str)
):
# Return int index if single dimension is passed, and is not part of a
# sequence
argmax_func = getattr(duck_array_ops, "argmax")
return self.reduce(argmax_func, dim=dim, axis=axis, **kwargs)
return self.reduce(argmax_func, dim=dim, **kwargs)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
Expand Down
9 changes: 3 additions & 6 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4746,6 +4746,9 @@ def test_reduce(self):

assert_equal(data.mean(dim=[]), data)

with pytest.raises(ValueError):
data.mean(axis=0)

def test_reduce_coords(self):
# regression test for GH1470
data = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"b": 4})
Expand Down Expand Up @@ -4926,9 +4929,6 @@ def mean_only_one_axis(x, axis):
with raises_regex(TypeError, "missing 1 required positional argument: 'axis'"):
ds.reduce(mean_only_one_axis)

with raises_regex(TypeError, "non-integer axis"):
ds.reduce(mean_only_one_axis, axis=["x", "y"])

def test_reduce_no_axis(self):
def total_sum(x):
return np.sum(x.flatten())
Expand All @@ -4938,9 +4938,6 @@ def total_sum(x):
actual = ds.reduce(total_sum)
assert_identical(expected, actual)

with raises_regex(TypeError, "unexpected keyword argument 'axis'"):
ds.reduce(total_sum, axis=0)

with raises_regex(TypeError, "unexpected keyword argument 'axis'"):
ds.reduce(total_sum, dim="x")

Expand Down
29 changes: 0 additions & 29 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -3972,35 +3972,6 @@ def test_repr(self, func, variant, dtype):
@pytest.mark.parametrize(
"func",
(
function("all"),
function("any"),
pytest.param(
function("argmax"),
marks=pytest.mark.skip(
reason="calling np.argmax as a function on xarray objects is not "
"supported"
),
),
pytest.param(
function("argmin"),
marks=pytest.mark.skip(
reason="calling np.argmin as a function on xarray objects is not "
"supported"
),
),
function("max"),
function("min"),
function("mean"),
pytest.param(
function("median"),
marks=pytest.mark.xfail(reason="median does not work with dataset yet"),
),
function("sum"),
function("prod"),
function("std"),
function("var"),
function("cumsum"),
function("cumprod"),
method("all"),
method("any"),
method("argmax", dim="x"),
Expand Down