diff --git a/cpp/src/groupby/hash/groupby.cu b/cpp/src/groupby/hash/groupby.cu index c07833520ab..90c869b8c58 100644 --- a/cpp/src/groupby/hash/groupby.cu +++ b/cpp/src/groupby/hash/groupby.cu @@ -512,18 +512,33 @@ rmm::device_uvector extract_populated_keys(map_type const& map, { rmm::device_uvector populated_keys(num_keys, stream); - auto get_key = [] __device__(auto const& element) { return element.first; }; // first = key - auto get_key_it = thrust::make_transform_iterator(map.data(), get_key); - auto key_used = [unused = map.get_unused_key()] __device__(auto key) { return key != unused; }; - - auto end_it = thrust::copy_if(rmm::exec_policy(stream), - get_key_it, - get_key_it + map.capacity(), - populated_keys.begin(), - key_used); - - populated_keys.resize(std::distance(populated_keys.begin(), end_it), stream); + auto const get_key = [] __device__(auto const& element) { return element.first; }; // first = key + auto const key_used = [unused = map.get_unused_key()] __device__(auto key) { + return key != unused; + }; + auto key_itr = thrust::make_transform_iterator(map.data(), get_key); + + // thrust::copy_if has a bug where it cannot iterate over int-max values + // so if map.capacity() > int-max we'll call thrust::copy_if in chunks instead + auto const copy_size = + std::min(map.capacity(), static_cast(std::numeric_limits::max())); + auto const key_end = key_itr + map.capacity(); + auto pop_keys_itr = populated_keys.begin(); + + std::size_t output_size = 0; + while (key_itr != key_end) { + auto const copy_end = static_cast(std::distance(key_itr, key_end)) <= copy_size + ? key_end + : key_itr + copy_size; + auto const end_it = + thrust::copy_if(rmm::exec_policy(stream), key_itr, copy_end, pop_keys_itr, key_used); + auto const copied = std::distance(pop_keys_itr, end_it); + pop_keys_itr += copied; + output_size += copied; + key_itr = copy_end; + } + populated_keys.resize(output_size, stream); return populated_keys; }