From 4b0f4e3a76571bca3e2723439bd214d8a16e5535 Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Sat, 21 Mar 2020 21:37:39 -0400 Subject: [PATCH] make some more requested changes --- xarray/core/computation.py | 93 +++++++++++++++++++++++++++++++- xarray/core/dataarray.py | 105 +++---------------------------------- xarray/core/dataset.py | 4 +- 3 files changed, 100 insertions(+), 102 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f2941a3d0ba..93378616c03 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -23,9 +23,10 @@ import numpy as np -from . import duck_array_ops, utils +from . import dtypes, duck_array_ops, utils from .alignment import deep_align from .merge import merge_coordinates_without_align +from .nanops import dask_array from .options import OPTIONS from .pycompat import dask_array_type from .utils import is_dict_like @@ -1306,3 +1307,93 @@ def where(cond, x, y): dataset_join="exact", dask="allowed", ) + + +def _calc_idxminmax( + *, + array, + func: Callable, + dim: Optional[Hashable], + axis: Optional[int], + skipna: Optional[bool], + promote: Optional[bool], + keep_attrs: Optional[bool], + **kwargs: Any, +): + """Apply common operations for idxmin and idxmax.""" + # This function doesn't make sense for scalars so don't try + if not array.ndim: + ValueError("This function does not apply for scalars") + + if dim is not None: + pass # Use the dim if available + elif axis is not None: + dim = array.dims[axis] + elif array.ndim == 1: + # it is okay to guess the dim if there is only 1 + dim = array.dims[0] + else: + # The dim is not specified and ambiguous. Don't guess. + raise ValueError( + "Must supply either 'dim' or 'axis' argument " "for multidimensional arrays" + ) + + if dim in array.coords: + pass # This is okay + elif axis is not None: + raise IndexError(f'Axis "{axis}" does not have coordinates') + else: + raise KeyError(f'Dimension "{dim}" does not have coordinates') + + # These are dtypes with NaN values argmin and argmax can handle + na_dtypes = "cf0" + + if skipna or (skipna is None and array.dtype.kind in na_dtypes): + # Need to skip NaN values since argmin and argmax can't handle them + allna = array.isnull().all(dim) + array = array.where(~allna, 0) + hasna = allna.any() + else: + allna = None + hasna = False + + # If promote is None we only promote if there are NaN values. + if promote is None: + promote = hasna + + if not promote and hasna and array.coords[dim].dtype.kind not in na_dtypes: + raise TypeError( + "NaN values present for NaN-incompatible dtype and Promote=False" + ) + + # This will run argmin or argmax. + indx = func(array, dim=dim, axis=None, keep_attrs=False, skipna=skipna, **kwargs) + + # Get the coordinate we want. + coordarray = array[dim] + + # Handle dask arrays. + if isinstance(array, dask_array_type): + res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype) + else: + res = coordarray[ + indx, + ] + + # Promote to a dtype that can handle NaN values if needed. + newdtype, fill_value = dtypes.maybe_promote(res.dtype) + if promote and newdtype != res.dtype: + res = res.astype(newdtype) + + # Put the NaN values back in after removing them, if necessary. + if hasna and allna is not None: + res = res.where(~allna, fill_value) + + # The dim is gone but we need to remove the corresponding coordinate. + del res.coords[dim] + + # Put the attrs back in if needed + if keep_attrs: + res.attrs = array.attrs + + return res diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 883cb179b37..8e9f6943974 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -55,9 +55,7 @@ from .indexes import Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, _extract_indexes_from_coords -from .nanops import dask_array from .options import OPTIONS -from .pycompat import dask_array_type from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs from .variable import ( IndexVariable, @@ -3432,99 +3430,6 @@ def pad( ) return self._from_temp_dataset(ds) - def _calc_idxminmax( - self, - *, - func: Callable, - dim: Optional[Hashable], - axis: Optional[int], - skipna: Optional[bool], - promote: Optional[bool], - keep_attrs: Optional[bool], - **kwargs: Any, - ) -> "DataArray": - """Apply common operations for idxmin and idxmax.""" - # This function doesn't make sense for scalars so don't try - if not self.ndim: - ValueError("This function does not apply for scalars") - - if dim is not None: - pass # Use the dim if available - elif axis is not None: - dim = self.dims[axis] - elif self.ndim == 1: - # it is okay to guess the dim if there is only 1 - dim = self.dims[0] - else: - # The dim is not specified and ambiguous. Don't guess. - raise ValueError( - "Must supply either 'dim' or 'axis' argument " - "for multidimensional arrays" - ) - - if dim in self.coords: - pass # This is okay - elif axis is not None: - raise IndexError(f'Axis "{axis}" does not have coordinates') - else: - raise KeyError(f'Dimension "{dim}" does not have coordinates') - - # These are dtypes with NaN values argmin and argmax can handle - na_dtypes = "cf0" - - if skipna or (skipna is None and self.dtype.kind in na_dtypes): - # Need to skip NaN values since argmin and argmax can't handle them - allna = self.isnull().all(dim) - array = self.where(~allna, 0) - hasna = allna.any() - else: - array = self - allna = None - hasna = False - - # If promote is None we only promote if there are NaN values. - if promote is None: - promote = hasna - - if not promote and hasna and array.coords[dim].dtype.kind not in na_dtypes: - raise TypeError( - "NaN values present for NaN-incompatible dtype and Promote=False" - ) - - # This will run argmin or argmax. - indx = func( - array, dim=dim, axis=None, keep_attrs=False, skipna=skipna, **kwargs - ) - - # Get the coordinate we want. - coordarray = array[dim] - - # Handle dask arrays. - if isinstance(array, dask_array_type): - res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype) - else: - res = coordarray[ - indx, - ] - - # Promote to a dtype that can handle NaN values if needed. - newdtype, fill_value = dtypes.maybe_promote(res.dtype) - if promote and newdtype != res.dtype: - res = res.astype(newdtype) - - # Put the NaN values back in after removing them, if necessary. - if hasna and allna is not None: - res = res.where(~allna, fill_value) - - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] - - # Put the attrs back in if needed - if keep_attrs: - res.attrs = self.attrs - - return res - def idxmin( self, dim: Hashable = None, @@ -3555,7 +3460,7 @@ def idxmin( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - promote : bool or None, default True + promote : bool or None, default True If True (default) dtypes that do not support NaN values will be automatically promoted to those that do. If False a NaN in the results will raise a TypeError. If None the result will only be @@ -3616,7 +3521,8 @@ def idxmin( -------- Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin """ - return self._calc_idxminmax( + return computation._calc_idxminmax( + array=self, func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs), dim=dim, axis=axis, @@ -3656,7 +3562,7 @@ def idxmax( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - promote : bool or None, default True + promote : bool or None, default True If True (default) dtypes that do not support NaN values will be automatically promoted to those that do. If False a NaN in the results will raise a TypeError. If None the result will only be @@ -3717,7 +3623,8 @@ def idxmax( -------- Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax """ - return self._calc_idxminmax( + return computation._calc_idxminmax( + array=self, func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs), dim=dim, axis=axis, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4edd164535b..34947a810a4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5951,7 +5951,7 @@ def idxmin( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - promote : bool or None, default True + promote : bool or None, default True If True (default) dtypes that do not support NaN values will be automatically promoted to those that do. If False a NaN in the results will raise a TypeError. If None the result will only be @@ -6052,7 +6052,7 @@ def idxmax( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - promote : bool or None, default True + promote : bool or None, default True If True (default) dtypes that do not support NaN values will be automatically promoted to those that do. If False a NaN in the results will raise a TypeError. If None the result will only be