Skip to content

Commit

Permalink
Add CAGRA-Q to ANN benchmarks (#2233)
Browse files Browse the repository at this point in the history
Add the relevant options to the CAGRA parameter parser and refinement to the CAGRA ANN benchmark.
No changes to the library code.

NB: the new option won't work correctly until #2206 is merged.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #2233
  • Loading branch information
achirkin authored Mar 21, 2024
1 parent de7341e commit b773494
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 3 deletions.
23 changes: 23 additions & 0 deletions cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ extern template class raft::bench::ann::RaftIvfPQ<int8_t, int64_t>;
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA
extern template class raft::bench::ann::RaftCagra<float, uint32_t>;
extern template class raft::bench::ann::RaftCagra<half, uint32_t>;
extern template class raft::bench::ann::RaftCagra<uint8_t, uint32_t>;
extern template class raft::bench::ann::RaftCagra<int8_t, uint32_t>;
#endif
Expand Down Expand Up @@ -149,6 +150,20 @@ void parse_build_param(const nlohmann::json& conf,
}
}

inline void parse_build_param(const nlohmann::json& conf, raft::neighbors::vpq_params& param)
{
if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); }
if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); }
if (conf.contains("vq_n_centers")) { param.vq_n_centers = conf.at("vq_n_centers"); }
if (conf.contains("kmeans_n_iters")) { param.kmeans_n_iters = conf.at("kmeans_n_iters"); }
if (conf.contains("vq_kmeans_trainset_fraction")) {
param.vq_kmeans_trainset_fraction = conf.at("vq_kmeans_trainset_fraction");
}
if (conf.contains("pq_kmeans_trainset_fraction")) {
param.pq_kmeans_trainset_fraction = conf.at("pq_kmeans_trainset_fraction");
}
}

nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf,
const std::string& prefix,
bool remove_prefix = true)
Expand Down Expand Up @@ -204,6 +219,12 @@ void parse_build_param(const nlohmann::json& conf,
}
param.nn_descent_params = nn_param;
}
nlohmann::json comp_search_conf = collect_conf_with_prefix(conf, "compression_");
if (!comp_search_conf.empty()) {
raft::neighbors::vpq_params vpq_pams;
parse_build_param(comp_search_conf, vpq_pams);
param.cagra_params.compression.emplace(vpq_pams);
}
}

raft::bench::ann::AllocatorType parse_allocator(std::string mem_type)
Expand Down Expand Up @@ -248,5 +269,7 @@ void parse_search_param(const nlohmann::json& conf,
if (conf.contains("internal_dataset_memory_type")) {
param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type"));
}
// Same ratio as in IVF-PQ
param.refine_ratio = conf.value("refine_ratio", 1.0f);
}
#endif
83 changes: 80 additions & 3 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>
#include <raft/neighbors/detail/cagra/cagra_build.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft/neighbors/nn_descent_types.hpp>
Expand Down Expand Up @@ -56,6 +57,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {

struct SearchParam : public AnnSearchParam {
raft::neighbors::experimental::cagra::search_params p;
float refine_ratio;
AllocatorType graph_mem = AllocatorType::Device;
AllocatorType dataset_mem = AllocatorType::Device;
auto needs_dataset() const -> bool override { return true; }
Expand Down Expand Up @@ -98,6 +100,8 @@ class RaftCagra : public ANN<T>, public AnnGPU {
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search_base(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand All @@ -124,6 +128,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {
raft::mr::cuda_huge_page_resource mr_huge_page_;
AllocatorType graph_mem_;
AllocatorType dataset_mem_;
float refine_ratio_;
BuildParam index_params_;
bool need_dataset_update_;
raft::neighbors::cagra::search_params search_params_;
Expand Down Expand Up @@ -151,6 +156,9 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow)

auto& params = index_params_.cagra_params;

// Do include the compressed dataset for the CAGRA-Q
bool shall_include_dataset = params.compression.has_value();

index_ = std::make_shared<raft::neighbors::cagra::index<T, IdxT>>(
std::move(raft::neighbors::cagra::detail::build(handle_,
params,
Expand All @@ -159,7 +167,7 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow)
index_params_.ivf_pq_refine_rate,
index_params_.ivf_pq_build_params,
index_params_.ivf_pq_search_params,
false)));
shall_include_dataset)));
}

inline std::string allocator_to_string(AllocatorType mem_type)
Expand All @@ -179,6 +187,7 @@ void RaftCagra<T, IdxT>::set_search_param(const AnnSearchParam& param)
{
auto search_param = dynamic_cast<const SearchParam&>(param);
search_params_ = search_param.p;
refine_ratio_ = search_param.refine_ratio;
if (search_param.graph_mem != graph_mem_) {
// Move graph to correct memory space
graph_mem_ = search_param.graph_mem;
Expand Down Expand Up @@ -223,12 +232,16 @@ void RaftCagra<T, IdxT>::set_search_param(const AnnSearchParam& param)
template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
{
using ds_idx_type = decltype(index_->data().n_rows());
bool is_vpq =
dynamic_cast<const raft::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data()) ||
dynamic_cast<const raft::neighbors::vpq_dataset<float, ds_idx_type>*>(&index_->data());
// It can happen that we are re-using a previous algo object which already has
// the dataset set. Check if we need update.
if (static_cast<size_t>(input_dataset_v_->extent(0)) != nrow ||
input_dataset_v_->data_handle() != dataset) {
*input_dataset_v_ = make_device_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
need_dataset_update_ = true;
need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset.
}
}

Expand Down Expand Up @@ -258,7 +271,7 @@ std::unique_ptr<ANN<T>> RaftCagra<T, IdxT>::copy()
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::search(
void RaftCagra<T, IdxT>::search_base(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
IdxT* neighbors_IdxT;
Expand Down Expand Up @@ -286,4 +299,68 @@ void RaftCagra<T, IdxT>::search(
raft::resource::get_cuda_stream(handle_));
}
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
auto k0 = static_cast<size_t>(refine_ratio_ * k);
const bool disable_refinement = k0 <= static_cast<size_t>(k);
const raft::resources& res = handle_;
auto stream = resource::get_cuda_stream(res);

if (disable_refinement) {
search_base(queries, batch_size, k, neighbors, distances);
} else {
auto candidate_ixs = raft::make_device_matrix<int64_t, int64_t>(res, batch_size, k0);
auto candidate_dists = raft::make_device_matrix<float, int64_t>(res, batch_size, k0);
search_base(queries,
batch_size,
k0,
reinterpret_cast<size_t*>(candidate_ixs.data_handle()),
candidate_dists.data_handle());

if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) {
auto queries_v =
raft::make_device_matrix_view<const T, int64_t>(queries, batch_size, dimension_);
auto neighours_v = raft::make_device_matrix_view<int64_t, int64_t>(
reinterpret_cast<int64_t*>(neighbors), batch_size, k);
auto distances_v = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);
raft::neighbors::refine<int64_t, T, float, int64_t>(
res,
*input_dataset_v_,
queries_v,
raft::make_const_mdspan(candidate_ixs.view()),
neighours_v,
distances_v,
index_->metric());
} else {
auto dataset_host = raft::make_host_matrix_view<const T, int64_t>(
input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1));
auto queries_host = raft::make_host_matrix<T, int64_t>(batch_size, dimension_);
auto candidates_host = raft::make_host_matrix<int64_t, int64_t>(batch_size, k0);
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, int64_t>(batch_size, k);

raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream);
raft::copy(
candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream);

raft::resource::sync_stream(res); // wait for the queries and candidates
raft::neighbors::refine<int64_t, T, float, int64_t>(res,
dataset_host,
queries_host.view(),
candidates_host.view(),
neighbors_host.view(),
distances_host.view(),
index_->metric());

raft::copy(neighbors,
reinterpret_cast<size_t*>(neighbors_host.data_handle()),
neighbors_host.size(),
stream);
raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream);
}
}
}
} // namespace raft::bench::ann

0 comments on commit b773494

Please sign in to comment.