diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index 72df13d760..c6a3aea0cf 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -839,6 +839,10 @@ inline auto arrange_fine_clusters(uint32_t n_clusters, * As a result, the fine clusters are what is returned by `build_hierarchical`; * this function returns the total number of fine clusters, which can be checked to be * the same as the requested number of clusters. + * + * Note: this function uses at most `fine_clusters_nums_max` points per mesocluster for training; + * if one of the clusters is larger than that (as given by `mesocluster_sizes`), the extra data + * is ignored and a warning is reported. */ template auto build_fine_clusters(const handle_t& handle, @@ -880,8 +884,8 @@ auto build_fine_clusters(const handle_t& handle, uint32_t n_clusters_done = 0; for (uint32_t i = 0; i < n_mesoclusters; i++) { uint32_t k = 0; - for (IdxT j = 0; j < n_rows; j++) { - if (labels_mptr[j] == (LabelT)i) { mc_trainset_ids[k++] = j; } + for (IdxT j = 0; j < n_rows && k < mesocluster_size_max; j++) { + if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; } } if (k != mesocluster_sizes[i]) RAFT_LOG_WARN("Incorrect mesocluster size at %d. %d vs %d", i, k, mesocluster_sizes[i]); @@ -896,19 +900,13 @@ auto build_fine_clusters(const handle_t& handle, "Number of fine clusters must be non-zero for a non-empty mesocluster"); } - utils::copy_selected((IdxT)mesocluster_sizes[i], - (IdxT)dim, - dataset_mptr, - mc_trainset_ids, - (IdxT)dim, - mc_trainset, - (IdxT)dim, - stream); + utils::copy_selected( + (IdxT)k, (IdxT)dim, dataset_mptr, mc_trainset_ids, (IdxT)dim, mc_trainset, (IdxT)dim, stream); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { thrust::gather(handle.get_thrust_policy(), mc_trainset_ids, - mc_trainset_ids + mesocluster_sizes[i], + mc_trainset_ids + k, dataset_norm_mptr, mc_trainset_norm); } @@ -917,7 +915,7 @@ auto build_fine_clusters(const handle_t& handle, n_iters, dim, mc_trainset, - mesocluster_sizes[i], + k, fine_clusters_nums[i], mc_trainset_ccenters.data(), mc_trainset_labels.data(), @@ -1036,10 +1034,19 @@ void build_hierarchical(const handle_t& handle, auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); - if (mesocluster_size_max * n_mesoclusters > 2 * n_rows) { - RAFT_LOG_WARN("build_hierarchical: built unbalanced mesoclusters"); + const auto mesocluster_size_max_balanced = uint32_t(div_rounding_up_safe( + 2lu * size_t(n_rows), std::max(size_t(n_mesoclusters), 1lu))); + if (mesocluster_size_max > mesocluster_size_max_balanced) { + RAFT_LOG_WARN( + "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " + "At most %u points will be used for training within each mesocluster. " + "Consider increasing the number of training iterations `n_iters`.", + mesocluster_size_max, + mesocluster_size_max_balanced, + mesocluster_size_max_balanced); RAFT_LOG_TRACE_VEC(mesocluster_sizes, n_mesoclusters); RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters); + mesocluster_size_max = mesocluster_size_max_balanced; } auto n_clusters_done = build_fine_clusters(handle,