Skip to content

Commit

Permalink
make some more requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
toddrjen committed Mar 22, 2020
1 parent a63c656 commit 4b0f4e3
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 102 deletions.
93 changes: 92 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

import numpy as np

from . import duck_array_ops, utils
from . import dtypes, duck_array_ops, utils
from .alignment import deep_align
from .merge import merge_coordinates_without_align
from .nanops import dask_array
from .options import OPTIONS
from .pycompat import dask_array_type
from .utils import is_dict_like
Expand Down Expand Up @@ -1306,3 +1307,93 @@ def where(cond, x, y):
dataset_join="exact",
dask="allowed",
)


def _calc_idxminmax(
*,
array,
func: Callable,
dim: Optional[Hashable],
axis: Optional[int],
skipna: Optional[bool],
promote: Optional[bool],
keep_attrs: Optional[bool],
**kwargs: Any,
):
"""Apply common operations for idxmin and idxmax."""
# This function doesn't make sense for scalars so don't try
if not array.ndim:
ValueError("This function does not apply for scalars")

if dim is not None:
pass # Use the dim if available
elif axis is not None:
dim = array.dims[axis]
elif array.ndim == 1:
# it is okay to guess the dim if there is only 1
dim = array.dims[0]
else:
# The dim is not specified and ambiguous. Don't guess.
raise ValueError(
"Must supply either 'dim' or 'axis' argument " "for multidimensional arrays"
)

if dim in array.coords:
pass # This is okay
elif axis is not None:
raise IndexError(f'Axis "{axis}" does not have coordinates')
else:
raise KeyError(f'Dimension "{dim}" does not have coordinates')

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cf0"

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Need to skip NaN values since argmin and argmax can't handle them
allna = array.isnull().all(dim)
array = array.where(~allna, 0)
hasna = allna.any()
else:
allna = None
hasna = False

# If promote is None we only promote if there are NaN values.
if promote is None:
promote = hasna

if not promote and hasna and array.coords[dim].dtype.kind not in na_dtypes:
raise TypeError(
"NaN values present for NaN-incompatible dtype and Promote=False"
)

# This will run argmin or argmax.
indx = func(array, dim=dim, axis=None, keep_attrs=False, skipna=skipna, **kwargs)

# Get the coordinate we want.
coordarray = array[dim]

# Handle dask arrays.
if isinstance(array, dask_array_type):
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
else:
res = coordarray[
indx,
]

# Promote to a dtype that can handle NaN values if needed.
newdtype, fill_value = dtypes.maybe_promote(res.dtype)
if promote and newdtype != res.dtype:
res = res.astype(newdtype)

# Put the NaN values back in after removing them, if necessary.
if hasna and allna is not None:
res = res.where(~allna, fill_value)

# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]

# Put the attrs back in if needed
if keep_attrs:
res.attrs = array.attrs

return res
105 changes: 6 additions & 99 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@
from .indexes import Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, _extract_indexes_from_coords
from .nanops import dask_array
from .options import OPTIONS
from .pycompat import dask_array_type
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
from .variable import (
IndexVariable,
Expand Down Expand Up @@ -3432,99 +3430,6 @@ def pad(
)
return self._from_temp_dataset(ds)

def _calc_idxminmax(
self,
*,
func: Callable,
dim: Optional[Hashable],
axis: Optional[int],
skipna: Optional[bool],
promote: Optional[bool],
keep_attrs: Optional[bool],
**kwargs: Any,
) -> "DataArray":
"""Apply common operations for idxmin and idxmax."""
# This function doesn't make sense for scalars so don't try
if not self.ndim:
ValueError("This function does not apply for scalars")

if dim is not None:
pass # Use the dim if available
elif axis is not None:
dim = self.dims[axis]
elif self.ndim == 1:
# it is okay to guess the dim if there is only 1
dim = self.dims[0]
else:
# The dim is not specified and ambiguous. Don't guess.
raise ValueError(
"Must supply either 'dim' or 'axis' argument "
"for multidimensional arrays"
)

if dim in self.coords:
pass # This is okay
elif axis is not None:
raise IndexError(f'Axis "{axis}" does not have coordinates')
else:
raise KeyError(f'Dimension "{dim}" does not have coordinates')

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cf0"

if skipna or (skipna is None and self.dtype.kind in na_dtypes):
# Need to skip NaN values since argmin and argmax can't handle them
allna = self.isnull().all(dim)
array = self.where(~allna, 0)
hasna = allna.any()
else:
array = self
allna = None
hasna = False

# If promote is None we only promote if there are NaN values.
if promote is None:
promote = hasna

if not promote and hasna and array.coords[dim].dtype.kind not in na_dtypes:
raise TypeError(
"NaN values present for NaN-incompatible dtype and Promote=False"
)

# This will run argmin or argmax.
indx = func(
array, dim=dim, axis=None, keep_attrs=False, skipna=skipna, **kwargs
)

# Get the coordinate we want.
coordarray = array[dim]

# Handle dask arrays.
if isinstance(array, dask_array_type):
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
else:
res = coordarray[
indx,
]

# Promote to a dtype that can handle NaN values if needed.
newdtype, fill_value = dtypes.maybe_promote(res.dtype)
if promote and newdtype != res.dtype:
res = res.astype(newdtype)

# Put the NaN values back in after removing them, if necessary.
if hasna and allna is not None:
res = res.where(~allna, fill_value)

# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]

# Put the attrs back in if needed
if keep_attrs:
res.attrs = self.attrs

return res

def idxmin(
self,
dim: Hashable = None,
Expand Down Expand Up @@ -3555,7 +3460,7 @@ def idxmin(
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
promote : bool or None, default True
promote : bool or None, default True
If True (default) dtypes that do not support NaN values will be
automatically promoted to those that do. If False a NaN in the
results will raise a TypeError. If None the result will only be
Expand Down Expand Up @@ -3616,7 +3521,8 @@ def idxmin(
--------
Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin
"""
return self._calc_idxminmax(
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs),
dim=dim,
axis=axis,
Expand Down Expand Up @@ -3656,7 +3562,7 @@ def idxmax(
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
promote : bool or None, default True
promote : bool or None, default True
If True (default) dtypes that do not support NaN values will be
automatically promoted to those that do. If False a NaN in the
results will raise a TypeError. If None the result will only be
Expand Down Expand Up @@ -3717,7 +3623,8 @@ def idxmax(
--------
Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax
"""
return self._calc_idxminmax(
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs),
dim=dim,
axis=axis,
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5951,7 +5951,7 @@ def idxmin(
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
promote : bool or None, default True
promote : bool or None, default True
If True (default) dtypes that do not support NaN values will be
automatically promoted to those that do. If False a NaN in the
results will raise a TypeError. If None the result will only be
Expand Down Expand Up @@ -6052,7 +6052,7 @@ def idxmax(
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
promote : bool or None, default True
promote : bool or None, default True
If True (default) dtypes that do not support NaN values will be
automatically promoted to those that do. If False a NaN in the
results will raise a TypeError. If None the result will only be
Expand Down

0 comments on commit 4b0f4e3

Please sign in to comment.