From 61ef76bf28e0a17029d1e747ed8dfc1b6fc8dbdc Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Wed, 4 Mar 2020 21:20:44 -0500 Subject: [PATCH] implement idxmax and idxmin --- doc/api.rst | 4 + doc/whats-new.rst | 3 + xarray/core/computation.py | 66 +++ xarray/core/dataarray.py | 196 +++++++++ xarray/core/dataset.py | 161 +++++++ xarray/tests/test_dataarray.py | 777 +++++++++++++++++++++++++++++++++ xarray/tests/test_dataset.py | 9 + 7 files changed, 1216 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index 216f47f988f..b37c84e7a81 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -180,6 +180,8 @@ Computation :py:attr:`~Dataset.any` :py:attr:`~Dataset.argmax` :py:attr:`~Dataset.argmin` +:py:attr:`~Dataset.idxmax` +:py:attr:`~Dataset.idxmin` :py:attr:`~Dataset.max` :py:attr:`~Dataset.mean` :py:attr:`~Dataset.median` @@ -362,6 +364,8 @@ Computation :py:attr:`~DataArray.any` :py:attr:`~DataArray.argmax` :py:attr:`~DataArray.argmin` +:py:attr:`~DataArray.idxmax` +:py:attr:`~DataArray.idxmin` :py:attr:`~DataArray.max` :py:attr:`~DataArray.mean` :py:attr:`~DataArray.median` diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f44fdeb3354..2288644760d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -35,6 +35,9 @@ New Features :py:func:`combine_by_coords` and :py:func:`combine_nested` using combine_attrs keyword argument. (:issue:`3865`, :pull:`3877`) By `John Omotani `_ +- Implement :py:meth:`~xarray.DataArray.idxmax`, :py:meth:`~xarray.DataArray.idxmin`, + :py:meth:`~xarray.Dataset.idxmax`, :py:meth:`~xarray.Dataset.idxmin`. (:issue:`60`, :pull:`3871`) + By `Todd Jennings `_ Bug fixes diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 13bf6248331..fb3b1ab8ca7 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -26,6 +26,7 @@ from . import duck_array_ops, utils from .alignment import deep_align from .merge import merge_coordinates_without_align +from .nanops import dask_array from .options import OPTIONS from .pycompat import dask_array_type from .utils import is_dict_like @@ -1338,3 +1339,68 @@ def polyval(coord, coeffs, degree_dim="degree"): coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]}, ) return (lhs * coeffs).sum(degree_dim) + + +def _calc_idxminmax( + *, + array, + func: Callable, + dim: Optional[Hashable], + skipna: Optional[bool], + fill_value: Any = None, + keep_attrs: Optional[bool], + **kwargs: Any, +): + """Apply common operations for idxmin and idxmax.""" + # This function doesn't make sense for scalars so don't try + if not array.ndim: + raise ValueError("This function does not apply for scalars") + + if dim is not None: + pass # Use the dim if available + elif array.ndim == 1: + # it is okay to guess the dim if there is only 1 + dim = array.dims[0] + else: + # The dim is not specified and ambiguous. Don't guess. + raise ValueError("Must supply 'dim' argument for multidimensional arrays") + + if dim not in array.dims: + raise KeyError(f'Dimension "{dim}" not in dimension') + if dim not in array.coords: + raise KeyError(f'Dimension "{dim}" does not have coordinates') + + # These are dtypes with NaN values argmin and argmax can handle + na_dtypes = "cfO" + + if skipna or (skipna is None and array.dtype.kind in na_dtypes): + # Need to skip NaN values since argmin and argmax can't handle them + allna = array.isnull().all(dim) + array = array.where(~allna, 0) + + # This will run argmin or argmax. + indx = func(array, dim=dim, axis=None, keep_attrs=False, skipna=skipna, **kwargs) + + # Get the coordinate we want. + coordarray = array[dim] + + # Handle dask arrays. + if isinstance(array, dask_array_type): + res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype) + else: + res = coordarray[ + indx, + ] + + if skipna or (skipna is None and array.dtype.kind in na_dtypes): + # Put the NaN values back in after removing them + res = res.where(~allna, fill_value) + + # The dim is gone but we need to remove the corresponding coordinate. + del res.coords[dim] + + # Put the attrs back in if needed + if keep_attrs: + res.attrs = array.attrs + + return res diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6c0e1506dd5..24c9e5d48fd 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3508,6 +3508,202 @@ def pad( ) return self._from_temp_dataset(ds) + def idxmin( + self, + dim: Hashable = None, + skipna: bool = None, + fill_value: Any = np.NaN, + keep_attrs: Optional[bool] = False, + **kwargs: Any, + ) -> "DataArray": + """Return the coordinate of the minimum value along a dimension. + + Returns a new DataArray named after the dimension with the values of + the coordinate along that dimension corresponding to minimum value + along that dimension. + + In comparison to `argmin`, this returns the coordinate while `argmin` + returns the index. + + Parameters + ---------- + dim : str, optional + Dimension over which to apply `idxmin`. This is optional for 1D + arrays, but required for arrays with 2 or more dimensions. + skipna : bool or None, default None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + fill_value : Any, default NaN + Value to be filled in case all of the values along a dimension are + null. By default this is NaN. The fill value and result are + automatically converted to a compatible dtype if possible. + Ignored if `skipna` is False. + keep_attrs : bool, default False + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `idxmin` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `idxmin` applied to its data and the + indicated dimension removed. + + Examples + -------- + + >>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x", + ... coords={"x": ['a', 'b', 'c', 'd', 'e']}) + >>> array.min() + + array(-2) + >>> array.argmin() + + array(4) + >>> array.idxmin() + + array('e', dtype='>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], + ... [np.NaN, np.NaN, 1., np.NaN, np.NaN]], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], + ... "x": np.arange(5.)**2} + ... ) + >>> array.min(dim="x") + + array([-2., -4., 1.]) + Coordinates: + * y (y) int64 -1 0 1 + >>> array.argmin(dim="x") + + array([4, 0, 2]) + Coordinates: + * y (y) int64 -1 0 1 + >>> array.idxmin(dim="x") + + array([16., 0., 4.]) + Coordinates: + * y (y) int64 -1 0 1 + + See also + -------- + Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin + """ + return computation._calc_idxminmax( + array=self, + func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs), + dim=dim, + skipna=skipna, + fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + + def idxmax( + self, + dim: Hashable = None, + skipna: bool = None, + fill_value: Any = np.NaN, + keep_attrs: Optional[bool] = False, + **kwargs: Any, + ) -> "DataArray": + """Return the coordinate of the maximum value along a dimension. + + Returns a new DataArray named after the dimension with the values of + the coordinate along that dimension corresponding to maximum value + along that dimension. + + In comparison to `argmax`, this returns the coordinate while `argmax` + returns the index. + + Parameters + ---------- + dim : str, optional + Dimension over which to apply `idxmax`. This is optional for 1D + arrays, but required for arrays with 2 or more dimensions. + skipna : bool or None, default None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + fill_value : Any, default NaN + Value to be filled in case all of the values along a dimension are + null. By default this is NaN. The fill value and result are + automatically converted to a compatible dtype if possible. + Ignored if `skipna` is False. + keep_attrs : bool, default False + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `idxmax` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `idxmax` applied to its data and the + indicated dimension removed. + + Examples + -------- + + >>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x", + ... coords={"x": ['a', 'b', 'c', 'd', 'e']}) + >>> array.max() + + array(2) + >>> array.argmax() + + array(1) + >>> array.idxmax() + + array('b', dtype='>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], + ... [np.NaN, np.NaN, 1., np.NaN, np.NaN]], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], + ... "x": np.arange(5.)**2} + ... ) + >>> array.max(dim="x") + + array([2., 2., 1.]) + Coordinates: + * y (y) int64 -1 0 1 + >>> array.argmax(dim="x") + + array([0, 2, 2]) + Coordinates: + * y (y) int64 -1 0 1 + >>> array.idxmax(dim="x") + + array([0., 4., 4.]) + Coordinates: + * y (y) int64 -1 0 1 + + See also + -------- + Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax + """ + return computation._calc_idxminmax( + array=self, + func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs), + dim=dim, + skipna=skipna, + fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = property(StringAccessor) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c49694b1fc0..4a32860c3b5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6,6 +6,7 @@ from collections import defaultdict from html import escape from numbers import Number +from operator import methodcaller from pathlib import Path from typing import ( TYPE_CHECKING, @@ -6093,5 +6094,165 @@ def pad( return self._replace_vars_and_dims(variables) + def idxmin( + self, + dim: Hashable = None, + skipna: bool = None, + fill_value: Any = np.NaN, + keep_attrs: bool = False, + **kwargs: Any, + ) -> "Dataset": + """Return the coordinate of the minimum value along a dimension. + + Returns a new Dataset named after the dimension with the values of + the coordinate along that dimension corresponding to minimum value + along that dimension. + + In comparison to `argmin`, this returns the coordinate while `argmin` + returns the index. + + Parameters + ---------- + dim : str, optional + Dimension over which to apply `idxmin`. This is optional for 1D + variables, but required for variables with 2 or more dimensions. + skipna : bool or None, default None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + fill_value : Any, default NaN + Value to be filled in case all of the values along a dimension are + null. By default this is NaN. The fill value and result are + automatically converted to a compatible dtype if possible. + Ignored if `skipna` is False. + keep_attrs : bool, default False + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `idx` on this object's data. + + Returns + ------- + reduced : Dataset + New Dataset object with `idxmin` applied to its data and the + indicated dimension removed. + + Examples + -------- + + >>> array1 = xr.DataArray([0, 2, 1, 0, -2], dims="x", + ... coords={"x": ['a', 'b', 'c', 'd', 'e']}) + >>> array2 = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], + ... [np.NaN, np.NaN, 1., np.NaN, np.NaN]], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], + ... "x": ['a', 'b', 'c', 'd', 'e']} + ... ) + >>> ds = xr.Dataset({'int': array1, 'float': array2}) + >>> ds.min(dim='x') + + Dimensions: (y: 3) + Coordinates: + * y (y) int64 -1 0 1 + Data variables: + int int64 -2 + float (y) float64 -2.0 -4.0 1.0 + >>> ds.argmin(dim='x') + + Dimensions: (y: 3) + Coordinates: + * y (y) int64 -1 0 1 + Data variables: + int int64 4 + float (y) int64 4 0 2 + >>> ds.idxmin(dim='x') + + Dimensions: (y: 3) + Coordinates: + * y (y) int64 -1 0 1 + Data variables: + int "Dataset": + """Return the coordinate of the maximum value along a dimension. + + Returns a new Dataset named after the dimension with the values of + the coordinate along that dimension corresponding to maximum value + along that dimension. + + In comparison to `argmax`, this returns the coordinate while `argmax` + returns the index. + + Parameters + ---------- + dim : str, optional + Dimension over which to apply `idxmax`. This is optional for 1D + variables, but required for variables with 2 or more dimensions. + skipna : bool or None, default None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + fill_value : Any, default NaN + Value to be filled in case all of the values along a dimension are + null. By default this is NaN. The fill value and result are + automatically converted to a compatible dtype if possible. + Ignored if `skipna` is False. + keep_attrs : bool, default False + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `idxmax` on this object's data. + + Returns + ------- + reduced : Dataset + New Dataset object with `idxmax` applied to its data and the + indicated dimension removed. + + See also + -------- + DataArray.idxmax, Dataset.idxmin, Dataset.max, Dataset.argmax + """ + return self.map( + methodcaller( + "idxmax", + dim=dim, + skipna=skipna, + fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ), + ) + ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 600c4254c5d..5b3e122bf72 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4349,6 +4349,783 @@ def test_pad_reflect(self, mode, reflect_type): assert_identical(actual, expected) +class TestReduce: + @pytest.fixture(autouse=True) + def setup(self): + self.attrs = {"attr1": "value1", "attr2": 2929} + + +@pytest.mark.parametrize( + "x, minindex, maxindex, nanindex", + [ + (np.array([0, 1, 2, 0, -2, -4, 2]), 5, 2, None), + (np.array([0.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0]), 5, 2, None), + (np.array([1.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0]), 5, 2, 1), + ( + np.array([1.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0]).astype("object"), + 5, + 2, + 1, + ), + (np.array([np.NaN, np.NaN]), np.NaN, np.NaN, 0), + ( + np.array( + ["2015-12-31", "2020-01-02", "2020-01-01", "2016-01-01"], + dtype="datetime64[ns]", + ), + 0, + 1, + None, + ), + ], +) +class TestReduce1D(TestReduce): + def test_min(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + + if np.isnan(minindex): + minindex = 0 + + expected0 = ar.isel(x=minindex, drop=True) + result0 = ar.min(keep_attrs=True) + assert_identical(result0, expected0) + + result1 = ar.min() + expected1 = expected0.copy() + expected1.attrs = {} + assert_identical(result1, expected1) + + result2 = ar.min(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = ar.isel(x=nanindex, drop=True) + expected2.attrs = {} + else: + expected2 = expected1 + + assert_identical(result2, expected2) + + def test_max(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + + if np.isnan(minindex): + maxindex = 0 + + expected0 = ar.isel(x=maxindex, drop=True) + result0 = ar.max(keep_attrs=True) + assert_identical(result0, expected0) + + result1 = ar.max() + expected1 = expected0.copy() + expected1.attrs = {} + assert_identical(result1, expected1) + + result2 = ar.max(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = ar.isel(x=nanindex, drop=True) + expected2.attrs = {} + else: + expected2 = expected1 + + assert_identical(result2, expected2) + + def test_argmin(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(minindex): + with pytest.raises(ValueError): + ar.argmin() + return + + expected0 = indarr[minindex] + result0 = ar.argmin() + assert_identical(result0, expected0) + + result1 = ar.argmin(keep_attrs=True) + expected1 = expected0.copy() + expected1.attrs = self.attrs + assert_identical(result1, expected1) + + result2 = ar.argmin(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = indarr.isel(x=nanindex, drop=True) + expected2.attrs = {} + else: + expected2 = expected0 + + assert_identical(result2, expected2) + + def test_argmax(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(maxindex): + with pytest.raises(ValueError): + ar.argmax() + return + + expected0 = indarr[maxindex] + result0 = ar.argmax() + assert_identical(result0, expected0) + + result1 = ar.argmax(keep_attrs=True) + expected1 = expected0.copy() + expected1.attrs = self.attrs + assert_identical(result1, expected1) + + result2 = ar.argmax(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = indarr.isel(x=nanindex, drop=True) + expected2.attrs = {} + else: + expected2 = expected0 + + assert_identical(result2, expected2) + + def test_idxmin(self, x, minindex, maxindex, nanindex): + ar0 = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs, + ) + + # dim doesn't exist + with pytest.raises(KeyError): + ar0.idxmin(dim="spam") + + # Scalar Dataarray + with pytest.raises(ValueError): + xr.DataArray(5).idxmin() + + coordarr0 = xr.DataArray(ar0.coords["x"], dims=["x"]) + coordarr1 = coordarr0.copy() + + hasna = np.isnan(minindex) + if np.isnan(minindex): + minindex = 0 + + if hasna: + coordarr1[...] = 1 + fill_value_0 = np.NaN + else: + fill_value_0 = 1 + + expected0 = ( + (coordarr1 * fill_value_0).isel(x=minindex, drop=True).astype("float") + ) + expected0.name = "x" + + # Default fill value (NaN) + result0 = ar0.idxmin() + assert_identical(result0, expected0) + + # Manually specify NaN fill_value + result1 = ar0.idxmin(fill_value=np.NaN) + assert_identical(result1, expected0) + + # keep_attrs + result2 = ar0.idxmin(keep_attrs=True) + expected2 = expected0.copy() + expected2.attrs = self.attrs + assert_identical(result2, expected2) + + # skipna=False + if nanindex is not None and ar0.dtype.kind != "O": + expected3 = coordarr0.isel(x=nanindex, drop=True).astype("float") + expected3.name = "x" + expected3.attrs = {} + else: + expected3 = expected0.copy() + + result3 = ar0.idxmin(skipna=False) + assert_identical(result3, expected3) + + # fill_value should be ignored with skipna=False + result4 = ar0.idxmin(skipna=False, fill_value=-100j) + assert_identical(result4, expected3) + + # Float fill_value + if hasna: + fill_value_5 = -1.1 + else: + fill_value_5 = 1 + + expected5 = (coordarr1 * fill_value_5).isel(x=minindex, drop=True) + expected5.name = "x" + + result5 = ar0.idxmin(fill_value=-1.1) + assert_identical(result5, expected5) + + # Integer fill_value + if hasna: + fill_value_6 = -1 + else: + fill_value_6 = 1 + + expected6 = (coordarr1 * fill_value_6).isel(x=minindex, drop=True) + expected6.name = "x" + + result6 = ar0.idxmin(fill_value=-1) + assert_identical(result6, expected6) + + # Complex fill_value + if hasna: + fill_value_7 = -1j + else: + fill_value_7 = 1 + + expected7 = (coordarr1 * fill_value_7).isel(x=minindex, drop=True) + expected7.name = "x" + + result7 = ar0.idxmin(fill_value=-1j) + assert_identical(result7, expected7) + + def test_idxmax(self, x, minindex, maxindex, nanindex): + ar0 = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs, + ) + + # dim doesn't exist + with pytest.raises(KeyError): + ar0.idxmax(dim="spam") + + # Scalar Dataarray + with pytest.raises(ValueError): + xr.DataArray(5).idxmax() + + coordarr0 = xr.DataArray(ar0.coords["x"], dims=["x"]) + coordarr1 = coordarr0.copy() + + hasna = np.isnan(maxindex) + if np.isnan(maxindex): + maxindex = 0 + + if hasna: + coordarr1[...] = 1 + fill_value_0 = np.NaN + else: + fill_value_0 = 1 + + expected0 = ( + (coordarr1 * fill_value_0).isel(x=maxindex, drop=True).astype("float") + ) + expected0.name = "x" + + # Default fill value (NaN) + result0 = ar0.idxmax() + assert_identical(result0, expected0) + + # Manually specify NaN fill_value + result1 = ar0.idxmax(fill_value=np.NaN) + assert_identical(result1, expected0) + + # keep_attrs + result2 = ar0.idxmax(keep_attrs=True) + expected2 = expected0.copy() + expected2.attrs = self.attrs + assert_identical(result2, expected2) + + # skipna=False + if nanindex is not None and ar0.dtype.kind != "O": + expected3 = coordarr0.isel(x=nanindex, drop=True).astype("float") + expected3.name = "x" + expected3.attrs = {} + else: + expected3 = expected0.copy() + + result3 = ar0.idxmax(skipna=False) + assert_identical(result3, expected3) + + # fill_value should be ignored with skipna=False + result4 = ar0.idxmax(skipna=False, fill_value=-100j) + assert_identical(result4, expected3) + + # Float fill_value + if hasna: + fill_value_5 = -1.1 + else: + fill_value_5 = 1 + + expected5 = (coordarr1 * fill_value_5).isel(x=maxindex, drop=True) + expected5.name = "x" + + result5 = ar0.idxmax(fill_value=-1.1) + assert_identical(result5, expected5) + + # Integer fill_value + if hasna: + fill_value_6 = -1 + else: + fill_value_6 = 1 + + expected6 = (coordarr1 * fill_value_6).isel(x=maxindex, drop=True) + expected6.name = "x" + + result6 = ar0.idxmax(fill_value=-1) + assert_identical(result6, expected6) + + # Complex fill_value + if hasna: + fill_value_7 = -1j + else: + fill_value_7 = 1 + + expected7 = (coordarr1 * fill_value_7).isel(x=maxindex, drop=True) + expected7.name = "x" + + result7 = ar0.idxmax(fill_value=-1j) + assert_identical(result7, expected7) + + +@pytest.mark.parametrize( + "x, minindex, maxindex, nanindex", + [ + ( + np.array( + [ + [0, 1, 2, 0, -2, -4, 2], + [1, 1, 1, 1, 1, 1, 1], + [0, 0, -10, 5, 20, 0, 0], + ] + ), + [5, 0, 2], + [2, 0, 4], + [None, None, None], + ), + ( + np.array( + [ + [2.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0], + [-4.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0], + [np.NaN] * 7, + ] + ), + [5, 0, np.NaN], + [0, 2, np.NaN], + [None, 1, 0], + ), + ( + np.array( + [ + [2.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0], + [-4.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0], + [np.NaN] * 7, + ] + ).astype("object"), + [5, 0, np.NaN], + [0, 2, np.NaN], + [None, 1, 0], + ), + ( + np.array( + [ + ["2015-12-31", "2020-01-02", "2020-01-01", "2016-01-01"], + ["2020-01-02", "2020-01-02", "2020-01-02", "2020-01-02"], + ["1900-01-01", "1-02-03", "1900-01-02", "1-02-03"], + ], + dtype="datetime64[ns]", + ), + [0, 0, 1], + [1, 0, 2], + [None, None, None], + ), + ], +) +class TestReduce2D(TestReduce): + def test_min(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + + minindex = [x if not np.isnan(x) else 0 for x in minindex] + expected0 = [ + ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex) + ] + expected0 = xr.concat(expected0, dim="y") + + result0 = ar.min(dim="x", keep_attrs=True) + assert_identical(result0, expected0) + + result1 = ar.min(dim="x") + expected1 = expected0 + expected1.attrs = {} + assert_identical(result1, expected1) + + result2 = ar.min(axis=1) + assert_identical(result2, expected1) + + minindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(minindex, nanindex) + ] + expected2 = [ + ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex) + ] + expected2 = xr.concat(expected2, dim="y") + expected2.attrs = {} + + result3 = ar.min(dim="x", skipna=False) + + assert_identical(result3, expected2) + + def test_max(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + + maxindex = [x if not np.isnan(x) else 0 for x in maxindex] + expected0 = [ + ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex) + ] + expected0 = xr.concat(expected0, dim="y") + + result0 = ar.max(dim="x", keep_attrs=True) + assert_identical(result0, expected0) + + result1 = ar.max(dim="x") + expected1 = expected0.copy() + expected1.attrs = {} + assert_identical(result1, expected1) + + result2 = ar.max(axis=1) + assert_identical(result2, expected1) + + maxindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(maxindex, nanindex) + ] + expected2 = [ + ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex) + ] + expected2 = xr.concat(expected2, dim="y") + expected2.attrs = {} + + result3 = ar.max(dim="x", skipna=False) + + assert_identical(result3, expected2) + + def test_argmin(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarr = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarr, dims=ar.dims, coords=ar.coords) + + if np.isnan(minindex).any(): + with pytest.raises(ValueError): + ar.argmin(dim="x") + return + + expected0 = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected0 = xr.concat(expected0, dim="y") + + result0 = ar.argmin(dim="x") + assert_identical(result0, expected0) + + result1 = ar.argmin(axis=1) + assert_identical(result1, expected0) + + result2 = ar.argmin(dim="x", keep_attrs=True) + expected1 = expected0.copy() + expected1.attrs = self.attrs + assert_identical(result2, expected1) + + minindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(minindex, nanindex) + ] + expected2 = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected2 = xr.concat(expected2, dim="y") + expected2.attrs = {} + + result3 = ar.argmin(dim="x", skipna=False) + + assert_identical(result3, expected2) + + def test_argmax(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarr = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarr, dims=ar.dims, coords=ar.coords) + + if np.isnan(maxindex).any(): + with pytest.raises(ValueError): + ar.argmax(dim="x") + return + + expected0 = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected0 = xr.concat(expected0, dim="y") + + result0 = ar.argmax(dim="x") + assert_identical(result0, expected0) + + result1 = ar.argmax(axis=1) + assert_identical(result1, expected0) + + result2 = ar.argmax(dim="x", keep_attrs=True) + expected1 = expected0.copy() + expected1.attrs = self.attrs + assert_identical(result2, expected1) + + maxindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(maxindex, nanindex) + ] + expected2 = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected2 = xr.concat(expected2, dim="y") + expected2.attrs = {} + + result3 = ar.argmax(dim="x", skipna=False) + + assert_identical(result3, expected2) + + def test_idxmin(self, x, minindex, maxindex, nanindex): + ar0 = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + + assert_identical(ar0, ar0) + + # No dimension specified + with pytest.raises(ValueError): + ar0.idxmin() + + # dim doesn't exist + with pytest.raises(KeyError): + ar0.idxmin(dim="Y") + + assert_identical(ar0, ar0) + + coordarr0 = xr.DataArray( + np.tile(ar0.coords["x"], [x.shape[0], 1]), dims=ar0.dims, coords=ar0.coords + ) + + hasna = [np.isnan(x) for x in minindex] + coordarr1 = coordarr0.copy() + coordarr1[hasna, :] = 1 + minindex0 = [x if not np.isnan(x) else 0 for x in minindex] + + nan_mult_0 = np.array([np.NaN if x else 1 for x in hasna])[:, None] + expected0 = [ + (coordarr1 * nan_mult_0).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex0) + ] + expected0 = xr.concat(expected0, dim="y") + expected0.name = "x" + + # Default fill value (NaN) + result0 = ar0.idxmin(dim="x") + assert_identical(result0, expected0) + + # Manually specify NaN fill_value + result1 = ar0.idxmin(dim="x", fill_value=np.NaN) + assert_identical(result1, expected0) + + # keep_attrs + result2 = ar0.idxmin(dim="x", keep_attrs=True) + expected2 = expected0.copy() + expected2.attrs = self.attrs + assert_identical(result2, expected2) + + # skipna=False + minindex3 = [ + x if y is None or ar0.dtype.kind == "O" else y + for x, y in zip(minindex0, nanindex) + ] + expected3 = [ + coordarr0.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex3) + ] + expected3 = xr.concat(expected3, dim="y") + expected3.name = "x" + expected3.attrs = {} + + result3 = ar0.idxmin(dim="x", skipna=False) + assert_identical(result3, expected3) + + # fill_value should be ignored with skipna=False + result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j) + assert_identical(result4, expected3) + + # Float fill_value + nan_mult_5 = np.array([-1.1 if x else 1 for x in hasna])[:, None] + expected5 = [ + (coordarr1 * nan_mult_5).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex0) + ] + expected5 = xr.concat(expected5, dim="y") + expected5.name = "x" + + result5 = ar0.idxmin(dim="x", fill_value=-1.1) + assert_identical(result5, expected5) + + # Integer fill_value + nan_mult_6 = np.array([-1 if x else 1 for x in hasna])[:, None] + expected6 = [ + (coordarr1 * nan_mult_6).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex0) + ] + expected6 = xr.concat(expected6, dim="y") + expected6.name = "x" + + result6 = ar0.idxmin(dim="x", fill_value=-1) + assert_identical(result6, expected6) + + # Complex fill_value + nan_mult_7 = np.array([-5j if x else 1 for x in hasna])[:, None] + expected7 = [ + (coordarr1 * nan_mult_7).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex0) + ] + expected7 = xr.concat(expected7, dim="y") + expected7.name = "x" + + result7 = ar0.idxmin(dim="x", fill_value=-5j) + assert_identical(result7, expected7) + + def test_idxmax(self, x, minindex, maxindex, nanindex): + ar0 = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + + # No dimension specified + with pytest.raises(ValueError): + ar0.idxmax() + + # dim doesn't exist + with pytest.raises(KeyError): + ar0.idxmax(dim="Y") + + ar1 = ar0.copy() + del ar1.coords["y"] + with pytest.raises(KeyError): + ar1.idxmax(dim="y") + + coordarr0 = xr.DataArray( + np.tile(ar0.coords["x"], [x.shape[0], 1]), dims=ar0.dims, coords=ar0.coords + ) + + hasna = [np.isnan(x) for x in maxindex] + coordarr1 = coordarr0.copy() + coordarr1[hasna, :] = 1 + maxindex0 = [x if not np.isnan(x) else 0 for x in maxindex] + + nan_mult_0 = np.array([np.NaN if x else 1 for x in hasna])[:, None] + expected0 = [ + (coordarr1 * nan_mult_0).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex0) + ] + expected0 = xr.concat(expected0, dim="y") + expected0.name = "x" + + # Default fill value (NaN) + result0 = ar0.idxmax(dim="x") + assert_identical(result0, expected0) + + # Manually specify NaN fill_value + result1 = ar0.idxmax(dim="x", fill_value=np.NaN) + assert_identical(result1, expected0) + + # keep_attrs + result2 = ar0.idxmax(dim="x", keep_attrs=True) + expected2 = expected0.copy() + expected2.attrs = self.attrs + assert_identical(result2, expected2) + + # skipna=False + maxindex3 = [ + x if y is None or ar0.dtype.kind == "O" else y + for x, y in zip(maxindex0, nanindex) + ] + expected3 = [ + coordarr0.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex3) + ] + expected3 = xr.concat(expected3, dim="y") + expected3.name = "x" + expected3.attrs = {} + + result3 = ar0.idxmax(dim="x", skipna=False) + assert_identical(result3, expected3) + + # fill_value should be ignored with skipna=False + result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) + assert_identical(result4, expected3) + + # Float fill_value + nan_mult_5 = np.array([-1.1 if x else 1 for x in hasna])[:, None] + expected5 = [ + (coordarr1 * nan_mult_5).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex0) + ] + expected5 = xr.concat(expected5, dim="y") + expected5.name = "x" + + result5 = ar0.idxmax(dim="x", fill_value=-1.1) + assert_identical(result5, expected5) + + # Integer fill_value + nan_mult_6 = np.array([-1 if x else 1 for x in hasna])[:, None] + expected6 = [ + (coordarr1 * nan_mult_6).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex0) + ] + expected6 = xr.concat(expected6, dim="y") + expected6.name = "x" + + result6 = ar0.idxmax(dim="x", fill_value=-1) + assert_identical(result6, expected6) + + # Complex fill_value + nan_mult_7 = np.array([-5j if x else 1 for x in hasna])[:, None] + expected7 = [ + (coordarr1 * nan_mult_7).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex0) + ] + expected7 = xr.concat(expected7, dim="y") + expected7.name = "x" + + result7 = ar0.idxmax(dim="x", fill_value=-5j) + assert_identical(result7, expected7) + + @pytest.fixture(params=[1]) def da(request): if request.param == 1: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 02698253e5d..237c315583c 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4584,6 +4584,7 @@ def test_reduce_non_numeric(self): def test_reduce_strings(self): expected = Dataset({"x": "a"}) ds = Dataset({"x": ("y", ["a", "b"])}) + ds.coords["y"] = [-10, 10] actual = ds.min() assert_identical(expected, actual) @@ -4599,6 +4600,14 @@ def test_reduce_strings(self): actual = ds.argmax() assert_identical(expected, actual) + expected = Dataset({"x": -10}) + actual = ds.idxmin() + assert_identical(expected, actual) + + expected = Dataset({"x": 10}) + actual = ds.idxmax() + assert_identical(expected, actual) + expected = Dataset({"x": b"a"}) ds = Dataset({"x": ("y", np.array(["a", "b"], "S1"))}) actual = ds.min()