From e15e42a766a83a7a749ec77a039d588fbf01a05d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 28 Dec 2023 12:32:18 -0700 Subject: [PATCH 1/3] Pass expected_groups to find_group_cohorts. Skips finding the `max` of labels, since we already know that. --- asv_bench/benchmarks/cohorts.py | 15 +++++++++++++-- flox/core.py | 9 ++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/asv_bench/benchmarks/cohorts.py b/asv_bench/benchmarks/cohorts.py index fbe2af0df..d8f9a7873 100644 --- a/asv_bench/benchmarks/cohorts.py +++ b/asv_bench/benchmarks/cohorts.py @@ -9,10 +9,14 @@ class Cohorts: """Time the core reduction function.""" def setup(self, *args, **kwargs): - raise NotImplementedError + self.expected = pd.RangeIndex(self.by.max()) def time_find_group_cohorts(self): - flox.core.find_group_cohorts(self.by, [self.array.chunks[ax] for ax in self.axis]) + flox.core.find_group_cohorts( + self.by, + [self.array.chunks[ax] for ax in self.axis], + expected_groups=self.expected, + ) # The cache clear fails dependably in CI # Not sure why try: @@ -62,6 +66,7 @@ def setup(self, *args, **kwargs): self.array = dask.array.ones(self.by.shape, chunks=(350, 350)) self.axis = (-2, -1) + super().setup() class ERA5Dataset: @@ -82,12 +87,14 @@ class ERA5DayOfYear(ERA5Dataset, Cohorts): def setup(self, *args, **kwargs): super().__init__() self.by = self.time.dt.dayofyear.values + super().setup() class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts): def setup(self, *args, **kwargs): super().setup() self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 24)) + super().setup() class ERA5MonthHour(ERA5Dataset, Cohorts): @@ -102,6 +109,7 @@ def setup(self, *args, **kwargs): ) # Add one so the rechunk code is simpler and makes sense self.by = ret[0][0] + 1 + super().setup() class ERA5MonthHourRechunked(ERA5MonthHour, Cohorts): @@ -118,6 +126,7 @@ def setup(self, *args, **kwargs): self.axis = (-1,) self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4)) self.by = self.time.dt.month.values + super().setup() def rechunk(self): self.array = flox.core.rechunk_for_cohorts( @@ -138,6 +147,7 @@ def setup(self, *args, **kwargs): self.axis = (2,) self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, 1)) self.by = self.time.dt.day.values + super().setup() def codes_for_resampling(group_as_index, freq): @@ -159,3 +169,4 @@ def setup(self, *args, **kwargs): self.axis = (2,) self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, 10)) self.by = codes_for_resampling(index, freq="5D") + super().setup() diff --git a/flox/core.py b/flox/core.py index 8edd7c22b..db584d552 100644 --- a/flox/core.py +++ b/flox/core.py @@ -215,7 +215,7 @@ def slices_from_chunks(chunks): @memoize -def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: +def find_group_cohorts(labels, chunks, merge: bool = True, expected_groups: None | pd.RangeIndex = None) -> dict: """ Finds groups labels that occur together aka "cohorts" @@ -246,7 +246,10 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: nchunks = math.prod(len(c) for c in chunks) # assumes that `labels` are factorized - nlabels = labels.max() + 1 + if expected_groups is None: + nlabels = labels.max() + 1 + else: + nlabels = expected_groups[-1] + 1 labels = np.broadcast_to(labels, shape[-labels.ndim :]) @@ -1547,7 +1550,7 @@ def dask_groupby_agg( elif method == "cohorts": chunks_cohorts = find_group_cohorts( - by_input, [array.chunks[ax] for ax in axis], merge=True + by_input, [array.chunks[ax] for ax in axis], merge=True, expected_groups=expected_groups, ) reduced_ = [] groups_ = [] From e19083f4cb85cb60d8d4bf31be9cf44b529113fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Dec 2023 20:42:23 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flox/core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index db584d552..f60956e7c 100644 --- a/flox/core.py +++ b/flox/core.py @@ -215,7 +215,9 @@ def slices_from_chunks(chunks): @memoize -def find_group_cohorts(labels, chunks, merge: bool = True, expected_groups: None | pd.RangeIndex = None) -> dict: +def find_group_cohorts( + labels, chunks, merge: bool = True, expected_groups: None | pd.RangeIndex = None +) -> dict: """ Finds groups labels that occur together aka "cohorts" @@ -1550,7 +1552,10 @@ def dask_groupby_agg( elif method == "cohorts": chunks_cohorts = find_group_cohorts( - by_input, [array.chunks[ax] for ax in axis], merge=True, expected_groups=expected_groups, + by_input, + [array.chunks[ax] for ax in axis], + merge=True, + expected_groups=expected_groups, ) reduced_ = [] groups_ = [] From 9158b4e85055201b3517d19b3e09a085ac9210b0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 28 Dec 2023 14:46:35 -0700 Subject: [PATCH 3/3] Fixes --- flox/core.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/flox/core.py b/flox/core.py index f60956e7c..d818718ed 100644 --- a/flox/core.py +++ b/flox/core.py @@ -276,9 +276,18 @@ def find_group_cohorts( cols_array = np.concatenate(cols) data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape) bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels)) + CHUNK_AXIS, LABEL_AXIS = 0, 1 + + chunks_per_label = bitmask.sum(axis=CHUNK_AXIS) + # can happen when `expected_groups` is passed but not all labels are present + # (binning, resampling) + present_labels = chunks_per_label != 0 + if not present_labels.all(): + bitmask = bitmask[..., present_labels] + label_chunks = { lab: bitmask.indices[slice(bitmask.indptr[lab], bitmask.indptr[lab + 1])] - for lab in range(nlabels) + for lab in range(bitmask.shape[-1]) } ## numpy bitmask approach, faster than finding uniques, but lots of memory @@ -308,9 +317,9 @@ def invert(x) -> tuple[np.ndarray, ...]: # If our dataset has chunksize one along the axis, # then no merging is possible. single_chunks = all(all(a == 1 for a in ac) for ac in chunks) - one_group_per_chunk = (bitmask.sum(axis=1) == 1).all() + one_group_per_chunk = (bitmask.sum(axis=LABEL_AXIS) == 1).all() # every group is contained to one block, we should be using blockwise here. - every_group_one_block = (bitmask.sum(axis=0) == 1).all() + every_group_one_block = (chunks_per_label == 1).all() if every_group_one_block or one_group_per_chunk or single_chunks or not merge: return chunks_cohorts