Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAGRA build + HNSW search #1956

Merged
merged 20 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm
option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_SINGLE_EXE
Expand All @@ -54,6 +55,7 @@ if(BUILD_CPU_ONLY)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF)
set(RAFT_ANN_BENCH_USE_GGNN OFF)
else()
# Disable faiss benchmarks on CUDA 12 since faiss is not yet CUDA 12-enabled.
Expand Down Expand Up @@ -88,14 +90,15 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
)
set(RAFT_ANN_BENCH_USE_RAFT ON)
endif()

# ##################################################################################################
# * Fetch requirements -------------------------------------------------------------

if(RAFT_ANN_BENCH_USE_HNSWLIB)
if(RAFT_ANN_BENCH_USE_HNSWLIB OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
include(cmake/thirdparty/get_hnswlib.cmake)
endif()

Expand Down Expand Up @@ -250,6 +253,21 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA)
)
endif()

if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
ConfigureAnnBench(
NAME
RAFT_CAGRA_HNSWLIB
PATH
bench/ann/src/raft/raft_benchmark.cu
$<$<BOOL:${RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB}>:bench/ann/src/raft/raft_cagra_hnswlib.cu>
INCLUDES
${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib
LINKS
raft::compiled
CXXFLAGS "${HNSW_CXX_FLAGS}"
)
endif()

set(RAFT_FAISS_TARGETS faiss::faiss)
if(TARGET faiss::faiss_avx2)
set(RAFT_FAISS_TARGETS faiss::faiss_avx2)
Expand Down
51 changes: 51 additions & 0 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ extern template class raft::bench::ann::RaftCagra<float, uint32_t>;
extern template class raft::bench::ann::RaftCagra<uint8_t, uint32_t>;
extern template class raft::bench::ann::RaftCagra<int8_t, uint32_t>;
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
#include "raft_cagra_hnswlib_wrapper.h"
// extern template class raft::bench::ann::RaftCagraHnswlib<float, uint32_t>;
divyegala marked this conversation as resolved.
Show resolved Hide resolved
// extern template class raft::bench::ann::RaftCagraHnswlib<uint8_t, uint32_t>;
// extern template class raft::bench::ann::RaftCagraHnswlib<int8_t, uint32_t>;
#endif
#define JSON_DIAGNOSTICS 1
#include <nlohmann/json.hpp>

Expand Down Expand Up @@ -182,6 +188,37 @@ void parse_search_param(const nlohmann::json& conf,
}
#endif

#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftCagraHnswlib<T, IdxT>::BuildParam& param)
{
if (conf.contains("graph_degree")) {
divyegala marked this conversation as resolved.
Show resolved Hide resolved
param.graph_degree = conf.at("graph_degree");
param.intermediate_graph_degree = param.graph_degree * 2;
}
if (conf.contains("intermediate_graph_degree")) {
param.intermediate_graph_degree = conf.at("intermediate_graph_degree");
}
if (conf.contains("graph_build_algo")) {
if (conf.at("graph_build_algo") == "IVF_PQ") {
param.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ;
} else if (conf.at("graph_build_algo") == "NN_DESCENT") {
param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
}
}
if (conf.contains("nn_descent_niter")) { param.nn_descent_niter = conf.at("nn_descent_niter"); }
}

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftCagraHnswlib<T, IdxT>::SearchParam& param)
{
param.ef = conf.at("ef");
if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); }
divyegala marked this conversation as resolved.
Show resolved Hide resolved
}
#endif

template <typename T>
std::unique_ptr<raft::bench::ann::ANN<T>> create_algo(const std::string& algo,
const std::string& distance,
Expand Down Expand Up @@ -223,6 +260,13 @@ std::unique_ptr<raft::bench::ann::ANN<T>> create_algo(const std::string& algo,
parse_build_param<T, uint32_t>(conf, param);
ann = std::make_unique<raft::bench::ann::RaftCagra<T, uint32_t>>(metric, dim, param);
}
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
if (algo == "raft_cagra_hnswlib") {
typename raft::bench::ann::RaftCagraHnswlib<T, uint32_t>::BuildParam param;
parse_build_param<T, uint32_t>(conf, param);
ann = std::make_unique<raft::bench::ann::RaftCagraHnswlib<T, uint32_t>>(metric, dim, param);
}
#endif
if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); }

Expand Down Expand Up @@ -260,6 +304,13 @@ std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search
parse_search_param<T, uint32_t>(conf, *param);
return param;
}
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
if (algo == "raft_cagra_hnswlib") {
auto param = std::make_unique<typename raft::bench::ann::RaftCagraHnswlib<T, uint32_t>::SearchParam>();
parse_search_param<T, uint32_t>(conf, *param);
return param;
}
#endif
// else
throw std::runtime_error("invalid algo: '" + algo + "'");
Expand Down
22 changes: 22 additions & 0 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// /*
// * Copyright (c) 2023, NVIDIA CORPORATION.
// *
// * Licensed under the Apache License, Version 2.0 (the "License");
// * you may not use this file except in compliance with the License.
// * You may obtain a copy of the License at
// *
// * http://www.apache.org/licenses/LICENSE-2.0
// *
// * Unless required by applicable law or agreed to in writing, software
// * distributed under the License is distributed on an "AS IS" BASIS,
// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// * See the License for the specific language governing permissions and
// * limitations under the License.
// */
// #include "raft_cagra_hnswlib_wrapper.h"

// namespace raft::bench::ann {
divyegala marked this conversation as resolved.
Show resolved Hide resolved
// template class RaftCagraHnswlib<uint8_t, uint32_t>;
// template class RaftCagraHnswlib<int8_t, uint32_t>;
// template class RaftCagraHnswlib<float, uint32_t>;
// } // namespace raft::bench::ann
216 changes: 216 additions & 0 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/operators.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
#include <stdexcept>
#include <string>
#include <type_traits>

#include "../common/ann_types.hpp"
#include "../common/thread_pool.hpp"
#include "raft_ann_bench_utils.h"
#include <raft/util/cudart_utils.hpp>

#include <hnswlib.h>

namespace raft::bench::ann {

template <typename T>
struct hnsw_dist_t {
using type = void;
};

template <>
struct hnsw_dist_t<float> {
using type = float;
};

template <>
struct hnsw_dist_t<uint8_t> {
using type = int;
};

template <>
struct hnsw_dist_t<int8_t> {
using type = int;
};

template <typename T, typename IdxT>
class RaftCagraHnswlib : public ANN<T> {
public:
using typename ANN<T>::AnnSearchParam;

struct SearchParam : public AnnSearchParam {
int ef;
int num_threads = 1;
};

using BuildParam = raft::neighbors::cagra::index_params;

RaftCagraHnswlib(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1)
: ANN<T>(metric, dim), index_params_(param), dimension_(dim), handle_(cudaStreamPerThread)
{
index_params_.metric = parse_metric_type(metric);
RAFT_CUDA_TRY(cudaGetDevice(&device_));
}

~RaftCagraHnswlib() noexcept {}

void build(const T* dataset, size_t nrow, cudaStream_t stream) final;

void set_search_param(const AnnSearchParam& param) override;

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(const T* queries,
divyegala marked this conversation as resolved.
Show resolved Hide resolved
int batch_size,
int k,
size_t* neighbors,
float* distances,
cudaStream_t stream = 0) const override;

// to enable dataset access from GPU memory
AlgoProperty get_preference() const override
{
AlgoProperty property;
property.dataset_memory_type = MemoryType::HostMmap;
property.query_memory_type = MemoryType::Host;
return property;
}
void save(const std::string& file) const override;
void load(const std::string&) override;

private:
void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const;

raft::device_resources handle_;
BuildParam index_params_;
std::optional<raft::neighbors::cagra::index<T, IdxT>> index_;
int device_;
int dimension_;

std::unique_ptr<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>> appr_alg_;
std::unique_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_;
int num_threads_;
std::unique_ptr<FixedThreadPool> thread_pool_;

Objective metric_objective_;
};

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
{
if (raft::get_device_for_address(dataset) == -1) {
auto dataset_view =
raft::make_host_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
} else {
auto dataset_view =
raft::make_device_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
}
}

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::set_search_param(const AnnSearchParam& param_)
{
auto param = dynamic_cast<const SearchParam&>(param_);
appr_alg_->ef_ = param.ef;
divyegala marked this conversation as resolved.
Show resolved Hide resolved
metric_objective_ = param.metric_objective;

bool use_pool = (metric_objective_ == Objective::LATENCY && param.num_threads > 1) &&
(!thread_pool_ || num_threads_ != param.num_threads);
if (use_pool) {
num_threads_ = param.num_threads;
thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_);
}
}

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::save(const std::string& file) const
{
raft::neighbors::cagra::serialize_to_hnswlib<T, IdxT>(handle_, file, *index_);
}

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::load(const std::string& file)
{
if constexpr (std::is_same_v<T, float>) {
if (static_cast<Metric>(index_params_.metric) == Metric::kInnerProduct) {
space_ = std::make_unique<hnswlib::InnerProductSpace>(dimension_);
} else {
space_ = std::make_unique<hnswlib::L2Space>(dimension_);
}
} else if constexpr (std::is_same_v<T, uint8_t>) {
space_ = std::make_unique<hnswlib::L2SpaceI>(dimension_);
}

appr_alg_ = std::make_unique<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
space_.get(), file);
appr_alg_->base_layer_only = true;
}

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const
{
auto f = [&](int i) {
// hnsw can only handle a single vector at a time.
get_search_knn_results_(queries + i * dimension_, k, neighbors + i * k, distances + i * k);
};
if (metric_objective_ == Objective::LATENCY) {
thread_pool_->submit(f, batch_size);
} else {
for (int i = 0; i < batch_size; i++) {
f(i);
}
}
}

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::get_search_knn_results_(const T* query,
int k,
size_t* indices,
float* distances) const
{
auto result = appr_alg_->searchKnn(query, k);
assert(result.size() >= static_cast<size_t>(k));

for (int i = k - 1; i >= 0; --i) {
indices[i] = result.top().second;
distances[i] = result.top().first;
result.pop();
}
}

} // namespace raft::bench::ann
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void RaftCagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::save(const std::string& file) const
{
raft::neighbors::cagra::serialize(handle_, file, *index_, false);
raft::neighbors::cagra::serialize<T, IdxT>(handle_, file, *index_);
}

template <typename T, typename IdxT>
Expand Down
6 changes: 3 additions & 3 deletions cpp/cmake/modules/ConfigureCUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ endif()
# Be very strict when compiling with GCC as host compiler (and thus more lenient when compiling with
# clang)
if(CMAKE_COMPILER_IS_GNUCXX)
list(APPEND RAFT_CXX_FLAGS -Wall -Werror -Wno-unknown-pragmas -Wno-error=deprecated-declarations)
list(APPEND RAFT_CUDA_FLAGS -Xcompiler=-Wall,-Werror,-Wno-error=deprecated-declarations)
# list(APPEND RAFT_CXX_FLAGS -Wall -Werror -Wno-unknown-pragmas -Wno-error=deprecated-declarations)
divyegala marked this conversation as resolved.
Show resolved Hide resolved
# list(APPEND RAFT_CUDA_FLAGS -Xcompiler=-Wall,-Werror,-Wno-error=deprecated-declarations)

# set warnings as errors
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.2.0)
list(APPEND RAFT_CUDA_FLAGS -Werror=all-warnings)
# list(APPEND RAFT_CUDA_FLAGS -Werror=all-warnings)
endif()
endif()

Expand Down
Loading
Loading