diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index f46830a8..bcbc982b 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -98,10 +98,8 @@ def quantile_or_topk( param = np.atleast_1d(param) param = np.reshape(param, (param.size,) + (1,) * array.ndim) - if is_scalar_param: - idxshape = array.shape[:-1] + (actual_sizes.shape[-1],) - else: - idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) + # For topk(.., k=+1 or -1), we always return the singleton dimension. + idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) if q is not None: # This is numpy's method="linear" @@ -110,6 +108,7 @@ def quantile_or_topk( if is_scalar_param: virtual_index = virtual_index.squeeze(axis=0) + idxshape = array.shape[:-1] + (actual_sizes.shape[-1],) lo_ = np.floor( virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64) @@ -122,7 +121,7 @@ def quantile_or_topk( else: virtual_index = inv_idx[:-1] + ((actual_sizes - k) if k > 0 else abs(k) - 1) kth = np.unique(virtual_index) - kth = kth[kth > 0] + kth = kth[kth >= 0] k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim) lo_ = k_offset + virtual_index[np.newaxis, ...] @@ -147,12 +146,18 @@ def quantile_or_topk( result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) else: result = loval - result[lo_ < 0] = fill_value + # This happens if numel in group < abs(k) + badmask = lo_ < 0 + if badmask.any(): + result[badmask] = fill_value + if not skipna and np.any(nanmask): result[..., nanmask] = fill_value + if k is not None: result = result.astype(dtype, copy=False) - np.copyto(out, result) + if out is not None: + np.copyto(out, result) return result diff --git a/flox/aggregations.py b/flox/aggregations.py index a0988751..ccfc413c 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -821,7 +821,7 @@ def _initialize_aggregation( ) final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value) - if agg.name not in ["min", "max", "nanmin", "nanmax"]: + if agg.name not in ["min", "max", "nanmin", "nanmax", "topk"]: final_dtype = _maybe_promote_int(final_dtype) agg.dtype = { "user": dtype, # Save to automatically choose an engine @@ -883,6 +883,8 @@ def _initialize_aggregation( if isinstance(combine, str): simple_combine.append(getattr(np, combine)) else: + if agg.name == "topk": + combine = partial(combine, **finalize_kwargs) simple_combine.append(combine) agg.simple_combine = tuple(simple_combine) diff --git a/flox/core.py b/flox/core.py index 8d0578f2..9dca484b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -958,7 +958,7 @@ def chunk_reduce( nfuncs = len(funcs) dtypes = _atleast_1d(dtype, nfuncs) fill_values = _atleast_1d(fill_value, nfuncs) - kwargss = _atleast_1d({}, nfuncs) if kwargs is None else kwargs + kwargss = _atleast_1d({} if kwargs is None else kwargs, nfuncs) if isinstance(axis, Sequence): axes: T_Axes = axis @@ -1645,6 +1645,7 @@ def dask_groupby_agg( dtype=agg.dtype["intermediate"], reindex=reindex, user_dtype=agg.dtype["user"], + kwargs=agg.finalize_kwargs if agg.name == "topk" else None, ) if do_simple_combine: # Add a dummy dimension that then gets reduced over @@ -2372,6 +2373,9 @@ def groupby_reduce( "Use engine='flox' instead (it is also much faster), " "or set engine=None to use the default." ) + if func == "topk": + if finalize_kwargs is None or "k" not in finalize_kwargs: + raise ValueError("Please pass `k` for topk calculations.") bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) nby = len(bys) diff --git a/flox/xrutils.py b/flox/xrutils.py index 9f72f04b..e85e327f 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -389,10 +389,13 @@ def topk(a, k, axis, keepdims): are not sorted internally. """ assert keepdims is True - axis = axis[0] + (axis,) = axis + axis = normalize_axis_index(axis, a.ndim) if abs(k) >= a.shape[axis]: return a + # TODO: handle NaNs a = np.partition(a, -k, axis=axis) k_slice = slice(-k, None) if k > 0 else slice(-k) - return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] + result = a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] + return result diff --git a/tests/test_properties.py b/tests/test_properties.py index 0d75e2a3..b7a5cf7b 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -230,3 +230,16 @@ def reverse(arr): backward = groupby_scan(a, by, func="bfill") forward_reversed = reverse(groupby_scan(reverse(a), reverse(by), func="ffill")) assert_equal(forward_reversed, backward) + + +@given(data=st.data(), array=chunked_arrays()) +def test_topk_max_min(data, array): + "top 1 == max; top -1 == min" + size = array.shape[-1] + by = data.draw(by_arrays(shape=(size,))) + k, npfunc = data.draw(st.sampled_from([(1, "max"), (-1, "min")])) + + for a in (array, array.compute()): + actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) + expected, _ = groupby_reduce(a, by, func=npfunc) + assert_equal(actual, expected[np.newaxis, :])