diff --git a/cpp/src/kmeans/kmeans_mg_impl.cuh b/cpp/src/kmeans/kmeans_mg_impl.cuh index 0fdfc6b826..c17cc9467b 100644 --- a/cpp/src/kmeans/kmeans_mg_impl.cuh +++ b/cpp/src/kmeans/kmeans_mg_impl.cuh @@ -636,7 +636,7 @@ void fit(const raft::handle_t& handle, centroids.extent(0), itr_wt, itr_wt, - wtInCluster.size(), + wtInCluster.extent(0), newCentroids.data_handle(), [=] __device__(raft::KeyValuePair map) { // predicate // copy when the # of samples in the cluster is 0