From eedb03a11f3b1c828fd5108af1b17b4e8282f0d4 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Thu, 24 Nov 2022 21:28:03 +0100 Subject: [PATCH] Fix concurrency issues in k-means++ --- cpp/include/raft/cluster/detail/kmeans.cuh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 3d71db96c5..060d05a333 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -226,7 +226,8 @@ void kmeansPlusPlus(const raft::handle_t& handle, temp_storage_bytes, costPerCandidate.data_handle(), minClusterIndexAndDistance.data(), - costPerCandidate.extent(0)); + costPerCandidate.extent(0), + stream); // Allocate temporary storage workspace.resize(temp_storage_bytes, stream); @@ -236,10 +237,12 @@ void kmeansPlusPlus(const raft::handle_t& handle, temp_storage_bytes, costPerCandidate.data_handle(), minClusterIndexAndDistance.data(), - costPerCandidate.extent(0)); + costPerCandidate.extent(0), + stream); int bestCandidateIdx = -1; raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); + handle.sync_stream(); /// <<< End of Step-3 >>> /// <<< Step-4 >>>: C = C U {x}