Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use threadpool for finding labels in chunk #327

Merged
merged 13 commits into from
May 2, 2024
50 changes: 43 additions & 7 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from collections import namedtuple
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce
from itertools import product
from numbers import Integral
Expand Down Expand Up @@ -255,6 +256,7 @@ def make_bitmask(rows, cols):
assert isinstance(labels, np.ndarray)
shape = tuple(sum(c) for c in chunks)
nchunks = math.prod(len(c) for c in chunks)
approx_chunk_size = math.prod(c[0] for c in chunks)

# Shortcut for 1D with size-1 chunks
if shape == (nchunks,):
Expand All @@ -265,21 +267,55 @@ def make_bitmask(rows, cols):

labels = np.broadcast_to(labels, shape[-labels.ndim :])
cols = []
# Add one to handle the -1 sentinel value
label_is_present = np.zeros((nlabels + 1,), dtype=bool)
ilabels = np.arange(nlabels)
for region in slices_from_chunks(chunks):

def chunk_unique(labels, slicer, nlabels, label_is_present=None):
if label_is_present is None:
label_is_present = np.empty((nlabels + 1,), dtype=bool)
label_is_present[:] = False
subset = labels[slicer]
# This is a quite fast way to find unique integers, when we know how many there are
# inspired by a similar idea in numpy_groupies for first, last
# instead of explicitly finding uniques, repeatedly write True to the same location
subset = labels[region]
# The reshape is not strictly necessary but is about 100ms faster on a test problem.
label_is_present[subset.reshape(-1)] = True
# skip the -1 sentinel by slicing
# Faster than np.argwhere by a lot
uniques = ilabels[label_is_present[:-1]]
cols.append(uniques)
label_is_present[:] = False
return uniques

# TODO: refine this heuristic.
# The general idea is that with the threadpool, we repeatedly allocate memory
# for `label_is_present`. We trade that off against the parallelism across number of chunks.
# For large enough number of chunks (relative to number of labels), it makes sense to
# suffer the extra allocation in exchange for parallelism.
THRESHOLD = 2
if nlabels < THRESHOLD * approx_chunk_size:
logger.debug(
"Using threadpool since num_labels %s < %d * chunksize %s",
nlabels,
THRESHOLD,
approx_chunk_size,
)
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(chunk_unique, labels, slicer, nlabels)
for slicer in slices_from_chunks(chunks)
]
cols = tuple(f.result() for f in futures)

else:
logger.debug(
"Using serial loop since num_labels %s > %d * chunksize %s",
nlabels,
THRESHOLD,
approx_chunk_size,
)
cols = []
# Add one to handle the -1 sentinel value
label_is_present = np.empty((nlabels + 1,), dtype=bool)
for region in slices_from_chunks(chunks):
uniques = chunk_unique(labels, region, nlabels, label_is_present)
cols.append(uniques)
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
cols_array = np.concatenate(cols)

Expand Down
Loading