diff --git a/flox/core.py b/flox/core.py index 094a8867c..c7455a05a 100644 --- a/flox/core.py +++ b/flox/core.py @@ -8,6 +8,7 @@ import warnings from collections import namedtuple from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor from functools import partial, reduce from itertools import product from numbers import Integral @@ -241,27 +242,43 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels): assert isinstance(labels, np.ndarray) shape = tuple(sum(c) for c in chunks) nchunks = math.prod(len(c) for c in chunks) + approx_chunk_size = math.prod(c[0] for c in chunks) labels = np.broadcast_to(labels, shape[-labels.ndim :]) - - cols = [] - # Add one to handle the -1 sentinel value - label_is_present = np.zeros((nlabels + 1,), dtype=bool) ilabels = np.arange(nlabels) - for region in slices_from_chunks(chunks): + + def chunk_unique(labels, slicer, nlabels, label_is_present=None): + if label_is_present is None: + label_is_present = np.zeros((nlabels + 1,), dtype=bool) + subset = labels[slicer] # This is a quite fast way to find unique integers, when we know how many there are # inspired by a similar idea in numpy_groupies for first, last # instead of explicitly finding uniques, repeatedly write True to the same location - subset = labels[region] - # The reshape is not strictly necessary but is about 100ms faster on a test problem. label_is_present[subset.reshape(-1)] = True # skip the -1 sentinel by slicing # Faster than np.argwhere by a lot uniques = ilabels[label_is_present[:-1]] - cols.append(uniques) - label_is_present[:] = False - rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols)) + return uniques + + if nlabels < approx_chunk_size: + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(chunk_unique, labels, slicer, nlabels) + for slicer in slices_from_chunks(chunks) + ] + cols = tuple(f.result() for f in futures) + + else: + cols = [] + # Add one to handle the -1 sentinel value + label_is_present = np.zeros((nlabels + 1,), dtype=bool) + for region in slices_from_chunks(chunks): + uniques = chunk_unique(labels, region, nlabels, label_is_present) + cols.append(uniques) + label_is_present[:] = False + cols_array = np.concatenate(cols) + rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols)) data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape) bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))