Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add topk #374

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 103 additions & 49 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from . import xrdtypes as dtypes
from .xrutils import is_scalar, isnull, notnull


Expand Down Expand Up @@ -46,74 +47,122 @@ def _lerp(a, b, *, t, dtype, out=None):
return out


def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None):
inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))
def quantile_or_topk(
array,
inv_idx,
*,
q=None,
k=None,
axis,
skipna,
group_idx,
dtype=None,
out=None,
fill_value=None,
):
assert q or k
assert axis == -1

array_nanmask = isnull(array)
actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis)
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
full_sizes = np.reshape(np.diff(inv_idx), newshape)
nanmask = full_sizes != actual_sizes
inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))

# The approach here is to use (complex_array.partition) because
# The approach for quantiles and topk, both of which are basically grouped partition,
# here is to use (complex_array.partition) because
# 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary
# 2. Using record_array.partition(..., order=["labels", "array"]) is incredibly slow.
# partition will first sort by real part, then by imaginary part, so it is a two element lex-partition.
# So we set
# partition will first sort by real part, then by imaginary part, so it is a two element
# lex-partition. Therefore we set
# complex_array = group_idx + 1j * array
# group_idx is an integer (guaranteed), but array can have NaNs. Now,
# 1 + 1j*NaN = NaN + 1j * NaN
# so we must replace all NaNs with the maximum array value in the group so these NaNs
# get sorted to the end.

# Replace NaNs with the maximum value for each group.
# Partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
# TODO: Don't know if this array has been copied in _prepare_for_flox. This is potentially wasteful
array = np.where(array_nanmask, -np.inf, array)
array_nanmask = isnull(array)
actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis)
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
full_sizes = np.reshape(np.diff(inv_idx), newshape)
nanmask = full_sizes != actual_sizes
# TODO: Don't know if this array has been copied in _prepare_for_flox.
# This is potentially wasteful
array = np.where(array_nanmask, dtypes.get_neg_infinity(array.dtype, min_for_int=True), array)
maxes = np.maximum.reduceat(array, inv_idx[:-1], axis=axis)
replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis)
array[array_nanmask] = replacement[array_nanmask]

qin = q
q = np.atleast_1d(qin)
q = np.reshape(q, (len(q),) + (1,) * array.ndim)

# This is numpy's method="linear"
# TODO: could support all the interpolations here
virtual_index = q * (actual_sizes - 1) + inv_idx[:-1]

is_scalar_q = is_scalar(qin)
if is_scalar_q:
virtual_index = virtual_index.squeeze(axis=0)
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)
param = q or k
if k is not None:
is_scalar_param = False
param = np.arange(abs(k))
else:
idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
is_scalar_param = is_scalar(q)
param = np.atleast_1d(param)
param = np.reshape(param, (param.size,) + (1,) * array.ndim)

# For topk(.., k=+1 or -1), we always return the singleton dimension.
idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)

if q is not None:
# This is numpy's method="linear"
# TODO: could support all the interpolations here
virtual_index = param * (actual_sizes - 1) + inv_idx[:-1]

if is_scalar_param:
virtual_index = virtual_index.squeeze(axis=0)
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)

lo_ = np.floor(
virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
)
hi_ = np.ceil(
virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
)
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))

lo_ = np.floor(
virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
)
hi_ = np.ceil(
virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
)
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))
else:
virtual_index = inv_idx[:-1] + ((actual_sizes - k) if k > 0 else abs(k) - 1)
kth = np.unique(virtual_index)
kth = kth[kth >= 0]
k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim)
lo_ = k_offset + virtual_index[np.newaxis, ...]

# partition the complex array in-place
labels_broadcast = np.broadcast_to(group_idx, array.shape)
with np.errstate(invalid="ignore"):
cmplx = labels_broadcast + 1j * array
cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
cmplx.partition(kth=kth, axis=-1)
if is_scalar_q:

if is_scalar_param:
a_ = cmplx.imag
else:
a_ = np.broadcast_to(cmplx.imag, (q.shape[0],) + array.shape)
a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape)

if array.dtype.kind in "Mm":
a_ = a_.astype(array.dtype)

# get bounds, Broadcast to (num quantiles, ..., num labels)
loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis)
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)
if q is not None:
# get bounds, Broadcast to (num quantiles, ..., num labels)
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)

# TODO: could support all the interpolations here
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
else:
result = loval
# This happens if numel in group < abs(k)
badmask = lo_ < 0
if badmask.any():
result[badmask] = fill_value

# TODO: could support all the interpolations here
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
if not skipna and np.any(nanmask):
result[..., nanmask] = np.nan
result[..., nanmask] = fill_value

if k is not None:
result = result.astype(dtype, copy=False)
if out is not None:
np.copyto(out, result)
return result


Expand All @@ -138,12 +187,14 @@ def _np_grouped_op(

if out is None:
q = kwargs.get("q", None)
if q is None:
k = kwargs.get("k", None)
if not q and not k:
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
else:
nq = len(np.atleast_1d(q))
nq = len(np.atleast_1d(q)) if q is not None else abs(k)
out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
kwargs["group_idx"] = group_idx
kwargs["fill_value"] = fill_value

if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
# The previous version of this if condition
Expand All @@ -158,6 +209,8 @@ def _np_grouped_op(


def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
if fillna in [dtypes.INF, dtypes.NINF]:
fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna)
result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs)
# np.nanmax([np.nan, np.nan]) = np.nan
# To recover this behaviour, we need to search for the fillna value
Expand All @@ -175,13 +228,14 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
prod = partial(_np_grouped_op, op=np.multiply.reduceat)
nanprod = partial(_nan_grouped_op, func=prod, fillna=1)
max = partial(_np_grouped_op, op=np.maximum.reduceat)
nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf)
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
min = partial(_np_grouped_op, op=np.minimum.reduceat)
nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf)
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=True))
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False))
topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True))
# TODO: all, any


Expand Down
Loading
Loading