diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 516ec2e762..c8bfe9b401 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -58,6 +58,8 @@ class RaftCagra : public ANN { void set_search_param(const AnnSearchParam& param) override; + void set_search_dataset(const T* dataset, size_t nrow) override; + // TODO: if the number of results is less than k, the remaining elements of 'neighbors' // will be filled with (size_t)-1 void search(const T* queries, @@ -120,7 +122,6 @@ void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t) raft::make_device_matrix_view(dataset, IdxT(nrow), dimension_); index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view)); } - return; } template @@ -128,21 +129,25 @@ void RaftCagra::set_search_param(const AnnSearchParam& param) { auto search_param = dynamic_cast(param); search_params_ = search_param.p; - return; +} + +template +void RaftCagra::set_search_dataset(const T* dataset, size_t nrow) +{ + index_->update_dataset(handle_, + raft::make_host_matrix_view(dataset, nrow, this->dim_)); } template void RaftCagra::save(const std::string& file) const { - raft::neighbors::cagra::serialize(handle_, file, *index_); - return; + raft::neighbors::cagra::serialize(handle_, file, *index_, false); } template void RaftCagra::load(const std::string& file) { index_ = raft::neighbors::cagra::deserialize(handle_, file); - return; } template @@ -175,6 +180,5 @@ void RaftCagra::search( } handle_.sync_stream(); - return; } } // namespace raft::bench::ann