From c8f012dbf906a855cd148f4bf4fde0ff5d785005 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Mon, 12 Dec 2022 17:40:01 +0100 Subject: [PATCH] Solve race condition in raft::random::discrete --- cpp/include/raft/random/detail/rng_impl.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 0a843857e1..cd465e634a 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -261,10 +261,11 @@ std::enable_if_t> discrete(RngState& rng_state, // Compute the cumulative sums of the weights size_t temp_storage_bytes = 0; rmm::device_uvector weights_csum(len, stream); - cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, weights, weights_csum.data(), len); + cub::DeviceScan::InclusiveSum( + nullptr, temp_storage_bytes, weights, weights_csum.data(), len, stream); rmm::device_uvector temp_storage(temp_storage_bytes, stream); cub::DeviceScan::InclusiveSum( - temp_storage.data(), temp_storage_bytes, weights, weights_csum.data(), len); + temp_storage.data(), temp_storage_bytes, weights, weights_csum.data(), len, stream); // Sample indices with replacement RAFT_CALL_RNG_FUNC(rng_state,