diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index d47de1eeac..d1cf66928c 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -19,10 +19,13 @@ #include #include #include +#include #include #include #include #include +#include +#include #include #include #include @@ -58,6 +61,7 @@ class RaftCagra : public ANN { void set_search_param(const AnnSearchParam& param) override; + void set_search_dataset(const T* dataset, size_t nrow) override; // TODO: if the number of results is less than k, the remaining elements of 'neighbors' // will be filled with (size_t)-1 void search(const T* queries, @@ -85,6 +89,7 @@ class RaftCagra : public ANN { raft::device_resources handle_; BuildParam index_params_; raft::neighbors::cagra::search_params search_params_; + raft::device_matrix graph_; std::optional> index_; int device_; int dimension_; @@ -96,7 +101,8 @@ RaftCagra::RaftCagra(Metric metric, int dim, const BuildParam& param) : ANN(metric, dim), index_params_(param), dimension_(dim), - mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) + mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull), + graph_(make_device_matrix(handle_, 0, 0)) { rmm::mr::set_current_device_resource(&mr_); index_params_.metric = parse_metric_type(metric); @@ -129,15 +135,81 @@ void RaftCagra::set_search_param(const AnnSearchParam& param) template void RaftCagra::save(const std::string& file) const { - raft::neighbors::cagra::serialize(handle_, file, *index_); + // 1 orig serialization: save both dataset and knn graph into the file + // raft::neighbors::cagra::serialize(handle_, file, *index_); + + // 2. Saving only knn graph + // We use numpy serialization format. + // std::ofstream of(file, std::ios::out | std::ios::binary); + // serialize_mdspan(handle_, of, index_->graph()); + // of.close(); + + // 3. Orig CAGRA type of serialization, saving only knn-graph + size_t degree = index_->graph_degree(); + + std::ofstream of(file, std::ios::out | std::ios::binary); + std::size_t size = index_->size(); + + of.write(reinterpret_cast(&size), sizeof(size)); + of.write(reinterpret_cast(°ree), sizeof(degree)); + + auto graph_h = make_host_matrix(size, degree); + raft::copy(graph_h.data_handle(), + index_->graph().data_handle(), + index_->graph().size(), + resource::get_cuda_stream(handle_)); + resource::sync_stream(handle_); + + of.write(reinterpret_cast(graph_h.data_handle()), graph_h.size() * sizeof(IdxT)); + + of.close(); return; } template void RaftCagra::load(const std::string& file) { - index_ = raft::neighbors::cagra::deserialize(handle_, file); - return; + // 1. Original index saving method, we load both dataset and index + // index_ = raft::neighbors::cagra::deserialize(handle_, file); + + // // 2. Read only knn_graph. In theory his could also load file saved withnumpy.save() + // std::ifstream is(file, std::ios::in | std::ios::binary); + // raft::detail::numpy_serializer::header_t header = + // raft::detail::numpy_serializer::read_header(is); is.seekg(0); /* rewind*/ graph_ = + // make_device_matrix(handle_, header.shape[0], header.shape[1]); + // deserialize_mdspan(handle_, is, graph_.view()); + // is.close(); + + // 3. Read only knn graph, using Cagra's knn file format + std::ifstream ifs(file, std::ios::in | std::ios::binary); + if (!ifs) { + throw std::runtime_error("File not exist : " + file + " (`" + __func__ + "` in " + __FILE__ + + ")"); + } + + std::size_t size, degree; + + ifs.read(reinterpret_cast(&size), sizeof(size)); + ifs.read(reinterpret_cast(°ree), sizeof(degree)); + + auto graph_h = make_host_matrix(size, degree); + graph_ = make_device_matrix(handle_, size, degree); + + for (std::size_t i = 0; i < size; i++) { + ifs.read(reinterpret_cast(graph_h.data_handle() + i * degree), sizeof(IdxT) * degree); + } + ifs.close(); + raft::copy( + graph_.data_handle(), graph_h.data_handle(), graph_.size(), resource::get_cuda_stream(handle_)); + resource::sync_stream(handle_); +} + +template +void RaftCagra::set_search_dataset(const T* dataset, size_t nrow) +{ + auto dataset_v = raft::make_host_matrix_view(dataset, nrow, this->dim_); + index_.emplace( + handle_, parse_metric_type(this->metric_), dataset_v, make_const_mdspan(graph_.view())); } template