diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index 507c6111b..f833f32d6 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -1,15 +1,27 @@ +from functools import partial + +import numbagg +import numbagg.grouped import numpy as np -from numbagg.grouped import group_nanmean, group_nansum -def nansum_of_squares( - group_idx, array, *, axis=-1, func="sum", size=None, fill_value=None, dtype=None +def _numbagg_wrapper( + group_idx, + array, + *, + axis=-1, + func="sum", + size=None, + fill_value=None, + dtype=None, + numbagg_func=None, ): - return group_nansum( - array**2, + return numbagg_func( + array, group_idx, axis=axis, num_labels=size, + # The following are unsupported # fill_value=fill_value, # dtype=dtype, ) @@ -18,7 +30,7 @@ def nansum_of_squares( def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): if np.issubdtype(array.dtype, np.bool_): array = array.astype(np.in64) - return group_nansum( + return numbagg.grouped.group_nansum( array, group_idx, axis=axis, @@ -31,7 +43,7 @@ def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None) def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): if np.issubdtype(array.dtype, np.int_): array = array.astype(np.float64) - return group_nanmean( + return numbagg.grouped.group_nanmean( array, group_idx, axis=axis, @@ -41,16 +53,13 @@ def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None ) -def nanlen(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): - return group_nansum( - (~np.isnan(array)).astype(int), - group_idx, - axis=axis, - num_labels=size, - # fill_value=fill_value, - # dtype=dtype, - ) - +nansum_of_squares = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nansum_of_squares) +nanlen = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_count) +nanprod = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanprod) +nanfirst = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanfirst) +nanlast = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanlast) +nanargmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmax) +nanargmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmin) # sum = nansum # mean = nanmean