Skip to content

Commit

Permalink
Add scans (#370)
Browse files Browse the repository at this point in the history
* Add scans

* grouped reduce

* Some fixes.

* Updates for ffill

* Better ffill

* Support numpy

* cleanup

* more tests

* Fix ffill

* [WIP] expand tests

* Fixes. we need two versions of binary_op

* Fix ffill again

* Disable cumsum for now.

* Fixes.

* Fix tests: Remove overflowing test cases, proper fill_value

* typing

* Fix tests

* Try and avoid some roundoff error

* Skip float32 for cumsum

* fix min deps test

* Another fix

* Silence warnings

* Cleanup

* Add docs

* fix

* bfill

* Fix test

* hypothesis: Better CI profile

* Small change.

* Add hypothesis to all envs

* Generate chunking along all dimensions

* lint

* more guards

* more guards

* fix

* Fix typing

* cleanup

* fix mypy

* Add comments
  • Loading branch information
dcherian authored Jul 27, 2024
1 parent 04338d4 commit 1c10b74
Show file tree
Hide file tree
Showing 15 changed files with 688 additions and 49 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
- name: Run Tests
id: status
run: |
pytest -n auto --cov=./ --cov-report=xml
pytest -n auto --cov=./ --cov-report=xml --hypothesis-profile ci
- name: Upload code coverage to Codecov
uses: codecov/[email protected]
with:
Expand Down
1 change: 1 addition & 0 deletions ci/minimal-requirements.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ channels:
- conda-forge
dependencies:
- codecov
- hypothesis
- pip
- pytest
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions ci/no-dask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
dependencies:
- codecov
- pandas
- hypothesis
- cftime
- numpy>=1.22
- scipy
Expand Down
1 change: 1 addition & 0 deletions ci/no-numba.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- cftime
- codecov
- dask-core
- hypothesis
- pandas
- numpy>=1.22
- scipy
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Functions
:toctree: generated/

groupby_reduce
groupby_scan
xarray.xarray_reduce

Rechunking
Expand Down Expand Up @@ -40,5 +41,7 @@ Aggregation Objects
:toctree: generated/

Aggregation
Scan

aggregations.sum_
aggregations.nansum
2 changes: 1 addition & 1 deletion docs/source/user-stories/hourly-climatology.html
Original file line number Diff line number Diff line change
Expand Up @@ -14033,7 +14033,7 @@
"result = _execute_task(task, cache)\n",
"return func(*(_execute_task(a, cache) for a in args))\n",
"ret = self.first(*args, **kwargs)\n",
"group_idx, array = _prepare_for_flox(group_idx, array)\n",
"group_idx, array, _ = _prepare_for_flox(group_idx, array)\n",
"return group_idx, found_groups, grp_shape, ngroups, size, props\n",
"",
"return _wrapfunc(a, 'searchsorted', v, side=side, sorter=sorter)\n",
Expand Down
2 changes: 1 addition & 1 deletion flox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Top-level module for flox ."""
from . import cache
from .aggregations import Aggregation # noqa
from .core import groupby_reduce, rechunk_for_blockwise, rechunk_for_cohorts # noqa
from .core import groupby_reduce, groupby_scan, rechunk_for_blockwise, rechunk_for_cohorts # noqa


def _get_version():
Expand Down
29 changes: 28 additions & 1 deletion flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ def _prepare_for_flox(group_idx, array):
issorted = (group_idx[:-1] <= group_idx[1:]).all()
if issorted:
ordered_array = array
perm = slice(None)
else:
perm = group_idx.argsort(kind="stable")
group_idx = group_idx[..., perm]
ordered_array = array[..., perm]
return group_idx, ordered_array
return group_idx, ordered_array, perm


def _lerp(a, b, *, t, dtype, out=None):
Expand Down Expand Up @@ -226,3 +227,29 @@ def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None
with np.errstate(invalid="ignore", divide="ignore"):
out /= nanlen(group_idx, array, size=size, axis=axis, fill_value=0)
return out


def ffill(group_idx, array, *, axis, **kwargs):
group_idx, array, perm = _prepare_for_flox(group_idx, array)
shape = array.shape
ndim = array.ndim
assert axis == (ndim - 1), (axis, ndim - 1)

flag = np.concatenate((np.array([True], like=array), group_idx[1:] != group_idx[:-1]))
(group_starts,) = flag.nonzero()

# https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array
mask = np.isnan(array)
# modified from the SO answer, just reset the index at the start of every group!
mask[..., np.asarray(group_starts)] = False

idx = np.where(mask, 0, np.arange(shape[axis]))
np.maximum.accumulate(idx, axis=axis, out=idx)
slc = [
np.arange(k)[tuple([slice(None) if dim == i else np.newaxis for dim in range(ndim)])]
for i, k in enumerate(shape)
]
slc[axis] = idx

invert_perm = slice(None) if isinstance(perm, slice) else np.argsort(perm, kind="stable")
return array[tuple(slc)][..., invert_perm]
181 changes: 179 additions & 2 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import copy
import logging
import warnings
from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict

import numpy as np
import pandas as pd
from numpy.typing import ArrayLike, DTypeLike

from . import aggregate_flox, aggregate_npg, xrutils
Expand All @@ -19,6 +21,7 @@


logger = logging.getLogger("flox")
T_ScanBinaryOpMode = Literal["apply_binary_op", "concat_then_scan"]


def _is_arg_reduction(func: str | Aggregation) -> bool:
Expand Down Expand Up @@ -63,6 +66,9 @@ def generic_aggregate(
dtype=None,
**kwargs,
):
if func == "identity":
return array

if engine == "flox":
try:
method = getattr(aggregate_flox, func)
Expand Down Expand Up @@ -567,7 +573,171 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)

aggregations = {

@dataclass
class Scan:
# This dataclass is separate from Aggregations since there's not much in common
# between reductions and scans
name: str
# binary operation (e.g. np.add)
# Must be None for mode="concat_then_scan"
binary_op: Callable | None
# in-memory grouped scan function (e.g. cumsum)
scan: str
# Grouped reduction that yields the last result of the scan (e.g. sum)
reduction: str
# Identity element
identity: Any
# dtype of result
dtype: Any = None
# "Mode" of applying binary op.
# for np.add we apply the op directly to the `state` array and the `current` array.
# for ffill, bfill we concat `state` to `current` and then run the scan again.
mode: T_ScanBinaryOpMode = "apply_binary_op"
preprocess: Callable | None = None
finalize: Callable | None = None


def concatenate(arrays: Sequence[AlignedArrays], axis=-1, out=None) -> AlignedArrays:
group_idx = np.concatenate([a.group_idx for a in arrays], axis=axis)
array = np.concatenate([a.array for a in arrays], axis=axis)
return AlignedArrays(array=array, group_idx=group_idx)


@dataclass
class AlignedArrays:
"""Simple Xarray DataArray type data class with two aligned arrays."""

array: np.ndarray
group_idx: np.ndarray

def __post_init__(self):
assert self.array.shape[-1] == self.group_idx.size

def last(self) -> AlignedArrays:
from flox.core import chunk_reduce

reduced = chunk_reduce(
self.array,
self.group_idx,
func=("nanlast",),
axis=-1,
# TODO: automate?
engine="flox",
dtype=self.array.dtype,
fill_value=_get_fill_value(self.array.dtype, dtypes.NA),
expected_groups=None,
)
return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"])


@dataclass
class ScanState:
"""Dataclass representing intermediates for scan."""

# last value of each group seen so far
state: AlignedArrays | None
# intermediate result
result: AlignedArrays | None

def __post_init__(self):
assert (self.state is not None) or (self.result is not None)


def reverse(a: AlignedArrays) -> AlignedArrays:
a.group_idx = a.group_idx[::-1]
a.array = a.array[::-1]
return a


def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) -> ScanState:
from .core import reindex_

assert left_state.state is not None
left = left_state.state
right = right_state.result if right_state.result is not None else right_state.state
assert right is not None

if agg.mode == "apply_binary_op":
assert agg.binary_op is not None
# Implements groupby binary operation.
reindexed = reindex_(
left.array,
from_=pd.Index(left.group_idx),
# can't use right.group_idx since we need to do the indexing later
to=pd.RangeIndex(right.group_idx.max() + 1),
fill_value=agg.identity,
axis=-1,
)
result = AlignedArrays(
array=agg.binary_op(reindexed[..., right.group_idx], right.array),
group_idx=right.group_idx,
)

elif agg.mode == "concat_then_scan":
# Implements the binary op portion of the scan as a concatenate-then-scan.
# This is useful for `ffill`, and presumably more generalized scans.
assert agg.binary_op is None
concat = concatenate([left, right], axis=-1)
final_value = generic_aggregate(
concat.group_idx,
concat.array,
func=agg.scan,
axis=concat.array.ndim - 1,
engine="flox",
fill_value=agg.identity,
)
result = AlignedArrays(
array=final_value[..., left.group_idx.size :], group_idx=right.group_idx
)
else:
raise ValueError(f"Unknown binary op application mode: {agg.mode!r}")

# This is quite important. We need to update the state seen so far and propagate that.
# So we must account for what we know when entering this function: i.e. `left`
# TODO: this is a bit wasteful since it will sort again, but for now let's focus on
# correctness and DRY
lasts = concatenate([left, result]).last()

return ScanState(
state=lasts,
# The binary op is called on the results of the reduction too when building up the tree.
# We need to be careful and assign those results only to `state` and not the final result.
# Up above, `result` is privileged when it exists.
result=None if right_state.result is None else result,
)


# TODO: numpy_groupies cumsum is a broken when NaNs are present.
# cumsum = Scan("cumsum", binary_op=np.add, reduction="sum", scan="cumsum", identity=0)
nancumsum = Scan("nancumsum", binary_op=np.add, reduction="nansum", scan="nancumsum", identity=0)
# ffill uses the identity for scan, and then at the binary-op state,
# we concatenate the blockwise-reduced values with the original block,
# and then execute the scan
# TODO: consider adding chunk="identity" here, like with reductions as an optimization
ffill = Scan(
"ffill",
binary_op=None,
reduction="nanlast",
scan="ffill",
identity=np.nan,
mode="concat_then_scan",
)
bfill = Scan(
"bfill",
binary_op=None,
reduction="nanlast",
scan="ffill",
identity=np.nan,
mode="concat_then_scan",
preprocess=reverse,
finalize=reverse,
)
# TODO: not implemented in numpy_groupies
# cumprod = Scan("cumprod", binary_op=np.multiply, preop="prod", scan="cumprod")


AGGREGATIONS: dict[str, Aggregation | Scan] = {
"any": any_,
"all": all_,
"count": count,
Expand Down Expand Up @@ -599,6 +769,10 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
"nanquantile": nanquantile,
"mode": mode,
"nanmode": nanmode,
# "cumsum": cumsum,
"nancumsum": nancumsum,
"ffill": ffill,
"bfill": bfill,
}


Expand All @@ -610,11 +784,14 @@ def _initialize_aggregation(
min_count: int,
finalize_kwargs: dict[Any, Any] | None,
) -> Aggregation:
agg: Aggregation
if not isinstance(func, Aggregation):
try:
# TODO: need better interface
# we set dtype, fillvalue on reduction later. so deepcopy now
agg = copy.deepcopy(aggregations[func])
agg_ = copy.deepcopy(AGGREGATIONS[func])
assert isinstance(agg_, Aggregation)
agg = agg_
except KeyError:
raise NotImplementedError(f"Reduction {func!r} not implemented yet")
elif isinstance(func, Aggregation):
Expand Down
Loading

0 comments on commit 1c10b74

Please sign in to comment.