Skip to content

Commit

Permalink
Don't serialize dataset with CAGRA bench
Browse files Browse the repository at this point in the history
As an alternative to rapidsai#1743, this uses the `include_dataset=False` param
in cagra::serialize to avoid writing the dataset to disk with the index.
This lets us avoid writing a second copy of the dataset, since it is
available in a separate file already
  • Loading branch information
benfred committed Aug 27, 2023
1 parent 08a1fad commit 770ee58
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class RaftCagra : public ANN<T> {

void set_search_param(const AnnSearchParam& param) override;

void set_search_dataset(const T* dataset, size_t nrow) override;

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(const T* queries,
Expand Down Expand Up @@ -120,29 +122,32 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
raft::make_device_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
}
return;
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::set_search_param(const AnnSearchParam& param)
{
auto search_param = dynamic_cast<const SearchParam&>(param);
search_params_ = search_param.p;
return;
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
{
index_->update_dataset(handle_,
raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_));
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::save(const std::string& file) const
{
raft::neighbors::cagra::serialize(handle_, file, *index_);
return;
raft::neighbors::cagra::serialize(handle_, file, *index_, false);
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::load(const std::string& file)
{
index_ = raft::neighbors::cagra::deserialize<T, IdxT>(handle_, file);
return;
}

template <typename T, typename IdxT>
Expand Down Expand Up @@ -175,6 +180,5 @@ void RaftCagra<T, IdxT>::search(
}

handle_.sync_stream();
return;
}
} // namespace raft::bench::ann

0 comments on commit 770ee58

Please sign in to comment.