From 185e933e123c043c488d252513f1360c40487af4 Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Fri, 23 Jun 2023 23:42:29 +0900 Subject: [PATCH] Add RAFT ANN benchmark for CAGRA (#1552) This PR adds the RAFT ANN benchmark for CAGRA Authors: - tsuki (https://github.com/enp1s0) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1552 --- cpp/bench/ann/CMakeLists.txt | 18 ++- cpp/bench/ann/conf/bigann-100M.json | 28 ++++ cpp/bench/ann/conf/deep-100M.json | 29 ++++ cpp/bench/ann/conf/sift-128-euclidean.json | 34 +++- cpp/bench/ann/src/common/benchmark.hpp | 113 +++++++------ cpp/bench/ann/src/faiss/faiss_benchmark.cu | 2 +- cpp/bench/ann/src/raft/raft_benchmark.cu | 52 +++--- cpp/bench/ann/src/raft/raft_cagra.cu | 26 +++ cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 166 ++++++++++++++++++++ 9 files changed, 395 insertions(+), 73 deletions(-) create mode 100644 cpp/bench/ann/src/raft/raft_cagra.cu create mode 100644 cpp/bench/ann/src/raft/raft_cagra_wrapper.h diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index a14018a15d..b3198de984 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -21,6 +21,7 @@ option(RAFT_ANN_BENCH_USE_FAISS_IVF_PQ "Include faiss' ivf pq algorithm in bench option(RAFT_ANN_BENCH_USE_RAFT_BFKNN "Include raft's brute-force knn algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON) option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON) @@ -36,8 +37,9 @@ endif() set(RAFT_ANN_BENCH_USE_RAFT OFF) if(RAFT_ANN_BENCH_USE_RAFT_BFKNN - OR RAFT_ANN_BENCH_USE_RAFT_IVFPQ - OR RAFT_ANN_BENCH_USE_RAFT_IVFFLAT + OR RAFT_ANN_BENCH_USE_RAFT_IVF_PQ + OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT + OR RAFT_ANN_BENCH_USE_RAFT_CAGRA ) set(RAFT_ANN_BENCH_USE_RAFT ON) endif() @@ -136,14 +138,24 @@ endif() if(RAFT_ANN_BENCH_USE_RAFT) ConfigureAnnBench( NAME - RAFT_IVF_PQ + RAFT PATH bench/ann/src/raft/raft_benchmark.cu $<$:bench/ann/src/raft/raft_ivf_pq.cu> $<$:bench/ann/src/raft/raft_ivf_flat.cu> + $<$:bench/ann/src/raft/raft_cagra.cu> LINKS raft::compiled ) + add_compile_definitions( + RAFT $<$:RAFT_ANN_BENCH_USE_RAFT_IVF_PQ> + ) + add_compile_definitions( + RAFT $<$:RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT> + ) + add_compile_definitions( + RAFT $<$:RAFT_ANN_BENCH_USE_RAFT_CAGRA> + ) endif() if(RAFT_ANN_BENCH_USE_FAISS) diff --git a/cpp/bench/ann/conf/bigann-100M.json b/cpp/bench/ann/conf/bigann-100M.json index 5f16f3378d..0ff7df4776 100644 --- a/cpp/bench/ann/conf/bigann-100M.json +++ b/cpp/bench/ann/conf/bigann-100M.json @@ -168,7 +168,35 @@ "search_result_file" : "result/bigann-100M/ivf_flat/nlist100K" }, + { + "name" : "cagra.dim32", + "algo" : "cagra", + "build_param": { + "index_dim" : 32 + }, + "file" : "index/bigann-100M/cagra/dim32", + "search_params" : [ + "itopk": 32, + "itopk": 64, + "itopk": 128 + ], + "search_result_file" : "result/bigann-100M/cagra/dim32" + }, + { + "name" : "cagra.dim64", + "algo" : "cagra", + "build_param": { + "index_dim" : 64 + }, + "file" : "index/bigann-100M/cagra/dim64", + "search_params" : [ + "itopk": 32, + "itopk": 64, + "itopk": 128 + ], + "search_result_file" : "result/bigann-100M/cagra/dim64" + } ] } diff --git a/cpp/bench/ann/conf/deep-100M.json b/cpp/bench/ann/conf/deep-100M.json index b3a945d50e..fb29556ead 100644 --- a/cpp/bench/ann/conf/deep-100M.json +++ b/cpp/bench/ann/conf/deep-100M.json @@ -218,6 +218,35 @@ "search_result_file" : "result/deep-100M/ivf_flat/nlist100K" }, + { + "name" : "cagra.dim32", + "algo" : "raft_cagra", + "build_param": { + "index_dim" : 32 + }, + "file" : "index/deep-100M/cagra/dim32", + "search_params" : [ + {"itopk": 32}, + {"itopk": 64}, + {"itopk": 128} + ], + "search_result_file" : "result/deep-100M/cagra/dim32" + }, + + { + "name" : "cagra.dim64", + "algo" : "raft_cagra", + "build_param": { + "index_dim" : 64 + }, + "file" : "index/deep-100M/cagra/dim64", + "search_params" : [ + {"itopk": 32}, + {"itopk": 64}, + {"itopk": 128} + ], + "search_result_file" : "result/deep-100M/cagra/dim64" + } ] } diff --git a/cpp/bench/ann/conf/sift-128-euclidean.json b/cpp/bench/ann/conf/sift-128-euclidean.json index 476c363ecd..98983fd62e 100644 --- a/cpp/bench/ann/conf/sift-128-euclidean.json +++ b/cpp/bench/ann/conf/sift-128-euclidean.json @@ -90,8 +90,8 @@ - - { + + { "name": "raft_bfknn", "algo": "raft_bfknn", "build_param": {}, @@ -1316,6 +1316,36 @@ } ], "search_result_file": "result/sift-128-euclidean/raft_ivf_flat/nlist16384" + }, + + { + "name" : "cagra.dim32", + "algo" : "raft_cagra", + "build_param": { + "index_dim" : 32 + }, + "file" : "index/sift-128-euclidean/cagra/dim32", + "search_params" : [ + {"itopk": 32}, + {"itopk": 64}, + {"itopk": 128} + ], + "search_result_file" : "result/sift-128-euclidean/cagra/dim32" + }, + + { + "name" : "cagra.dim64", + "algo" : "raft_cagra", + "build_param": { + "index_dim" : 64 + }, + "file" : "index/sift-128-euclidean/cagra/dim64", + "search_params" : [ + {"itopk": 32}, + {"itopk": 64}, + {"itopk": 128} + ], + "search_result_file" : "result/sift-128-euclidean/cagra/dim64" } ] } diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index c34b95010f..28df4640ee 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -30,6 +30,8 @@ #include #include +#include + #include "benchmark_util.hpp" #include "conf.h" #include "dataset.h" @@ -108,8 +110,8 @@ inline bool mkdir(const std::vector& dirs) } inline bool check(const std::vector& indices, - bool build_mode, - bool force_overwrite) + const bool build_mode, + const bool force_overwrite) { std::vector files_should_exist; std::vector dirs_should_exist; @@ -119,7 +121,7 @@ inline bool check(const std::vector& indices, output_files.push_back(index.file); output_files.push_back(index.file + ".txt"); - auto pos = index.file.rfind('/'); + const auto pos = index.file.rfind('/'); if (pos != std::string::npos) { dirs_should_exist.push_back(index.file.substr(0, pos)); } } else { files_should_exist.push_back(index.file); @@ -128,7 +130,7 @@ inline bool check(const std::vector& indices, output_files.push_back(index.search_result_file + ".0.ibin"); output_files.push_back(index.search_result_file + ".0.txt"); - auto pos = index.search_result_file.rfind('/'); + const auto pos = index.search_result_file.rfind('/'); if (pos != std::string::npos) { dirs_should_exist.push_back(index.search_result_file.substr(0, pos)); } @@ -149,7 +151,7 @@ inline void write_build_info(const std::string& file_prefix, const std::string& name, const std::string& algo, const std::string& build_param, - float build_time) + const float build_time) { std::ofstream ofs(file_prefix + ".txt"); if (!ofs) { throw std::runtime_error("can't open build info file: " + file_prefix + ".txt"); } @@ -175,13 +177,13 @@ void build(const Dataset* dataset, const std::vector& i for (const auto& index : indices) { log_info("creating algo '%s', param=%s", index.algo.c_str(), index.build_param.dump().c_str()); - auto algo = create_algo(index.algo, - dataset->distance(), - dataset->dim(), - index.refine_ratio, - index.build_param, - index.dev_list); - auto algo_property = algo->get_property(); + const auto algo = create_algo(index.algo, + dataset->distance(), + dataset->dim(), + index.refine_ratio, + index.build_param, + index.dev_list); + const auto algo_property = algo->get_property(); const T* base_set_ptr = nullptr; if (algo_property.dataset_memory_type == MemoryType::Host) { @@ -203,7 +205,7 @@ void build(const Dataset* dataset, const std::vector& i Timer timer; algo->build(base_set_ptr, dataset->base_set_size(), stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - float elapsed_ms = timer.elapsed_ms(); + const float elapsed_ms = timer.elapsed_ms(); #ifdef NVTX nvtxRangePop(); #endif @@ -232,15 +234,17 @@ inline void write_search_result(const std::string& file_prefix, const std::string& algo, const std::string& build_param, const std::string& search_param, - int batch_size, - int run_count, - int k, + std::size_t batch_size, + unsigned run_count, + unsigned k, float search_time_average, float search_time_p99, float search_time_p999, + float query_per_second, const int* neighbors, size_t query_set_size) { + log_info("throughput : %e [QPS]", query_per_second); std::ofstream ofs(file_prefix + ".txt"); if (!ofs) { throw std::runtime_error("can't open search result file: " + file_prefix + ".txt"); } ofs << "dataset: " << dataset << "\n" @@ -254,13 +258,16 @@ inline void write_search_result(const std::string& file_prefix, << "batch_size: " << batch_size << "\n" << "run_count: " << run_count << "\n" << "k: " << k << "\n" + << "query_per_second: " << query_per_second << "\n" << "average_search_time: " << search_time_average << endl; + if (search_time_p99 != std::numeric_limits::max()) { ofs << "p99_search_time: " << search_time_p99 << endl; } if (search_time_p999 != std::numeric_limits::max()) { ofs << "p999_search_time: " << search_time_p999 << endl; } + ofs.close(); if (!ofs) { throw std::runtime_error("can't write to search result file: " + file_prefix + ".txt"); @@ -280,15 +287,15 @@ inline void search(const Dataset* dataset, const std::vectorname().c_str(), dataset->query_set_size()); - const T* query_set = dataset->query_set(); + const T* const query_set = dataset->query_set(); // query set is usually much smaller than base set, so load it eagerly - const T* d_query_set = dataset->query_set_on_gpu(); - size_t query_set_size = dataset->query_set_size(); + const T* const d_query_set = dataset->query_set_on_gpu(); + const size_t query_set_size = dataset->query_set_size(); // currently all indices has same batch_size, k and run_count - const int batch_size = indices[0].batch_size; - const int k = indices[0].k; - const int run_count = indices[0].run_count; + const std::size_t batch_size = indices[0].batch_size; + const unsigned k = indices[0].k; + const unsigned run_count = indices[0].run_count; log_info( "basic search parameters: batch_size = %d, k = %d, run_count = %d", batch_size, k, run_count); if (query_set_size % batch_size != 0) { @@ -297,10 +304,10 @@ inline void search(const Dataset* dataset, const std::vector search_times; search_times.reserve(num_batches); std::size_t* d_neighbors; @@ -310,13 +317,13 @@ inline void search(const Dataset* dataset, const std::vector(index.algo, - dataset->distance(), - dataset->dim(), - index.refine_ratio, - index.build_param, - index.dev_list); - auto algo_property = algo->get_property(); + const auto algo = create_algo(index.algo, + dataset->distance(), + dataset->dim(), + index.refine_ratio, + index.build_param, + index.dev_list); + const auto algo_property = algo->get_property(); log_info("loading index '%s' from file '%s'", index.name.c_str(), index.file.c_str()); algo->load(index.file); @@ -349,7 +356,7 @@ inline void search(const Dataset* dataset, const std::vector(index.algo, index.search_params[i]); + const auto p_param = create_search_param(index.algo, index.search_params[i]); algo->set_search_param(*p_param); log_info("search with param: %s", index.search_params[i].dump().c_str()); @@ -364,11 +371,13 @@ inline void search(const Dataset* dataset, const std::vector::max(); float best_search_time_p99 = std::numeric_limits::max(); float best_search_time_p999 = std::numeric_limits::max(); - for (int run = 0; run < run_count; ++run) { + float total_search_time = 0; + for (unsigned run = 0; run < run_count; ++run) { log_info("run %d / %d", run + 1, run_count); for (std::size_t batch_id = 0; batch_id < num_batches; ++batch_id) { - std::size_t row = batch_id * batch_size; - int actual_batch_size = (batch_id == num_batches - 1) ? query_set_size - row : batch_size; + const std::size_t row = batch_id * batch_size; + const std::size_t actual_batch_size = + (batch_id == num_batches - 1) ? query_set_size - row : batch_size; RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); #ifdef NVTX string nvtx_label = "batch" + to_string(batch_id); @@ -389,7 +398,7 @@ inline void search(const Dataset* dataset, const std::vector* dataset, const std::vector= 100) { std::sort(search_times.begin(), search_times.end()); - auto calc_percentile_pos = [](float percentile, size_t N) { + const auto calc_percentile_pos = [](float percentile, size_t N) { return static_cast(std::ceil(percentile / 100.0 * N)) - 1; }; - float search_time_p99 = search_times[calc_percentile_pos(99, search_times.size())]; - best_search_time_p99 = std::min(best_search_time_p99, search_time_p99); + const float search_time_p99 = search_times[calc_percentile_pos(99, search_times.size())]; + best_search_time_p99 = std::min(best_search_time_p99, search_time_p99); if (search_times.size() >= 1000) { - float search_time_p999 = search_times[calc_percentile_pos(99.9, search_times.size())]; - best_search_time_p999 = std::min(best_search_time_p999, search_time_p999); + const float search_time_p999 = + search_times[calc_percentile_pos(99.9, search_times.size())]; + best_search_time_p999 = std::min(best_search_time_p999, search_time_p999); } } search_times.clear(); } RAFT_CUDA_TRY(cudaDeviceSynchronize()); RAFT_CUDA_TRY(cudaPeekAtLastError()); + const auto query_per_second = + (run_count * raft::round_down_safe(query_set_size, batch_size)) / total_search_time; if (algo_property.query_memory_type == MemoryType::Device) { RAFT_CUDA_TRY(cudaMemcpy(neighbors, @@ -436,7 +450,7 @@ inline void search(const Dataset* dataset, const std::vector* dataset, const std::vector -inline int dispatch_benchmark(Configuration& conf, - std::string& index_patterns, +inline int dispatch_benchmark(const Configuration& conf, + const std::string& index_patterns, bool force_overwrite, bool only_check, bool build_mode, bool search_mode) { try { - auto dataset_conf = conf.get_dataset_conf(); + const auto dataset_conf = conf.get_dataset_conf(); BinDataset dataset(dataset_conf.name, dataset_conf.base_file, diff --git a/cpp/bench/ann/src/faiss/faiss_benchmark.cu b/cpp/bench/ann/src/faiss/faiss_benchmark.cu index 294da9a14f..0aa4e76103 100644 --- a/cpp/bench/ann/src/faiss/faiss_benchmark.cu +++ b/cpp/bench/ann/src/faiss/faiss_benchmark.cu @@ -147,4 +147,4 @@ std::unique_ptr::AnnSearchParam> create_search #include "../common/benchmark.hpp" -int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } \ No newline at end of file +int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index baff1b1c45..22204c2b61 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -40,6 +40,12 @@ extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA +#include "raft_cagra_wrapper.h" +extern template class raft::bench::ann::RaftCagra; +extern template class raft::bench::ann::RaftCagra; +extern template class raft::bench::ann::RaftCagra; +#endif #define JSON_DIAGNOSTICS 1 #include @@ -117,28 +123,24 @@ void parse_search_param(const nlohmann::json& conf, } #endif -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf) +#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA +template +void parse_build_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftCagra::BuildParam& param) { - typename Algo::BuildParam param; - parse_build_param(conf, param); - return std::make_unique>(metric, dim, param); + if (conf.contains("index_dim")) { + param.graph_degree = conf.at("index_dim"); + param.intermediate_graph_degree = param.graph_degree * 2; + } } -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) +template +void parse_search_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftCagra::SearchParam& param) { - typename Algo::BuildParam param; - parse_build_param(conf, param); - - (void)dev_list; - return std::make_unique>(metric, dim, param); + param.itopk_size = conf.at("itopk"); } +#endif template std::unique_ptr> create_algo(const std::string& algo, @@ -176,6 +178,13 @@ std::unique_ptr> create_algo(const std::string& algo, ann = std::make_unique>(metric, dim, param, refine_ratio); } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA + if (algo == "raft_cagra") { + typename raft::bench::ann::RaftCagra::BuildParam param; + parse_build_param(conf, param); + ann = std::make_unique>(metric, dim, param); + } #endif if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } @@ -207,6 +216,13 @@ std::unique_ptr::AnnSearchParam> create_search parse_search_param(conf, *param); return param; } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA + if (algo == "raft_cagra") { + auto param = std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } #endif // else throw std::runtime_error("invalid algo: '" + algo + "'"); @@ -216,4 +232,4 @@ std::unique_ptr::AnnSearchParam> create_search #include "../common/benchmark.hpp" -int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } \ No newline at end of file +int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } diff --git a/cpp/bench/ann/src/raft/raft_cagra.cu b/cpp/bench/ann/src/raft/raft_cagra.cu new file mode 100644 index 0000000000..b375af0526 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_cagra.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "raft_cagra_wrapper.h" + +#ifdef RAFT_COMPILED +#include +#endif + +namespace raft::bench::ann { +template class RaftCagra; +template class RaftCagra; +template class RaftCagra; +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h new file mode 100644 index 0000000000..399fd6a0a8 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../common/ann_types.hpp" +#include "raft_ann_bench_utils.h" +#include + +namespace raft::bench::ann { + +template +class RaftCagra : public ANN { + public: + using typename ANN::AnnSearchParam; + + struct SearchParam : public AnnSearchParam { + unsigned itopk_size; + }; + + using BuildParam = raft::neighbors::experimental::cagra::index_params; + + RaftCagra(Metric metric, int dim, const BuildParam& param); + + void build(const T* dataset, size_t nrow, cudaStream_t stream) final; + + void set_search_param(const AnnSearchParam& param) override; + + // TODO: if the number of results is less than k, the remaining elements of 'neighbors' + // will be filled with (size_t)-1 + void search(const T* queries, + int batch_size, + int k, + size_t* neighbors, + float* distances, + cudaStream_t stream = 0) const override; + + // to enable dataset access from GPU memory + AlgoProperty get_property() const override + { + AlgoProperty property; + property.dataset_memory_type = MemoryType::Device; + property.query_memory_type = MemoryType::Device; + property.need_dataset_when_search = true; + return property; + } + void save(const std::string& file) const override; + void load(const std::string&) override; + + private: + raft::device_resources handle_; + BuildParam index_params_; + raft::neighbors::experimental::cagra::search_params search_params_; + std::optional> index_; + int device_; + int dimension_; + rmm::mr::pool_memory_resource mr_; +}; + +template +RaftCagra::RaftCagra(Metric metric, int dim, const BuildParam& param) + : ANN(metric, dim), + index_params_(param), + dimension_(dim), + mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) +{ + rmm::mr::set_current_device_resource(&mr_); + index_params_.metric = parse_metric_type(metric); + RAFT_CUDA_TRY(cudaGetDevice(&device_)); +} + +template +void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t) +{ + auto dataset_view = raft::make_device_matrix_view(dataset, IdxT(nrow), dimension_); + index_.emplace(raft::neighbors::experimental::cagra::build(handle_, index_params_, dataset_view)); + return; +} + +template +void RaftCagra::set_search_param(const AnnSearchParam& param) +{ + return; +} + +template +void RaftCagra::save(const std::string& file) const +{ + raft::neighbors::experimental::cagra::serialize(handle_, file, *index_); + return; +} + +template +void RaftCagra::load(const std::string& file) +{ + index_ = raft::neighbors::experimental::cagra::deserialize(handle_, file); + return; +} + +template +void RaftCagra::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const +{ + IdxT* neighbors_IdxT; + rmm::device_uvector neighbors_storage(0, resource::get_cuda_stream(handle_)); + if constexpr (std::is_same::value) { + neighbors_IdxT = neighbors; + } else { + neighbors_storage.resize(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage.data(); + } + + auto queries_view = raft::make_device_matrix_view(queries, batch_size, dimension_); + auto neighbors_view = raft::make_device_matrix_view(neighbors_IdxT, batch_size, k); + auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); + + raft::neighbors::experimental::cagra::search_params search_params; + search_params.max_queries = batch_size; + search_params.itopk_size = search_params_.max_queries; + raft::neighbors::experimental::cagra::search( + handle_, search_params, *index_, queries_view, neighbors_view, distances_view); + + if (!std::is_same::value) { + raft::linalg::unaryOp(neighbors, + neighbors_IdxT, + batch_size * k, + raft::cast_op(), + resource::get_cuda_stream(handle_)); + } + + handle_.sync_stream(); + return; +} +} // namespace raft::bench::ann