Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve parallelism of refine host #2059

Merged
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions cpp/include/raft/neighbors/detail/refine_host-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,35 +43,43 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
"neighbors::refine_host(%zu, %zu -> %zu)", n_queries, orig_k, refined_k);

auto suggested_n_threads = std::max(1, std::min(omp_get_num_procs(), omp_get_max_threads()));
if (size_t(suggested_n_threads) > n_queries) { suggested_n_threads = n_queries; }

#pragma omp parallel num_threads(suggested_n_threads)
{
std::vector<std::tuple<DistanceT, IdxT>> refined_pairs(orig_k);
for (size_t i = omp_get_thread_num(); i < n_queries; i += omp_get_num_threads()) {
// Compute the refined distance using original dataset vectors
// For efficiency, each thread should read a certain amount of array elements.
// The number of threads for distance computation is determined taking this into account.
constexpr int n_elements = 512;
size_t max_n_threads = ((n_queries * orig_k * dim) + n_elements - 1) / n_elements;
anaruse marked this conversation as resolved.
Show resolved Hide resolved
auto suggested_n_threads_for_distance = std::min(size_t(suggested_n_threads), max_n_threads);

// The max number of threads for topk computation is the number of queries.
auto suggested_n_threads_for_topk = std::min(size_t(suggested_n_threads), n_queries);

std::vector<std::vector<std::tuple<DistanceT, IdxT>>>
refined_pairs(n_queries, std::vector<std::tuple<DistanceT, IdxT>>(orig_k));
achirkin marked this conversation as resolved.
Show resolved Hide resolved

// Compute the refined distance using original dataset vectors
#pragma omp parallel for collapse(2) num_threads(suggested_n_threads_for_distance)
for (size_t i = 0; i < n_queries; i++) {
for (size_t j = 0; j < orig_k; j++) {
const DataT* query = queries.data_handle() + dim * i;
for (size_t j = 0; j < orig_k; j++) {
IdxT id = neighbor_candidates(i, j);
DistanceT distance = 0.0;
if (static_cast<size_t>(id) >= n_rows) {
distance = std::numeric_limits<DistanceT>::max();
} else {
const DataT* row = dataset.data_handle() + dim * id;
for (size_t k = 0; k < dim; k++) {
distance += DC::template eval<DistanceT>(query[k], row[k]);
}
}
refined_pairs[j] = std::make_tuple(distance, id);
IdxT id = neighbor_candidates(i, j);
const DataT* row = dataset.data_handle() + dim * id;
DistanceT distance = 0.0;
for (size_t k = 0; k < dim; k++) {
distance += DC::template eval<DistanceT>(query[k], row[k]);
}
// Sort the query neighbors by their refined distances
std::sort(refined_pairs.begin(), refined_pairs.end());
// Store first refined_k neighbors
for (size_t j = 0; j < refined_k; j++) {
indices(i, j) = std::get<1>(refined_pairs[j]);
if (distances.data_handle() != nullptr) {
distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[j]));
}
refined_pairs[i][j] = std::make_tuple(distance, id);
}
}

// Sort the query neighbors by their refined distances
#pragma omp parallel for num_threads(suggested_n_threads_for_topk)
for (size_t i = 0; i < n_queries; i++) {
std::sort(refined_pairs[i].begin(), refined_pairs[i].end());
// Store first refined_k neighbors
for (size_t j = 0; j < refined_k; j++) {
indices(i, j) = std::get<1>(refined_pairs[i][j]);
if (distances.data_handle() != nullptr) {
distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j]));
}
}
}
Expand Down