Skip to content

Commit

Permalink
Use threadpool for finding labels in chunk
Browse files Browse the repository at this point in the history
Great when we have lots of decent size chunks, particularly the NWM
county groupby: 600ms -> 400ms.

```
| Before [0cccb90] <optimize-again>   | After [38fe8a6c] <threadpool>   |   Ratio | Benchmark (Parameter)                       |
|--------------------------------------|---------------------------------|---------|---------------------------------------------|
| 3.50±0.2ms                           | 2.93±0.07ms                     |    0.84 | cohorts.PerfectMonthly.time_graph_construct |
| 20.0±1ms                             | 9.66±1ms                        |    0.48 | cohorts.NWMMidwest.time_find_group_cohorts  |
```
  • Loading branch information
dcherian committed Jan 20, 2024
1 parent 3e0653f commit 247824d
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from collections import namedtuple
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce
from itertools import product
from numbers import Integral
Expand Down Expand Up @@ -241,27 +242,43 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):
assert isinstance(labels, np.ndarray)
shape = tuple(sum(c) for c in chunks)
nchunks = math.prod(len(c) for c in chunks)
approx_chunk_size = math.prod(c[0] for c in chunks)

labels = np.broadcast_to(labels, shape[-labels.ndim :])

cols = []
# Add one to handle the -1 sentinel value
label_is_present = np.zeros((nlabels + 1,), dtype=bool)
ilabels = np.arange(nlabels)
for region in slices_from_chunks(chunks):

def chunk_unique(labels, slicer, nlabels, label_is_present=None):
if label_is_present is None:
label_is_present = np.zeros((nlabels + 1,), dtype=bool)
subset = labels[slicer]
# This is a quite fast way to find unique integers, when we know how many there are
# inspired by a similar idea in numpy_groupies for first, last
# instead of explicitly finding uniques, repeatedly write True to the same location
subset = labels[region]
# The reshape is not strictly necessary but is about 100ms faster on a test problem.
label_is_present[subset.reshape(-1)] = True
# skip the -1 sentinel by slicing
# Faster than np.argwhere by a lot
uniques = ilabels[label_is_present[:-1]]
cols.append(uniques)
label_is_present[:] = False
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
return uniques

if nlabels < approx_chunk_size:
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(chunk_unique, labels, slicer, nlabels)
for slicer in slices_from_chunks(chunks)
]
cols = tuple(f.result() for f in futures)

else:
cols = []
# Add one to handle the -1 sentinel value
label_is_present = np.zeros((nlabels + 1,), dtype=bool)
for region in slices_from_chunks(chunks):
uniques = chunk_unique(labels, region, nlabels, label_is_present)
cols.append(uniques)
label_is_present[:] = False

cols_array = np.concatenate(cols)
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))

Expand Down

0 comments on commit 247824d

Please sign in to comment.