Skip to content

Commit

Permalink
implement idxmax and idxmin
Browse files Browse the repository at this point in the history
  • Loading branch information
toddrjen committed Mar 20, 2020
1 parent 3f34b95 commit 925b5c8
Show file tree
Hide file tree
Showing 5 changed files with 983 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ Computation
:py:attr:`~Dataset.any`
:py:attr:`~Dataset.argmax`
:py:attr:`~Dataset.argmin`
:py:attr:`~Dataset.idxmax`
:py:attr:`~Dataset.idxmin`
:py:attr:`~Dataset.max`
:py:attr:`~Dataset.mean`
:py:attr:`~Dataset.median`
Expand Down Expand Up @@ -359,6 +361,8 @@ Computation
:py:attr:`~DataArray.any`
:py:attr:`~DataArray.argmax`
:py:attr:`~DataArray.argmin`
:py:attr:`~DataArray.idxmax`
:py:attr:`~DataArray.idxmin`
:py:attr:`~DataArray.max`
:py:attr:`~DataArray.mean`
:py:attr:`~DataArray.median`
Expand Down
221 changes: 221 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
from .indexes import Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, _extract_indexes_from_coords
from .nanops import dask_array
from .options import OPTIONS
from .pycompat import dask_array_type
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
from .variable import (
IndexVariable,
Expand Down Expand Up @@ -3430,6 +3432,225 @@ def pad(
)
return self._from_temp_dataset(ds)

def _calc_idxminmax(
self,
*,
func: str,
dim: Optional[Hashable],
axis: Optional[int],
skipna: Optional[bool],
promote: Optional[bool],
keep_attrs: Optional[bool],
**kwargs: Any,
) -> "DataArray":
"""Apply common operations for idxmin and idxmax."""
# This function doesn't make sense for scalars so don't try
if not self.ndim:
ValueError("This function does not apply for scalars")

if dim is not None:
pass # Use the dim if available
elif axis is not None:
dim = self.dims[axis]
elif self.ndim == 1:
# it is okay to guess the dim if there is only 1
dim = self.dims[0]
else:
# The dim is not specified and ambiguous. Don't guess.
raise ValueError(
"Must supply either 'dim' or 'axis' argument "
"for multidimensional arrays"
)

if dim in self.coords:
pass # This is okay
elif axis is not None:
raise IndexError(f'Axis "{axis}" does not have coordinates')
else:
raise KeyError(f'Dimension "{dim}" does not have coordinates')

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cf0"

if skipna or (skipna is None and self.dtype.kind in na_dtypes):
# Need to skip NaN values since argmin and argmax can't handle them
allna = self.isnull().all(dim)
array = self.where(~allna, 0)
hasna = allna.any()
else:
array = self
allna = None
hasna = False

# If promote is None we only promote if there are NaN values.
if promote is None:
promote = hasna

if not promote and hasna and array.coords[dim].dtype.kind not in na_dtypes:
raise TypeError(
"NaN values present for NaN-incompatible dtype and Promote=False"
)

# This will run argmin or argmax.
indx = getattr(array, func)(
dim=dim, axis=None, keep_attrs=False, skipna=skipna, **kwargs
)

# Get the coordinate we want.
coordarray = array[dim]

# Handle dask arrays.
if isinstance(array, dask_array_type):
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
else:
res = coordarray[
indx,
]

# Promote to a dtype that can handle NaN values if needed.
newdtype, fill_value = dtypes.maybe_promote(res.dtype)
if promote and newdtype != res.dtype:
res = res.astype(newdtype)

# Put the NaN values back in after removing them, if necessary.
if hasna and allna is not None:
res = res.where(~allna, fill_value)

# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]

# Put the attrs back in if needed
if keep_attrs:
res.attrs = self.attrs

return res

def idxmin(
self,
dim: Optional[Hashable] = None,
axis: Optional[int] = None,
skipna: Optional[bool] = None,
promote: Optional[bool] = None,
keep_attrs: Optional[bool] = False,
**kwargs: Any,
) -> "DataArray":
"""Return the coordinate of the minimum value along a dimension.
Returns a new DataArray named after the dimension with the values of
the coordinate along that dimension corresponding to minimum value
along that dimension.
In comparison to `argmin`, this returns the coordinate while `argmin`
returns the index.
Parameters
----------
dim : str (optional)
Dimension over which to apply `idxmin`.
axis : int (optional)
Axis(es) over which to repeatedly apply `idxmin`. Exactly one of
the 'dim' and 'axis' arguments must be supplied.
skipna : bool, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).{min_count_docs}
promote : bool (optional)
If True (default) dtypes that do not support NaN values will be
automatically promoted to those that do. If False a NaN in the
results will raise a TypeError. If None the result will only be
promoted if a NaN is actually present.
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to the appropriate array
function for calculating `{name}` on this object's data.
Returns
-------
reduced : DataArray
New DataArray object with `idxmin` applied to its data and the
indicated dimension removed.
See also
--------
Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin
"""
return self._calc_idxminmax(
func="argmin",
dim=dim,
axis=axis,
skipna=skipna,
promote=promote,
keep_attrs=keep_attrs,
**kwargs,
)

def idxmax(
self,
dim: Optional[Hashable] = None,
axis: Optional[int] = None,
skipna: Optional[bool] = None,
promote: Optional[bool] = None,
keep_attrs: Optional[bool] = False,
**kwargs: Any,
) -> "DataArray":
"""Return the coordinate of the maximum value along a dimension.
Returns a new DataArray named after the dimension with the values of
the coordinate along that dimension corresponding to maximum value
along that dimension.
In comparison to `argmax`, this returns the coordinate while `argmax`
returns the index.
Parameters
----------
dim : str (optional)
Dimension over which to apply `idxmax`.
axis : int (optional)
Axis(es) over which to repeatedly apply `idxmax`. Exactly one of
the 'dim' and 'axis' arguments must be supplied.
skipna : bool (optional)
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).{min_count_docs}
promote : bool (optional)
If True (default) dtypes that do not support NaN values will be
automatically promoted to those that do. If False a NaN in the
results will raise a TypeError. If None the result will only be
promoted if a NaN is actually present.
keep_attrs : bool (optional)
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to the appropriate array
function for calculating `{name}` on this object's data.
Returns
-------
reduced : DataArray
New DataArray object with `idxmax` applied to its data and the
indicated dimension removed.
See also
--------
Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax
"""
return self._calc_idxminmax(
func="argmax",
dim=dim,
axis=axis,
skipna=skipna,
promote=promote,
keep_attrs=keep_attrs,
**kwargs,
)

# this needs to be at the end, or mypy will confuse with `str`
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
str = property(StringAccessor)
Expand Down
126 changes: 126 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5921,5 +5921,131 @@ def pad(

return self._replace_vars_and_dims(variables)

def idxmin(
self,
dim: Optional[Hashable] = None,
axis: Optional[int] = None,
skipna: Optional[bool] = None,
promote: Optional[bool] = None,
keep_attrs: Optional[bool] = False,
**kwargs: Any,
) -> "Dataset":
"""Return the coordinate of the minimum value along a dimension.
Returns a new Dataset named after the dimension with the values of
the coordinate along that dimension corresponding to minimum value
along that dimension.
In comparison to `argmin`, this returns the coordinate while `argmin`
returns the index.
Parameters
----------
dim : str (optional)
Dimension over which to apply `idxmin`.
axis : int (optional)
Axis(es) over which to repeatedly apply `idxmin`. Exactly one of
the 'dim' and 'axis' arguments must be supplied.
skipna : bool, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).{min_count_docs}
promote : bool (optional)
If True (default) dtypes that do not support NaN values will be
automatically promoted to those that do. If False a NaN in the
results will raise a TypeError. If None the result will only be
promoted if a NaN is actually present.
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to the appropriate array
function for calculating `{name}` on this object's data.
Returns
-------
reduced : Dataset
New Dataset object with `idxmin` applied to its data and the
indicated dimension removed.
See also
--------
DataArray.idxmin, Dataset.idxmax, Dataset.min, Dataset.argmin
"""
return self.map(
"idxmin",
dim=dim,
axis=axis,
skipna=skipna,
promote=promote,
keep_attrs=keep_attrs,
**kwargs,
)

def idxmax(
self,
dim: Optional[Hashable] = None,
axis: Optional[int] = None,
skipna: Optional[bool] = None,
promote: Optional[bool] = None,
keep_attrs: Optional[bool] = False,
**kwargs: Any,
) -> "Dataset":
"""Return the coordinate of the maximum value along a dimension.
Returns a new Dataset named after the dimension with the values of
the coordinate along that dimension corresponding to maximum value
along that dimension.
In comparison to `argmax`, this returns the coordinate while `argmax`
returns the index.
Parameters
----------
dim : str (optional)
Dimension over which to apply `idxmax`.
axis : int (optional)
Axis(es) over which to repeatedly apply `idxmax`. Exactly one of
the 'dim' and 'axis' arguments must be supplied.
skipna : bool (optional)
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).{min_count_docs}
promote : bool (optional)
If True (default) dtypes that do not support NaN values will be
automatically promoted to those that do. If False a NaN in the
results will raise a TypeError. If None the result will only be
promoted if a NaN is actually present.
keep_attrs : bool (optional)
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to the appropriate array
function for calculating `{name}` on this object's data.
Returns
-------
reduced : Dataset
New Dataset object with `idxmax` applied to its data and the
indicated dimension removed.
See also
--------
DataArray.idxmax, Dataset.idxmin, Dataset.max, Dataset.argmax
"""
return self.map(
"idxmax",
dim=dim,
axis=axis,
skipna=skipna,
promote=promote,
keep_attrs=keep_attrs,
**kwargs,
)


ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False)
Loading

0 comments on commit 925b5c8

Please sign in to comment.