diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e9b1e9d6b..c36fa38cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.0.276' + rev: 'v0.0.292' hooks: - id: ruff args: ["--fix"] @@ -18,12 +18,12 @@ repos: - id: check-docstring-first - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.9.1 hooks: - id: black - repo: https://github.com/executablebooks/mdformat - rev: 0.7.16 + rev: 0.7.17 hooks: - id: mdformat additional_dependencies: @@ -44,13 +44,13 @@ repos: args: [--extra-keys=metadata.kernelspec metadata.language_info.version] - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.2.6 hooks: - id: codespell additional_dependencies: - tomli - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.13 + rev: v0.14 hooks: - id: validate-pyproject diff --git a/ci/environment.yml b/ci/environment.yml index cd96707ce..c565ac3f5 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -22,5 +22,6 @@ dependencies: - pooch - toolz - numba + - scipy - pip: - git+https://github.com/numbagg/numbagg diff --git a/docs/source/aggregations.md b/docs/source/aggregations.md index e6c10e4ba..d3591d2dc 100644 --- a/docs/source/aggregations.md +++ b/docs/source/aggregations.md @@ -11,8 +11,11 @@ the `func` kwarg: - `"std"`, `"nanstd"` - `"argmin"` - `"argmax"` -- `"first"` -- `"last"` +- `"first"`, `"nanfirst"` +- `"last"`, `"nanlast"` +- `"median"`, `"nanmedian"` +- `"mode"`, `"nanmode"` +- `"quantile"`, `"nanquantile"` ```{tip} We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome! diff --git a/docs/source/implementation.md b/docs/source/implementation.md index f3a2a87f7..29d9faf46 100644 --- a/docs/source/implementation.md +++ b/docs/source/implementation.md @@ -199,7 +199,7 @@ width: 100% 1. Group labels must be known at graph construction time, so this only works for numpy arrays. 1. This does require more tasks and a more complicated graph, but the communication overhead can be significantly lower. 1. The detection of "cohorts" is currently slow but could be improved. -1. The extra effort of detecting cohorts and mul;tiple copying of intermediate blocks may be worthwhile only if the chunk sizes are small +1. The extra effort of detecting cohorts and multiple copying of intermediate blocks may be worthwhile only if the chunk sizes are small relative to the approximate period of group labels, or small relative to the size of spatially localized groups. ### Example : sensitivity to chunking diff --git a/docs/source/user-stories.md b/docs/source/user-stories.md index 22b37939e..0241e01dc 100644 --- a/docs/source/user-stories.md +++ b/docs/source/user-stories.md @@ -8,4 +8,5 @@ user-stories/climatology.ipynb user-stories/climatology-hourly.ipynb user-stories/custom-aggregations.ipynb + user-stories/nD-bins.ipynb ``` diff --git a/docs/source/user-stories/custom-aggregations.ipynb b/docs/source/user-stories/custom-aggregations.ipynb index 7b4167b98..8b9be09e9 100644 --- a/docs/source/user-stories/custom-aggregations.ipynb +++ b/docs/source/user-stories/custom-aggregations.ipynb @@ -15,8 +15,13 @@ ">\n", "> A = da.groupby(['lon_bins', 'lat_bins']).mode()\n", "\n", - "This notebook will describe how to accomplish this using a custom `Aggregation`\n", - "since `mode` and `median` aren't supported by flox yet.\n" + "This notebook will describe how to accomplish this using a custom `Aggregation`.\n", + "\n", + "\n", + "```{tip}\n", + "flox now supports `mode`, `nanmode`, `quantile`, `nanquantile`, `median`, `nanmedian` using exactly the same \n", + "approach as shown below\n", + "```\n" ] }, { @@ -135,7 +140,7 @@ " # The next are for dask inputs and describe how to reduce\n", " # the data in parallel\n", " chunk=(\"sum\", \"nanlen\"), # first compute these blockwise : (grouped_sum, grouped_count)\n", - " combine=(\"sum\", \"sum\"), # reduce intermediate reuslts (sum the sums, sum the counts)\n", + " combine=(\"sum\", \"sum\"), # reduce intermediate results (sum the sums, sum the counts)\n", " finalize=lambda sum_, count: sum_ / count, # final mean value (divide sum by count)\n", "\n", " fill_value=(0, 0), # fill value for intermediate sums and counts when groups have no members\n", diff --git a/docs/source/user-stories/nD-bins.ipynb b/docs/source/user-stories/nD-bins.ipynb new file mode 100644 index 000000000..87ef942bf --- /dev/null +++ b/docs/source/user-stories/nD-bins.ipynb @@ -0,0 +1,373 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e970d800-c612-482a-bb3a-b1eb7ad53d88", + "metadata": { + "tags": [], + "user_expressions": [] + }, + "source": [ + "# Binning with multi-dimensional bins\n", + "\n", + "```{warning}\n", + "This post is a proof-of-concept for discussion. Expect APIs to change to enable this use case.\n", + "```\n", + "\n", + "Here we explore a binning problem where the bins are multidimensional\n", + "([xhistogram issue](https://github.com/xgcm/xhistogram/issues/28))\n", + "\n", + "> One of such multi-dim bin applications is the ranked probability score rps we\n", + "> use in `xskillscore.rps`, where we want to know how many forecasts fell into\n", + "> which bins. Bins are often defined as terciles of the forecast distribution\n", + "> and the bins for these terciles\n", + "> (`forecast_with_lon_lat_time_dims.quantile(q=[.33,.66],dim='time')`) depend on\n", + "> `lon` and `lat`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01f1a2ef-de62-45d0-a04e-343cd78debc5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import math\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "\n", + "import flox\n", + "import flox.xarray" + ] + }, + { + "cell_type": "markdown", + "id": "0be3e214-0cf0-426f-8ebb-669cc5322310", + "metadata": { + "user_expressions": [] + }, + "source": [ + "## Create test data\n" + ] + }, + { + "cell_type": "markdown", + "id": "ce239000-e053-4fc3-ad14-e9e0160da869", + "metadata": { + "user_expressions": [] + }, + "source": [ + "Data to be reduced\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7659c24e-f5a1-4e59-84c0-5ec965ef92d2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "array = xr.DataArray(\n", + " np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),\n", + " dims=(\"space\", \"time\"),\n", + " name=\"array\",\n", + ")\n", + "array" + ] + }, + { + "cell_type": "markdown", + "id": "da0c0ac9-ad75-42cd-a1ea-99069f5bef00", + "metadata": { + "user_expressions": [] + }, + "source": [ + "Array to group by\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4601e744-5d22-447e-97ce-9644198d485e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "by = xr.DataArray(\n", + " np.array([[1, 2, 3], [3, 4, 5], [5, 6, 7], [6, 7, 9]]),\n", + " dims=(\"space\", \"time\"),\n", + " name=\"by\",\n", + ")\n", + "by" + ] + }, + { + "cell_type": "markdown", + "id": "61c21c94-7b6e-46a6-b9c2-59d7b2d40c81", + "metadata": { + "tags": [], + "user_expressions": [] + }, + "source": [ + "Multidimensional bins:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "863a1991-ab8d-47c0-aa48-22b422fcea8c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "bins = by + 0.5\n", + "bins = xr.DataArray(\n", + " np.concatenate([bins, bins[:, [-1]] + 1], axis=-1)[:, :-1].T,\n", + " dims=(\"time\", \"nbins\"),\n", + " name=\"bins\",\n", + ")\n", + "bins" + ] + }, + { + "cell_type": "markdown", + "id": "e65ecaba-d1cc-4485-ae58-c390cb2ebfab", + "metadata": { + "user_expressions": [] + }, + "source": [ + "## Concept\n", + "\n", + "The key idea is that GroupBy is two steps:\n", + "\n", + "1. Factorize (a.k.a \"digitize\") : convert the `by` data to a set of integer\n", + " codes representing the bins.\n", + "2. Apply the reduction.\n", + "\n", + "We treat multi-dimensional binning as a slightly complicated factorization\n", + "problem. Assume that bins are a function of `time`. So we\n", + "\n", + "1. generate a set of appropriate integer codes by:\n", + " 1. Loop over \"time\" and factorize the data appropriately.\n", + " 2. Add an offset to these codes so that \"bin 0\" for `time=0` is different\n", + " from \"bin 0\" for `time=1`\n", + "2. apply the groupby reduction to the \"offset codes\"\n", + "3. reshape the output to the right shape\n", + "\n", + "We will work at the xarray level, so its easy to keep track of the different\n", + "dimensions.\n", + "\n", + "### Factorizing\n", + "\n", + "The core `factorize_` function (which wraps `pd.cut`) only handles 1D bins, so\n", + "we use `xr.apply_ufunc` to vectorize it for us.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa33ab2c-0ecf-4198-a033-2a77f5d83c99", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "factorize_loop_dim = \"time\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afcddcc1-dd57-461e-a649-1f8bcd30342f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def factorize_nd_bins_core(by, bins):\n", + " group_idx, *_, props = flox.core.factorize_(\n", + " (by,),\n", + " axes=(-1,),\n", + " expected_groups=(pd.IntervalIndex.from_breaks(bins),),\n", + " )\n", + " # Use -1 as the NaN sentinel value\n", + " group_idx[props.nanmask] = -1\n", + " return group_idx\n", + "\n", + "\n", + "codes = xr.apply_ufunc(\n", + " factorize_nd_bins_core,\n", + " by,\n", + " bins,\n", + " # TODO: avoid hardcoded dim names\n", + " input_core_dims=[[\"space\"], [\"nbins\"]],\n", + " output_core_dims=[[\"space\"]],\n", + " vectorize=True,\n", + ")\n", + "codes" + ] + }, + { + "cell_type": "markdown", + "id": "1661312a-dc61-4a26-bfd8-12c2dc01eb15", + "metadata": { + "user_expressions": [] + }, + "source": [ + "### Offset the codes\n", + "\n", + "These are integer codes appropriate for a single timestep.\n", + "\n", + "We now add an offset that changes in time, to make sure \"bin 0\" for `time=0` is\n", + "different from \"bin 0\" for `time=1` (taken from\n", + "[this StackOverflow thread](https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy)).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e5801cb-a79c-4670-ad10-36bb19f1a6ff", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "N = math.prod([codes.sizes[d] for d in codes.dims if d != factorize_loop_dim])\n", + "offset = xr.DataArray(np.arange(codes.sizes[factorize_loop_dim]), dims=factorize_loop_dim)\n", + "# TODO: think about N-1 here\n", + "offset_codes = (codes + offset * (N - 1)).rename(by.name)\n", + "offset_codes.data[codes == -1] = -1\n", + "offset_codes" + ] + }, + { + "cell_type": "markdown", + "id": "6c06c48b-316b-4a33-9bc3-921acd10bcba", + "metadata": { + "user_expressions": [] + }, + "source": [ + "### Reduce\n", + "\n", + "Now that we have appropriate codes, let's apply the reduction\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cf1295e-4585-48b9-ac2b-9e00d03b2b9a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "interim = flox.xarray.xarray_reduce(\n", + " array,\n", + " offset_codes,\n", + " func=\"sum\",\n", + " # We use RangeIndex to indicate that `-1` code can be safely ignored\n", + " # (it indicates values outside the bins)\n", + " # TODO: Avoid hardcoding 9 = sizes[\"time\"] x (sizes[\"nbins\"] - 1)\n", + " expected_groups=pd.RangeIndex(9),\n", + ")\n", + "interim" + ] + }, + { + "cell_type": "markdown", + "id": "3539509b-d9b4-4342-a679-6ada6f285dfb", + "metadata": { + "user_expressions": [] + }, + "source": [ + "## Make final result\n", + "\n", + "Now reshape that 1D result appropriately.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1389d37-d76d-4a50-9dfb-8710258de3fd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "final = (\n", + " interim.coarsen(by=3)\n", + " # bin_number dimension is last, this makes sense since it is the core dimension\n", + " # and we vectorize over the loop dims.\n", + " # So the first (Nbins-1) elements are for the first index of the loop dim\n", + " .construct({\"by\": (factorize_loop_dim, \"bin_number\")})\n", + " .transpose(..., factorize_loop_dim)\n", + " .drop_vars(\"by\")\n", + ")\n", + "final" + ] + }, + { + "cell_type": "markdown", + "id": "a98b5e60-94af-45ae-be1b-4cb47e2d77ba", + "metadata": { + "user_expressions": [] + }, + "source": [ + "I think this is the expected answer.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "053a8643-f6d9-4fd1-b014-230fa716449c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "array.isel(space=slice(1, None)).rename({\"space\": \"bin_number\"}).identical(final)" + ] + }, + { + "cell_type": "markdown", + "id": "619ba4c4-7c87-459a-ab86-c187d3a86c67", + "metadata": { + "tags": [], + "user_expressions": [] + }, + "source": [ + "## TODO\n", + "\n", + "This could be extended to:\n", + "\n", + "1. handle multiple `factorize_loop_dim`\n", + "2. avoid hard coded dimension names in the `apply_ufunc` call for factorizing\n", + "3. avoid hard coded number of output elements in the `xarray_reduce` call.\n", + "4. Somehow propagate the bin edges to the final output.\n" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/flox/aggregate_npg.py b/flox/aggregate_npg.py index 30e0eb257..966bd43b8 100644 --- a/flox/aggregate_npg.py +++ b/flox/aggregate_npg.py @@ -100,3 +100,84 @@ def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None, len = partial(_len, func="len") nanlen = partial(_len, func="nanlen") + + +def median(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=np.median, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def nanmedian(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=np.nanmedian, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def quantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(np.quantile, q=q), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(np.nanquantile, q=q), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def mode_(array, nan_policy, dtype): + from scipy.stats import mode + + # npg splits `array` into object arrays for each group + # scipy.stats.mode does not like that + # here we cast back + return mode(array.astype(dtype, copy=False), nan_policy=nan_policy, axis=-1).mode + + +def mode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(mode_, nan_policy="propagate", dtype=array.dtype), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def nanmode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(mode_, nan_policy="omit", dtype=array.dtype), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) diff --git a/flox/aggregations.py b/flox/aggregations.py index 7b90b00ab..d2ccf4c64 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -479,10 +479,22 @@ def _pick_second(*x): final_dtype=bool, ) -# numpy_groupies does not support median -# And the dask version is really hard! -# median = Aggregation("median", chunk=None, combine=None, fill_value=None) -# nanmedian = Aggregation("nanmedian", chunk=None, combine=None, fill_value=None) +# Support statistical quantities only blockwise +# The parallel versions will be approximate and are hard to implement! +median = Aggregation( + name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +nanmedian = Aggregation( + name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +quantile = Aggregation( + name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +nanquantile = Aggregation( + name="nanquantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None) +nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None) aggregations = { "any": any_, @@ -510,6 +522,12 @@ def _pick_second(*x): "nanfirst": nanfirst, "last": last, "nanlast": nanlast, + "median": median, + "nanmedian": nanmedian, + "quantile": quantile, + "nanquantile": nanquantile, + "mode": mode, + "nanmode": nanmode, } diff --git a/flox/core.py b/flox/core.py index a5e07d37d..01f031240 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1307,15 +1307,14 @@ def dask_groupby_agg( assert isinstance(axis, Sequence) assert all(ax >= 0 for ax in axis) - if method == "blockwise" and not isinstance(by, np.ndarray): - raise NotImplementedError - inds = tuple(range(array.ndim)) name = f"groupby_{agg.name}" token = dask.base.tokenize(array, by, agg, expected_groups, axis) if expected_groups is None and reindex: expected_groups = _get_expected_groups(by, sort=sort) + if method == "cohorts": + assert reindex is False by_input = by @@ -1349,7 +1348,6 @@ def dask_groupby_agg( # b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction. # This allows us to discover groups at compute time, support argreductions, lower intermediate # memory usage (but method="cohorts" would also work to reduce memory in some cases) - do_simple_combine = not _is_arg_reduction(agg) if method == "blockwise": @@ -1375,7 +1373,7 @@ def dask_groupby_agg( partial( blockwise_method, axis=axis, - expected_groups=None if method == "cohorts" else expected_groups, + expected_groups=expected_groups if reindex else None, engine=engine, sort=sort, ), @@ -1468,14 +1466,24 @@ def dask_groupby_agg( elif method == "blockwise": reduced = intermediate - # Here one input chunk → one output chunks - # find number of groups in each chunk, this is needed for output chunks - # along the reduced axis - slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) - groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) - groups = (np.concatenate(groups_in_block),) - ngroups_per_block = tuple(len(grp) for grp in groups_in_block) - group_chunks = (ngroups_per_block,) + if reindex: + if TYPE_CHECKING: + assert expected_groups is not None + # TODO: we could have `expected_groups` be a dask array with appropriate chunks + # for now, we have a numpy array that is interpreted as listing all group labels + # that are present in every chunk + groups = (expected_groups,) + group_chunks = ((len(expected_groups),),) + else: + # Here one input chunk → one output chunks + # find number of groups in each chunk, this is needed for output chunks + # along the reduced axis + # TODO: this logic is very specialized for the resampling case + slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) + groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) + groups = (np.concatenate(groups_in_block),) + ngroups_per_block = tuple(len(grp) for grp in groups_in_block) + group_chunks = (ngroups_per_block,) else: raise ValueError(f"Unknown method={method}.") @@ -1547,7 +1555,7 @@ def _validate_reindex( if reindex is True and not all_numpy: if _is_arg_reduction(func): raise NotImplementedError - if method in ["blockwise", "cohorts"]: + if method == "cohorts" or (method == "blockwise" and not any_by_dask): raise ValueError( "reindex=True is not a valid choice for method='blockwise' or method='cohorts'." ) @@ -1562,7 +1570,11 @@ def _validate_reindex( # have to do the grouped_combine since there's no good fill_value reindex = False - if method == "blockwise" or _is_arg_reduction(func): + if method == "blockwise": + # for grouping by dask arrays, we set reindex=True + reindex = any_by_dask + + elif _is_arg_reduction(func): reindex = False elif method == "cohorts": @@ -1767,7 +1779,10 @@ def groupby_reduce( *by : ndarray or DaskArray Array of labels to group over. Must be aligned with ``array`` so that ``array.shape[-by.ndim :] == by.shape`` - func : str or Aggregation + func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ + "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ + "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : (optional) Sequence Expected unique labels. @@ -1838,7 +1853,7 @@ def groupby_reduce( boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the original block size. Avoid that by using ``method="cohorts"``. By default, it is turned off for argreductions. finalize_kwargs : dict, optional - Kwargs passed to finalize the reduction such as ``ddof`` for var, std. + Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile. Returns ------- @@ -1866,6 +1881,9 @@ def groupby_reduce( "See https://github.com/numbagg/numbagg/issues/121." ) + if func == "quantile" and (finalize_kwargs is None or "q" not in finalize_kwargs): + raise ValueError("Please pass `q` for quantile calculations.") + bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) nby = len(bys) by_is_dask = tuple(is_duck_dask_array(b) for b in bys) @@ -2034,7 +2052,7 @@ def groupby_reduce( result, groups = partial_agg( array, by_, - expected_groups=None if method == "blockwise" else expected_groups, + expected_groups=expected_groups, agg=agg, reindex=reindex, method=method, diff --git a/flox/xarray.py b/flox/xarray.py index bde5cc3f2..7e8c4f2b1 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -88,8 +88,11 @@ def xarray_reduce( Xarray object to reduce *by : DataArray or iterable of str or iterable of DataArray Variables with which to group by ``obj`` - func : str or Aggregation - Reduction method + func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ + "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ + "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "first", "nanfirst", "last", "nanlast"} or Aggregation + Single function name or an Aggregation instance expected_groups : str or sequence expected group labels corresponding to each `by` variable isbin : iterable of bool @@ -167,7 +170,7 @@ def xarray_reduce( boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the original block size. Avoid that by using method="cohorts". By default, it is turned off for arg reductions. **finalize_kwargs - kwargs passed to the finalize function, like ``ddof`` for var, std. + kwargs passed to the finalize function, like ``ddof`` for var, std or ``q`` for quantile. Returns ------- diff --git a/pyproject.toml b/pyproject.toml index 364011007..5ca8e0317 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,8 @@ module=[ "matplotlib.*", "pandas", "setuptools", - "toolz" + "scipy.*", + "toolz", ] ignore_missing_imports = true diff --git a/tests/__init__.py b/tests/__init__.py index 1de52fb2b..f46319172 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -48,6 +48,7 @@ def LooseVersion(vstring): has_dask, requires_dask = _importorskip("dask") has_numba, requires_numba = _importorskip("numba") has_numbagg, requires_numbagg = _importorskip("numbagg") +has_scipy, requires_scipy = _importorskip("scipy") has_xarray, requires_xarray = _importorskip("xarray") diff --git a/tests/test_core.py b/tests/test_core.py index 7c2ae68c3..ca8656088 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -31,6 +31,7 @@ has_dask, raise_if_dask_computes, requires_dask, + requires_scipy, ) labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) @@ -50,6 +51,9 @@ def dask_array_ones(*args): return None +DEFAULT_QUANTILE = 0.9 +SCIPY_STATS_FUNCS = ("mode", "nanmode") +BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS ALL_FUNCS = ( "sum", "nansum", @@ -73,9 +77,11 @@ def dask_array_ones(*args): "any", "all", "nanlast", - pytest.param("median", marks=(pytest.mark.skip,)), - pytest.param("nanmedian", marks=(pytest.mark.skip,)), -) + "median", + "nanmedian", + "quantile", + "nanquantile", +) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS) if TYPE_CHECKING: from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method @@ -84,12 +90,26 @@ def dask_array_ones(*args): def _get_array_func(func: str) -> Callable: if func == "count": - def npfunc(x): + def npfunc(x, **kwargs): x = np.asarray(x) return (~np.isnan(x)).sum() elif func in ["nanfirst", "nanlast"]: npfunc = getattr(xrutils, func) + + elif func in SCIPY_STATS_FUNCS: + import scipy.stats + + if "nan" in func: + func = func[3:] + nan_policy = "omit" + else: + nan_policy = "propagate" + + def npfunc(x, **kwargs): + spfunc = partial(getattr(scipy.stats, func), nan_policy=nan_policy) + return getattr(spfunc(x, **kwargs), func) + else: npfunc = getattr(np, func) @@ -205,7 +225,7 @@ def gen_array_by(size, func): @pytest.mark.parametrize("add_nan_by", [True, False]) @pytest.mark.parametrize("func", ALL_FUNCS) def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): - if "arg" in func and engine == "flox": + if ("arg" in func and engine == "flox") or (func in BLOCKWISE_FUNCS and chunks != -1): pytest.skip() array, by = gen_array_by(size, func) @@ -224,6 +244,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): finalize_kwargs = finalize_kwargs + [{"ddof": 1}, {"ddof": 0}] fill_value = np.nan tolerance = {"rtol": 1e-14, "atol": 1e-16} + elif "quantile" in func: + finalize_kwargs = [{"q": DEFAULT_QUANTILE}] + fill_value = None + tolerance = None else: fill_value = None tolerance = None @@ -246,15 +270,16 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): func_ = f"nan{func}" if "nan" not in func else func array_[..., nanmask] = np.nan expected = getattr(np, func_)(array_, axis=-1, **kwargs) - # elif func in ["first", "last"]: - # expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs) - elif func in ["nanfirst", "nanlast"]: - expected = getattr(xrutils, func)(array_[..., ~nanmask], axis=-1, **kwargs) else: - expected = getattr(np, func)(array_[..., ~nanmask], axis=-1, **kwargs) + array_func = _get_array_func(func) + expected = array_func(array_[..., ~nanmask], axis=-1, **kwargs) for _ in range(nby): expected = np.expand_dims(expected, -1) + if func in BLOCKWISE_FUNCS: + assert chunks == -1 + flox_kwargs["method"] = "blockwise" + actual, *groups = groupby_reduce(array, *by, **flox_kwargs) assert actual.ndim == (array.ndim + nby - 1) assert expected.ndim == (array.ndim + nby - 1) @@ -265,7 +290,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert actual.dtype.kind == "i" assert_equal(actual, expected, tolerance) - if not has_dask or chunks is None: + if not has_dask or chunks is None or func in BLOCKWISE_FUNCS: continue params = list(itertools.product(["map-reduce"], [True, False, None])) @@ -396,7 +421,7 @@ def test_numpy_reduce_nd_md(): def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtype, engine, reindex): """Tests groupby_reduce with dask arrays against groupby_reduce with numpy arrays""" - if func in ["first", "last"]: + if func in ["first", "last"] or func in BLOCKWISE_FUNCS: pytest.skip() if "arg" in func and (engine == "flox" or reindex): @@ -551,7 +576,7 @@ def test_first_last_disallowed_dask(func): "axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)] ) def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): - if "arg" in func and engine == "flox": + if ("arg" in func and engine == "flox") or func in BLOCKWISE_FUNCS: pytest.skip() if not isinstance(axis, int): @@ -847,7 +872,7 @@ def test_rechunk_for_cohorts(chunk_at, expected): def test_fill_value_behaviour(func, chunks, fill_value, engine): # fill_value = np.nan tests promotion of int counts to float # This is used by xarray - if func in ["all", "any"] or "arg" in func: + if (func in ["all", "any"] or "arg" in func) or func in BLOCKWISE_FUNCS: pytest.skip() npfunc = _get_array_func(func) @@ -906,8 +931,17 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype): @requires_dask @pytest.mark.parametrize("func", ALL_FUNCS) @pytest.mark.parametrize("axis", (-1, None)) -@pytest.mark.parametrize("method", ["blockwise", "cohorts", "map-reduce", "split-reduce"]) +@pytest.mark.parametrize("method", ["blockwise", "cohorts", "map-reduce"]) def test_cohorts_nd_by(func, method, axis, engine): + if ( + ("arg" in func and (axis is None or engine == "flox")) + or (method != "blockwise" and func in BLOCKWISE_FUNCS) + or (axis is None and ("first" in func or "last" in func)) + ): + pytest.skip() + if axis is not None and method != "map-reduce": + pytest.xfail() + o = dask.array.ones((3,), chunks=-1) o2 = dask.array.ones((2, 3), chunks=-1) @@ -918,20 +952,14 @@ def test_cohorts_nd_by(func, method, axis, engine): by[0, 4] = 31 array = np.broadcast_to(array, (2, 3) + array.shape) - if "arg" in func and (axis is None or engine == "flox"): - pytest.skip() - if func in ["any", "all"]: fill_value = False else: fill_value = -123 - if axis is not None and method != "map-reduce": - pytest.xfail() - if axis is None and ("first" in func or "last" in func): - pytest.skip() - kwargs = dict(func=func, engine=engine, method=method, axis=axis, fill_value=fill_value) + if "quantile" in func: + kwargs["finalize_kwargs"] = {"q": DEFAULT_QUANTILE} actual, groups = groupby_reduce(array, by, **kwargs) expected, sorted_groups = groupby_reduce(array.compute(), by, **kwargs) assert_equal(groups, sorted_groups) @@ -993,6 +1021,8 @@ def test_datetime_binning(): def test_bool_reductions(func, engine): if "arg" in func and engine == "flox": pytest.skip() + if "quantile" in func or "mode" in func: + pytest.skip() groups = np.array([1, 1, 1]) data = np.array([True, True, False]) npfunc = _get_array_func(func) @@ -1248,9 +1278,14 @@ def test_dtype(func, dtype, engine): pytest.skip() if "arg" in func or func in ["any", "all"]: pytest.skip() + + finalize_kwargs = {"q": DEFAULT_QUANTILE} if "quantile" in func else {} + arr = np.ones((4, 12), dtype=dtype) labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) - actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64, engine=engine) + actual, _ = groupby_reduce( + arr, labels, func=func, dtype=np.float64, engine=engine, finalize_kwargs=finalize_kwargs + ) assert actual.dtype == np.dtype("float64") @@ -1393,6 +1428,33 @@ def test_validate_reindex() -> None: ) assert actual is False + with pytest.raises(ValueError): + _validate_reindex( + True, + "sum", + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=False, + is_dask_array=True, + ) + + assert _validate_reindex( + True, + "sum", + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=True, + is_dask_array=True, + ) + assert _validate_reindex( + None, + "sum", + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=True, + is_dask_array=True, + ) + @requires_dask def test_1d_blockwise_sort_optimization():