diff --git a/doc/whats-new.rst b/doc/whats-new.rst index eed4e16eb62..7c3d57f4fe8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,7 +34,8 @@ Deprecations Bug fixes ~~~~~~~~~ - +- Don't allow passing ``axis`` to :py:meth:`Dataset.reduce` methods (:issue:`3510`, :pull:`4940`). + By `Justus Magin `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9faf74dd4bc..cbc30dddda9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4665,6 +4665,12 @@ def reduce( Dataset with this object's DataArrays replaced with new DataArrays of summarized data and the indicated dimension(s) removed. """ + if "axis" in kwargs: + raise ValueError( + "passing 'axis' to Dataset reduce methods is ambiguous." + " Please use 'dim' instead." + ) + if dim is None or dim is ...: dims = set(self.dims) elif isinstance(dim, str) or not isinstance(dim, Iterable): @@ -6854,7 +6860,7 @@ def idxmax( ) ) - def argmin(self, dim=None, axis=None, **kwargs): + def argmin(self, dim=None, **kwargs): """Indices of the minima of the member variables. If there are multiple minima, the indices of the first one found will be @@ -6868,9 +6874,6 @@ def argmin(self, dim=None, axis=None, **kwargs): this is deprecated, in future will be an error, since DataArray.argmin will return a dict with indices for all dimensions, which does not make sense for a Dataset. - axis : int, optional - Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments - can be supplied. keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -6888,28 +6891,25 @@ def argmin(self, dim=None, axis=None, **kwargs): See Also -------- DataArray.argmin - """ - if dim is None and axis is None: + if dim is None: warnings.warn( - "Once the behaviour of DataArray.argmin() and Variable.argmin() with " - "neither dim nor axis argument changes to return a dict of indices of " - "each dimension, for consistency it will be an error to call " - "Dataset.argmin() with no argument, since we don't return a dict of " - "Datasets.", + "Once the behaviour of DataArray.argmin() and Variable.argmin() without " + "dim changes to return a dict of indices of each dimension, for " + "consistency it will be an error to call Dataset.argmin() with no argument," + "since we don't return a dict of Datasets.", DeprecationWarning, stacklevel=2, ) if ( dim is None - or axis is not None or (not isinstance(dim, Sequence) and dim is not ...) or isinstance(dim, str) ): # Return int index if single dimension is passed, and is not part of a # sequence argmin_func = getattr(duck_array_ops, "argmin") - return self.reduce(argmin_func, dim=dim, axis=axis, **kwargs) + return self.reduce(argmin_func, dim=dim, **kwargs) else: raise ValueError( "When dim is a sequence or ..., DataArray.argmin() returns a dict. " @@ -6917,7 +6917,7 @@ def argmin(self, dim=None, axis=None, **kwargs): "Dataset.argmin() with a sequence or ... for dim" ) - def argmax(self, dim=None, axis=None, **kwargs): + def argmax(self, dim=None, **kwargs): """Indices of the maxima of the member variables. If there are multiple maxima, the indices of the first one found will be @@ -6931,9 +6931,6 @@ def argmax(self, dim=None, axis=None, **kwargs): this is deprecated, in future will be an error, since DataArray.argmax will return a dict with indices for all dimensions, which does not make sense for a Dataset. - axis : int, optional - Axis over which to apply `argmax`. Only one of the 'dim' and 'axis' arguments - can be supplied. keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -6953,26 +6950,24 @@ def argmax(self, dim=None, axis=None, **kwargs): DataArray.argmax """ - if dim is None and axis is None: + if dim is None: warnings.warn( - "Once the behaviour of DataArray.argmax() and Variable.argmax() with " - "neither dim nor axis argument changes to return a dict of indices of " - "each dimension, for consistency it will be an error to call " - "Dataset.argmax() with no argument, since we don't return a dict of " - "Datasets.", + "Once the behaviour of DataArray.argmin() and Variable.argmin() without " + "dim changes to return a dict of indices of each dimension, for " + "consistency it will be an error to call Dataset.argmin() with no argument," + "since we don't return a dict of Datasets.", DeprecationWarning, stacklevel=2, ) if ( dim is None - or axis is not None or (not isinstance(dim, Sequence) and dim is not ...) or isinstance(dim, str) ): # Return int index if single dimension is passed, and is not part of a # sequence argmax_func = getattr(duck_array_ops, "argmax") - return self.reduce(argmax_func, dim=dim, axis=axis, **kwargs) + return self.reduce(argmax_func, dim=dim, **kwargs) else: raise ValueError( "When dim is a sequence or ..., DataArray.argmin() returns a dict. " diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2118bc8b780..8afd6d68ab4 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4746,6 +4746,9 @@ def test_reduce(self): assert_equal(data.mean(dim=[]), data) + with pytest.raises(ValueError): + data.mean(axis=0) + def test_reduce_coords(self): # regression test for GH1470 data = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"b": 4}) @@ -4926,9 +4929,6 @@ def mean_only_one_axis(x, axis): with raises_regex(TypeError, "missing 1 required positional argument: 'axis'"): ds.reduce(mean_only_one_axis) - with raises_regex(TypeError, "non-integer axis"): - ds.reduce(mean_only_one_axis, axis=["x", "y"]) - def test_reduce_no_axis(self): def total_sum(x): return np.sum(x.flatten()) @@ -4938,9 +4938,6 @@ def total_sum(x): actual = ds.reduce(total_sum) assert_identical(expected, actual) - with raises_regex(TypeError, "unexpected keyword argument 'axis'"): - ds.reduce(total_sum, axis=0) - with raises_regex(TypeError, "unexpected keyword argument 'axis'"): ds.reduce(total_sum, dim="x") diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 76dd830de23..8b7835e5da6 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -3972,35 +3972,6 @@ def test_repr(self, func, variant, dtype): @pytest.mark.parametrize( "func", ( - function("all"), - function("any"), - pytest.param( - function("argmax"), - marks=pytest.mark.skip( - reason="calling np.argmax as a function on xarray objects is not " - "supported" - ), - ), - pytest.param( - function("argmin"), - marks=pytest.mark.skip( - reason="calling np.argmin as a function on xarray objects is not " - "supported" - ), - ), - function("max"), - function("min"), - function("mean"), - pytest.param( - function("median"), - marks=pytest.mark.xfail(reason="median does not work with dataset yet"), - ), - function("sum"), - function("prod"), - function("std"), - function("var"), - function("cumsum"), - function("cumprod"), method("all"), method("any"), method("argmax", dim="x"),