Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ANN bench: use different offset for each thread #1981

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 49 additions & 25 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <string>
#include <unistd.h>
#include <vector>
Expand Down Expand Up @@ -175,7 +176,6 @@ void bench_search(::benchmark::State& state,
std::shared_ptr<const Dataset<T>> dataset,
Objective metric_objective)
{
std::ptrdiff_t batch_offset = 0;
std::size_t queries_processed = 0;

const auto& sp_json = index.search_params[search_param_ix];
Expand All @@ -189,6 +189,20 @@ void bench_search(::benchmark::State& state,
// Round down the query data to a multiple of the batch size to loop over full batches of data
const std::size_t query_set_size = (dataset->query_set_size() / n_queries) * n_queries;

if (dataset->query_set_size() < n_queries) {
std::stringstream msg;
msg << "Not enough queries in benchmark set. Expected " << n_queries << ", actual "
<< dataset->query_set_size();
return state.SkipWithError(msg.str());
}

// Each thread start from a different offset, so that the queries that they process do not
// overlap.
std::ptrdiff_t batch_offset = (state.thread_index() * n_queries) % query_set_size;
std::ptrdiff_t queries_stride = state.threads() * n_queries;
// Output is saved into a contiguous buffer (separate buffers for each thread).
std::ptrdiff_t out_offset = 0;

const T* query_set = nullptr;

if (!file_exists(index.file)) {
Expand Down Expand Up @@ -278,7 +292,6 @@ void bench_search(::benchmark::State& state,
{
nvtx_case nvtx{state.name()};

// TODO: Have the odd threads load the queries backwards just to rule out caching.
ANN<T>* algo = dynamic_cast<ANN<T>*>(current_algo.get());
for (auto _ : state) {
[[maybe_unused]] auto ntx_lap = nvtx.lap();
Expand All @@ -289,15 +302,16 @@ void bench_search(::benchmark::State& state,
algo->search(query_set + batch_offset * dataset->dim(),
n_queries,
k,
neighbors->data + batch_offset * k,
distances->data + batch_offset * k,
neighbors->data + out_offset * k,
distances->data + out_offset * k,
gpu_timer.stream());
} catch (const std::exception& e) {
state.SkipWithError(std::string(e.what()));
}

// advance to the next batch
batch_offset = (batch_offset + n_queries) % query_set_size;
batch_offset = (batch_offset + queries_stride) % query_set_size;
out_offset = (out_offset + n_queries) % query_set_size;

queries_processed += n_queries;
}
Expand All @@ -323,31 +337,41 @@ void bench_search(::benchmark::State& state,
// last thread to finish processing notifies all
if (processed_threads-- == 0) { cond_var.notify_all(); }

// Use the last thread as a sanity check that all the threads are working.
if (state.thread_index() == state.threads() - 1) {
// evaluate recall
if (dataset->max_k() >= k) {
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t max_k = dataset->max_k();
buf<std::size_t> neighbors_host = neighbors->move(MemoryType::Host);
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);
for (std::size_t i = 0; i < rows; i++) {
for (std::uint32_t j = 0; j < k; j++) {
auto act_idx = std::int32_t(neighbors_host.data[i * k + j]);
for (std::uint32_t l = 0; l < k; l++) {
auto exp_idx = gt[i * max_k + l];
if (act_idx == exp_idx) {
match_count++;
break;
// Each thread calculates recall on their partition of queries.
// evaluate recall
if (dataset->max_k() >= k) {
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t max_k = dataset->max_k();
buf<std::size_t> neighbors_host = neighbors->move(MemoryType::Host);
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);

// We go through the groundtruth with same stride as the benchmark loop.
size_t out_offset = 0;
size_t batch_offset = (state.thread_index() * n_queries) % query_set_size;
while (out_offset < rows) {
for (std::size_t i = 0; i < n_queries; i++) {
size_t i_orig_idx = batch_offset + i;
size_t i_out_idx = out_offset + i;
if (i_out_idx < rows) {
for (std::uint32_t j = 0; j < k; j++) {
auto act_idx = std::int32_t(neighbors_host.data[i_out_idx * k + j]);
for (std::uint32_t l = 0; l < k; l++) {
auto exp_idx = gt[i_orig_idx * max_k + l];
if (act_idx == exp_idx) {
match_count++;
break;
}
}
}
}
}
double actual_recall = static_cast<double>(match_count) / static_cast<double>(total_count);
state.counters.insert({{"Recall", actual_recall}});
out_offset += n_queries;
batch_offset = (batch_offset + queries_stride) % query_set_size;
}
double actual_recall = static_cast<double>(match_count) / static_cast<double>(total_count);
state.counters.insert({"Recall", {actual_recall, benchmark::Counter::kAvgThreads}});
}
}

Expand Down