diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 342eebe6e3..1cbd54cb7b 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -175,7 +176,6 @@ void bench_search(::benchmark::State& state, std::shared_ptr> dataset, Objective metric_objective) { - std::ptrdiff_t batch_offset = 0; std::size_t queries_processed = 0; const auto& sp_json = index.search_params[search_param_ix]; @@ -189,6 +189,20 @@ void bench_search(::benchmark::State& state, // Round down the query data to a multiple of the batch size to loop over full batches of data const std::size_t query_set_size = (dataset->query_set_size() / n_queries) * n_queries; + if (dataset->query_set_size() < n_queries) { + std::stringstream msg; + msg << "Not enough queries in benchmark set. Expected " << n_queries << ", actual " + << dataset->query_set_size(); + return state.SkipWithError(msg.str()); + } + + // Each thread start from a different offset, so that the queries that they process do not + // overlap. + std::ptrdiff_t batch_offset = (state.thread_index() * n_queries) % query_set_size; + std::ptrdiff_t queries_stride = state.threads() * n_queries; + // Output is saved into a contiguous buffer (separate buffers for each thread). + std::ptrdiff_t out_offset = 0; + const T* query_set = nullptr; if (!file_exists(index.file)) { @@ -278,7 +292,6 @@ void bench_search(::benchmark::State& state, { nvtx_case nvtx{state.name()}; - // TODO: Have the odd threads load the queries backwards just to rule out caching. ANN* algo = dynamic_cast*>(current_algo.get()); for (auto _ : state) { [[maybe_unused]] auto ntx_lap = nvtx.lap(); @@ -289,15 +302,16 @@ void bench_search(::benchmark::State& state, algo->search(query_set + batch_offset * dataset->dim(), n_queries, k, - neighbors->data + batch_offset * k, - distances->data + batch_offset * k, + neighbors->data + out_offset * k, + distances->data + out_offset * k, gpu_timer.stream()); } catch (const std::exception& e) { state.SkipWithError(std::string(e.what())); } // advance to the next batch - batch_offset = (batch_offset + n_queries) % query_set_size; + batch_offset = (batch_offset + queries_stride) % query_set_size; + out_offset = (out_offset + n_queries) % query_set_size; queries_processed += n_queries; } @@ -323,31 +337,41 @@ void bench_search(::benchmark::State& state, // last thread to finish processing notifies all if (processed_threads-- == 0) { cond_var.notify_all(); } - // Use the last thread as a sanity check that all the threads are working. - if (state.thread_index() == state.threads() - 1) { - // evaluate recall - if (dataset->max_k() >= k) { - const std::int32_t* gt = dataset->gt_set(); - const std::uint32_t max_k = dataset->max_k(); - buf neighbors_host = neighbors->move(MemoryType::Host); - std::size_t rows = std::min(queries_processed, query_set_size); - std::size_t match_count = 0; - std::size_t total_count = rows * static_cast(k); - for (std::size_t i = 0; i < rows; i++) { - for (std::uint32_t j = 0; j < k; j++) { - auto act_idx = std::int32_t(neighbors_host.data[i * k + j]); - for (std::uint32_t l = 0; l < k; l++) { - auto exp_idx = gt[i * max_k + l]; - if (act_idx == exp_idx) { - match_count++; - break; + // Each thread calculates recall on their partition of queries. + // evaluate recall + if (dataset->max_k() >= k) { + const std::int32_t* gt = dataset->gt_set(); + const std::uint32_t max_k = dataset->max_k(); + buf neighbors_host = neighbors->move(MemoryType::Host); + std::size_t rows = std::min(queries_processed, query_set_size); + std::size_t match_count = 0; + std::size_t total_count = rows * static_cast(k); + + // We go through the groundtruth with same stride as the benchmark loop. + size_t out_offset = 0; + size_t batch_offset = (state.thread_index() * n_queries) % query_set_size; + while (out_offset < rows) { + for (std::size_t i = 0; i < n_queries; i++) { + size_t i_orig_idx = batch_offset + i; + size_t i_out_idx = out_offset + i; + if (i_out_idx < rows) { + for (std::uint32_t j = 0; j < k; j++) { + auto act_idx = std::int32_t(neighbors_host.data[i_out_idx * k + j]); + for (std::uint32_t l = 0; l < k; l++) { + auto exp_idx = gt[i_orig_idx * max_k + l]; + if (act_idx == exp_idx) { + match_count++; + break; + } } } } } - double actual_recall = static_cast(match_count) / static_cast(total_count); - state.counters.insert({{"Recall", actual_recall}}); + out_offset += n_queries; + batch_offset = (batch_offset + queries_stride) % query_set_size; } + double actual_recall = static_cast(match_count) / static_cast(total_count); + state.counters.insert({"Recall", {actual_recall, benchmark::Counter::kAvgThreads}}); } }