From b7734949c8b6ae00a6ccc726289fa2641b1d30c3 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Thu, 21 Mar 2024 07:53:50 +0100 Subject: [PATCH] Add CAGRA-Q to ANN benchmarks (#2233) Add the relevant options to the CAGRA parameter parser and refinement to the CAGRA ANN benchmark. No changes to the library code. NB: the new option won't work correctly until https://github.com/rapidsai/raft/pull/2206 is merged. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2233 --- .../src/raft/raft_ann_bench_param_parser.h | 23 +++++ cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 83 ++++++++++++++++++- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 2339677340..48bf1d70d8 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -43,6 +43,7 @@ extern template class raft::bench::ann::RaftIvfPQ; #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA extern template class raft::bench::ann::RaftCagra; +extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; #endif @@ -149,6 +150,20 @@ void parse_build_param(const nlohmann::json& conf, } } +inline void parse_build_param(const nlohmann::json& conf, raft::neighbors::vpq_params& param) +{ + if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } + if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); } + if (conf.contains("vq_n_centers")) { param.vq_n_centers = conf.at("vq_n_centers"); } + if (conf.contains("kmeans_n_iters")) { param.kmeans_n_iters = conf.at("kmeans_n_iters"); } + if (conf.contains("vq_kmeans_trainset_fraction")) { + param.vq_kmeans_trainset_fraction = conf.at("vq_kmeans_trainset_fraction"); + } + if (conf.contains("pq_kmeans_trainset_fraction")) { + param.pq_kmeans_trainset_fraction = conf.at("pq_kmeans_trainset_fraction"); + } +} + nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, const std::string& prefix, bool remove_prefix = true) @@ -204,6 +219,12 @@ void parse_build_param(const nlohmann::json& conf, } param.nn_descent_params = nn_param; } + nlohmann::json comp_search_conf = collect_conf_with_prefix(conf, "compression_"); + if (!comp_search_conf.empty()) { + raft::neighbors::vpq_params vpq_pams; + parse_build_param(comp_search_conf, vpq_pams); + param.cagra_params.compression.emplace(vpq_pams); + } } raft::bench::ann::AllocatorType parse_allocator(std::string mem_type) @@ -248,5 +269,7 @@ void parse_search_param(const nlohmann::json& conf, if (conf.contains("internal_dataset_memory_type")) { param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type")); } + // Same ratio as in IVF-PQ + param.refine_ratio = conf.value("refine_ratio", 1.0f); } #endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 25f7f93777..70fd22001e 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -56,6 +57,7 @@ class RaftCagra : public ANN, public AnnGPU { struct SearchParam : public AnnSearchParam { raft::neighbors::experimental::cagra::search_params p; + float refine_ratio; AllocatorType graph_mem = AllocatorType::Device; AllocatorType dataset_mem = AllocatorType::Device; auto needs_dataset() const -> bool override { return true; } @@ -98,6 +100,8 @@ class RaftCagra : public ANN, public AnnGPU { // will be filled with (size_t)-1 void search( const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search_base( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -124,6 +128,7 @@ class RaftCagra : public ANN, public AnnGPU { raft::mr::cuda_huge_page_resource mr_huge_page_; AllocatorType graph_mem_; AllocatorType dataset_mem_; + float refine_ratio_; BuildParam index_params_; bool need_dataset_update_; raft::neighbors::cagra::search_params search_params_; @@ -151,6 +156,9 @@ void RaftCagra::build(const T* dataset, size_t nrow) auto& params = index_params_.cagra_params; + // Do include the compressed dataset for the CAGRA-Q + bool shall_include_dataset = params.compression.has_value(); + index_ = std::make_shared>( std::move(raft::neighbors::cagra::detail::build(handle_, params, @@ -159,7 +167,7 @@ void RaftCagra::build(const T* dataset, size_t nrow) index_params_.ivf_pq_refine_rate, index_params_.ivf_pq_build_params, index_params_.ivf_pq_search_params, - false))); + shall_include_dataset))); } inline std::string allocator_to_string(AllocatorType mem_type) @@ -179,6 +187,7 @@ void RaftCagra::set_search_param(const AnnSearchParam& param) { auto search_param = dynamic_cast(param); search_params_ = search_param.p; + refine_ratio_ = search_param.refine_ratio; if (search_param.graph_mem != graph_mem_) { // Move graph to correct memory space graph_mem_ = search_param.graph_mem; @@ -223,12 +232,16 @@ void RaftCagra::set_search_param(const AnnSearchParam& param) template void RaftCagra::set_search_dataset(const T* dataset, size_t nrow) { + using ds_idx_type = decltype(index_->data().n_rows()); + bool is_vpq = + dynamic_cast*>(&index_->data()) || + dynamic_cast*>(&index_->data()); // It can happen that we are re-using a previous algo object which already has // the dataset set. Check if we need update. if (static_cast(input_dataset_v_->extent(0)) != nrow || input_dataset_v_->data_handle() != dataset) { *input_dataset_v_ = make_device_matrix_view(dataset, nrow, this->dim_); - need_dataset_update_ = true; + need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset. } } @@ -258,7 +271,7 @@ std::unique_ptr> RaftCagra::copy() } template -void RaftCagra::search( +void RaftCagra::search_base( const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { IdxT* neighbors_IdxT; @@ -286,4 +299,68 @@ void RaftCagra::search( raft::resource::get_cuda_stream(handle_)); } } + +template +void RaftCagra::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +{ + auto k0 = static_cast(refine_ratio_ * k); + const bool disable_refinement = k0 <= static_cast(k); + const raft::resources& res = handle_; + auto stream = resource::get_cuda_stream(res); + + if (disable_refinement) { + search_base(queries, batch_size, k, neighbors, distances); + } else { + auto candidate_ixs = raft::make_device_matrix(res, batch_size, k0); + auto candidate_dists = raft::make_device_matrix(res, batch_size, k0); + search_base(queries, + batch_size, + k0, + reinterpret_cast(candidate_ixs.data_handle()), + candidate_dists.data_handle()); + + if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) { + auto queries_v = + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto neighours_v = raft::make_device_matrix_view( + reinterpret_cast(neighbors), batch_size, k); + auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); + raft::neighbors::refine( + res, + *input_dataset_v_, + queries_v, + raft::make_const_mdspan(candidate_ixs.view()), + neighours_v, + distances_v, + index_->metric()); + } else { + auto dataset_host = raft::make_host_matrix_view( + input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1)); + auto queries_host = raft::make_host_matrix(batch_size, dimension_); + auto candidates_host = raft::make_host_matrix(batch_size, k0); + auto neighbors_host = raft::make_host_matrix(batch_size, k); + auto distances_host = raft::make_host_matrix(batch_size, k); + + raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); + raft::copy( + candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream); + + raft::resource::sync_stream(res); // wait for the queries and candidates + raft::neighbors::refine(res, + dataset_host, + queries_host.view(), + candidates_host.view(), + neighbors_host.view(), + distances_host.view(), + index_->metric()); + + raft::copy(neighbors, + reinterpret_cast(neighbors_host.data_handle()), + neighbors_host.size(), + stream); + raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); + } + } +} } // namespace raft::bench::ann