Skip to content

Commit

Permalink
Protect balanced k-means out-of-memory in some cases (rapidsai#1161)
Browse files Browse the repository at this point in the history
There's no guarantee that our balanced k-means implementation always produces balanced clusters. In the first stage, when mesoclusters are trained, the biggest cluster can grow larger than half of all input data. This becomes a problem at the second stage, when in `build_fine_clusters`, the mesocluster data is copied in a temporary buffer. If size is too big, there may be not enough memory on the device. A quick workaround:

 1. Expand the error reporting (RAFT_LOG_WARN)
 2. Artificially limit the mesocluster size in the event of highly unbalanced clustering

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: rapidsai#1161
  • Loading branch information
achirkin authored and ahendriksen committed Jan 23, 2023
1 parent fd09b1c commit 185cf10
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,10 @@ inline auto arrange_fine_clusters(uint32_t n_clusters,
* As a result, the fine clusters are what is returned by `build_hierarchical`;
* this function returns the total number of fine clusters, which can be checked to be
* the same as the requested number of clusters.
*
* Note: this function uses at most `fine_clusters_nums_max` points per mesocluster for training;
* if one of the clusters is larger than that (as given by `mesocluster_sizes`), the extra data
* is ignored and a warning is reported.
*/
template <typename T, typename IdxT, typename LabelT>
auto build_fine_clusters(const handle_t& handle,
Expand Down Expand Up @@ -880,8 +884,8 @@ auto build_fine_clusters(const handle_t& handle,
uint32_t n_clusters_done = 0;
for (uint32_t i = 0; i < n_mesoclusters; i++) {
uint32_t k = 0;
for (IdxT j = 0; j < n_rows; j++) {
if (labels_mptr[j] == (LabelT)i) { mc_trainset_ids[k++] = j; }
for (IdxT j = 0; j < n_rows && k < mesocluster_size_max; j++) {
if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; }
}
if (k != mesocluster_sizes[i])
RAFT_LOG_WARN("Incorrect mesocluster size at %d. %d vs %d", i, k, mesocluster_sizes[i]);
Expand All @@ -896,19 +900,13 @@ auto build_fine_clusters(const handle_t& handle,
"Number of fine clusters must be non-zero for a non-empty mesocluster");
}

utils::copy_selected((IdxT)mesocluster_sizes[i],
(IdxT)dim,
dataset_mptr,
mc_trainset_ids,
(IdxT)dim,
mc_trainset,
(IdxT)dim,
stream);
utils::copy_selected(
(IdxT)k, (IdxT)dim, dataset_mptr, mc_trainset_ids, (IdxT)dim, mc_trainset, (IdxT)dim, stream);
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded) {
thrust::gather(handle.get_thrust_policy(),
mc_trainset_ids,
mc_trainset_ids + mesocluster_sizes[i],
mc_trainset_ids + k,
dataset_norm_mptr,
mc_trainset_norm);
}
Expand All @@ -917,7 +915,7 @@ auto build_fine_clusters(const handle_t& handle,
n_iters,
dim,
mc_trainset,
mesocluster_sizes[i],
k,
fine_clusters_nums[i],
mc_trainset_ccenters.data(),
mc_trainset_labels.data(),
Expand Down Expand Up @@ -1036,10 +1034,19 @@ void build_hierarchical(const handle_t& handle,
auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] =
arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes);

if (mesocluster_size_max * n_mesoclusters > 2 * n_rows) {
RAFT_LOG_WARN("build_hierarchical: built unbalanced mesoclusters");
const auto mesocluster_size_max_balanced = uint32_t(div_rounding_up_safe<size_t>(
2lu * size_t(n_rows), std::max<size_t>(size_t(n_mesoclusters), 1lu)));
if (mesocluster_size_max > mesocluster_size_max_balanced) {
RAFT_LOG_WARN(
"build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). "
"At most %u points will be used for training within each mesocluster. "
"Consider increasing the number of training iterations `n_iters`.",
mesocluster_size_max,
mesocluster_size_max_balanced,
mesocluster_size_max_balanced);
RAFT_LOG_TRACE_VEC(mesocluster_sizes, n_mesoclusters);
RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters);
mesocluster_size_max = mesocluster_size_max_balanced;
}

auto n_clusters_done = build_fine_clusters<T, IdxT, LabelT>(handle,
Expand Down

0 comments on commit 185cf10

Please sign in to comment.