Skip to content

Commit

Permalink
CAGRA ANN bench: parse build options for IVF-PQ build algo (rapidsai#…
Browse files Browse the repository at this point in the history
…1912)

This PR enables fine tuning the the parameters for CAGRA index building using the IVF-PQ build algo.

This is only affects ANN bechmarks (an example json file is added). The public API of CAGRA is not changed because:
- Advanced users can achieve the same effect by calling `build_knn_graph` and `optimize` individually.
- Long term we consider NN descent to be the preferred build algo, which does not need IVF-PQ parameters.

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1912
  • Loading branch information
tfeher authored and benfred committed Nov 8, 2023
1 parent 66a6af8 commit 1073340
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 84 deletions.
77 changes: 67 additions & 10 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <algorithm>
#include <cmath>
#include <memory>
#include <raft/core/logger.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <stdexcept>
#include <string>
Expand All @@ -35,8 +36,10 @@ extern template class raft::bench::ann::RaftIvfFlatGpu<float, int64_t>;
extern template class raft::bench::ann::RaftIvfFlatGpu<uint8_t, int64_t>;
extern template class raft::bench::ann::RaftIvfFlatGpu<int8_t, int64_t>;
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA)
#include "raft_ivf_pq_wrapper.h"
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
extern template class raft::bench::ann::RaftIvfPQ<float, int64_t>;
extern template class raft::bench::ann::RaftIvfPQ<uint8_t, int64_t>;
extern template class raft::bench::ann::RaftIvfPQ<int8_t, int64_t>;
Expand Down Expand Up @@ -70,12 +73,12 @@ void parse_search_param(const nlohmann::json& conf,
}
#endif

#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::BuildParam& param)
{
param.n_lists = conf.at("nlist");
if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); }
if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); }
if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); }
if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); }
Expand All @@ -97,7 +100,7 @@ template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::SearchParam& param)
{
param.pq_param.n_probes = conf.at("nprobe");
if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); }
if (conf.contains("internalDistanceDtype")) {
std::string type = conf.at("internalDistanceDtype");
if (type == "float") {
Expand Down Expand Up @@ -137,25 +140,79 @@ void parse_search_param(const nlohmann::json& conf,
#endif

#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
raft::neighbors::experimental::nn_descent::index_params& param)
{
if (conf.contains("graph_degree")) { param.graph_degree = conf.at("graph_degree"); }
if (conf.contains("intermediate_graph_degree")) {
param.intermediate_graph_degree = conf.at("intermediate_graph_degree");
}
// we allow niter shorthand for max_iterations
if (conf.contains("niter")) { param.max_iterations = conf.at("niter"); }
if (conf.contains("max_iterations")) { param.max_iterations = conf.at("max_iterations"); }
if (conf.contains("termination_threshold")) {
param.termination_threshold = conf.at("termination_threshold");
}
}

nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf,
const std::string& prefix,
bool remove_prefix = true)
{
nlohmann::json out;
for (auto& i : conf.items()) {
if (i.key().compare(0, prefix.size(), prefix) == 0) {
auto new_key = remove_prefix ? i.key().substr(prefix.size()) : i.key();
out[new_key] = i.value();
}
}
return out;
}

template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftCagra<T, IdxT>::BuildParam& param)
{
if (conf.contains("graph_degree")) {
param.graph_degree = conf.at("graph_degree");
param.intermediate_graph_degree = param.graph_degree * 2;
param.cagra_params.graph_degree = conf.at("graph_degree");
param.cagra_params.intermediate_graph_degree = param.cagra_params.graph_degree * 2;
}
if (conf.contains("intermediate_graph_degree")) {
param.intermediate_graph_degree = conf.at("intermediate_graph_degree");
param.cagra_params.intermediate_graph_degree = conf.at("intermediate_graph_degree");
}
if (conf.contains("graph_build_algo")) {
if (conf.at("graph_build_algo") == "IVF_PQ") {
param.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ;
param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ;
} else if (conf.at("graph_build_algo") == "NN_DESCENT") {
param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
}
}
nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_");
if (!ivf_pq_build_conf.empty()) {
raft::neighbors::ivf_pq::index_params bparam;
parse_build_param<T, IdxT>(ivf_pq_build_conf, bparam);
param.ivf_pq_build_params = bparam;
}
nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_");
if (!ivf_pq_search_conf.empty()) {
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::SearchParam sparam;
parse_search_param<T, IdxT>(ivf_pq_search_conf, sparam);
param.ivf_pq_search_params = sparam.pq_param;
param.ivf_pq_refine_rate = sparam.refine_ratio;
}
nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_");
if (!nn_descent_conf.empty()) {
raft::neighbors::experimental::nn_descent::index_params nn_param;
nn_param.intermediate_graph_degree = 1.5 * param.cagra_params.intermediate_graph_degree;
parse_build_param<T, IdxT>(nn_descent_conf, nn_param);
if (nn_param.graph_degree != param.cagra_params.intermediate_graph_degree) {
RAFT_LOG_WARN(
"nn_descent_graph_degree has to be equal to CAGRA intermediate_grpah_degree, overriding");
nn_param.graph_degree = param.cagra_params.intermediate_graph_degree;
}
param.nn_descent_params = nn_param;
}
if (conf.contains("nn_descent_niter")) { param.nn_descent_niter = conf.at("nn_descent_niter"); }
}

template <typename T, typename IdxT>
Expand Down
40 changes: 27 additions & 13 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <fstream>
#include <iostream>
#include <memory>
#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
Expand All @@ -28,6 +29,9 @@
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/detail/cagra/cagra_build.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft/neighbors/nn_descent_types.hpp>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
#include <stdexcept>
Expand All @@ -50,12 +54,20 @@ class RaftCagra : public ANN<T> {
auto needs_dataset() const -> bool override { return true; }
};

using BuildParam = raft::neighbors::cagra::index_params;
struct BuildParam {
raft::neighbors::cagra::index_params cagra_params;
std::optional<raft::neighbors::experimental::nn_descent::index_params> nn_descent_params =
std::nullopt;
std::optional<float> ivf_pq_refine_rate = std::nullopt;
std::optional<raft::neighbors::ivf_pq::index_params> ivf_pq_build_params = std::nullopt;
std::optional<raft::neighbors::ivf_pq::search_params> ivf_pq_search_params = std::nullopt;
};

RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1)
: ANN<T>(metric, dim), index_params_(param), dimension_(dim), handle_(cudaStreamPerThread)
{
index_params_.metric = parse_metric_type(metric);
index_params_.cagra_params.metric = parse_metric_type(metric);
index_params_.ivf_pq_build_params->metric = parse_metric_type(metric);
RAFT_CUDA_TRY(cudaGetDevice(&device_));
}

Expand Down Expand Up @@ -99,17 +111,19 @@ class RaftCagra : public ANN<T> {
template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
{
if (raft::get_device_for_address(dataset) == -1) {
auto dataset_view =
raft::make_host_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
} else {
auto dataset_view =
raft::make_device_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
}
auto dataset_view =
raft::make_host_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);

auto& params = index_params_.cagra_params;

index_.emplace(raft::neighbors::cagra::detail::build(handle_,
params,
dataset_view,
index_params_.nn_descent_params,
index_params_.ivf_pq_refine_rate,
index_params_.ivf_pq_build_params,
index_params_.ivf_pq_search_params));
return;
}

template <typename T, typename IdxT>
Expand Down
59 changes: 2 additions & 57 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,22 +224,7 @@ void optimize(raft::resources const& res,
mdspan<IdxT, matrix_extent<int64_t>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, int64_t, row_major> new_graph)
{
using internal_IdxT = typename std::make_unsigned<IdxT>::type;

auto new_graph_internal = raft::make_host_matrix_view<internal_IdxT, int64_t>(
reinterpret_cast<internal_IdxT*>(new_graph.data_handle()),
new_graph.extent(0),
new_graph.extent(1));

using g_accessor_internal =
host_device_accessor<std::experimental::default_accessor<internal_IdxT>, memory_type::host>;
auto knn_graph_internal =
mdspan<internal_IdxT, matrix_extent<int64_t>, row_major, g_accessor_internal>(
reinterpret_cast<internal_IdxT*>(knn_graph.data_handle()),
knn_graph.extent(0),
knn_graph.extent(1));

cagra::detail::graph::optimize(res, knn_graph_internal, new_graph_internal);
detail::optimize(res, knn_graph, new_graph);
}

/**
Expand Down Expand Up @@ -290,47 +275,7 @@ index<T, IdxT> build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset)
{
size_t intermediate_degree = params.intermediate_graph_degree;
size_t graph_degree = params.graph_degree;
if (intermediate_degree >= static_cast<size_t>(dataset.extent(0))) {
RAFT_LOG_WARN(
"Intermediate graph degree cannot be larger than dataset size, reducing it to %lu",
dataset.extent(0));
intermediate_degree = dataset.extent(0) - 1;
}
if (intermediate_degree < graph_degree) {
RAFT_LOG_WARN(
"Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing "
"graph_degree.",
graph_degree,
intermediate_degree);
graph_degree = intermediate_degree;
}

std::optional<raft::host_matrix<IdxT, int64_t>> knn_graph(
raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), intermediate_degree));

if (params.build_algo == graph_build_algo::IVF_PQ) {
build_knn_graph(res, dataset, knn_graph->view());

} else {
// Use nn-descent to build CAGRA knn graph
auto nn_descent_params = experimental::nn_descent::index_params();
nn_descent_params.graph_degree = intermediate_degree;
nn_descent_params.intermediate_graph_degree = 1.5 * intermediate_degree;
nn_descent_params.max_iterations = params.nn_descent_niter;
build_knn_graph<T, IdxT>(res, dataset, knn_graph->view(), nn_descent_params);
}

auto cagra_graph = raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), graph_degree);

optimize<IdxT>(res, knn_graph->view(), cagra_graph.view());

// free intermediate graph before trying to create the index
knn_graph.reset();

// Construct an index from dataset and optimized knn graph.
return index<T, IdxT>(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view()));
return detail::build<T, IdxT, Accessor>(res, params, dataset);
}

/**
Expand Down
82 changes: 82 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,86 @@ void build_knn_graph(raft::resources const& res,
graph::sort_knn_graph(res, dataset, knn_graph_internal);
}

template <typename IdxT = uint32_t,
typename g_accessor =
host_device_accessor<std::experimental::default_accessor<IdxT>, memory_type::host>>
void optimize(raft::resources const& res,
mdspan<IdxT, matrix_extent<int64_t>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, int64_t, row_major> new_graph)
{
using internal_IdxT = typename std::make_unsigned<IdxT>::type;

auto new_graph_internal = raft::make_host_matrix_view<internal_IdxT, int64_t>(
reinterpret_cast<internal_IdxT*>(new_graph.data_handle()),
new_graph.extent(0),
new_graph.extent(1));

using g_accessor_internal =
host_device_accessor<std::experimental::default_accessor<internal_IdxT>, memory_type::host>;
auto knn_graph_internal =
mdspan<internal_IdxT, matrix_extent<int64_t>, row_major, g_accessor_internal>(
reinterpret_cast<internal_IdxT*>(knn_graph.data_handle()),
knn_graph.extent(0),
knn_graph.extent(1));

cagra::detail::graph::optimize(res, knn_graph_internal, new_graph_internal);
}

template <typename T,
typename IdxT = uint32_t,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
index<T, IdxT> build(
raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
std::optional<experimental::nn_descent::index_params> nn_descent_params = std::nullopt,
std::optional<float> refine_rate = std::nullopt,
std::optional<ivf_pq::index_params> pq_build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
size_t intermediate_degree = params.intermediate_graph_degree;
size_t graph_degree = params.graph_degree;
if (intermediate_degree >= static_cast<size_t>(dataset.extent(0))) {
RAFT_LOG_WARN(
"Intermediate graph degree cannot be larger than dataset size, reducing it to %lu",
dataset.extent(0));
intermediate_degree = dataset.extent(0) - 1;
}
if (intermediate_degree < graph_degree) {
RAFT_LOG_WARN(
"Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing "
"graph_degree.",
graph_degree,
intermediate_degree);
graph_degree = intermediate_degree;
}

std::optional<raft::host_matrix<IdxT, int64_t>> knn_graph(
raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), intermediate_degree));

if (params.build_algo == graph_build_algo::IVF_PQ) {
build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params);

} else {
// Use nn-descent to build CAGRA knn graph
if (!nn_descent_params) {
nn_descent_params = experimental::nn_descent::index_params();
nn_descent_params->graph_degree = intermediate_degree;
nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree;
nn_descent_params->max_iterations = params.nn_descent_niter;
}
build_knn_graph<T, IdxT>(res, dataset, knn_graph->view(), *nn_descent_params);
}

auto cagra_graph = raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), graph_degree);

optimize<IdxT>(res, knn_graph->view(), cagra_graph.view());

// free intermediate graph before trying to create the index
knn_graph.reset();

// Construct an index from dataset and optimized knn graph.
return index<T, IdxT>(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view()));
}
} // namespace raft::neighbors::cagra::detail
Loading

0 comments on commit 1073340

Please sign in to comment.