From 057fb374ceaee465748c755c5fe83c815459d669 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Wed, 14 Dec 2022 00:12:39 +0100 Subject: [PATCH] Fix race condition in `raft::random::discrete` (#1094) The stream needs to be passed to `cub::DeviceScan`. Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1094 --- cpp/include/raft/random/detail/rng_impl.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 0a843857e1..cd465e634a 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -261,10 +261,11 @@ std::enable_if_t> discrete(RngState& rng_state, // Compute the cumulative sums of the weights size_t temp_storage_bytes = 0; rmm::device_uvector weights_csum(len, stream); - cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, weights, weights_csum.data(), len); + cub::DeviceScan::InclusiveSum( + nullptr, temp_storage_bytes, weights, weights_csum.data(), len, stream); rmm::device_uvector temp_storage(temp_storage_bytes, stream); cub::DeviceScan::InclusiveSum( - temp_storage.data(), temp_storage_bytes, weights, weights_csum.data(), len); + temp_storage.data(), temp_storage_bytes, weights, weights_csum.data(), len, stream); // Sample indices with replacement RAFT_CALL_RNG_FUNC(rng_state,