From 026b78effb4da3335f1eef7c294fe360197f8543 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Tue, 29 Nov 2022 20:55:57 +0100 Subject: [PATCH] Fix concurrency issues in k-means++ initialization (#1048) The cub calls were previously launched on the default stream whereas the rest was launched in the stream attached to the raft handle, and there was a missing synchronization after a D2H copy of the index that is used immediately after. Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1048 --- cpp/include/raft/cluster/detail/kmeans.cuh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 3d71db96c5..060d05a333 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -226,7 +226,8 @@ void kmeansPlusPlus(const raft::handle_t& handle, temp_storage_bytes, costPerCandidate.data_handle(), minClusterIndexAndDistance.data(), - costPerCandidate.extent(0)); + costPerCandidate.extent(0), + stream); // Allocate temporary storage workspace.resize(temp_storage_bytes, stream); @@ -236,10 +237,12 @@ void kmeansPlusPlus(const raft::handle_t& handle, temp_storage_bytes, costPerCandidate.data_handle(), minClusterIndexAndDistance.data(), - costPerCandidate.extent(0)); + costPerCandidate.extent(0), + stream); int bestCandidateIdx = -1; raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); + handle.sync_stream(); /// <<< End of Step-3 >>> /// <<< Step-4 >>>: C = C U {x}