Skip to content

Commit

Permalink
balanced-k-means: fix a too large initial memory pool size (#1148)
Browse files Browse the repository at this point in the history
`calc_minibatch_size` decides on the batch size under assumption that the workspace shouldn't exceed 1GB. It takes into account that fewer extra buffers are needed when the data type `T` is float. However, we don't take this into account when setting the initial memory pool size immediately after calculating `max_minibatch_size`. As a result, under some conditions, the algorithm attempts to allocate more memory than available. This PR sets the limit of the initial pool size to 1GB to fix the issue.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1148
  • Loading branch information
achirkin authored Jan 19, 2023
1 parent a7399cb commit 7215c8a
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,9 +18,6 @@

#include "ann_utils.cuh"

#include <thrust/gather.h>
#include <thrust/transform.h>

#include <raft/cluster/detail/kmeans_common.cuh>
#include <raft/common/nvtx.hpp>
#include <raft/core/cudart_utils.hpp>
Expand All @@ -45,6 +42,11 @@
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <thrust/gather.h>
#include <thrust/transform.h>

#include <tuple>

namespace raft::spatial::knn::detail::kmeans {

constexpr static inline const float kAdjustCentersWeight = 7.0f;
Expand Down Expand Up @@ -170,35 +172,40 @@ inline void predict_float_core(const handle_t& handle,
*
* @param n_clusters number of clusters in kmeans clustering
* @param n_rows dataset size
* @return a suggested minibatch size
* @param dim
* @param metric
* @param is_float float input requires less temporary buffers
* @return a suggested minibatch size and the expected memory cost per-row (in bytes)
*/
template <typename IdxT>
constexpr inline auto calc_minibatch_size(uint32_t n_clusters,
IdxT n_rows,
uint32_t dim,
raft::distance::DistanceType metric,
bool is_float) -> IdxT
constexpr auto calc_minibatch_size(
uint32_t n_clusters, IdxT n_rows, uint32_t dim, distance::DistanceType metric, bool is_float)
-> std::tuple<IdxT, size_t>
{
n_clusters = std::max<uint32_t>(1, n_clusters);

// Estimate memory needs per row (i.e element of the batch).
IdxT mem_per_row = 0;
/* fusedL2NN only needs one integer per row for a mutex.
* Other metrics require storing a distance matrix. */
if (metric != raft::distance::DistanceType::L2Expanded &&
metric != raft::distance::DistanceType::L2SqrtExpanded) {
mem_per_row += sizeof(float) * n_clusters;
} else {
mem_per_row += sizeof(int);
size_t mem_per_row = 0;
switch (metric) {
// fusedL2NN only needs one integer per row for a mutex.
case distance::DistanceType::L2Expanded:
case distance::DistanceType::L2SqrtExpanded: {
mem_per_row += sizeof(int);
} break;
// Other metrics require storing a distance matrix.
default: {
mem_per_row += sizeof(float) * n_clusters;
}
}

// If we need to convert to float, space required for the converted batch.
if (!is_float) { mem_per_row += sizeof(float) * dim; }

// Heuristic: calculate the minibatch size in order to use at most 1GB of memory.
IdxT minibatch_size = (1 << 30) / mem_per_row;
minibatch_size = 64 * ceildiv(minibatch_size, (IdxT)64);
minibatch_size = 64 * div_rounding_up_safe(minibatch_size, IdxT{64});
minibatch_size = std::min<IdxT>(minibatch_size, n_rows);
return minibatch_size;
return std::make_tuple(minibatch_size, mem_per_row);
}

/**
Expand Down Expand Up @@ -383,7 +390,7 @@ void predict(const handle_t& handle,
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"kmeans::predict(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }
IdxT max_minibatch_size =
auto [max_minibatch_size, _mem_per_row] =
calc_minibatch_size(n_clusters, n_rows, dim, metric, std::is_same_v<T, float>);
rmm::device_uvector<float> cur_dataset(
std::is_same_v<T, float> ? 0 : max_minibatch_size * dim, stream, mr);
Expand Down Expand Up @@ -972,9 +979,10 @@ void build_hierarchical(const handle_t& handle,

rmm::mr::managed_memory_resource managed_memory;
rmm::mr::device_memory_resource* device_memory = nullptr;
IdxT max_minibatch_size =
auto [max_minibatch_size, mem_per_row] =
calc_minibatch_size(n_clusters, n_rows, dim, metric, std::is_same_v<T, float>);
auto pool_guard = raft::get_pool_memory_resource(device_memory, max_minibatch_size * dim * 4);
auto pool_guard =
raft::get_pool_memory_resource(device_memory, mem_per_row * size_t(max_minibatch_size));
if (pool_guard) {
RAFT_LOG_DEBUG(
"kmeans::build_hierarchical: using pool memory resource with initial size %zu bytes",
Expand Down

0 comments on commit 7215c8a

Please sign in to comment.