diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index f13dcd8cc6..0577d24349 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -41,6 +41,7 @@ #include #include +#include #include #include #include @@ -430,26 +431,18 @@ auto calculate_offsets_and_indices(IdxT n_rows, IdxT* data_indices, rmm::cuda_stream_view stream) -> uint32_t { - auto exec_policy = rmm::exec_policy(stream); - uint32_t max_cluster_size = 0; - rmm::device_scalar max_cluster_size_dev_buf(stream); - auto max_cluster_size_dev = max_cluster_size_dev_buf.data(); - update_device(max_cluster_size_dev, &max_cluster_size, 1, stream); + auto exec_policy = rmm::exec_policy(stream); // Calculate the offsets IdxT cumsum = 0; update_device(cluster_offsets, &cumsum, 1, stream); - thrust::inclusive_scan(exec_policy, - cluster_sizes, - cluster_sizes + n_lists, - cluster_offsets + 1, - [max_cluster_size_dev] __device__(IdxT s, uint32_t l) { - atomicMax(max_cluster_size_dev, l); - return s + l; - }); + thrust::inclusive_scan( + exec_policy, cluster_sizes, cluster_sizes + n_lists, cluster_offsets + 1, thrust::plus{}); update_host(&cumsum, cluster_offsets + n_lists, 1, stream); - update_host(&max_cluster_size, max_cluster_size_dev, 1, stream); + uint32_t max_cluster_size = + *thrust::max_element(exec_policy, cluster_sizes, cluster_sizes + n_lists); stream.synchronize(); RAFT_EXPECTS(cumsum == n_rows, "cluster sizes do not add up."); + RAFT_LOG_DEBUG("Max cluster size %d", max_cluster_size); rmm::device_uvector data_offsets_buf(n_lists, stream); auto data_offsets = data_offsets_buf.data(); copy(data_offsets, cluster_offsets, n_lists, stream);