diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 7174552c..c23c5471 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -2,6 +2,7 @@ import numpy as np +from . import xrdtypes as dtypes from .xrutils import is_scalar, isnull, notnull @@ -98,7 +99,7 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non # partition the complex array in-place labels_broadcast = np.broadcast_to(group_idx, array.shape) with np.errstate(invalid="ignore"): - cmplx = labels_broadcast + 1j * array + cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array) cmplx.partition(kth=kth, axis=-1) if is_scalar_q: a_ = cmplx.imag @@ -158,6 +159,8 @@ def _np_grouped_op( def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs): + if fillna in [dtypes.INF, dtypes.NINF]: + fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna) result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs) # np.nanmax([np.nan, np.nan]) = np.nan # To recover this behaviour, we need to search for the fillna value @@ -175,9 +178,9 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs): prod = partial(_np_grouped_op, op=np.multiply.reduceat) nanprod = partial(_nan_grouped_op, func=prod, fillna=1) max = partial(_np_grouped_op, op=np.maximum.reduceat) -nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf) +nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF) min = partial(_np_grouped_op, op=np.minimum.reduceat) -nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf) +nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF) quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False)) nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True)) median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False)) diff --git a/flox/aggregations.py b/flox/aggregations.py index 51e650a6..5515a5b6 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -115,60 +115,6 @@ def generic_aggregate( return result -def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype: - if dtype is None: - dtype = array_dtype - if dtype is np.floating: - # mean, std, var always result in floating - # but we preserve the array's dtype if it is floating - if array_dtype.kind in "fcmM": - dtype = array_dtype - else: - dtype = np.dtype("float64") - elif not isinstance(dtype, np.dtype): - dtype = np.dtype(dtype) - if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]: - dtype = np.result_type(dtype, fill_value) - return dtype - - -def _maybe_promote_int(dtype) -> np.dtype: - # https://numpy.org/doc/stable/reference/generated/numpy.prod.html - # The dtype of a is used by default unless a has an integer dtype of less precision - # than the default platform integer. - if not isinstance(dtype, np.dtype): - dtype = np.dtype(dtype) - if dtype.kind == "i": - dtype = np.result_type(dtype, np.intp) - elif dtype.kind == "u": - dtype = np.result_type(dtype, np.uintp) - return dtype - - -def _get_fill_value(dtype, fill_value): - """Returns dtype appropriate infinity. Returns +Inf equivalent for None.""" - if fill_value in [None, dtypes.NA] and dtype.kind in "US": - return "" - if fill_value == dtypes.INF or fill_value is None: - return dtypes.get_pos_infinity(dtype, max_for_int=True) - if fill_value == dtypes.NINF: - return dtypes.get_neg_infinity(dtype, min_for_int=True) - if fill_value == dtypes.NA: - if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating): - return np.nan - # This is madness, but npg checks that fill_value is compatible - # with array dtype even if the fill_value is never used. - elif ( - np.issubdtype(dtype, np.integer) - or np.issubdtype(dtype, np.timedelta64) - or np.issubdtype(dtype, np.datetime64) - ): - return dtypes.get_neg_infinity(dtype, min_for_int=True) - else: - return None - return fill_value - - def _atleast_1d(inp, min_length: int = 1): if xrutils.is_scalar(inp): inp = (inp,) * min_length @@ -210,6 +156,7 @@ def __init__( final_dtype: DTypeLike | None = None, reduction_type: Literal["reduce", "argreduce"] = "reduce", new_dims_func: Callable | None = None, + preserves_dtype: bool = False, ): """ Blueprint for computing grouped aggregations. @@ -256,6 +203,8 @@ def __init__( Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2, so returns (2,) + preserves_dtype: bool, + Whether a function preserves the dtype on return E.g. min, max, first, last, mode """ self.name = name # preprocess before blockwise @@ -292,6 +241,7 @@ def __init__( self.new_dims_func: Callable = ( returns_empty_tuple if new_dims_func is None else new_dims_func ) + self.preserves_dtype = preserves_dtype @cached_property def new_dims(self) -> tuple[Dim]: @@ -434,10 +384,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0): ) -min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF) -nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan) -max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF) -nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan) +min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, preserves_dtype=True) +nanmin = Aggregation( + "nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA, preserves_dtype=True +) +max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True) +nanmax = Aggregation( + "nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA, preserves_dtype=True +) def argreduce_preprocess(array, axis): @@ -525,10 +479,14 @@ def _pick_second(*x): final_dtype=np.intp, ) -first = Aggregation("first", chunk=None, combine=None, fill_value=None) -last = Aggregation("last", chunk=None, combine=None, fill_value=None) -nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA) -nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA) +first = Aggregation("first", chunk=None, combine=None, fill_value=None, preserves_dtype=True) +last = Aggregation("last", chunk=None, combine=None, fill_value=None, preserves_dtype=True) +nanfirst = Aggregation( + "nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA, preserves_dtype=True +) +nanlast = Aggregation( + "nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA, preserves_dtype=True +) all_ = Aggregation( "all", @@ -579,8 +537,12 @@ def quantile_new_dims_func(q) -> tuple[Dim]: final_dtype=np.floating, new_dims_func=quantile_new_dims_func, ) -mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None) -nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None) +mode = Aggregation( + name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True +) +nanmode = Aggregation( + name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True +) @dataclass @@ -634,7 +596,7 @@ def last(self) -> AlignedArrays: # TODO: automate? engine="flox", dtype=self.array.dtype, - fill_value=_get_fill_value(self.array.dtype, dtypes.NA), + fill_value=dtypes._get_fill_value(self.array.dtype, dtypes.NA), expected_groups=None, ) return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"]) @@ -729,6 +691,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) binary_op=None, reduction="nanlast", scan="ffill", + # Important: this must be NaN otherwise, ffill does not work. identity=np.nan, mode="concat_then_scan", ) @@ -737,6 +700,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) binary_op=None, reduction="nanlast", scan="ffill", + # Important: this must be NaN otherwise, bfill does not work. identity=np.nan, mode="concat_then_scan", preprocess=reverse, @@ -815,17 +779,18 @@ def _initialize_aggregation( dtype_: np.dtype | None = ( np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype ) - - final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value) - if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax"]: - final_dtype = _maybe_promote_int(final_dtype) + final_dtype = dtypes._normalize_dtype( + dtype_ or agg.dtype_init["final"], array_dtype, fill_value + ) + if not agg.preserves_dtype: + final_dtype = dtypes._maybe_promote_int(final_dtype) agg.dtype = { "user": dtype, # Save to automatically choose an engine "final": final_dtype, "numpy": (final_dtype,), "intermediate": tuple( ( - _normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv) + dtypes._normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv) if int_dtype is None else np.dtype(int_dtype) ) @@ -838,10 +803,10 @@ def _initialize_aggregation( # Replace sentinel fill values according to dtype agg.fill_value["user"] = fill_value agg.fill_value["intermediate"] = tuple( - _get_fill_value(dt, fv) + dtypes._get_fill_value(dt, fv) for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"]) ) - agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func]) + agg.fill_value[func] = dtypes._get_fill_value(agg.dtype["final"], agg.fill_value[func]) fv = fill_value if fill_value is not None else agg.fill_value[agg.name] if _is_arg_reduction(agg): diff --git a/flox/xrdtypes.py b/flox/xrdtypes.py index 2d6ce369..3fd0f4fe 100644 --- a/flox/xrdtypes.py +++ b/flox/xrdtypes.py @@ -1,6 +1,7 @@ import functools import numpy as np +from numpy.typing import DTypeLike from . import xrutils as utils @@ -147,3 +148,57 @@ def get_neg_infinity(dtype, min_for_int=False): def is_datetime_like(dtype): """Check if a dtype is a subclass of the numpy datetime types""" return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + + +def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype: + if dtype is None: + dtype = array_dtype + if dtype is np.floating: + # mean, std, var always result in floating + # but we preserve the array's dtype if it is floating + if array_dtype.kind in "fcmM": + dtype = array_dtype + else: + dtype = np.dtype("float64") + elif not isinstance(dtype, np.dtype): + dtype = np.dtype(dtype) + if fill_value not in [None, INF, NINF, NA]: + dtype = np.result_type(dtype, fill_value) + return dtype + + +def _maybe_promote_int(dtype) -> np.dtype: + # https://numpy.org/doc/stable/reference/generated/numpy.prod.html + # The dtype of a is used by default unless a has an integer dtype of less precision + # than the default platform integer. + if not isinstance(dtype, np.dtype): + dtype = np.dtype(dtype) + if dtype.kind == "i": + dtype = np.result_type(dtype, np.intp) + elif dtype.kind == "u": + dtype = np.result_type(dtype, np.uintp) + return dtype + + +def _get_fill_value(dtype, fill_value): + """Returns dtype appropriate infinity. Returns +Inf equivalent for None.""" + if fill_value in [None, NA] and dtype.kind in "US": + return "" + if fill_value == INF or fill_value is None: + return get_pos_infinity(dtype, max_for_int=True) + if fill_value == NINF: + return get_neg_infinity(dtype, min_for_int=True) + if fill_value == NA: + if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating): + return np.nan + # This is madness, but npg checks that fill_value is compatible + # with array dtype even if the fill_value is never used. + elif np.issubdtype(dtype, np.integer): + return get_neg_infinity(dtype, min_for_int=True) + elif np.issubdtype(dtype, np.timedelta64): + return np.timedelta64("NaT") + elif np.issubdtype(dtype, np.datetime64): + return np.datetime64("NaT") + else: + return None + return fill_value diff --git a/tests/test_core.py b/tests/test_core.py index 540e32c0..5d4e7ec3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -13,8 +13,9 @@ from numpy_groupies.aggregate_numpy import aggregate import flox +from flox import xrdtypes as dtypes from flox import xrutils -from flox.aggregations import Aggregation, _initialize_aggregation, _maybe_promote_int +from flox.aggregations import Aggregation, _initialize_aggregation from flox.core import ( HAS_NUMBAGG, _choose_engine, @@ -161,7 +162,7 @@ def test_groupby_reduce( if func == "mean" or func == "nanmean": expected_result = np.array(expected, dtype=np.float64) elif func == "sum": - expected_result = np.array(expected, dtype=_maybe_promote_int(array.dtype)) + expected_result = np.array(expected, dtype=dtypes._maybe_promote_int(array.dtype)) elif func == "count": expected_result = np.array(expected, dtype=np.intp) @@ -389,7 +390,7 @@ def test_groupby_reduce_preserves_dtype(dtype, func): array = np.ones((2, 12), dtype=dtype) by = np.array([labels] * 2) result, _ = groupby_reduce(from_array(array, chunks=(-1, 4)), by, func=func) - expect_dtype = _maybe_promote_int(array.dtype) + expect_dtype = dtypes._maybe_promote_int(array.dtype) assert result.dtype == expect_dtype @@ -1054,7 +1055,7 @@ def test_dtype_preservation(dtype, func, engine): # https://github.com/numbagg/numbagg/issues/121 pytest.skip() if func == "sum": - expected = _maybe_promote_int(dtype) + expected = dtypes._maybe_promote_int(dtype) elif func == "mean" and "int" in dtype: expected = np.float64 else: @@ -1085,7 +1086,7 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype): actual, actual_groups = groupby_reduce(array, labels, func="sum", method=method) assert_equal(actual_groups, np.arange(6, dtype=labels.dtype)) - expect_dtype = _maybe_promote_int(dtype) + expect_dtype = dtypes._maybe_promote_int(dtype) assert_equal(actual, np.array([0, 4, 24, 6, 12, 20], dtype=expect_dtype))