Skip to content

Commit

Permalink
Fixing hnswlib in latency mode (rapidsai#1959)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: rapidsai#1959
  • Loading branch information
cjnolet authored and benfred committed Nov 8, 2023
1 parent 0b45774 commit ebd611e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
12 changes: 8 additions & 4 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void bench_search(::benchmark::State& state,
* Make sure the first thread loads the algo and dataset
*/
if (state.thread_index() == 0) {
std::lock_guard lk(init_mutex);
std::unique_lock lk(init_mutex);
// algo is static to cache it between close search runs to save time on index loading
static std::string index_file = "";
if (index.file != index_file) {
Expand Down Expand Up @@ -249,9 +249,11 @@ void bench_search(::benchmark::State& state,
query_set = dataset->query_set(current_algo_props->query_memory_type);
cond_var.notify_all();
} else {
// All other threads will wait for the first thread to initialize the algo.
std::unique_lock lk(init_mutex);
cond_var.wait(lk, [] { return current_algo_props.get() != nullptr; });
// All other threads will wait for the first thread to initialize the algo.

cond_var.wait(
lk, [] { return current_algo_props.get() != nullptr && current_algo.get() != nullptr; });
// gbench ensures that all threads are synchronized at the start of the benchmark loop.
// We are accessing shared variables (like current_algo, current_algo_probs) before the
// benchmark loop, therefore the synchronization here is necessary.
Expand Down Expand Up @@ -292,6 +294,7 @@ void bench_search(::benchmark::State& state,

// advance to the next batch
batch_offset = (batch_offset + n_queries) % query_set_size;

queries_processed += n_queries;
}
}
Expand Down Expand Up @@ -410,7 +413,6 @@ void register_search(std::shared_ptr<const Dataset<T>> dataset,
auto* b = ::benchmark::RegisterBenchmark(
index.name + suf, bench_search<T>, index, i, dataset, metric_objective)
->Unit(benchmark::kMillisecond)
->ThreadRange(threads[0], threads[1])
/**
* The following are important for getting accuracy QPS measurements on both CPU
* and GPU These make sure that
Expand All @@ -420,6 +422,8 @@ void register_search(std::shared_ptr<const Dataset<T>> dataset,
*/
->MeasureProcessCPUTime()
->UseRealTime();

if (metric_objective == Objective::THROUGHPUT) { b->ThreadRange(threads[0], threads[1]); }
}
}
}
Expand Down
13 changes: 5 additions & 8 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ void HnswLib<T>::build(const T* dataset, size_t nrow, cudaStream_t)
char buf[20];
std::time_t now = std::time(nullptr);
std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now));

printf("%s building %zu / %zu\n", buf, i, items_per_thread);
fflush(stdout);
}
Expand All @@ -163,13 +162,11 @@ void HnswLib<T>::set_search_param(const AnnSearchParam& param_)
auto param = dynamic_cast<const SearchParam&>(param_);
appr_alg_->ef_ = param.ef;
metric_objective_ = param.metric_objective;
num_threads_ = param.num_threads;

bool use_pool = (metric_objective_ == Objective::LATENCY && param.num_threads > 1) &&
(!thread_pool_ || num_threads_ != param.num_threads);
if (use_pool) {
num_threads_ = param.num_threads;
thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_);
}
// Create a pool if multiple query threads have been set and the pool hasn't been created already
bool create_pool = (metric_objective_ == Objective::LATENCY && num_threads_ > 1 && !thread_pool_);
if (create_pool) { thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_); }
}

template <typename T>
Expand All @@ -180,7 +177,7 @@ void HnswLib<T>::search(
// hnsw can only handle a single vector at a time.
get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k);
};
if (metric_objective_ == Objective::LATENCY) {
if (metric_objective_ == Objective::LATENCY && num_threads_ > 1) {
thread_pool_->submit(f, batch_size);
} else {
for (int i = 0; i < batch_size; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,36 +39,42 @@
dims: 960
base_file: gist-960-euclidean/base.fbin
query_file: gist-960-euclidean/query.fbin
groundtruth_neighbors_file: gist-960-euclidean/groundtruth.neighbors.ibin
distance: euclidean

- name: glove-50-angular
dims: 50
base_file: glove-50-angular/base.fbin
query_file: glove-50-angular/query.fbin
groundtruth_neighbors_file: glove-50-angular/groundtruth.neighbors.ibin
distance: euclidean

- name: glove-50-inner
dims: 50
base_file: glove-50-inner/base.fbin
query_file: glove-50-inner/query.fbin
groundtruth_neighbors_file: glove-50-inner/groundtruth.neighbors.ibin
distance: euclidean

- name: glove-100-angular
dims: 100
base_file: glove-100-angular/base.fbin
query_file: glove-100-angular/query.fbin
groundtruth_neighbors_file: glove-100-angular/groundtruth.neighbors.ibin
distance: euclidean

- name: glove-100-inner
dims: 100
base_file: glove-100-inner/base.fbin
query_file: glove-100-inner/query.fbin
groundtruth_neighbors_file: glove-100-inner/groundtruth.neighbors.ibin
distance: euclidean

- name: lastfm-65-angular
dims: 65
base_file: lastfm-65-angular/base.fbin
query_file: lastfm-65-angular/query.fbin
groundtruth_neighbors_file: lastfm-65-angular/groundtruth.neighbors.ibin
distance: euclidean

- name: mnist-784-euclidean
Expand Down

0 comments on commit ebd611e

Please sign in to comment.