Skip to content

Commit

Permalink
Use threadpool for finding labels in chunk (#327)
Browse files Browse the repository at this point in the history
* Use threadpool for finding labels in chunk

Great when we have lots of decent size chunks, particularly the NWM
county groupby: 600ms -> 400ms.

```
| Before [0cccb90] <optimize-again>   | After [38fe8a6c] <threadpool>   |   Ratio | Benchmark (Parameter)                       |
|--------------------------------------|---------------------------------|---------|---------------------------------------------|
| 3.50±0.2ms                           | 2.93±0.07ms                     |    0.84 | cohorts.PerfectMonthly.time_graph_construct |
| 20.0±1ms                             | 9.66±1ms                        |    0.48 | cohorts.NWMMidwest.time_find_group_cohorts  |
```

* Add threshold

* Fix + comment

* Fix benchmark.

* Tweak threshold

* Small cleanup

* Comment

* Try single allocation

* Revert "Try single allocation"

This reverts commit c6b93367e2024e60d77af24a69d177670a040dfc.

* cleanup
  • Loading branch information
dcherian authored May 2, 2024
1 parent eb3c0ef commit c398f4e
Showing 1 changed file with 43 additions and 7 deletions.
50 changes: 43 additions & 7 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from collections import namedtuple
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce
from itertools import product
from numbers import Integral
Expand Down Expand Up @@ -261,6 +262,7 @@ def make_bitmask(rows, cols):
assert isinstance(labels, np.ndarray)
shape = tuple(sum(c) for c in chunks)
nchunks = math.prod(len(c) for c in chunks)
approx_chunk_size = math.prod(c[0] for c in chunks)

# Shortcut for 1D with size-1 chunks
if shape == (nchunks,):
Expand All @@ -271,21 +273,55 @@ def make_bitmask(rows, cols):

labels = np.broadcast_to(labels, shape[-labels.ndim :])
cols = []
# Add one to handle the -1 sentinel value
label_is_present = np.zeros((nlabels + 1,), dtype=bool)
ilabels = np.arange(nlabels)
for region in slices_from_chunks(chunks):

def chunk_unique(labels, slicer, nlabels, label_is_present=None):
if label_is_present is None:
label_is_present = np.empty((nlabels + 1,), dtype=bool)
label_is_present[:] = False
subset = labels[slicer]
# This is a quite fast way to find unique integers, when we know how many there are
# inspired by a similar idea in numpy_groupies for first, last
# instead of explicitly finding uniques, repeatedly write True to the same location
subset = labels[region]
# The reshape is not strictly necessary but is about 100ms faster on a test problem.
label_is_present[subset.reshape(-1)] = True
# skip the -1 sentinel by slicing
# Faster than np.argwhere by a lot
uniques = ilabels[label_is_present[:-1]]
cols.append(uniques)
label_is_present[:] = False
return uniques

# TODO: refine this heuristic.
# The general idea is that with the threadpool, we repeatedly allocate memory
# for `label_is_present`. We trade that off against the parallelism across number of chunks.
# For large enough number of chunks (relative to number of labels), it makes sense to
# suffer the extra allocation in exchange for parallelism.
THRESHOLD = 2
if nlabels < THRESHOLD * approx_chunk_size:
logger.debug(
"Using threadpool since num_labels %s < %d * chunksize %s",
nlabels,
THRESHOLD,
approx_chunk_size,
)
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(chunk_unique, labels, slicer, nlabels)
for slicer in slices_from_chunks(chunks)
]
cols = tuple(f.result() for f in futures)

else:
logger.debug(
"Using serial loop since num_labels %s > %d * chunksize %s",
nlabels,
THRESHOLD,
approx_chunk_size,
)
cols = []
# Add one to handle the -1 sentinel value
label_is_present = np.empty((nlabels + 1,), dtype=bool)
for region in slices_from_chunks(chunks):
uniques = chunk_unique(labels, region, nlabels, label_is_present)
cols.append(uniques)
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
cols_array = np.concatenate(cols)

Expand Down

0 comments on commit c398f4e

Please sign in to comment.