diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 0a843857e1..cd465e634a 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -261,10 +261,11 @@ std::enable_if_t> discrete(RngState& rng_state, // Compute the cumulative sums of the weights size_t temp_storage_bytes = 0; rmm::device_uvector weights_csum(len, stream); - cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, weights, weights_csum.data(), len); + cub::DeviceScan::InclusiveSum( + nullptr, temp_storage_bytes, weights, weights_csum.data(), len, stream); rmm::device_uvector temp_storage(temp_storage_bytes, stream); cub::DeviceScan::InclusiveSum( - temp_storage.data(), temp_storage_bytes, weights, weights_csum.data(), len); + temp_storage.data(), temp_storage_bytes, weights, weights_csum.data(), len, stream); // Sample indices with replacement RAFT_CALL_RNG_FUNC(rng_state,