Skip to content

Commit

Permalink
Support quantile, median, mode with method="blockwise". (#269)
Browse files Browse the repository at this point in the history
* Support quantile, median with method="blockwise".

We allow method="blockwise" when grouping by a dask array.
This can only work if we have expected_groups, and set
reindex=True.

* Update flox/core.py

* fix comment

* Update validate_reindex test

* Fix

* Fix

* Raise early if `q` is not provided for quantile

* WIP test

* narrow type

* fix type

* Mode + Tests

* limit tests

* cleanup tests

* fix bool tests

* Revert "limit tests"

This reverts commit d46c3ae.

* Small cleanup

* more cleanup

* fix test

* ignore scipy typing

* update docs
  • Loading branch information
dcherian authored Oct 5, 2023
1 parent 30d522d commit 68b122e
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 62 deletions.
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ dependencies:
- pooch
- toolz
- numba
- scipy
7 changes: 5 additions & 2 deletions docs/source/aggregations.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ the `func` kwarg:
- `"std"`, `"nanstd"`
- `"argmin"`
- `"argmax"`
- `"first"`
- `"last"`
- `"first"`, `"nanfirst"`
- `"last"`, `"nanlast"`
- `"median"`, `"nanmedian"`
- `"mode"`, `"nanmode"`
- `"quantile"`, `"nanquantile"`

```{tip}
We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome!
Expand Down
9 changes: 7 additions & 2 deletions docs/source/user-stories/custom-aggregations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
">\n",
"> A = da.groupby(['lon_bins', 'lat_bins']).mode()\n",
"\n",
"This notebook will describe how to accomplish this using a custom `Aggregation`\n",
"since `mode` and `median` aren't supported by flox yet.\n"
"This notebook will describe how to accomplish this using a custom `Aggregation`.\n",
"\n",
"\n",
"```{tip}\n",
"flox now supports `mode`, `nanmode`, `quantile`, `nanquantile`, `median`, `nanmedian` using exactly the same \n",
"approach as shown below\n",
"```\n"
]
},
{
Expand Down
81 changes: 81 additions & 0 deletions flox/aggregate_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,84 @@ def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None,

len = partial(_len, func="len")
nanlen = partial(_len, func="nanlen")


def median(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=np.median,
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def nanmedian(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=np.nanmedian,
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def quantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(np.quantile, q=q),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(np.nanquantile, q=q),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def mode_(array, nan_policy, dtype):
from scipy.stats import mode

# npg splits `array` into object arrays for each group
# scipy.stats.mode does not like that
# here we cast back
return mode(array.astype(dtype, copy=False), nan_policy=nan_policy, axis=-1).mode


def mode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(mode_, nan_policy="propagate", dtype=array.dtype),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)


def nanmode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
return npg.aggregate_numpy.aggregate(
group_idx,
array,
func=partial(mode_, nan_policy="omit", dtype=array.dtype),
axis=axis,
size=size,
fill_value=fill_value,
dtype=dtype,
)
48 changes: 36 additions & 12 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import TYPE_CHECKING, Any, Callable, TypedDict

import numpy as np
import numpy_groupies as npg
from numpy.typing import DTypeLike

from . import aggregate_flox, aggregate_npg, xrutils
Expand Down Expand Up @@ -35,6 +34,16 @@ class AggDtype(TypedDict):
intermediate: tuple[np.dtype | type[np.intp], ...]


def get_npg_aggregation(func, *, engine):
try:
method_ = getattr(aggregate_npg, func)
method = partial(method_, engine=engine)
except AttributeError:
aggregate = aggregate_npg._get_aggregate(engine).aggregate
method = partial(aggregate, func=func)
return method


def generic_aggregate(
group_idx,
array,
Expand All @@ -51,14 +60,11 @@ def generic_aggregate(
try:
method = getattr(aggregate_flox, func)
except AttributeError:
method = partial(npg.aggregate_numpy.aggregate, func=func)
method = get_npg_aggregation(func, engine="numpy")

elif engine in ["numpy", "numba"]:
try:
method_ = getattr(aggregate_npg, func)
method = partial(method_, engine=engine)
except AttributeError:
aggregate = aggregate_npg._get_aggregate(engine).aggregate
method = partial(aggregate, func=func)
method = get_npg_aggregation(func, engine=engine)

else:
raise ValueError(
f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
Expand Down Expand Up @@ -465,10 +471,22 @@ def _pick_second(*x):
final_dtype=bool,
)

# numpy_groupies does not support median
# And the dask version is really hard!
# median = Aggregation("median", chunk=None, combine=None, fill_value=None)
# nanmedian = Aggregation("nanmedian", chunk=None, combine=None, fill_value=None)
# Support statistical quantities only blockwise
# The parallel versions will be approximate and are hard to implement!
median = Aggregation(
name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
)
nanmedian = Aggregation(
name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
)
quantile = Aggregation(
name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
)
nanquantile = Aggregation(
name="nanquantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)

aggregations = {
"any": any_,
Expand Down Expand Up @@ -496,6 +514,12 @@ def _pick_second(*x):
"nanfirst": nanfirst,
"last": last,
"nanlast": nanlast,
"median": median,
"nanmedian": nanmedian,
"quantile": quantile,
"nanquantile": nanquantile,
"mode": mode,
"nanmode": nanmode,
}


Expand Down
54 changes: 36 additions & 18 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,15 +1307,14 @@ def dask_groupby_agg(
assert isinstance(axis, Sequence)
assert all(ax >= 0 for ax in axis)

if method == "blockwise" and not isinstance(by, np.ndarray):
raise NotImplementedError

inds = tuple(range(array.ndim))
name = f"groupby_{agg.name}"
token = dask.base.tokenize(array, by, agg, expected_groups, axis)

if expected_groups is None and reindex:
expected_groups = _get_expected_groups(by, sort=sort)
if method == "cohorts":
assert reindex is False

by_input = by

Expand Down Expand Up @@ -1349,7 +1348,6 @@ def dask_groupby_agg(
# b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction.
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)

do_simple_combine = not _is_arg_reduction(agg)

if method == "blockwise":
Expand All @@ -1375,7 +1373,7 @@ def dask_groupby_agg(
partial(
blockwise_method,
axis=axis,
expected_groups=None if method == "cohorts" else expected_groups,
expected_groups=expected_groups if reindex else None,
engine=engine,
sort=sort,
),
Expand Down Expand Up @@ -1468,14 +1466,24 @@ def dask_groupby_agg(

elif method == "blockwise":
reduced = intermediate
# Here one input chunk → one output chunks
# find number of groups in each chunk, this is needed for output chunks
# along the reduced axis
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
groups = (np.concatenate(groups_in_block),)
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
group_chunks = (ngroups_per_block,)
if reindex:
if TYPE_CHECKING:
assert expected_groups is not None
# TODO: we could have `expected_groups` be a dask array with appropriate chunks
# for now, we have a numpy array that is interpreted as listing all group labels
# that are present in every chunk
groups = (expected_groups,)
group_chunks = ((len(expected_groups),),)
else:
# Here one input chunk → one output chunks
# find number of groups in each chunk, this is needed for output chunks
# along the reduced axis
# TODO: this logic is very specialized for the resampling case
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
groups = (np.concatenate(groups_in_block),)
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
group_chunks = (ngroups_per_block,)
else:
raise ValueError(f"Unknown method={method}.")

Expand Down Expand Up @@ -1547,7 +1555,7 @@ def _validate_reindex(
if reindex is True and not all_numpy:
if _is_arg_reduction(func):
raise NotImplementedError
if method in ["blockwise", "cohorts"]:
if method == "cohorts" or (method == "blockwise" and not any_by_dask):
raise ValueError(
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
)
Expand All @@ -1562,7 +1570,11 @@ def _validate_reindex(
# have to do the grouped_combine since there's no good fill_value
reindex = False

if method == "blockwise" or _is_arg_reduction(func):
if method == "blockwise":
# for grouping by dask arrays, we set reindex=True
reindex = any_by_dask

elif _is_arg_reduction(func):
reindex = False

elif method == "cohorts":
Expand Down Expand Up @@ -1767,7 +1779,10 @@ def groupby_reduce(
*by : ndarray or DaskArray
Array of labels to group over. Must be aligned with ``array`` so that
``array.shape[-by.ndim :] == by.shape``
func : str or Aggregation
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
"first", "nanfirst", "last", "nanlast"} or Aggregation
Single function name or an Aggregation instance
expected_groups : (optional) Sequence
Expected unique labels.
Expand Down Expand Up @@ -1835,7 +1850,7 @@ def groupby_reduce(
boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
original block size. Avoid that by using ``method="cohorts"``. By default, it is turned off for argreductions.
finalize_kwargs : dict, optional
Kwargs passed to finalize the reduction such as ``ddof`` for var, std.
Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile.
Returns
-------
Expand All @@ -1855,6 +1870,9 @@ def groupby_reduce(
"Try engine='numpy' or engine='numba' instead."
)

if func == "quantile" and (finalize_kwargs is None or "q" not in finalize_kwargs):
raise ValueError("Please pass `q` for quantile calculations.")

bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
Expand Down Expand Up @@ -2023,7 +2041,7 @@ def groupby_reduce(
result, groups = partial_agg(
array,
by_,
expected_groups=None if method == "blockwise" else expected_groups,
expected_groups=expected_groups,
agg=agg,
reindex=reindex,
method=method,
Expand Down
9 changes: 6 additions & 3 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ def xarray_reduce(
Xarray object to reduce
*by : DataArray or iterable of str or iterable of DataArray
Variables with which to group by ``obj``
func : str or Aggregation
Reduction method
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
"first", "nanfirst", "last", "nanlast"} or Aggregation
Single function name or an Aggregation instance
expected_groups : str or sequence
expected group labels corresponding to each `by` variable
isbin : iterable of bool
Expand Down Expand Up @@ -164,7 +167,7 @@ def xarray_reduce(
boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
original block size. Avoid that by using method="cohorts". By default, it is turned off for arg reductions.
**finalize_kwargs
kwargs passed to the finalize function, like ``ddof`` for var, std.
kwargs passed to the finalize function, like ``ddof`` for var, std or ``q`` for quantile.
Returns
-------
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ module=[
"matplotlib.*",
"pandas",
"setuptools",
"toolz"
"scipy.*",
"toolz",
]
ignore_missing_imports = true

Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def LooseVersion(vstring):

has_dask, requires_dask = _importorskip("dask")
has_numba, requires_numba = _importorskip("numba")
has_scipy, requires_scipy = _importorskip("scipy")
has_xarray, requires_xarray = _importorskip("xarray")


Expand Down
Loading

0 comments on commit 68b122e

Please sign in to comment.