diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 3d71db96c5..060d05a333 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -226,7 +226,8 @@ void kmeansPlusPlus(const raft::handle_t& handle, temp_storage_bytes, costPerCandidate.data_handle(), minClusterIndexAndDistance.data(), - costPerCandidate.extent(0)); + costPerCandidate.extent(0), + stream); // Allocate temporary storage workspace.resize(temp_storage_bytes, stream); @@ -236,10 +237,12 @@ void kmeansPlusPlus(const raft::handle_t& handle, temp_storage_bytes, costPerCandidate.data_handle(), minClusterIndexAndDistance.data(), - costPerCandidate.extent(0)); + costPerCandidate.extent(0), + stream); int bestCandidateIdx = -1; raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); + handle.sync_stream(); /// <<< End of Step-3 >>> /// <<< Step-4 >>>: C = C U {x}