Skip to content

Commit

Permalink
Fix concurrency issues in k-means++ initialization (#1048)
Browse files Browse the repository at this point in the history
The cub calls were previously launched on the default stream whereas the rest was launched in the stream attached to the raft handle, and there was a missing synchronization after a D2H copy of the index that is used immediately after.

Authors:
  - Louis Sugy (https://github.com/Nyrio)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1048
  • Loading branch information
Nyrio authored Nov 29, 2022
1 parent 433972b commit 026b78e
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ void kmeansPlusPlus(const raft::handle_t& handle,
temp_storage_bytes,
costPerCandidate.data_handle(),
minClusterIndexAndDistance.data(),
costPerCandidate.extent(0));
costPerCandidate.extent(0),
stream);

// Allocate temporary storage
workspace.resize(temp_storage_bytes, stream);
Expand All @@ -236,10 +237,12 @@ void kmeansPlusPlus(const raft::handle_t& handle,
temp_storage_bytes,
costPerCandidate.data_handle(),
minClusterIndexAndDistance.data(),
costPerCandidate.extent(0));
costPerCandidate.extent(0),
stream);

int bestCandidateIdx = -1;
raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream);
handle.sync_stream();
/// <<< End of Step-3 >>>

/// <<< Step-4 >>>: C = C U {x}
Expand Down

0 comments on commit 026b78e

Please sign in to comment.