diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index 6888340b4d..fb7d83a829 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -35,8 +36,10 @@ extern template class raft::bench::ann::RaftIvfFlatGpu; extern template class raft::bench::ann::RaftIvfFlatGpu; extern template class raft::bench::ann::RaftIvfFlatGpu; #endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ +#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) #include "raft_ivf_pq_wrapper.h" +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; @@ -70,12 +73,12 @@ void parse_search_param(const nlohmann::json& conf, } #endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ +#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) template void parse_build_param(const nlohmann::json& conf, typename raft::bench::ann::RaftIvfPQ::BuildParam& param) { - param.n_lists = conf.at("nlist"); + if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } @@ -97,7 +100,7 @@ template void parse_search_param(const nlohmann::json& conf, typename raft::bench::ann::RaftIvfPQ::SearchParam& param) { - param.pq_param.n_probes = conf.at("nprobe"); + if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } if (conf.contains("internalDistanceDtype")) { std::string type = conf.at("internalDistanceDtype"); if (type == "float") { @@ -137,25 +140,79 @@ void parse_search_param(const nlohmann::json& conf, #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA +template +void parse_build_param(const nlohmann::json& conf, + raft::neighbors::experimental::nn_descent::index_params& param) +{ + if (conf.contains("graph_degree")) { param.graph_degree = conf.at("graph_degree"); } + if (conf.contains("intermediate_graph_degree")) { + param.intermediate_graph_degree = conf.at("intermediate_graph_degree"); + } + // we allow niter shorthand for max_iterations + if (conf.contains("niter")) { param.max_iterations = conf.at("niter"); } + if (conf.contains("max_iterations")) { param.max_iterations = conf.at("max_iterations"); } + if (conf.contains("termination_threshold")) { + param.termination_threshold = conf.at("termination_threshold"); + } +} + +nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, + const std::string& prefix, + bool remove_prefix = true) +{ + nlohmann::json out; + for (auto& i : conf.items()) { + if (i.key().compare(0, prefix.size(), prefix) == 0) { + auto new_key = remove_prefix ? i.key().substr(prefix.size()) : i.key(); + out[new_key] = i.value(); + } + } + return out; +} + template void parse_build_param(const nlohmann::json& conf, typename raft::bench::ann::RaftCagra::BuildParam& param) { if (conf.contains("graph_degree")) { - param.graph_degree = conf.at("graph_degree"); - param.intermediate_graph_degree = param.graph_degree * 2; + param.cagra_params.graph_degree = conf.at("graph_degree"); + param.cagra_params.intermediate_graph_degree = param.cagra_params.graph_degree * 2; } if (conf.contains("intermediate_graph_degree")) { - param.intermediate_graph_degree = conf.at("intermediate_graph_degree"); + param.cagra_params.intermediate_graph_degree = conf.at("intermediate_graph_degree"); } if (conf.contains("graph_build_algo")) { if (conf.at("graph_build_algo") == "IVF_PQ") { - param.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ; + param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ; } else if (conf.at("graph_build_algo") == "NN_DESCENT") { - param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; + param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; + } + } + nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_"); + if (!ivf_pq_build_conf.empty()) { + raft::neighbors::ivf_pq::index_params bparam; + parse_build_param(ivf_pq_build_conf, bparam); + param.ivf_pq_build_params = bparam; + } + nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_"); + if (!ivf_pq_search_conf.empty()) { + typename raft::bench::ann::RaftIvfPQ::SearchParam sparam; + parse_search_param(ivf_pq_search_conf, sparam); + param.ivf_pq_search_params = sparam.pq_param; + param.ivf_pq_refine_rate = sparam.refine_ratio; + } + nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_"); + if (!nn_descent_conf.empty()) { + raft::neighbors::experimental::nn_descent::index_params nn_param; + nn_param.intermediate_graph_degree = 1.5 * param.cagra_params.intermediate_graph_degree; + parse_build_param(nn_descent_conf, nn_param); + if (nn_param.graph_degree != param.cagra_params.intermediate_graph_degree) { + RAFT_LOG_WARN( + "nn_descent_graph_degree has to be equal to CAGRA intermediate_grpah_degree, overriding"); + nn_param.graph_degree = param.cagra_params.intermediate_graph_degree; } + param.nn_descent_params = nn_param; } - if (conf.contains("nn_descent_niter")) { param.nn_descent_niter = conf.at("nn_descent_niter"); } } template diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index f1c8154b7c..73fae027bc 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -28,6 +29,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -50,12 +54,20 @@ class RaftCagra : public ANN { auto needs_dataset() const -> bool override { return true; } }; - using BuildParam = raft::neighbors::cagra::index_params; + struct BuildParam { + raft::neighbors::cagra::index_params cagra_params; + std::optional nn_descent_params = + std::nullopt; + std::optional ivf_pq_refine_rate = std::nullopt; + std::optional ivf_pq_build_params = std::nullopt; + std::optional ivf_pq_search_params = std::nullopt; + }; RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) : ANN(metric, dim), index_params_(param), dimension_(dim), handle_(cudaStreamPerThread) { - index_params_.metric = parse_metric_type(metric); + index_params_.cagra_params.metric = parse_metric_type(metric); + index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); RAFT_CUDA_TRY(cudaGetDevice(&device_)); } @@ -99,17 +111,19 @@ class RaftCagra : public ANN { template void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t) { - if (raft::get_device_for_address(dataset) == -1) { - auto dataset_view = - raft::make_host_matrix_view(dataset, IdxT(nrow), dimension_); - index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view)); - return; - } else { - auto dataset_view = - raft::make_device_matrix_view(dataset, IdxT(nrow), dimension_); - index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view)); - return; - } + auto dataset_view = + raft::make_host_matrix_view(dataset, IdxT(nrow), dimension_); + + auto& params = index_params_.cagra_params; + + index_.emplace(raft::neighbors::cagra::detail::build(handle_, + params, + dataset_view, + index_params_.nn_descent_params, + index_params_.ivf_pq_refine_rate, + index_params_.ivf_pq_build_params, + index_params_.ivf_pq_search_params)); + return; } template diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 1efb4da95e..384ed05e1f 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -224,22 +224,7 @@ void optimize(raft::resources const& res, mdspan, row_major, g_accessor> knn_graph, raft::host_matrix_view new_graph) { - using internal_IdxT = typename std::make_unsigned::type; - - auto new_graph_internal = raft::make_host_matrix_view( - reinterpret_cast(new_graph.data_handle()), - new_graph.extent(0), - new_graph.extent(1)); - - using g_accessor_internal = - host_device_accessor, memory_type::host>; - auto knn_graph_internal = - mdspan, row_major, g_accessor_internal>( - reinterpret_cast(knn_graph.data_handle()), - knn_graph.extent(0), - knn_graph.extent(1)); - - cagra::detail::graph::optimize(res, knn_graph_internal, new_graph_internal); + detail::optimize(res, knn_graph, new_graph); } /** @@ -290,47 +275,7 @@ index build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset) { - size_t intermediate_degree = params.intermediate_graph_degree; - size_t graph_degree = params.graph_degree; - if (intermediate_degree >= static_cast(dataset.extent(0))) { - RAFT_LOG_WARN( - "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", - dataset.extent(0)); - intermediate_degree = dataset.extent(0) - 1; - } - if (intermediate_degree < graph_degree) { - RAFT_LOG_WARN( - "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " - "graph_degree.", - graph_degree, - intermediate_degree); - graph_degree = intermediate_degree; - } - - std::optional> knn_graph( - raft::make_host_matrix(dataset.extent(0), intermediate_degree)); - - if (params.build_algo == graph_build_algo::IVF_PQ) { - build_knn_graph(res, dataset, knn_graph->view()); - - } else { - // Use nn-descent to build CAGRA knn graph - auto nn_descent_params = experimental::nn_descent::index_params(); - nn_descent_params.graph_degree = intermediate_degree; - nn_descent_params.intermediate_graph_degree = 1.5 * intermediate_degree; - nn_descent_params.max_iterations = params.nn_descent_niter; - build_knn_graph(res, dataset, knn_graph->view(), nn_descent_params); - } - - auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); - - optimize(res, knn_graph->view(), cagra_graph.view()); - - // free intermediate graph before trying to create the index - knn_graph.reset(); - - // Construct an index from dataset and optimized knn graph. - return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); + return detail::build(res, params, dataset); } /** diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 40024a3deb..ddaf77a22f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -264,4 +264,86 @@ void build_knn_graph(raft::resources const& res, graph::sort_knn_graph(res, dataset, knn_graph_internal); } +template , memory_type::host>> +void optimize(raft::resources const& res, + mdspan, row_major, g_accessor> knn_graph, + raft::host_matrix_view new_graph) +{ + using internal_IdxT = typename std::make_unsigned::type; + + auto new_graph_internal = raft::make_host_matrix_view( + reinterpret_cast(new_graph.data_handle()), + new_graph.extent(0), + new_graph.extent(1)); + + using g_accessor_internal = + host_device_accessor, memory_type::host>; + auto knn_graph_internal = + mdspan, row_major, g_accessor_internal>( + reinterpret_cast(knn_graph.data_handle()), + knn_graph.extent(0), + knn_graph.extent(1)); + + cagra::detail::graph::optimize(res, knn_graph_internal, new_graph_internal); +} + +template , memory_type::host>> +index build( + raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset, + std::optional nn_descent_params = std::nullopt, + std::optional refine_rate = std::nullopt, + std::optional pq_build_params = std::nullopt, + std::optional search_params = std::nullopt) +{ + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + if (intermediate_degree >= static_cast(dataset.extent(0))) { + RAFT_LOG_WARN( + "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", + dataset.extent(0)); + intermediate_degree = dataset.extent(0) - 1; + } + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + std::optional> knn_graph( + raft::make_host_matrix(dataset.extent(0), intermediate_degree)); + + if (params.build_algo == graph_build_algo::IVF_PQ) { + build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params); + + } else { + // Use nn-descent to build CAGRA knn graph + if (!nn_descent_params) { + nn_descent_params = experimental::nn_descent::index_params(); + nn_descent_params->graph_degree = intermediate_degree; + nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree; + nn_descent_params->max_iterations = params.nn_descent_niter; + } + build_knn_graph(res, dataset, knn_graph->view(), *nn_descent_params); + } + + auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); + + optimize(res, knn_graph->view(), cagra_graph.view()); + + // free intermediate graph before trying to create the index + knn_graph.reset(); + + // Construct an index from dataset and optimized knn graph. + return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); +} } // namespace raft::neighbors::cagra::detail diff --git a/docs/source/ann_benchmarks_param_tuning.md b/docs/source/ann_benchmarks_param_tuning.md index d787a96955..cdc7958714 100644 --- a/docs/source/ann_benchmarks_param_tuning.md +++ b/docs/source/ann_benchmarks_param_tuning.md @@ -53,7 +53,6 @@ CAGRA uses a graph-based index, which creates an intermediate, approximate kNN g | `graph_degree` | `build_param` | N | Positive Integer >0 | 64 | Degree of the final kNN graph index. | | `intermediate_graph_degree` | `build_param` | N | Positive Integer >0 | 128 | Degree of the intermediate kNN graph. | | `graph_build_algo` | `build_param` | N | ["IVF_PQ", "NN_DESCENT"] | "IVF_PQ" | Algorithm to use for search | -| `nn_descent_niter` | `build_param` | N | Positive Integer>0 | 20 | Number of iterations if using NN_DESCENT. | | `dataset_memory_type` | `build_param` | N | ["device", "host", "mmap"] | "device" | What memory type should the dataset reside? | | `query_memory_type` | `search_params` | N | ["device", "host", "mmap"] | "device | What memory type should the queries reside? | | `itopk` | `search_wdith` | N | Positive Integer >0 | 64 | Number of intermediate search results retained during the search. Higher values improve search accuracy at the cost of speed. | @@ -61,6 +60,28 @@ CAGRA uses a graph-based index, which creates an intermediate, approximate kNN g | `max_iterations` | `search_param` | N | Integer >=0 | 0 | Upper limit of search iterations. Auto select when 0. | | `algo` | `search_param` | N | string | "auto" | Algorithm to use for search. Possible values: {"auto", "single_cta", "multi_cta", "multi_kernel"} | +To fine tune CAGRA index building we can customize IVF-PQ index builder options using the following settings. These take effect only if `graph_build_algo == "IVF_PQ"`. It is recommended to experiment using a separate IVF-PQ index to find the config that gives the largest QPS for large batch. Recall does not need to be very high, since CAGRA further optimizes the kNN neighbor graph. Some of the default values are derived from the dataset size which is assumed to be [n_vecs, dim]. + +| Parameter | Type | Required | Data Type | Default | Description | +|------------------------|----------------|---|----------------------------------|---------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `ivf_pq_build_nlist` | `build_param` | N | Positive Integer >0 | n_vecs / 2500 | Number of clusters to partition the vectors into. Larger values will put less points into each cluster but this will impact index build time as more clusters need to be trained. | +| `ivf_pq_build_niter` | `build_param` | N | Positive Integer >0 | 25 | Number of k-means iterations to use when training the clusters. | +| `ivf_pq_build_ratio` | `build_param` | N | Positive Integer >0 | 10 | `1/ratio` is the number of training points which should be used to train the clusters. | +| `ivf_pq_build_pq_dim` | `build_param` | N | Positive Integer. Multiple of 8. | dim/2 rounded up to 8 | Dimensionality of the vector after product quantization. When 0, a heuristic is used to select this value. `pq_dim` * `pq_bits` must be a multiple of 8. | +| `ivf_pq_build_pq_bits` | `build_param` | N | Positive Integer. [4-8] | 8 | Bit length of the vector element after quantization. | +| `ivf_pq_build_codebook_kind` | `build_param` | N | ["cluster", "subspace"] | "subspace" | Type of codebook. See the [API docs](https://docs.rapids.ai/api/raft/nightly/cpp_api/neighbors_ivf_pq/#_CPPv412codebook_gen) for more detail | +| `ivf_pq_search_nprobe` | `build_params` | N | Positive Integer >0 | min(2*dim, nlist) | The closest number of clusters to search for each query vector. | +| `ivf_pq_search_internalDistanceDtype` | `build_params` | N | [`float`, `half`] | `fp8` | The precision to use for the distance computations. Lower precision can increase performance at the cost of accuracy. | +| `ivf_pq_search_smemLutDtype` | `build_params` | N | [`float`, `half`, `fp8`] | `half` | The precision to use for the lookup table in shared memory. Lower precision can increase performance at the cost of accuracy. | +| `ivf_pq_search_refine_ratio` | `build_params` | N| Positive Number >=0 | 2 | `refine_ratio * k` nearest neighbors are queried from the index initially and an additional refinement step improves recall by selecting only the best `k` neighbors. | + +Alternatively, if `graph_build_algo == "NN_DESCENT"`, then we can customize the following parameters +| Parameter | Type | Required | Data Type | Default | Description | +|-----------------------------|----------------|----------|----------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `nn_descent_niter` | `build_param` | N | Positive Integer>0 | 20 | Number of NN Descent iterations. | +| `nn_descent_intermediate_graph_degree` | `build_param` | N | Positive Integer>0 | `intermediate_graph_degree` * 1.5 | Intermadiate graph degree during NN descent iterations | +| `nn_descent_max_iterations` | `build_param` | N | Positive Integer>0 | 20 | Alias for `nn_descent_niter` | +| `nn_descent_termination_threshold` | `build_param` | N | Positive float>0 | 0.0001 | Termination threshold for NN descent. | ## FAISS Indexes diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/wiki_all_1M.json b/python/raft-ann-bench/src/raft-ann-bench/run/conf/wiki_all_1M.json index 6eb72a65a1..2d1ec1e322 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/wiki_all_1M.json +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/wiki_all_1M.json @@ -1,9 +1,10 @@ { "dataset": { "name": "wiki_all_1M", - "base_file": "wiki_all_1M/base.88M.fbin", + "base_file": "wiki_all_1M/base.1M.fbin", + "subset_size": 1000000, "query_file": "wiki_all_1M/queries.fbin", - "groundtruth_neighbors_file": "wiki_all_1M/groundtruth.88M.neighbors.ibin", + "groundtruth_neighbors_file": "wiki_all_1M/groundtruth.1M.neighbors.ibin", "distance": "euclidean" }, "search_basic_param": { @@ -169,7 +170,22 @@ { "name": "raft_cagra.dim32.multi_cta", "algo": "raft_cagra", - "build_param": { "graph_degree": 32, "intermediate_graph_degree": 48 }, + "build_param": { "graph_degree": 32, + "intermediate_graph_degree": 48, + "graph_build_algo": "NN_DESCENT", + "ivf_pq_build_pq_dim": 32, + "ivf_pq_build_pq_bits": 8, + "ivf_pq_build_nlist": 16384, + "ivf_pq_build_niter": 10, + "ivf_pq_build_ratio": 10, + "ivf_pq_search_nprobe": 30, + "ivf_pq_search_internalDistanceDtype": "half", + "ivf_pq_search_smemLutDtype": "half", + "ivf_pq_search_refine_ratio": 8, + "nn_descent_max_iterations": 10, + "nn_descent_intermediate_graph_degree": 72, + "nn_descent_termination_threshold": 0.001 + }, "file": "wiki_all_1M/raft_cagra/dim32.ibin", "search_params": [ { "itopk": 32, "search_width": 1, "max_iterations": 0, "algo": "multi_cta" },