diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index 961cc76381..72df13d760 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,6 @@ #include "ann_utils.cuh" -#include -#include - #include #include #include @@ -45,6 +42,11 @@ #include #include +#include +#include + +#include + namespace raft::spatial::knn::detail::kmeans { constexpr static inline const float kAdjustCentersWeight = 7.0f; @@ -170,35 +172,40 @@ inline void predict_float_core(const handle_t& handle, * * @param n_clusters number of clusters in kmeans clustering * @param n_rows dataset size - * @return a suggested minibatch size + * @param dim + * @param metric + * @param is_float float input requires less temporary buffers + * @return a suggested minibatch size and the expected memory cost per-row (in bytes) */ template -constexpr inline auto calc_minibatch_size(uint32_t n_clusters, - IdxT n_rows, - uint32_t dim, - raft::distance::DistanceType metric, - bool is_float) -> IdxT +constexpr auto calc_minibatch_size( + uint32_t n_clusters, IdxT n_rows, uint32_t dim, distance::DistanceType metric, bool is_float) + -> std::tuple { n_clusters = std::max(1, n_clusters); // Estimate memory needs per row (i.e element of the batch). - IdxT mem_per_row = 0; - /* fusedL2NN only needs one integer per row for a mutex. - * Other metrics require storing a distance matrix. */ - if (metric != raft::distance::DistanceType::L2Expanded && - metric != raft::distance::DistanceType::L2SqrtExpanded) { - mem_per_row += sizeof(float) * n_clusters; - } else { - mem_per_row += sizeof(int); + size_t mem_per_row = 0; + switch (metric) { + // fusedL2NN only needs one integer per row for a mutex. + case distance::DistanceType::L2Expanded: + case distance::DistanceType::L2SqrtExpanded: { + mem_per_row += sizeof(int); + } break; + // Other metrics require storing a distance matrix. + default: { + mem_per_row += sizeof(float) * n_clusters; + } } + // If we need to convert to float, space required for the converted batch. if (!is_float) { mem_per_row += sizeof(float) * dim; } // Heuristic: calculate the minibatch size in order to use at most 1GB of memory. IdxT minibatch_size = (1 << 30) / mem_per_row; - minibatch_size = 64 * ceildiv(minibatch_size, (IdxT)64); + minibatch_size = 64 * div_rounding_up_safe(minibatch_size, IdxT{64}); minibatch_size = std::min(minibatch_size, n_rows); - return minibatch_size; + return std::make_tuple(minibatch_size, mem_per_row); } /** @@ -383,7 +390,7 @@ void predict(const handle_t& handle, common::nvtx::range fun_scope( "kmeans::predict(%zu, %u)", static_cast(n_rows), n_clusters); if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - IdxT max_minibatch_size = + auto [max_minibatch_size, _mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, metric, std::is_same_v); rmm::device_uvector cur_dataset( std::is_same_v ? 0 : max_minibatch_size * dim, stream, mr); @@ -972,9 +979,10 @@ void build_hierarchical(const handle_t& handle, rmm::mr::managed_memory_resource managed_memory; rmm::mr::device_memory_resource* device_memory = nullptr; - IdxT max_minibatch_size = + auto [max_minibatch_size, mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, metric, std::is_same_v); - auto pool_guard = raft::get_pool_memory_resource(device_memory, max_minibatch_size * dim * 4); + auto pool_guard = + raft::get_pool_memory_resource(device_memory, mem_per_row * size_t(max_minibatch_size)); if (pool_guard) { RAFT_LOG_DEBUG( "kmeans::build_hierarchical: using pool memory resource with initial size %zu bytes",