Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAGRA build + HNSW search #1956

Merged
merged 20 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm
option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_SINGLE_EXE
Expand All @@ -54,6 +55,7 @@ if(BUILD_CPU_ONLY)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF)
set(RAFT_ANN_BENCH_USE_GGNN OFF)
else()
# Disable faiss benchmarks on CUDA 12 since faiss is not yet CUDA 12-enabled.
Expand Down Expand Up @@ -88,14 +90,15 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
)
set(RAFT_ANN_BENCH_USE_RAFT ON)
endif()

# ##################################################################################################
# * Fetch requirements -------------------------------------------------------------

if(RAFT_ANN_BENCH_USE_HNSWLIB)
if(RAFT_ANN_BENCH_USE_HNSWLIB OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
include(cmake/thirdparty/get_hnswlib.cmake)
endif()

Expand Down Expand Up @@ -250,6 +253,20 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA)
)
endif()

if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
ConfigureAnnBench(
NAME
RAFT_CAGRA_HNSWLIB
PATH
bench/ann/src/raft/raft_cagra_hnswlib.cu
INCLUDES
${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib
LINKS
raft::compiled
CXXFLAGS "${HNSW_CXX_FLAGS}"
)
endif()

set(RAFT_FAISS_TARGETS faiss::faiss)
if(TARGET faiss::faiss_avx2)
set(RAFT_FAISS_TARGETS faiss::faiss_avx2)
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class HnswLib : public ANN<T> {
return property;
}

void set_base_layer_only() { appr_alg_->base_layer_only = true; }

private:
void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const;

Expand Down
231 changes: 231 additions & 0 deletions cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#define JSON_DIAGNOSTICS 1
#include <nlohmann/json.hpp>

#undef WARP_SIZE
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
#include "raft_wrapper.h"
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
#include "raft_ivf_flat_wrapper.h"
extern template class raft::bench::ann::RaftIvfFlatGpu<float, int64_t>;
extern template class raft::bench::ann::RaftIvfFlatGpu<uint8_t, int64_t>;
extern template class raft::bench::ann::RaftIvfFlatGpu<int8_t, int64_t>;
#endif
#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \
defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
#include "raft_ivf_pq_wrapper.h"
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
extern template class raft::bench::ann::RaftIvfPQ<float, int64_t>;
extern template class raft::bench::ann::RaftIvfPQ<uint8_t, int64_t>;
extern template class raft::bench::ann::RaftIvfPQ<int8_t, int64_t>;
#endif
#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
#include "raft_cagra_wrapper.h"
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA
extern template class raft::bench::ann::RaftCagra<float, uint32_t>;
extern template class raft::bench::ann::RaftCagra<uint8_t, uint32_t>;
extern template class raft::bench::ann::RaftCagra<int8_t, uint32_t>;
#endif

#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftIvfFlatGpu<T, IdxT>::BuildParam& param)
{
param.n_lists = conf.at("nlist");
if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); }
if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); }
}

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftIvfFlatGpu<T, IdxT>::SearchParam& param)
{
param.ivf_flat_params.n_probes = conf.at("nprobe");
}
#endif

#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \
defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::BuildParam& param)
{
if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); }
if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); }
if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); }
if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); }
if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); }
if (conf.contains("codebook_kind")) {
std::string kind = conf.at("codebook_kind");
if (kind == "cluster") {
param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER;
} else if (kind == "subspace") {
param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE;
} else {
throw std::runtime_error("codebook_kind: '" + kind +
"', should be either 'cluster' or 'subspace'");
}
}
}

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::SearchParam& param)
{
if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); }
if (conf.contains("internalDistanceDtype")) {
std::string type = conf.at("internalDistanceDtype");
if (type == "float") {
param.pq_param.internal_distance_dtype = CUDA_R_32F;
} else if (type == "half") {
param.pq_param.internal_distance_dtype = CUDA_R_16F;
} else {
throw std::runtime_error("internalDistanceDtype: '" + type +
"', should be either 'float' or 'half'");
}
} else {
// set half as default type
param.pq_param.internal_distance_dtype = CUDA_R_16F;
}

if (conf.contains("smemLutDtype")) {
std::string type = conf.at("smemLutDtype");
if (type == "float") {
param.pq_param.lut_dtype = CUDA_R_32F;
} else if (type == "half") {
param.pq_param.lut_dtype = CUDA_R_16F;
} else if (type == "fp8") {
param.pq_param.lut_dtype = CUDA_R_8U;
} else {
throw std::runtime_error("smemLutDtype: '" + type +
"', should be either 'float', 'half' or 'fp8'");
}
} else {
// set half as default
param.pq_param.lut_dtype = CUDA_R_16F;
}
if (conf.contains("refine_ratio")) {
param.refine_ratio = conf.at("refine_ratio");
if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); }
}
}
#endif

#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
raft::neighbors::experimental::nn_descent::index_params& param)
{
if (conf.contains("graph_degree")) { param.graph_degree = conf.at("graph_degree"); }
if (conf.contains("intermediate_graph_degree")) {
param.intermediate_graph_degree = conf.at("intermediate_graph_degree");
}
// we allow niter shorthand for max_iterations
if (conf.contains("niter")) { param.max_iterations = conf.at("niter"); }
if (conf.contains("max_iterations")) { param.max_iterations = conf.at("max_iterations"); }
if (conf.contains("termination_threshold")) {
param.termination_threshold = conf.at("termination_threshold");
}
}

nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf,
const std::string& prefix,
bool remove_prefix = true)
{
nlohmann::json out;
for (auto& i : conf.items()) {
if (i.key().compare(0, prefix.size(), prefix) == 0) {
auto new_key = remove_prefix ? i.key().substr(prefix.size()) : i.key();
out[new_key] = i.value();
}
}
return out;
}

template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftCagra<T, IdxT>::BuildParam& param)
{
if (conf.contains("graph_degree")) {
param.cagra_params.graph_degree = conf.at("graph_degree");
param.cagra_params.intermediate_graph_degree = param.cagra_params.graph_degree * 2;
}
if (conf.contains("intermediate_graph_degree")) {
param.cagra_params.intermediate_graph_degree = conf.at("intermediate_graph_degree");
}
if (conf.contains("graph_build_algo")) {
if (conf.at("graph_build_algo") == "IVF_PQ") {
param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ;
} else if (conf.at("graph_build_algo") == "NN_DESCENT") {
param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
}
}
nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_");
if (!ivf_pq_build_conf.empty()) {
raft::neighbors::ivf_pq::index_params bparam;
parse_build_param<T, IdxT>(ivf_pq_build_conf, bparam);
param.ivf_pq_build_params = bparam;
}
nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_");
if (!ivf_pq_search_conf.empty()) {
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::SearchParam sparam;
parse_search_param<T, IdxT>(ivf_pq_search_conf, sparam);
param.ivf_pq_search_params = sparam.pq_param;
param.ivf_pq_refine_rate = sparam.refine_ratio;
}
nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_");
if (!nn_descent_conf.empty()) {
raft::neighbors::experimental::nn_descent::index_params nn_param;
nn_param.intermediate_graph_degree = 1.5 * param.cagra_params.intermediate_graph_degree;
parse_build_param<T, IdxT>(nn_descent_conf, nn_param);
if (nn_param.graph_degree != param.cagra_params.intermediate_graph_degree) {
nn_param.graph_degree = param.cagra_params.intermediate_graph_degree;
}
param.nn_descent_params = nn_param;
}
}

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftCagra<T, IdxT>::SearchParam& param)
{
if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); }
if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); }
if (conf.contains("max_iterations")) { param.p.max_iterations = conf.at("max_iterations"); }
if (conf.contains("algo")) {
if (conf.at("algo") == "single_cta") {
param.p.algo = raft::neighbors::experimental::cagra::search_algo::SINGLE_CTA;
} else if (conf.at("algo") == "multi_cta") {
param.p.algo = raft::neighbors::experimental::cagra::search_algo::MULTI_CTA;
} else if (conf.at("algo") == "multi_kernel") {
param.p.algo = raft::neighbors::experimental::cagra::search_algo::MULTI_KERNEL;
} else if (conf.at("algo") == "auto") {
param.p.algo = raft::neighbors::experimental::cagra::search_algo::AUTO;
} else {
std::string tmp = conf.at("algo");
THROW("Invalid value for algo: %s", tmp.c_str());
}
}
}
#endif
Loading