Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass expected_groups to find_group_cohorts. #303

Merged
merged 3 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ class Cohorts:
"""Time the core reduction function."""

def setup(self, *args, **kwargs):
raise NotImplementedError
self.expected = pd.RangeIndex(self.by.max())

def time_find_group_cohorts(self):
flox.core.find_group_cohorts(self.by, [self.array.chunks[ax] for ax in self.axis])
flox.core.find_group_cohorts(
self.by,
[self.array.chunks[ax] for ax in self.axis],
expected_groups=self.expected,
)
# The cache clear fails dependably in CI
# Not sure why
try:
Expand Down Expand Up @@ -62,6 +66,7 @@ def setup(self, *args, **kwargs):

self.array = dask.array.ones(self.by.shape, chunks=(350, 350))
self.axis = (-2, -1)
super().setup()


class ERA5Dataset:
Expand All @@ -82,12 +87,14 @@ class ERA5DayOfYear(ERA5Dataset, Cohorts):
def setup(self, *args, **kwargs):
super().__init__()
self.by = self.time.dt.dayofyear.values
super().setup()


class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts):
def setup(self, *args, **kwargs):
super().setup()
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 24))
super().setup()


class ERA5MonthHour(ERA5Dataset, Cohorts):
Expand All @@ -102,6 +109,7 @@ def setup(self, *args, **kwargs):
)
# Add one so the rechunk code is simpler and makes sense
self.by = ret[0][0] + 1
super().setup()


class ERA5MonthHourRechunked(ERA5MonthHour, Cohorts):
Expand All @@ -118,6 +126,7 @@ def setup(self, *args, **kwargs):
self.axis = (-1,)
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4))
self.by = self.time.dt.month.values
super().setup()

def rechunk(self):
self.array = flox.core.rechunk_for_cohorts(
Expand All @@ -138,6 +147,7 @@ def setup(self, *args, **kwargs):
self.axis = (2,)
self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, 1))
self.by = self.time.dt.day.values
super().setup()


def codes_for_resampling(group_as_index, freq):
Expand All @@ -159,3 +169,4 @@ def setup(self, *args, **kwargs):
self.axis = (2,)
self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, 10))
self.by = codes_for_resampling(index, freq="5D")
super().setup()
29 changes: 23 additions & 6 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def slices_from_chunks(chunks):


@memoize
def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
def find_group_cohorts(
labels, chunks, merge: bool = True, expected_groups: None | pd.RangeIndex = None
) -> dict:
"""
Finds groups labels that occur together aka "cohorts"

Expand Down Expand Up @@ -246,7 +248,10 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
nchunks = math.prod(len(c) for c in chunks)

# assumes that `labels` are factorized
nlabels = labels.max() + 1
if expected_groups is None:
nlabels = labels.max() + 1
else:
nlabels = expected_groups[-1] + 1

labels = np.broadcast_to(labels, shape[-labels.ndim :])

Expand All @@ -271,9 +276,18 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
cols_array = np.concatenate(cols)
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))
CHUNK_AXIS, LABEL_AXIS = 0, 1

chunks_per_label = bitmask.sum(axis=CHUNK_AXIS)
# can happen when `expected_groups` is passed but not all labels are present
# (binning, resampling)
present_labels = chunks_per_label != 0
if not present_labels.all():
bitmask = bitmask[..., present_labels]

label_chunks = {
lab: bitmask.indices[slice(bitmask.indptr[lab], bitmask.indptr[lab + 1])]
for lab in range(nlabels)
for lab in range(bitmask.shape[-1])
}

## numpy bitmask approach, faster than finding uniques, but lots of memory
Expand Down Expand Up @@ -303,9 +317,9 @@ def invert(x) -> tuple[np.ndarray, ...]:
# If our dataset has chunksize one along the axis,
# then no merging is possible.
single_chunks = all(all(a == 1 for a in ac) for ac in chunks)
one_group_per_chunk = (bitmask.sum(axis=1) == 1).all()
one_group_per_chunk = (bitmask.sum(axis=LABEL_AXIS) == 1).all()
# every group is contained to one block, we should be using blockwise here.
every_group_one_block = (bitmask.sum(axis=0) == 1).all()
every_group_one_block = (chunks_per_label == 1).all()
if every_group_one_block or one_group_per_chunk or single_chunks or not merge:
return chunks_cohorts

Expand Down Expand Up @@ -1547,7 +1561,10 @@ def dask_groupby_agg(

elif method == "cohorts":
chunks_cohorts = find_group_cohorts(
by_input, [array.chunks[ax] for ax in axis], merge=True
by_input,
[array.chunks[ax] for ax in axis],
merge=True,
expected_groups=expected_groups,
)
reduced_ = []
groups_ = []
Expand Down
Loading