Skip to content

Commit

Permalink
Support more aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Oct 3, 2023
1 parent e92f483 commit 76317f0
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions flox/aggregate_numbagg.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
from functools import partial

import numbagg
import numbagg.grouped
import numpy as np
from numbagg.grouped import group_nanmean, group_nansum


def nansum_of_squares(
group_idx, array, *, axis=-1, func="sum", size=None, fill_value=None, dtype=None
def _numbagg_wrapper(
group_idx,
array,
*,
axis=-1,
func="sum",
size=None,
fill_value=None,
dtype=None,
numbagg_func=None,
):
return group_nansum(
array**2,
return numbagg_func(
array,
group_idx,
axis=axis,
num_labels=size,
# The following are unsupported
# fill_value=fill_value,
# dtype=dtype,
)
Expand All @@ -18,7 +30,7 @@ def nansum_of_squares(
def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
if np.issubdtype(array.dtype, np.bool_):
array = array.astype(np.in64)
return group_nansum(
return numbagg.grouped.group_nansum(
array,
group_idx,
axis=axis,
Expand All @@ -31,7 +43,7 @@ def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None)
def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
if np.issubdtype(array.dtype, np.int_):
array = array.astype(np.float64)
return group_nanmean(
return numbagg.grouped.group_nanmean(
array,
group_idx,
axis=axis,
Expand All @@ -41,16 +53,13 @@ def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None
)


def nanlen(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
return group_nansum(
(~np.isnan(array)).astype(int),
group_idx,
axis=axis,
num_labels=size,
# fill_value=fill_value,
# dtype=dtype,
)

nansum_of_squares = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nansum_of_squares)
nanlen = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_count)
nanprod = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanprod)
nanfirst = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanfirst)
nanlast = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanlast)
nanargmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmax)
nanargmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmin)

# sum = nansum
# mean = nanmean
Expand Down

0 comments on commit 76317f0

Please sign in to comment.