From 226cf9ec186e655589f31618a4773341392e27f0 Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Tue, 9 Jan 2024 15:21:01 +0900 Subject: [PATCH] Improve parallelism of refine host (#2059) This PR addresses https://github.com/rapidsai/raft/issues/2058 by changing the thread parallelism method. In the first half of the `refine` process, the distance calculation is performed on all candidate vectors, i.e., the number of queries * the original top-k vectors. Since the distance calculations for each vector can be performed independently, this part is thread-parallelized assuming that maximum parallelism is the number of queries * original top-k. This means that even if the number of queries is 1, this part can be executed in thread parallel. On the other hand, the second half of the `refine` process, the so-called top-k calculation, can be performed independently for each query, but it is difficult to thread parallelize the calculation for a given query, Therefore, this part is parallelized assuming the maximum parallelism is the number of queries, as in the current implementation. Authors: - Akira Naruse (https://github.com/anaruse) - Corey J. Nolet (https://github.com/cjnolet) - William Hicks (https://github.com/wphicks) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2059 --- .../raft/neighbors/detail/refine_host-inl.hpp | 55 ++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp index 14c53a4699..a54525f3e6 100644 --- a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp +++ b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -43,6 +44,58 @@ template %zu)", n_queries, orig_k, refined_k); auto suggested_n_threads = std::max(1, std::min(omp_get_num_procs(), omp_get_max_threads())); + + // If the number of queries is small, separate the distance calculation and + // the top-k calculation into separate loops, and apply finer-grained thread + // parallelism to the distance calculation loop. + if (n_queries < size_t(suggested_n_threads)) { + std::vector>> refined_pairs( + n_queries, std::vector>(orig_k)); + + // For efficiency, each thread should read a certain amount of array + // elements. The number of threads for distance computation is determined + // taking this into account. + auto n_elements = std::max(size_t(512), dim); + auto max_n_threads = raft::div_rounding_up_safe(n_queries * orig_k * dim, n_elements); + auto suggested_n_threads_for_distance = std::min(size_t(suggested_n_threads), max_n_threads); + + // The max number of threads for topk computation is the number of queries. + auto suggested_n_threads_for_topk = std::min(size_t(suggested_n_threads), n_queries); + + // Compute the refined distance using original dataset vectors +#pragma omp parallel for collapse(2) num_threads(suggested_n_threads_for_distance) + for (size_t i = 0; i < n_queries; i++) { + for (size_t j = 0; j < orig_k; j++) { + const DataT* query = queries.data_handle() + dim * i; + IdxT id = neighbor_candidates(i, j); + DistanceT distance = 0.0; + if (static_cast(id) >= n_rows) { + distance = std::numeric_limits::max(); + } else { + const DataT* row = dataset.data_handle() + dim * id; + for (size_t k = 0; k < dim; k++) { + distance += DC::template eval(query[k], row[k]); + } + } + refined_pairs[i][j] = std::make_tuple(distance, id); + } + } + + // Sort the query neighbors by their refined distances +#pragma omp parallel for num_threads(suggested_n_threads_for_topk) + for (size_t i = 0; i < n_queries; i++) { + std::sort(refined_pairs[i].begin(), refined_pairs[i].end()); + // Store first refined_k neighbors + for (size_t j = 0; j < refined_k; j++) { + indices(i, j) = std::get<1>(refined_pairs[i][j]); + if (distances.data_handle() != nullptr) { + distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j])); + } + } + } + return; + } + if (size_t(suggested_n_threads) > n_queries) { suggested_n_threads = n_queries; } #pragma omp parallel num_threads(suggested_n_threads)