diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index c2b718e8..2bef93a5 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -30,6 +30,8 @@ "nanmean": {np.int_: np.float64}, "nanvar": {np.int_: np.float64}, "nanstd": {np.int_: np.float64}, + "nanfirst": {np.datetime64: np.int64, np.timedelta64: np.int64}, + "nanlast": {np.datetime64: np.int64, np.timedelta64: np.int64}, } @@ -51,7 +53,7 @@ def _numbagg_wrapper( if cast_to: for from_, to_ in cast_to.items(): if np.issubdtype(array.dtype, from_): - array = array.astype(to_) + array = array.astype(to_, copy=False) func_ = getattr(numbagg.grouped, f"group_{func}") diff --git a/flox/core.py b/flox/core.py index 91903ded..7f15ae2f 100644 --- a/flox/core.py +++ b/flox/core.py @@ -45,6 +45,10 @@ ) from .cache import memoize from .xrutils import ( + _contains_cftime_datetimes, + _datetime_nanmin, + _to_pytimedelta, + datetime_to_numeric, is_chunked_array, is_duck_array, is_duck_cubed_array, @@ -2473,7 +2477,8 @@ def groupby_reduce( has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_) - if _is_first_last_reduction(func): + is_first_last = _is_first_last_reduction(func) + if is_first_last: if has_dask and nax != 1: raise ValueError( "For dask arrays: first, last, nanfirst, nanlast reductions are " @@ -2486,6 +2491,24 @@ def groupby_reduce( "along a single axis or when reducing across all dimensions of `by`." ) + # Flox's count works with non-numeric and its faster than converting. + is_npdatetime = array.dtype.kind in "Mm" + is_cftime = _contains_cftime_datetimes(array) + requires_numeric = ( + (func not in ["count", "any", "all"] and not is_first_last) + or (func == "count" and engine != "flox") + or (is_first_last and is_cftime) + ) + if requires_numeric: + if is_npdatetime: + offset = _datetime_nanmin(array) + # xarray always uses np.datetime64[ns] for np.datetime64 data + dtype = "timedelta64[ns]" + array = datetime_to_numeric(array, offset) + elif is_cftime: + offset = array.min() + array = datetime_to_numeric(array, offset, datetime_unit="us") + if nax == 1 and by_.ndim > 1 and expected_ is None: # When we reduce along all axes, we are guaranteed to see all # groups in the final combine stage, so everything works. @@ -2671,6 +2694,14 @@ def groupby_reduce( if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)): result = result.astype(bool) + + # Output of count has an int dtype. + if requires_numeric and func != "count": + if is_npdatetime: + return result.astype(dtype) + offset + elif is_cftime: + return _to_pytimedelta(result, unit="us") + offset + return (result, *groups) diff --git a/flox/xarray.py b/flox/xarray.py index 1562acc8..fbeeedba 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -7,7 +7,6 @@ import pandas as pd import xarray as xr from packaging.version import Version -from xarray.core.duck_array_ops import _datetime_nanmin from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func from .core import ( @@ -18,7 +17,6 @@ ) from .core import rechunk_for_blockwise as rechunk_array_for_blockwise from .core import rechunk_for_cohorts as rechunk_array_for_cohorts -from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric if TYPE_CHECKING: from xarray.core.types import T_DataArray, T_Dataset @@ -366,22 +364,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): if "nan" not in func and func not in ["all", "any", "count"]: func = f"nan{func}" - # Flox's count works with non-numeric and its faster than converting. - requires_numeric = func not in ["count", "any", "all"] or ( - func == "count" and kwargs["engine"] != "flox" - ) - if requires_numeric: - is_npdatetime = array.dtype.kind in "Mm" - is_cftime = _contains_cftime_datetimes(array) - if is_npdatetime: - offset = _datetime_nanmin(array) - # xarray always uses np.datetime64[ns] for np.datetime64 data - dtype = "timedelta64[ns]" - array = datetime_to_numeric(array, offset) - elif is_cftime: - offset = array.min() - array = datetime_to_numeric(array, offset, datetime_unit="us") - result, *groups = groupby_reduce(array, *by, func=func, **kwargs) # Transpose the new quantile dimension to the end. This is ugly. @@ -395,13 +377,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): # output dim order: (*broadcast_dims, *group_dims, quantile_dim) result = np.moveaxis(result, 0, -1) - # Output of count has an int dtype. - if requires_numeric and func != "count": - if is_npdatetime: - return result.astype(dtype) + offset - elif is_cftime: - return _to_pytimedelta(result, unit="us") + offset - return result # These data variables do not have any of the core dimension, diff --git a/flox/xrutils.py b/flox/xrutils.py index ba8a5672..9ae0ae02 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -345,6 +345,28 @@ def _contains_cftime_datetimes(array) -> bool: return False +def _datetime_nanmin(array): + """nanmin() function for datetime64. + + Caveats that this function deals with: + + - In numpy < 1.18, min() on datetime64 incorrectly ignores NaT + - numpy nanmin() don't work on datetime64 (all versions at the moment of writing) + - dask min() does not work on datetime64 (all versions at the moment of writing) + """ + from .xrdtypes import is_datetime_like + + dtype = array.dtype + assert is_datetime_like(dtype) + # (NaT).astype(float) does not produce NaN... + array = np.where(pd.isnull(array), np.nan, array.astype(float)) + array = min(array, skipna=True) + if isinstance(array, float): + array = np.array(array) + # ...but (NaN).astype("M8") does produce NaT + return array.astype(dtype) + + def _select_along_axis(values, idx, axis): other_ind = np.ix_(*[np.arange(s) for s in idx.shape]) sl = other_ind[:axis] + (idx,) + other_ind[axis:] diff --git a/tests/test_core.py b/tests/test_core.py index 164f87b3..055c641c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2006,3 +2006,19 @@ def test_blockwise_avoid_rechunk(): actual, groups = groupby_reduce(array, by, func="first") assert_equal(groups, ["", "0", "1"]) assert_equal(actual, np.array([0, 0, 0], dtype=np.int64)) + + +@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"]) +def test_datetime_timedelta_first_last(engine, func): + import flox + + idx = 0 if "first" in func else -1 + + dt = pd.date_range("2001-01-01", freq="d", periods=5).values + by = np.ones(dt.shape, dtype=int) + actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine) + assert_equal(actual, dt[[idx]]) + + dt = dt - dt[0] + actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine) + assert_equal(actual, dt[[idx]])