Skip to content

Commit

Permalink
Improve parallelism of refine host (rapidsai#2059)
Browse files Browse the repository at this point in the history
This PR addresses rapidsai#2058 by changing the thread parallelism method.

In the first half of the `refine` process, the distance calculation is performed on all candidate vectors, i.e., the number of queries * the original top-k vectors. Since the distance calculations for each vector can be performed independently, this part is thread-parallelized assuming that maximum parallelism is the number of queries * original top-k. This means that even if the number of queries is 1, this part can be executed in thread parallel.

On the other hand, the second half of the `refine` process, the so-called top-k calculation, can be performed independently for each query, but it is difficult to thread parallelize the calculation for a given query, Therefore, this part is parallelized assuming the maximum parallelism is the number of queries, as in the current implementation.

Authors:
  - Akira Naruse (https://github.com/anaruse)
  - Corey J. Nolet (https://github.com/cjnolet)
  - William Hicks (https://github.com/wphicks)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#2059
  • Loading branch information
anaruse authored Jan 9, 2024
1 parent 6762fe5 commit 3b88d17
Showing 1 changed file with 54 additions and 1 deletion.
55 changes: 54 additions & 1 deletion cpp/include/raft/neighbors/detail/refine_host-inl.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
#include <raft/core/host_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/neighbors/detail/refine_common.hpp>
#include <raft/util/integer_utils.hpp>

#include <algorithm>
#include <omp.h>
Expand All @@ -43,6 +44,58 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
"neighbors::refine_host(%zu, %zu -> %zu)", n_queries, orig_k, refined_k);

auto suggested_n_threads = std::max(1, std::min(omp_get_num_procs(), omp_get_max_threads()));

// If the number of queries is small, separate the distance calculation and
// the top-k calculation into separate loops, and apply finer-grained thread
// parallelism to the distance calculation loop.
if (n_queries < size_t(suggested_n_threads)) {
std::vector<std::vector<std::tuple<DistanceT, IdxT>>> refined_pairs(
n_queries, std::vector<std::tuple<DistanceT, IdxT>>(orig_k));

// For efficiency, each thread should read a certain amount of array
// elements. The number of threads for distance computation is determined
// taking this into account.
auto n_elements = std::max(size_t(512), dim);
auto max_n_threads = raft::div_rounding_up_safe<size_t>(n_queries * orig_k * dim, n_elements);
auto suggested_n_threads_for_distance = std::min(size_t(suggested_n_threads), max_n_threads);

// The max number of threads for topk computation is the number of queries.
auto suggested_n_threads_for_topk = std::min(size_t(suggested_n_threads), n_queries);

// Compute the refined distance using original dataset vectors
#pragma omp parallel for collapse(2) num_threads(suggested_n_threads_for_distance)
for (size_t i = 0; i < n_queries; i++) {
for (size_t j = 0; j < orig_k; j++) {
const DataT* query = queries.data_handle() + dim * i;
IdxT id = neighbor_candidates(i, j);
DistanceT distance = 0.0;
if (static_cast<size_t>(id) >= n_rows) {
distance = std::numeric_limits<DistanceT>::max();
} else {
const DataT* row = dataset.data_handle() + dim * id;
for (size_t k = 0; k < dim; k++) {
distance += DC::template eval<DistanceT>(query[k], row[k]);
}
}
refined_pairs[i][j] = std::make_tuple(distance, id);
}
}

// Sort the query neighbors by their refined distances
#pragma omp parallel for num_threads(suggested_n_threads_for_topk)
for (size_t i = 0; i < n_queries; i++) {
std::sort(refined_pairs[i].begin(), refined_pairs[i].end());
// Store first refined_k neighbors
for (size_t j = 0; j < refined_k; j++) {
indices(i, j) = std::get<1>(refined_pairs[i][j]);
if (distances.data_handle() != nullptr) {
distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j]));
}
}
}
return;
}

if (size_t(suggested_n_threads) > n_queries) { suggested_n_threads = n_queries; }

#pragma omp parallel num_threads(suggested_n_threads)
Expand Down

0 comments on commit 3b88d17

Please sign in to comment.