Skip to content

Commit

Permalink
Adapt to the new bench
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 10, 2022
1 parent f97d7cd commit 5a62ce3
Showing 1 changed file with 20 additions and 31 deletions.
51 changes: 20 additions & 31 deletions cpp/bench/spatial/selection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,20 @@ struct params {
};

template <typename KeyT, typename IdxT, raft::spatial::knn::SelectKAlgo Algo>
struct selection : public Fixture {
selection(const std::string& name, const params& p) : Fixture(name), params_(p) {}

protected:
void allocateBuffers(const ::benchmark::State& state) override
{
auto in_len = params_.n_inputs * params_.input_len;
alloc(in_dists_, in_len, false);
alloc(in_ids_, in_len, false);
alloc(out_dists_, params_.n_inputs * params_.k, false);
alloc(out_ids_, params_.n_inputs * params_.k, false);

raft::sparse::iota_fill(in_ids_, IdxT(params_.n_inputs), IdxT(params_.input_len), stream);
raft::random::Rng(42).uniform(in_dists_, in_len, KeyT(-1.0), KeyT(1.0), stream);
}

void deallocateBuffers(const ::benchmark::State& state) override
struct selection : public fixture {
explicit selection(const params& p)
: params_(p),
in_dists_(p.n_inputs * p.input_len, stream),
in_ids_(p.n_inputs * p.input_len, stream),
out_dists_(p.n_inputs * p.k, stream),
out_ids_(p.n_inputs * p.k, stream)
{
dealloc(in_dists_, params_.n_inputs * params_.input_len);
dealloc(in_ids_, params_.n_inputs * params_.input_len);
dealloc(out_dists_, params_.n_inputs * params_.k);
dealloc(out_ids_, params_.n_inputs * params_.k);
raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream);
raft::random::Rng(42).uniform(
in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0), stream);
}

void runBenchmark(::benchmark::State& state) override
void run_benchmark(::benchmark::State& state) override
{
rmm::mr::cuda_memory_resource cuda_mr;
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
Expand All @@ -67,13 +56,13 @@ struct selection : public Fixture {
std::ostringstream label_stream;
label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k;
state.SetLabel(label_stream.str());
loopOnState(state, [this]() {
raft::spatial::knn::select_k<IdxT, KeyT>(in_dists_,
in_ids_,
loop_on_state(state, [this]() {
raft::spatial::knn::select_k<IdxT, KeyT>(in_dists_.data(),
in_ids_.data(),
params_.n_inputs,
params_.input_len,
out_dists_,
out_ids_,
out_dists_.data(),
out_ids_.data(),
params_.select_min,
params_.k,
stream,
Expand All @@ -86,9 +75,9 @@ struct selection : public Fixture {
}

private:
params params_;
KeyT *in_dists_, *out_dists_;
IdxT *in_ids_, *out_ids_;
const params params_;
rmm::device_uvector<KeyT> in_dists_, out_dists_;
rmm::device_uvector<IdxT> in_ids_, out_ids_;
};

const std::vector<params> kInputs{
Expand Down Expand Up @@ -116,7 +105,7 @@ const std::vector<params> kInputs{
namespace BENCHMARK_PRIVATE_NAME(selection) \
{ \
using SelectK = selection<KeyT, IdxT, raft::spatial::knn::SelectKAlgo::Algo>; \
RAFT_BENCH_REGISTER(params, SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \
}

SELECTION_REGISTER(float, int, FAISS);
Expand Down

0 comments on commit 5a62ce3

Please sign in to comment.