From 10b4b803cc270d823d73597d114dcd892c896e7f Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 20 Jan 2023 10:22:43 +0100 Subject: [PATCH 1/5] Protect balanced k-means from allocating too large temporary buffers when the mesoclusters turn out to be unbalanced --- .../knn/detail/ann_kmeans_balanced.cuh | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index 72df13d760..060f1e940c 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -880,8 +880,8 @@ auto build_fine_clusters(const handle_t& handle, uint32_t n_clusters_done = 0; for (uint32_t i = 0; i < n_mesoclusters; i++) { uint32_t k = 0; - for (IdxT j = 0; j < n_rows; j++) { - if (labels_mptr[j] == (LabelT)i) { mc_trainset_ids[k++] = j; } + for (IdxT j = 0; j < n_rows && k < mesocluster_size_max; j++) { + if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; } } if (k != mesocluster_sizes[i]) RAFT_LOG_WARN("Incorrect mesocluster size at %d. %d vs %d", i, k, mesocluster_sizes[i]); @@ -896,19 +896,13 @@ auto build_fine_clusters(const handle_t& handle, "Number of fine clusters must be non-zero for a non-empty mesocluster"); } - utils::copy_selected((IdxT)mesocluster_sizes[i], - (IdxT)dim, - dataset_mptr, - mc_trainset_ids, - (IdxT)dim, - mc_trainset, - (IdxT)dim, - stream); + utils::copy_selected( + (IdxT)k, (IdxT)dim, dataset_mptr, mc_trainset_ids, (IdxT)dim, mc_trainset, (IdxT)dim, stream); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { thrust::gather(handle.get_thrust_policy(), mc_trainset_ids, - mc_trainset_ids + mesocluster_sizes[i], + mc_trainset_ids + k, dataset_norm_mptr, mc_trainset_norm); } @@ -917,7 +911,7 @@ auto build_fine_clusters(const handle_t& handle, n_iters, dim, mc_trainset, - mesocluster_sizes[i], + k, fine_clusters_nums[i], mc_trainset_ccenters.data(), mc_trainset_labels.data(), @@ -1036,10 +1030,19 @@ void build_hierarchical(const handle_t& handle, auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); - if (mesocluster_size_max * n_mesoclusters > 2 * n_rows) { - RAFT_LOG_WARN("build_hierarchical: built unbalanced mesoclusters"); + const auto mesocluster_size_max_balanced = + uint32_t(size_t{n_rows * 2lu} / std::max(n_mesoclusters, 1lu)); + if (mesocluster_size_max > mesocluster_size_max_balanced) { + RAFT_LOG_WARN( + "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " + "At most %u points will be used for training within each mesocluster. " + "Consider increasing the number of training iterations `n_iters`.", + mesocluster_size_max, + mesocluster_size_max_balanced, + mesocluster_size_max_balanced); RAFT_LOG_TRACE_VEC(mesocluster_sizes, n_mesoclusters); RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters); + mesocluster_size_max = mesocluster_size_max_balanced; } auto n_clusters_done = build_fine_clusters(handle, From 1ca100e9cbe917708c7647a612f138e0c740696a Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 20 Jan 2023 12:35:51 +0100 Subject: [PATCH 2/5] Add a note to the docstring regarding the updated behavior --- cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index 060f1e940c..61562b41da 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -839,6 +839,10 @@ inline auto arrange_fine_clusters(uint32_t n_clusters, * As a result, the fine clusters are what is returned by `build_hierarchical`; * this function returns the total number of fine clusters, which can be checked to be * the same as the requested number of clusters. + * + * Note: this function uses at most `fine_clusters_nums_max` points per mesocluster for training; + * if one of the clusters is larger than that (as given by `mesocluster_sizes`), the extra data + * is ignored and a warning is reported. */ template auto build_fine_clusters(const handle_t& handle, From 2747745d0b503850607f24cacc405a9586301c7c Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 20 Jan 2023 15:08:44 +0100 Subject: [PATCH 3/5] Fix an edge case of very small dataset --- cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index 61562b41da..d6d373c2fc 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -1034,8 +1034,8 @@ void build_hierarchical(const handle_t& handle, auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); - const auto mesocluster_size_max_balanced = - uint32_t(size_t{n_rows * 2lu} / std::max(n_mesoclusters, 1lu)); + const auto mesocluster_size_max_balanced = uint32_t( + div_rounding_up_safe(2lu * size_t{n_rows}, std::max(n_mesoclusters, 1lu))); if (mesocluster_size_max > mesocluster_size_max_balanced) { RAFT_LOG_WARN( "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " From 882f045ba80ecbd345a3aa8ced02b0b201097a50 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 20 Jan 2023 15:59:45 +0100 Subject: [PATCH 4/5] Fix an edge case of very small dataset --- cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index d6d373c2fc..9d51e94751 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -1034,8 +1034,8 @@ void build_hierarchical(const handle_t& handle, auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); - const auto mesocluster_size_max_balanced = uint32_t( - div_rounding_up_safe(2lu * size_t{n_rows}, std::max(n_mesoclusters, 1lu))); + const auto mesocluster_size_max_balanced = uint32_t(div_rounding_up_safe( + 2lu * size_t{n_rows}, std::max(size_t{n_mesoclusters}, 1lu))); if (mesocluster_size_max > mesocluster_size_max_balanced) { RAFT_LOG_WARN( "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " From d247dabfb7ba2b5f849168dfc042b6937f0f5266 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 20 Jan 2023 19:06:37 +0100 Subject: [PATCH 5/5] Fix typo bad conversion --- cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index 9d51e94751..c6a3aea0cf 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -1035,7 +1035,7 @@ void build_hierarchical(const handle_t& handle, arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); const auto mesocluster_size_max_balanced = uint32_t(div_rounding_up_safe( - 2lu * size_t{n_rows}, std::max(size_t{n_mesoclusters}, 1lu))); + 2lu * size_t(n_rows), std::max(size_t(n_mesoclusters), 1lu))); if (mesocluster_size_max > mesocluster_size_max_balanced) { RAFT_LOG_WARN( "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). "