diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index d6a5fddb98..eb44e58cb5 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -30,6 +30,7 @@ option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON) option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON) option(RAFT_ANN_BENCH_SINGLE_EXE @@ -54,6 +55,7 @@ if(BUILD_CPU_ONLY) set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF) set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF) + set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF) set(RAFT_ANN_BENCH_USE_GGNN OFF) else() # Disable faiss benchmarks on CUDA 12 since faiss is not yet CUDA 12-enabled. @@ -88,6 +90,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OR RAFT_ANN_BENCH_USE_RAFT_CAGRA + OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB ) set(RAFT_ANN_BENCH_USE_RAFT ON) endif() @@ -95,7 +98,7 @@ endif() # ################################################################################################## # * Fetch requirements ------------------------------------------------------------- -if(RAFT_ANN_BENCH_USE_HNSWLIB) +if(RAFT_ANN_BENCH_USE_HNSWLIB OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) include(cmake/thirdparty/get_hnswlib.cmake) endif() @@ -250,6 +253,20 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA) ) endif() +if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) + ConfigureAnnBench( + NAME + RAFT_CAGRA_HNSWLIB + PATH + bench/ann/src/raft/raft_cagra_hnswlib.cu + INCLUDES + ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib + LINKS + raft::compiled + CXXFLAGS "${HNSW_CXX_FLAGS}" + ) +endif() + set(RAFT_FAISS_TARGETS faiss::faiss) if(TARGET faiss::faiss_avx2) set(RAFT_FAISS_TARGETS faiss::faiss_avx2) diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index 364da81f77..921d72decc 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -91,6 +91,8 @@ class HnswLib : public ANN { return property; } + void set_base_layer_only() { appr_alg_->base_layer_only = true; } + private: void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const; diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h new file mode 100644 index 0000000000..479a90e3b5 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#define JSON_DIAGNOSTICS 1 +#include + +#undef WARP_SIZE +#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN +#include "raft_wrapper.h" +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT +#include "raft_ivf_flat_wrapper.h" +extern template class raft::bench::ann::RaftIvfFlatGpu; +extern template class raft::bench::ann::RaftIvfFlatGpu; +extern template class raft::bench::ann::RaftIvfFlatGpu; +#endif +#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +#include "raft_ivf_pq_wrapper.h" +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ +extern template class raft::bench::ann::RaftIvfPQ; +extern template class raft::bench::ann::RaftIvfPQ; +extern template class raft::bench::ann::RaftIvfPQ; +#endif +#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +#include "raft_cagra_wrapper.h" +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA +extern template class raft::bench::ann::RaftCagra; +extern template class raft::bench::ann::RaftCagra; +extern template class raft::bench::ann::RaftCagra; +#endif + +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT +template +void parse_build_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftIvfFlatGpu::BuildParam& param) +{ + param.n_lists = conf.at("nlist"); + if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } + if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } +} + +template +void parse_search_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftIvfFlatGpu::SearchParam& param) +{ + param.ivf_flat_params.n_probes = conf.at("nprobe"); +} +#endif + +#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +template +void parse_build_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftIvfPQ::BuildParam& param) +{ + if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } + if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } + if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } + if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } + if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); } + if (conf.contains("codebook_kind")) { + std::string kind = conf.at("codebook_kind"); + if (kind == "cluster") { + param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; + } else if (kind == "subspace") { + param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + } else { + throw std::runtime_error("codebook_kind: '" + kind + + "', should be either 'cluster' or 'subspace'"); + } + } +} + +template +void parse_search_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftIvfPQ::SearchParam& param) +{ + if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } + if (conf.contains("internalDistanceDtype")) { + std::string type = conf.at("internalDistanceDtype"); + if (type == "float") { + param.pq_param.internal_distance_dtype = CUDA_R_32F; + } else if (type == "half") { + param.pq_param.internal_distance_dtype = CUDA_R_16F; + } else { + throw std::runtime_error("internalDistanceDtype: '" + type + + "', should be either 'float' or 'half'"); + } + } else { + // set half as default type + param.pq_param.internal_distance_dtype = CUDA_R_16F; + } + + if (conf.contains("smemLutDtype")) { + std::string type = conf.at("smemLutDtype"); + if (type == "float") { + param.pq_param.lut_dtype = CUDA_R_32F; + } else if (type == "half") { + param.pq_param.lut_dtype = CUDA_R_16F; + } else if (type == "fp8") { + param.pq_param.lut_dtype = CUDA_R_8U; + } else { + throw std::runtime_error("smemLutDtype: '" + type + + "', should be either 'float', 'half' or 'fp8'"); + } + } else { + // set half as default + param.pq_param.lut_dtype = CUDA_R_16F; + } + if (conf.contains("refine_ratio")) { + param.refine_ratio = conf.at("refine_ratio"); + if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); } + } +} +#endif + +#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +template +void parse_build_param(const nlohmann::json& conf, + raft::neighbors::experimental::nn_descent::index_params& param) +{ + if (conf.contains("graph_degree")) { param.graph_degree = conf.at("graph_degree"); } + if (conf.contains("intermediate_graph_degree")) { + param.intermediate_graph_degree = conf.at("intermediate_graph_degree"); + } + // we allow niter shorthand for max_iterations + if (conf.contains("niter")) { param.max_iterations = conf.at("niter"); } + if (conf.contains("max_iterations")) { param.max_iterations = conf.at("max_iterations"); } + if (conf.contains("termination_threshold")) { + param.termination_threshold = conf.at("termination_threshold"); + } +} + +nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, + const std::string& prefix, + bool remove_prefix = true) +{ + nlohmann::json out; + for (auto& i : conf.items()) { + if (i.key().compare(0, prefix.size(), prefix) == 0) { + auto new_key = remove_prefix ? i.key().substr(prefix.size()) : i.key(); + out[new_key] = i.value(); + } + } + return out; +} + +template +void parse_build_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftCagra::BuildParam& param) +{ + if (conf.contains("graph_degree")) { + param.cagra_params.graph_degree = conf.at("graph_degree"); + param.cagra_params.intermediate_graph_degree = param.cagra_params.graph_degree * 2; + } + if (conf.contains("intermediate_graph_degree")) { + param.cagra_params.intermediate_graph_degree = conf.at("intermediate_graph_degree"); + } + if (conf.contains("graph_build_algo")) { + if (conf.at("graph_build_algo") == "IVF_PQ") { + param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ; + } else if (conf.at("graph_build_algo") == "NN_DESCENT") { + param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; + } + } + nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_"); + if (!ivf_pq_build_conf.empty()) { + raft::neighbors::ivf_pq::index_params bparam; + parse_build_param(ivf_pq_build_conf, bparam); + param.ivf_pq_build_params = bparam; + } + nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_"); + if (!ivf_pq_search_conf.empty()) { + typename raft::bench::ann::RaftIvfPQ::SearchParam sparam; + parse_search_param(ivf_pq_search_conf, sparam); + param.ivf_pq_search_params = sparam.pq_param; + param.ivf_pq_refine_rate = sparam.refine_ratio; + } + nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_"); + if (!nn_descent_conf.empty()) { + raft::neighbors::experimental::nn_descent::index_params nn_param; + nn_param.intermediate_graph_degree = 1.5 * param.cagra_params.intermediate_graph_degree; + parse_build_param(nn_descent_conf, nn_param); + if (nn_param.graph_degree != param.cagra_params.intermediate_graph_degree) { + nn_param.graph_degree = param.cagra_params.intermediate_graph_degree; + } + param.nn_descent_params = nn_param; + } +} + +template +void parse_search_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftCagra::SearchParam& param) +{ + if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } + if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); } + if (conf.contains("max_iterations")) { param.p.max_iterations = conf.at("max_iterations"); } + if (conf.contains("algo")) { + if (conf.at("algo") == "single_cta") { + param.p.algo = raft::neighbors::experimental::cagra::search_algo::SINGLE_CTA; + } else if (conf.at("algo") == "multi_cta") { + param.p.algo = raft::neighbors::experimental::cagra::search_algo::MULTI_CTA; + } else if (conf.at("algo") == "multi_kernel") { + param.p.algo = raft::neighbors::experimental::cagra::search_algo::MULTI_KERNEL; + } else if (conf.at("algo") == "auto") { + param.p.algo = raft::neighbors::experimental::cagra::search_algo::AUTO; + } else { + std::string tmp = conf.at("algo"); + THROW("Invalid value for algo: %s", tmp.c_str()); + } + } +} +#endif diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index fb7d83a829..f8c65a2d6e 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -16,6 +16,8 @@ #include "../common/ann_types.hpp" +#include "raft_ann_bench_param_parser.h" + #include #include #include @@ -26,219 +28,11 @@ #include #include -#undef WARP_SIZE -#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN -#include "raft_wrapper.h" -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT -#include "raft_ivf_flat_wrapper.h" -extern template class raft::bench::ann::RaftIvfFlatGpu; -extern template class raft::bench::ann::RaftIvfFlatGpu; -extern template class raft::bench::ann::RaftIvfFlatGpu; -#endif -#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) -#include "raft_ivf_pq_wrapper.h" -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ -extern template class raft::bench::ann::RaftIvfPQ; -extern template class raft::bench::ann::RaftIvfPQ; -extern template class raft::bench::ann::RaftIvfPQ; -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA -#include "raft_cagra_wrapper.h" -extern template class raft::bench::ann::RaftCagra; -extern template class raft::bench::ann::RaftCagra; -extern template class raft::bench::ann::RaftCagra; -#endif #define JSON_DIAGNOSTICS 1 #include namespace raft::bench::ann { -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfFlatGpu::BuildParam& param) -{ - param.n_lists = conf.at("nlist"); - if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } - if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfFlatGpu::SearchParam& param) -{ - param.ivf_flat_params.n_probes = conf.at("nprobe"); -} -#endif - -#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfPQ::BuildParam& param) -{ - if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } - if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } - if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } - if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } - if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); } - if (conf.contains("codebook_kind")) { - std::string kind = conf.at("codebook_kind"); - if (kind == "cluster") { - param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; - } else if (kind == "subspace") { - param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; - } else { - throw std::runtime_error("codebook_kind: '" + kind + - "', should be either 'cluster' or 'subspace'"); - } - } -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfPQ::SearchParam& param) -{ - if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } - if (conf.contains("internalDistanceDtype")) { - std::string type = conf.at("internalDistanceDtype"); - if (type == "float") { - param.pq_param.internal_distance_dtype = CUDA_R_32F; - } else if (type == "half") { - param.pq_param.internal_distance_dtype = CUDA_R_16F; - } else { - throw std::runtime_error("internalDistanceDtype: '" + type + - "', should be either 'float' or 'half'"); - } - } else { - // set half as default type - param.pq_param.internal_distance_dtype = CUDA_R_16F; - } - - if (conf.contains("smemLutDtype")) { - std::string type = conf.at("smemLutDtype"); - if (type == "float") { - param.pq_param.lut_dtype = CUDA_R_32F; - } else if (type == "half") { - param.pq_param.lut_dtype = CUDA_R_16F; - } else if (type == "fp8") { - param.pq_param.lut_dtype = CUDA_R_8U; - } else { - throw std::runtime_error("smemLutDtype: '" + type + - "', should be either 'float', 'half' or 'fp8'"); - } - } else { - // set half as default - param.pq_param.lut_dtype = CUDA_R_16F; - } - if (conf.contains("refine_ratio")) { - param.refine_ratio = conf.at("refine_ratio"); - if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); } - } -} -#endif - -#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA -template -void parse_build_param(const nlohmann::json& conf, - raft::neighbors::experimental::nn_descent::index_params& param) -{ - if (conf.contains("graph_degree")) { param.graph_degree = conf.at("graph_degree"); } - if (conf.contains("intermediate_graph_degree")) { - param.intermediate_graph_degree = conf.at("intermediate_graph_degree"); - } - // we allow niter shorthand for max_iterations - if (conf.contains("niter")) { param.max_iterations = conf.at("niter"); } - if (conf.contains("max_iterations")) { param.max_iterations = conf.at("max_iterations"); } - if (conf.contains("termination_threshold")) { - param.termination_threshold = conf.at("termination_threshold"); - } -} - -nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, - const std::string& prefix, - bool remove_prefix = true) -{ - nlohmann::json out; - for (auto& i : conf.items()) { - if (i.key().compare(0, prefix.size(), prefix) == 0) { - auto new_key = remove_prefix ? i.key().substr(prefix.size()) : i.key(); - out[new_key] = i.value(); - } - } - return out; -} - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftCagra::BuildParam& param) -{ - if (conf.contains("graph_degree")) { - param.cagra_params.graph_degree = conf.at("graph_degree"); - param.cagra_params.intermediate_graph_degree = param.cagra_params.graph_degree * 2; - } - if (conf.contains("intermediate_graph_degree")) { - param.cagra_params.intermediate_graph_degree = conf.at("intermediate_graph_degree"); - } - if (conf.contains("graph_build_algo")) { - if (conf.at("graph_build_algo") == "IVF_PQ") { - param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ; - } else if (conf.at("graph_build_algo") == "NN_DESCENT") { - param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; - } - } - nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_"); - if (!ivf_pq_build_conf.empty()) { - raft::neighbors::ivf_pq::index_params bparam; - parse_build_param(ivf_pq_build_conf, bparam); - param.ivf_pq_build_params = bparam; - } - nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_"); - if (!ivf_pq_search_conf.empty()) { - typename raft::bench::ann::RaftIvfPQ::SearchParam sparam; - parse_search_param(ivf_pq_search_conf, sparam); - param.ivf_pq_search_params = sparam.pq_param; - param.ivf_pq_refine_rate = sparam.refine_ratio; - } - nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_"); - if (!nn_descent_conf.empty()) { - raft::neighbors::experimental::nn_descent::index_params nn_param; - nn_param.intermediate_graph_degree = 1.5 * param.cagra_params.intermediate_graph_degree; - parse_build_param(nn_descent_conf, nn_param); - if (nn_param.graph_degree != param.cagra_params.intermediate_graph_degree) { - RAFT_LOG_WARN( - "nn_descent_graph_degree has to be equal to CAGRA intermediate_grpah_degree, overriding"); - nn_param.graph_degree = param.cagra_params.intermediate_graph_degree; - } - param.nn_descent_params = nn_param; - } -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftCagra::SearchParam& param) -{ - if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } - if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); } - if (conf.contains("max_iterations")) { param.p.max_iterations = conf.at("max_iterations"); } - if (conf.contains("algo")) { - if (conf.at("algo") == "single_cta") { - param.p.algo = raft::neighbors::experimental::cagra::search_algo::SINGLE_CTA; - } else if (conf.at("algo") == "multi_cta") { - param.p.algo = raft::neighbors::experimental::cagra::search_algo::MULTI_CTA; - } else if (conf.at("algo") == "multi_kernel") { - param.p.algo = raft::neighbors::experimental::cagra::search_algo::MULTI_KERNEL; - } else if (conf.at("algo") == "auto") { - param.p.algo = raft::neighbors::experimental::cagra::search_algo::AUTO; - } else { - std::string tmp = conf.at("algo"); - THROW("Invalid value for algo: %s", tmp.c_str()); - } - } -} -#endif - template std::unique_ptr> create_algo(const std::string& algo, const std::string& distance, @@ -281,6 +75,7 @@ std::unique_ptr> create_algo(const std::string& algo, ann = std::make_unique>(metric, dim, param); } #endif + if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } return ann; @@ -318,6 +113,7 @@ std::unique_ptr::AnnSearchParam> create_search return param; } #endif + // else throw std::runtime_error("invalid algo: '" + algo + "'"); } diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu b/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu new file mode 100644 index 0000000000..ce6fa255b2 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../common/ann_types.hpp" +#include "raft_ann_bench_param_parser.h" +#include "raft_cagra_hnswlib_wrapper.h" + +#include + +#define JSON_DIAGNOSTICS 1 +#include + +namespace raft::bench::ann { + +template +void parse_search_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftCagraHnswlib::SearchParam& param) +{ + param.ef = conf.at("ef"); + if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); } +} + +template +std::unique_ptr> create_algo(const std::string& algo, + const std::string& distance, + int dim, + const nlohmann::json& conf, + const std::vector& dev_list) +{ + // stop compiler warning; not all algorithms support multi-GPU so it may not be used + (void)dev_list; + + raft::bench::ann::Metric metric = parse_metric(distance); + std::unique_ptr> ann; + + if constexpr (std::is_same_v or std::is_same_v) { + if (algo == "raft_cagra_hnswlib") { + typename raft::bench::ann::RaftCagraHnswlib::BuildParam param; + parse_build_param(conf, param); + ann = std::make_unique>(metric, dim, param); + } + } + + if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } + + return ann; +} + +template +std::unique_ptr::AnnSearchParam> create_search_param( + const std::string& algo, const nlohmann::json& conf) +{ + if (algo == "raft_cagra_hnswlib") { + auto param = + std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } + + throw std::runtime_error("invalid algo: '" + algo + "'"); +} + +} // namespace raft::bench::ann + +REGISTER_ALGO_INSTANCE(float); +REGISTER_ALGO_INSTANCE(std::int8_t); +REGISTER_ALGO_INSTANCE(std::uint8_t); + +#ifdef ANN_BENCH_BUILD_MAIN +#include "../common/benchmark.hpp" +int main(int argc, char** argv) +{ + rmm::mr::cuda_memory_resource cuda_mr; + // Construct a resource that uses a coalescing best-fit pool allocator + rmm::mr::pool_memory_resource pool_mr{&cuda_mr}; + rmm::mr::set_current_device_resource( + &pool_mr); // Updates the current device resource pointer to `pool_mr` + rmm::mr::device_memory_resource* mr = + rmm::mr::get_current_device_resource(); // Points to `pool_mr` + return raft::bench::ann::run_main(argc, argv); +} +#endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h new file mode 100644 index 0000000000..432caecfcc --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../hnswlib/hnswlib_wrapper.h" +#include "raft_cagra_wrapper.h" +#include + +namespace raft::bench::ann { + +template +class RaftCagraHnswlib : public ANN { + public: + using typename ANN::AnnSearchParam; + using BuildParam = typename RaftCagra::BuildParam; + using SearchParam = typename HnswLib::SearchParam; + + RaftCagraHnswlib(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) + : ANN(metric, dim), + metric_(metric), + index_params_(param), + dimension_(dim), + handle_(cudaStreamPerThread) + { + } + + ~RaftCagraHnswlib() noexcept {} + + void build(const T* dataset, size_t nrow, cudaStream_t stream) final; + + void set_search_param(const AnnSearchParam& param) override; + + // TODO: if the number of results is less than k, the remaining elements of 'neighbors' + // will be filled with (size_t)-1 + void search(const T* queries, + int batch_size, + int k, + size_t* neighbors, + float* distances, + cudaStream_t stream = 0) const override; + + // to enable dataset access from GPU memory + AlgoProperty get_preference() const override + { + AlgoProperty property; + property.dataset_memory_type = MemoryType::HostMmap; + property.query_memory_type = MemoryType::Host; + return property; + } + void save(const std::string& file) const override; + void load(const std::string&) override; + + private: + raft::device_resources handle_; + Metric metric_; + BuildParam index_params_; + int dimension_; + + std::unique_ptr> cagra_build_; + std::unique_ptr> hnswlib_search_; + + Objective metric_objective_; +}; + +template +void RaftCagraHnswlib::build(const T* dataset, size_t nrow, cudaStream_t stream) +{ + if (not cagra_build_) { + cagra_build_ = std::make_unique>(metric_, dimension_, index_params_); + } + cagra_build_->build(dataset, nrow, stream); +} + +template +void RaftCagraHnswlib::set_search_param(const AnnSearchParam& param_) +{ + hnswlib_search_->set_search_param(param_); +} + +template +void RaftCagraHnswlib::save(const std::string& file) const +{ + cagra_build_->save_to_hnswlib(file); +} + +template +void RaftCagraHnswlib::load(const std::string& file) +{ + typename HnswLib::BuildParam param; + // these values don't matter since we don't build with HnswLib + param.M = 50; + param.ef_construction = 100; + if (not hnswlib_search_) { + hnswlib_search_ = std::make_unique>(metric_, dimension_, param); + } + hnswlib_search_->load(file); + hnswlib_search_->set_base_layer_only(); +} + +template +void RaftCagraHnswlib::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const +{ + hnswlib_search_->search(queries, batch_size, k, neighbors, distances); +} + +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 73fae027bc..bf526101be 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -98,6 +98,7 @@ class RaftCagra : public ANN { } void save(const std::string& file) const override; void load(const std::string&) override; + void save_to_hnswlib(const std::string& file) const; private: raft::device_resources handle_; @@ -143,7 +144,13 @@ void RaftCagra::set_search_dataset(const T* dataset, size_t nrow) template void RaftCagra::save(const std::string& file) const { - raft::neighbors::cagra::serialize(handle_, file, *index_, false); + raft::neighbors::cagra::serialize(handle_, file, *index_); +} + +template +void RaftCagra::save_to_hnswlib(const std::string& file) const +{ + raft::neighbors::cagra::serialize_to_hnswlib(handle_, file, *index_); } template diff --git a/cpp/cmake/patches/hnswlib.patch b/cpp/cmake/patches/hnswlib.patch new file mode 100644 index 0000000000..32c1537c58 --- /dev/null +++ b/cpp/cmake/patches/hnswlib.patch @@ -0,0 +1,130 @@ +diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h +index e95e0b5..f0fe50a 100644 +--- a/hnswlib/hnswalg.h ++++ b/hnswlib/hnswalg.h +@@ -3,6 +3,7 @@ + #include "visited_list_pool.h" + #include "hnswlib.h" + #include ++#include + #include + #include + #include +@@ -16,6 +17,8 @@ namespace hnswlib { + template + class HierarchicalNSW : public AlgorithmInterface { + public: ++ bool base_layer_only{false}; ++ int num_seeds=32; + static const tableint max_update_element_locks = 65536; + HierarchicalNSW(SpaceInterface *s) { + } +@@ -56,7 +59,7 @@ namespace hnswlib { + visited_list_pool_ = new VisitedListPool(1, max_elements); + + //initializations for special treatment of the first node +- enterpoint_node_ = -1; ++ enterpoint_node_ = std::numeric_limits::max(); + maxlevel_ = -1; + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); +@@ -527,7 +530,7 @@ namespace hnswlib { + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; +- if (cand < 0 || cand > max_elements_) ++ if (cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + +@@ -1067,7 +1070,7 @@ namespace hnswlib { + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; +- if (cand < 0 || cand > max_elements_) ++ if (cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { +@@ -1119,28 +1122,41 @@ namespace hnswlib { + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + +- for (int level = maxlevel_; level > 0; level--) { +- bool changed = true; +- while (changed) { +- changed = false; +- unsigned int *data; ++ if (base_layer_only) { ++ // You can increase the number of seeds when testing large-scale dataset, num_seeds = 48 for 100M-scale ++ for (int i = 0; i < num_seeds; i++) { ++ tableint obj = i * (max_elements_ / num_seeds); ++ dist_t dist = fstdistfunc_(query_data, getDataByInternalId(obj), dist_func_param_); ++ if (dist < curdist) { ++ curdist = dist; ++ currObj = obj; ++ } ++ } ++ } ++ else{ ++ for (int level = maxlevel_; level > 0; level--) { ++ bool changed = true; ++ while (changed) { ++ changed = false; ++ unsigned int *data; + +- data = (unsigned int *) get_linklist(currObj, level); +- int size = getListCount(data); +- metric_hops++; +- metric_distance_computations+=size; ++ data = (unsigned int *) get_linklist(currObj, level); ++ int size = getListCount(data); ++ metric_hops++; ++ metric_distance_computations+=size; + +- tableint *datal = (tableint *) (data + 1); +- for (int i = 0; i < size; i++) { +- tableint cand = datal[i]; +- if (cand < 0 || cand > max_elements_) +- throw std::runtime_error("cand error"); +- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); ++ tableint *datal = (tableint *) (data + 1); ++ for (int i = 0; i < size; i++) { ++ tableint cand = datal[i]; ++ if (cand > max_elements_) ++ throw std::runtime_error("cand error"); ++ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + +- if (d < curdist) { +- curdist = d; +- currObj = cand; +- changed = true; ++ if (d < curdist) { ++ curdist = d; ++ currObj = cand; ++ changed = true; ++ } + } + } + } +diff --git a/hnswlib/visited_list_pool.h b/hnswlib/visited_list_pool.h +index 5e1a4a5..4195ebd 100644 +--- a/hnswlib/visited_list_pool.h ++++ b/hnswlib/visited_list_pool.h +@@ -3,6 +3,7 @@ + #include + #include + #include ++#include + + namespace hnswlib { + typedef unsigned short int vl_type; +@@ -14,7 +15,7 @@ namespace hnswlib { + unsigned int numelements; + + VisitedList(int numelements1) { +- curV = -1; ++ curV = std::numeric_limits::max(); + numelements = numelements1; + mass = new vl_type[numelements]; + } diff --git a/cpp/cmake/thirdparty/get_hnswlib.cmake b/cpp/cmake/thirdparty/get_hnswlib.cmake index 94033e8333..a4ceacae38 100644 --- a/cpp/cmake/thirdparty/get_hnswlib.cmake +++ b/cpp/cmake/thirdparty/get_hnswlib.cmake @@ -26,6 +26,11 @@ function(find_and_configure_hnswlib) COMMAND git clone --branch=v0.6.2 https://github.com/nmslib/hnswlib.git hnswlib-src WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps ) + message("SOURCE ${CMAKE_CURRENT_SOURCE_DIR}") + execute_process ( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.patch + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src + ) endif () include(cmake/modules/FindAVX.cmake) diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index 0a806402d2..c801bc9eda 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -93,6 +93,70 @@ void serialize(raft::resources const& handle, detail::serialize(handle, filename, index, include_dataset); } +/** + * Write the CAGRA built index as a base layer HNSW index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cagra::build(...);` + * raft::serialize_to_hnswlib(handle, os, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index CAGRA index + * + */ +template +void serialize_to_hnswlib(raft::resources const& handle, + std::ostream& os, + const index& index) +{ + detail::serialize_to_hnswlib(handle, os, index); +} + +/** + * Write the CAGRA built index as a base layer HNSW index to file + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = cagra::build(...);` + * raft::serialize_to_hnswlib(handle, filename, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * + */ +template +void serialize_to_hnswlib(raft::resources const& handle, + const std::string& filename, + const index& index) +{ + detail::serialize_to_hnswlib(handle, filename, index); +} + /** * Load index from input stream * diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 8261f637e1..eb21b75d3a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -16,12 +16,18 @@ #pragma once +#include +#include +#include #include +#include #include +#include #include #include #include +#include namespace raft::neighbors::cagra::detail { @@ -104,6 +110,129 @@ void serialize(raft::resources const& res, if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } +template +void serialize_to_hnswlib(raft::resources const& res, + std::ostream& os, + const index& index_) +{ + common::nvtx::range fun_scope("cagra::serialize_to_hnswlib"); + RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", + static_cast(index_.size()), + index_.dim()); + + // offset_level_0 + std::size_t offset_level_0 = 0; + os.write(reinterpret_cast(&offset_level_0), sizeof(std::size_t)); + // max_element + std::size_t max_element = index_.size(); + os.write(reinterpret_cast(&max_element), sizeof(std::size_t)); + // curr_element_count + std::size_t curr_element_count = index_.size(); + os.write(reinterpret_cast(&curr_element_count), sizeof(std::size_t)); + // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, + // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + + // dim * sizeof(data_t) + sizeof(labeltype) + auto size_data_per_element = + static_cast(index_.graph_degree() * 4 + 4 + index_.dim() * 4 + 8); + os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); + // label_offset + std::size_t label_offset = size_data_per_element - 8; + os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); + // offset_data + auto offset_data = static_cast(index_.graph_degree() * 4 + 4); + os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); + // max_level + int max_level = 1; + os.write(reinterpret_cast(&max_level), sizeof(int)); + // entrypoint_node + auto entrypoint_node = static_cast(index_.size() / 2); + os.write(reinterpret_cast(&entrypoint_node), sizeof(int)); + // max_M + auto max_M = static_cast(index_.graph_degree() / 2); + os.write(reinterpret_cast(&max_M), sizeof(std::size_t)); + // max_M0 + std::size_t max_M0 = index_.graph_degree(); + os.write(reinterpret_cast(&max_M0), sizeof(std::size_t)); + // M + auto M = static_cast(index_.graph_degree() / 2); + os.write(reinterpret_cast(&M), sizeof(std::size_t)); + // mult, can be anything + double mult = 0.42424242; + os.write(reinterpret_cast(&mult), sizeof(double)); + // efConstruction, can be anything + std::size_t efConstruction = 500; + os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); + + auto dataset = index_.dataset(); + // Remove padding before saving the dataset + auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), + sizeof(T) * host_dataset.extent(1), + dataset.data_handle(), + sizeof(T) * dataset.stride(0), + sizeof(T) * host_dataset.extent(1), + dataset.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + + auto graph = index_.graph(); + auto host_graph = + raft::make_host_matrix(graph.extent(0), graph.extent(1)); + raft::copy(host_graph.data_handle(), + graph.data_handle(), + graph.size(), + raft::resource::get_cuda_stream(res)); + resource::sync_stream(res); + + // Write one dataset and graph row at a time + for (std::size_t i = 0; i < index_.size(); i++) { + auto graph_degree = static_cast(index_.graph_degree()); + os.write(reinterpret_cast(&graph_degree), sizeof(int)); + + for (std::size_t j = 0; j < index_.graph_degree(); ++j) { + auto graph_elem = host_graph(i, j); + os.write(reinterpret_cast(&graph_elem), sizeof(IdxT)); + } + + auto data_row = host_dataset.data_handle() + (index_.dim() * i); + if constexpr (std::is_same_v) { + for (std::size_t j = 0; j < index_.dim(); ++j) { + auto data_elem = host_dataset(i, j); + os.write(reinterpret_cast(&data_elem), sizeof(T)); + } + } else if constexpr (std::is_same_v or std::is_same_v) { + for (std::size_t j = 0; j < index_.dim(); ++j) { + auto data_elem = static_cast(host_dataset(i, j)); + os.write(reinterpret_cast(&data_elem), sizeof(int)); + } + } + + os.write(reinterpret_cast(&i), sizeof(std::size_t)); + } + + for (std::size_t i = 0; i < index_.size(); i++) { + // zeroes + auto zero = 0; + os.write(reinterpret_cast(&zero), sizeof(int)); + } + // delete [] host_graph; +} + +template +void serialize_to_hnswlib(raft::resources const& res, + const std::string& filename, + const index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + detail::serialize_to_hnswlib(res, of, index_); + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } +} + /** Load an index from file. * * Experimental, both the API and the serialization format are subject to change. diff --git a/docs/source/ann_benchmarks_param_tuning.md b/docs/source/ann_benchmarks_param_tuning.md index cdc7958714..4c95b9e520 100644 --- a/docs/source/ann_benchmarks_param_tuning.md +++ b/docs/source/ann_benchmarks_param_tuning.md @@ -46,7 +46,7 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of ### `raft_cagra` -CAGRA uses a graph-based index, which creates an intermediate, approximate kNN graph using IVF-PQ and then further refining and optimizing to create a final kNN graph. This kNN graph is used by CAGRA as an index for search. +CAGRA uses a graph-based index, which creates an intermediate, approximate kNN graph using IVF-PQ and then further refining and optimizing to create a final kNN graph. This kNN graph is used by CAGRA as an index for search. | Parameter | Type | Required | Data Type | Default | Description | |-----------------------------|----------------|----------|----------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| @@ -83,6 +83,13 @@ Alternatively, if `graph_build_algo == "NN_DESCENT"`, then we can customize the | `nn_descent_max_iterations` | `build_param` | N | Positive Integer>0 | 20 | Alias for `nn_descent_niter` | | `nn_descent_termination_threshold` | `build_param` | N | Positive float>0 | 0.0001 | Termination threshold for NN descent. | +### `raft_cagra_hnswlib` +This is a benchmark that enables interoperability between `CAGRA` built `HNSW` search. It uses the `CAGRA` built graph as the base layer of an `hnswlib` index to search queries only within the base layer (this is enabled with a simple patch to `hnswlib`). + +`build_param` : Same as `build_param` of [CAGRA](#raft-cagra) + +`search_param` : Same as `search_param` of [hnswlib](#hnswlib) + ## FAISS Indexes ### `faiss_gpu_flat` @@ -152,7 +159,7 @@ Use FAISS IVF-PQ index on CPU ## HNSW - + ### `hnswlib` | Parameter | Type | Required | Data Type | Default | Description | diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml index 7ea360e0c9..e382bdcba6 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml +++ b/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml @@ -37,3 +37,6 @@ ggnn: hnswlib: executable: HNSWLIB_ANN_BENCH requires_gpu: false +raft_cagra_hnswlib: + executable: RAFT_CAGRA_HNSWLIB_ANN_BENCH + requires_gpu: true diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra.yaml index 0f80608eef..d8015da5c6 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra.yaml +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra.yaml @@ -10,8 +10,3 @@ groups: search: itopk: [32, 64, 128, 256, 512] search_width: [1, 2, 4, 8, 16, 32, 64] - - - - - diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra_hnswlib.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra_hnswlib.yaml new file mode 100644 index 0000000000..787675d65d --- /dev/null +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra_hnswlib.yaml @@ -0,0 +1,11 @@ +name: raft_cagra_hnswlib +constraints: + search: raft-ann-bench.constraints.hnswlib_search_constraints +groups: + base: + build: + graph_degree: [32, 64, 128, 256] + intermediate_graph_degree: [32, 64, 96, 128] + graph_build_algo: ["NN_DESCENT"] + search: + ef: [10, 20, 40, 60, 80, 120, 200, 400, 600, 800]