Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significantly faster cohorts detection. #272

Merged
merged 10 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dask
import numpy as np
import pandas as pd
import xarray as xr

import flox
from flox.xarray import xarray_reduce


class Cohorts:
Expand All @@ -12,7 +14,7 @@ def setup(self, *args, **kwargs):
raise NotImplementedError

def time_find_group_cohorts(self):
flox.core.find_group_cohorts(self.by, self.array.chunks)
flox.core.find_group_cohorts(self.by, [self.array.chunks[ax] for ax in self.axis])
# The cache clear fails dependably in CI
# Not sure why
try:
Expand Down Expand Up @@ -125,3 +127,13 @@ class PerfectMonthlyRechunked(PerfectMonthly):
def setup(self, *args, **kwargs):
super().setup()
super().rechunk()


def time_cohorts_era5_single():
TIME = 900 # 92044 in Google ARCO ERA5
da = xr.DataArray(
dask.array.ones((TIME, 721, 1440), chunks=(1, -1, -1)),
dims=("time", "lat", "lon"),
coords=dict(time=pd.date_range("1959-01-01", freq="6H", periods=TIME)),
)
xarray_reduce(da, da.time.dt.day, method="cohorts", func="any")
21 changes: 15 additions & 6 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,20 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
# 1. First subset the array appropriately
axis = range(-labels.ndim, 0)
# Easier to create a dask array and use the .blocks property
array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks)
array = dask.array.empty(tuple(sum(c) for c in chunks), chunks=chunks)
labels = np.broadcast_to(labels, array.shape[-labels.ndim :])

# Iterate over each block and create a new block of same shape with "chunk number"
shape = tuple(array.blocks.shape[ax] for ax in axis)
blocks = np.empty(math.prod(shape), dtype=object)
for idx, block in enumerate(array.blocks.ravel()):
blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1)
# Use a numpy object array to enable assignment in the loop
# TODO: is it possible to just use a nested list?
# That is what we need for `np.block`
blocks = np.empty(shape, dtype=object)
array_chunks = tuple(np.array(c) for c in array.chunks)
for idx, blockindex in enumerate(np.ndindex(array.numblocks)):
chunkshape = tuple(c[i] for c, i in zip(array_chunks, blockindex))
blocks[blockindex] = np.full(chunkshape, idx)
which_chunk = np.block(blocks.tolist()).reshape(-1)

raveled = labels.reshape(-1)
# these are chunks where a label is present
Expand All @@ -229,7 +234,11 @@ def invert(x) -> tuple[np.ndarray, ...]:

chunks_cohorts = tlz.groupby(invert, label_chunks.keys())

if merge:
# If our dataset has chunksize one along the axis,
# then no merging is possible.
single_chunks = all((ac == 1).all() for ac in array_chunks)

if merge and not single_chunks:
# First sort by number of chunks occupied by cohort
sorted_chunks_cohorts = dict(
sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
Expand Down
Loading