diff --git a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp index 14c53a4699..a54525f3e6 100644 --- a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp +++ b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -43,6 +44,58 @@ template %zu)", n_queries, orig_k, refined_k); auto suggested_n_threads = std::max(1, std::min(omp_get_num_procs(), omp_get_max_threads())); + + // If the number of queries is small, separate the distance calculation and + // the top-k calculation into separate loops, and apply finer-grained thread + // parallelism to the distance calculation loop. + if (n_queries < size_t(suggested_n_threads)) { + std::vector>> refined_pairs( + n_queries, std::vector>(orig_k)); + + // For efficiency, each thread should read a certain amount of array + // elements. The number of threads for distance computation is determined + // taking this into account. + auto n_elements = std::max(size_t(512), dim); + auto max_n_threads = raft::div_rounding_up_safe(n_queries * orig_k * dim, n_elements); + auto suggested_n_threads_for_distance = std::min(size_t(suggested_n_threads), max_n_threads); + + // The max number of threads for topk computation is the number of queries. + auto suggested_n_threads_for_topk = std::min(size_t(suggested_n_threads), n_queries); + + // Compute the refined distance using original dataset vectors +#pragma omp parallel for collapse(2) num_threads(suggested_n_threads_for_distance) + for (size_t i = 0; i < n_queries; i++) { + for (size_t j = 0; j < orig_k; j++) { + const DataT* query = queries.data_handle() + dim * i; + IdxT id = neighbor_candidates(i, j); + DistanceT distance = 0.0; + if (static_cast(id) >= n_rows) { + distance = std::numeric_limits::max(); + } else { + const DataT* row = dataset.data_handle() + dim * id; + for (size_t k = 0; k < dim; k++) { + distance += DC::template eval(query[k], row[k]); + } + } + refined_pairs[i][j] = std::make_tuple(distance, id); + } + } + + // Sort the query neighbors by their refined distances +#pragma omp parallel for num_threads(suggested_n_threads_for_topk) + for (size_t i = 0; i < n_queries; i++) { + std::sort(refined_pairs[i].begin(), refined_pairs[i].end()); + // Store first refined_k neighbors + for (size_t j = 0; j < refined_k; j++) { + indices(i, j) = std::get<1>(refined_pairs[i][j]); + if (distances.data_handle() != nullptr) { + distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j])); + } + } + } + return; + } + if (size_t(suggested_n_threads) > n_queries) { suggested_n_threads = n_queries; } #pragma omp parallel num_threads(suggested_n_threads)