From 5a62ce3e751e010c18f636e4ba595a3861342f88 Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 10 Mar 2022 11:43:59 +0100 Subject: [PATCH] Adapt to the new bench --- cpp/bench/spatial/selection.cu | 51 +++++++++++++--------------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/cpp/bench/spatial/selection.cu b/cpp/bench/spatial/selection.cu index 644b983a7e..4b15a00b7d 100644 --- a/cpp/bench/spatial/selection.cu +++ b/cpp/bench/spatial/selection.cu @@ -33,31 +33,20 @@ struct params { }; template -struct selection : public Fixture { - selection(const std::string& name, const params& p) : Fixture(name), params_(p) {} - - protected: - void allocateBuffers(const ::benchmark::State& state) override - { - auto in_len = params_.n_inputs * params_.input_len; - alloc(in_dists_, in_len, false); - alloc(in_ids_, in_len, false); - alloc(out_dists_, params_.n_inputs * params_.k, false); - alloc(out_ids_, params_.n_inputs * params_.k, false); - - raft::sparse::iota_fill(in_ids_, IdxT(params_.n_inputs), IdxT(params_.input_len), stream); - raft::random::Rng(42).uniform(in_dists_, in_len, KeyT(-1.0), KeyT(1.0), stream); - } - - void deallocateBuffers(const ::benchmark::State& state) override +struct selection : public fixture { + explicit selection(const params& p) + : params_(p), + in_dists_(p.n_inputs * p.input_len, stream), + in_ids_(p.n_inputs * p.input_len, stream), + out_dists_(p.n_inputs * p.k, stream), + out_ids_(p.n_inputs * p.k, stream) { - dealloc(in_dists_, params_.n_inputs * params_.input_len); - dealloc(in_ids_, params_.n_inputs * params_.input_len); - dealloc(out_dists_, params_.n_inputs * params_.k); - dealloc(out_ids_, params_.n_inputs * params_.k); + raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream); + raft::random::Rng(42).uniform( + in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0), stream); } - void runBenchmark(::benchmark::State& state) override + void run_benchmark(::benchmark::State& state) override { rmm::mr::cuda_memory_resource cuda_mr; rmm::mr::pool_memory_resource pool_mr{ @@ -67,13 +56,13 @@ struct selection : public Fixture { std::ostringstream label_stream; label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k; state.SetLabel(label_stream.str()); - loopOnState(state, [this]() { - raft::spatial::knn::select_k(in_dists_, - in_ids_, + loop_on_state(state, [this]() { + raft::spatial::knn::select_k(in_dists_.data(), + in_ids_.data(), params_.n_inputs, params_.input_len, - out_dists_, - out_ids_, + out_dists_.data(), + out_ids_.data(), params_.select_min, params_.k, stream, @@ -86,9 +75,9 @@ struct selection : public Fixture { } private: - params params_; - KeyT *in_dists_, *out_dists_; - IdxT *in_ids_, *out_ids_; + const params params_; + rmm::device_uvector in_dists_, out_dists_; + rmm::device_uvector in_ids_, out_ids_; }; const std::vector kInputs{ @@ -116,7 +105,7 @@ const std::vector kInputs{ namespace BENCHMARK_PRIVATE_NAME(selection) \ { \ using SelectK = selection; \ - RAFT_BENCH_REGISTER(params, SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \ } SELECTION_REGISTER(float, int, FAISS);