diff --git a/flox/aggregations.py b/flox/aggregations.py index c61ef31a..19106705 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -10,7 +10,7 @@ import numpy as np from numpy.typing import ArrayLike, DTypeLike -from . import aggregate_flox, aggregate_npg, xrutils +from . import aggregate_flox, aggregate_npg, sketches, xrutils from . import xrdtypes as dtypes if TYPE_CHECKING: @@ -119,7 +119,10 @@ def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) - elif not isinstance(dtype, np.dtype): dtype = np.dtype(dtype) if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]: - dtype = np.result_type(dtype, fill_value) + try: + dtype = np.result_type(dtype, fill_value) + except TypeError: + pass return dtype @@ -567,6 +570,30 @@ def quantile_new_dims_func(q) -> tuple[Dim]: mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None) nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None) + +from crick import TDigest + +quantile_tdigest = Aggregation( + "quantile_tdigest", + numpy=(sketches.tdigest_aggregate,), + chunk=(sketches.tdigest_chunk,), + combine=(sketches.tdigest_combine,), + finalize=sketches.tdigest_aggregate, + fill_value=TDigest(), + final_dtype=np.float64, +) + +nanquantile_tdigest = Aggregation( + "nanquantile_tdigest", + numpy=(sketches.tdigest_aggregate,), + chunk=(sketches.tdigest_chunk,), + combine=(sketches.tdigest_combine,), + finalize=sketches.tdigest_aggregate, + fill_value=TDigest(), + final_dtype=np.float64, +) + + aggregations = { "any": any_, "all": all_, @@ -599,6 +626,8 @@ def quantile_new_dims_func(q) -> tuple[Dim]: "nanquantile": nanquantile, "mode": mode, "nanmode": nanmode, + "quantile_tdigest": quantile_tdigest, + "nanquantile_tdigest": nanquantile_tdigest, } diff --git a/flox/sketches.py b/flox/sketches.py new file mode 100644 index 00000000..2f70d6c5 --- /dev/null +++ b/flox/sketches.py @@ -0,0 +1,42 @@ +import numpy as np +import numpy_groupies as npg + + +def tdigest_chunk(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, **kwargs): + from crick import TDigest + + def _(arr): + digest = TDigest() + # we receive object arrays from numpy_groupies + digest.update(arr.astype(array.dtype, copy=False)) + return digest + + result = npg.aggregate_numpy.aggregate( + group_idx, array, func=_, size=size, fill_value=fill_value, axis=axis, dtype=object + ) + return result + + +def tdigest_combine(digests, axis=-1, keepdims=True): + from crick import TDigest + + def _(arr): + t = TDigest() + t.merge(*arr) + return np.array([t], dtype=object) + + if not isinstance(axis, tuple): + axis = (axis,) + + # If reducing along multiple axes, we can just keep combining ;) + result = digests + for ax in axis: + result = np.apply_along_axis(_, ax, result) + + return result + + +def tdigest_aggregate(digests, q, axis=-1, keepdims=True): + for idx in np.ndindex(digests.shape): + digests[idx] = digests[idx].quantile(q) + return digests diff --git a/pyproject.toml b/pyproject.toml index d9f99cd0..512a04a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,7 @@ exclude=["asv_bench/pkgs"] module=[ "asv_runner.*", "cachey", + "crick", "cftime", "cubed.*", "dask.*",