Skip to content

Commit

Permalink
style fix
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Jul 8, 2024
1 parent 1a559a6 commit 11d30da
Show file tree
Hide file tree
Showing 23 changed files with 513 additions and 383 deletions.
4 changes: 2 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu
Expand Down Expand Up @@ -592,8 +594,6 @@ if(RAFT_COMPILE_LIBRARY)
INTERFACE_POSITION_INDEPENDENT_CODE ON
)



foreach(target raft_lib raft_lib_static raft_objs)
target_link_libraries(
${target}
Expand Down
54 changes: 27 additions & 27 deletions cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,7 +39,7 @@ extern template class raft::bench::ann::RaftIvfPQ<uint8_t, int64_t>;
extern template class raft::bench::ann::RaftIvfPQ<int8_t, int64_t>;
#endif
#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \
defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA)
defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA)
#include "raft_cagra_wrapper.h"
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA
Expand Down Expand Up @@ -70,11 +70,11 @@ extern template class raft::bench::ann::RaftAnnMG_Cagra<int8_t, uint32_t>;
#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
typename raft::bench::ann::RaftIvfFlatGpu<T, IdxT>::BuildParam& param)
#else
#else
typename raft::bench::ann::RaftAnnMG_IvfFlat<T, IdxT>::BuildParam& param)
#endif
#endif
{
param.n_lists = conf.at("nlist");
if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); }
Expand All @@ -83,26 +83,26 @@ void parse_build_param(const nlohmann::json& conf,

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
typename raft::bench::ann::RaftIvfFlatGpu<T, IdxT>::SearchParam& param)
#else
#else
typename raft::bench::ann::RaftAnnMG_IvfFlat<T, IdxT>::SearchParam& param)
#endif
#endif
{
param.ivf_flat_params.n_probes = conf.at("nprobe");
}
#endif

#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \
defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) || \
defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA)
defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \
defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ
#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ
typename raft::bench::ann::RaftAnnMG_IvfPq<T, IdxT>::BuildParam& param)
#else
#else
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::BuildParam& param)
#endif
#endif
{
if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); }
if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); }
Expand All @@ -124,11 +124,11 @@ void parse_build_param(const nlohmann::json& conf,

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ
#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ
typename raft::bench::ann::RaftAnnMG_IvfPq<T, IdxT>::SearchParam& param)
#else
#else
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::SearchParam& param)
#endif
#endif
{
if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); }
if (conf.contains("internalDistanceDtype")) {
Expand Down Expand Up @@ -170,7 +170,7 @@ void parse_search_param(const nlohmann::json& conf,
#endif

#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \
defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA)
defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
raft::neighbors::experimental::nn_descent::index_params& param)
Expand Down Expand Up @@ -217,11 +217,11 @@ nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf,

template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA
#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA
typename raft::bench::ann::RaftAnnMG_Cagra<T, IdxT>::BuildParam& param)
#else
#else
typename raft::bench::ann::RaftCagra<T, IdxT>::BuildParam& param)
#endif
#endif
{
if (conf.contains("graph_degree")) {
param.cagra_params.graph_degree = conf.at("graph_degree");
Expand Down Expand Up @@ -285,11 +285,11 @@ raft::bench::ann::AllocatorType parse_allocator(std::string mem_type)

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA
typename raft::bench::ann::RaftAnnMG_Cagra<T, IdxT>::SearchParam& param)
#else
typename raft::bench::ann::RaftCagra<T, IdxT>::SearchParam& param)
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA
typename raft::bench::ann::RaftAnnMG_Cagra<T, IdxT>::SearchParam& param)
#else
typename raft::bench::ann::RaftCagra<T, IdxT>::SearchParam& param)
#endif
{
if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); }
if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); }
Expand All @@ -309,14 +309,14 @@ void parse_search_param(const nlohmann::json& conf,
}
}

#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
if (conf.contains("graph_memory_type")) {
param.graph_mem = parse_allocator(conf.at("graph_memory_type"));
}
if (conf.contains("internal_dataset_memory_type")) {
param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type"));
}
#endif
#endif
// Same ratio as in IVF-PQ
param.refine_ratio = conf.value("refine_ratio", 1.0f);
}
Expand Down
10 changes: 5 additions & 5 deletions cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
* limitations under the License.
*/


#include "raft_ann_mg_cagra_wrapper.hpp"
#include <raft/neighbors/ann_mg_helpers.cuh>

#include <raft/comms/std_comms.hpp>
#include <raft/neighbors/ann_mg_helpers.cuh>

namespace raft::bench::ann {

template class RaftAnnMG_Cagra<float, uint32_t>;
template class RaftAnnMG_Cagra<uint8_t, uint32_t>;
template class RaftAnnMG_Cagra<int8_t, uint32_t>;
template class RaftAnnMG_Cagra<float, uint32_t>;
template class RaftAnnMG_Cagra<uint8_t, uint32_t>;
template class RaftAnnMG_Cagra<int8_t, uint32_t>;

} // namespace raft::bench::ann
61 changes: 40 additions & 21 deletions cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "raft_ann_mg_wrapper.hpp"

#include <raft/neighbors/cagra_mg.cuh>
#include <raft/neighbors/cagra_mg_serialize.cuh>
#include <raft/neighbors/nn_descent_types.hpp>
Expand All @@ -39,45 +40,51 @@ class RaftAnnMG_Cagra : public RaftAnnMG<T> {

struct BuildParam {
raft::neighbors::cagra::mg_index_params cagra_params;
std::optional<raft::neighbors::experimental::nn_descent::index_params> nn_descent_params = std::nullopt;
std::optional<raft::neighbors::experimental::nn_descent::index_params> nn_descent_params =
std::nullopt;
std::optional<float> ivf_pq_refine_rate = std::nullopt;
std::optional<raft::neighbors::ivf_pq::index_params> ivf_pq_build_params = std::nullopt;
std::optional<raft::neighbors::ivf_pq::search_params> ivf_pq_search_params = std::nullopt;
};

RaftAnnMG_Cagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1)
: RaftAnnMG<T>(metric, dim),
index_params_(param),
dimension_(dim)
: RaftAnnMG<T>(metric, dim), index_params_(param), dimension_(dim)
{
index_params_.cagra_params.add_data_on_build = true;
index_params_.cagra_params.mode = raft::neighbors::mg::parallel_mode::SHARDED;
index_params_.cagra_params.metric = parse_metric_type(metric);
index_params_.ivf_pq_build_params->metric = parse_metric_type(metric);
index_params_.cagra_params.add_data_on_build = true;
index_params_.cagra_params.mode = raft::neighbors::mg::parallel_mode::SHARDED;
index_params_.cagra_params.metric = parse_metric_type(metric);
index_params_.ivf_pq_build_params->metric = parse_metric_type(metric);
}

void build(const T* dataset, size_t nrow) final;
void set_search_param(const AnnSearchParam& param) override;
void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void save(const std::string& file) const override;
void load(const std::string&) override;
std::unique_ptr<ANN<T>> copy() override;

private:
BuildParam index_params_;
raft::neighbors::cagra::search_params search_params_;
std::shared_ptr<raft::neighbors::mg::detail::ann_mg_index<raft::neighbors::cagra::index<T, IdxT>, T, IdxT>> index_;
std::shared_ptr<
raft::neighbors::mg::detail::ann_mg_index<raft::neighbors::cagra::index<T, IdxT>, T, IdxT>>
index_;
float refine_ratio_ = 1.0;
int dimension_;
};

template <typename T, typename IdxT>
void RaftAnnMG_Cagra<T, IdxT>::build(const T* dataset, size_t nrow)
{
const auto& handle = this->clique_->set_current_device_to_root_rank();
auto dataset_matrix = raft::make_host_matrix_view<const T, IdxT, row_major>(dataset, IdxT(nrow), IdxT(this->dimension_));
auto idx = raft::neighbors::mg::build<T, IdxT>(handle, *this->clique_, index_params_.cagra_params, dataset_matrix);
index_ = std::make_shared<raft::neighbors::mg::detail::ann_mg_index<raft::neighbors::cagra::index<T, IdxT>, T, IdxT>>(std::move(idx));
const auto& handle = this->clique_->set_current_device_to_root_rank();
auto dataset_matrix = raft::make_host_matrix_view<const T, IdxT, row_major>(
dataset, IdxT(nrow), IdxT(this->dimension_));
auto idx = raft::neighbors::mg::build<T, IdxT>(
handle, *this->clique_, index_params_.cagra_params, dataset_matrix);
index_ = std::make_shared<
raft::neighbors::mg::detail::ann_mg_index<raft::neighbors::cagra::index<T, IdxT>, T, IdxT>>(
std::move(idx));
return;
}

Expand All @@ -103,7 +110,9 @@ void RaftAnnMG_Cagra<T, IdxT>::load(const std::string& file)
{
const auto& handle = this->clique_->set_current_device_to_root_rank();
auto idx = raft::neighbors::mg::deserialize_cagra<T, IdxT>(handle, *this->clique_, file);
index_ = std::make_shared<raft::neighbors::mg::detail::ann_mg_index<raft::neighbors::cagra::index<T, IdxT>, T, IdxT>>(std::move(idx));
index_ = std::make_shared<
raft::neighbors::mg::detail::ann_mg_index<raft::neighbors::cagra::index<T, IdxT>, T, IdxT>>(
std::move(idx));
}

template <typename T, typename IdxT>
Expand All @@ -113,14 +122,24 @@ std::unique_ptr<ANN<T>> RaftAnnMG_Cagra<T, IdxT>::copy()
}

template <typename T, typename IdxT>
void RaftAnnMG_Cagra<T, IdxT>::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
void RaftAnnMG_Cagra<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
const auto& handle = this->clique_->set_current_device_to_root_rank();
auto query_matrix = raft::make_host_matrix_view<const T, IdxT, row_major>(queries, IdxT(batch_size), IdxT(this->dimension_));
auto neighbors_matrix = raft::make_host_matrix_view<IdxT, IdxT, row_major>((IdxT*)neighbors, IdxT(batch_size), IdxT(k));
auto distances_matrix = raft::make_host_matrix_view<float, IdxT, row_major>(distances, IdxT(batch_size), IdxT(k));

raft::neighbors::mg::search<T, IdxT>(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix);
auto query_matrix = raft::make_host_matrix_view<const T, IdxT, row_major>(
queries, IdxT(batch_size), IdxT(this->dimension_));
auto neighbors_matrix =
raft::make_host_matrix_view<IdxT, IdxT, row_major>((IdxT*)neighbors, IdxT(batch_size), IdxT(k));
auto distances_matrix =
raft::make_host_matrix_view<float, IdxT, row_major>(distances, IdxT(batch_size), IdxT(k));

raft::neighbors::mg::search<T, IdxT>(handle,
*this->clique_,
*index_,
search_params_,
query_matrix,
neighbors_matrix,
distances_matrix);
resource::sync_stream(handle);
return;
}
Expand Down
10 changes: 5 additions & 5 deletions cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
* limitations under the License.
*/


#include "raft_ann_mg_ivf_flat_wrapper.hpp"
#include <raft/neighbors/ann_mg_helpers.cuh>

#include <raft/comms/std_comms.hpp>
#include <raft/neighbors/ann_mg_helpers.cuh>

namespace raft::bench::ann {

template class RaftAnnMG_IvfFlat<float, int64_t>;
template class RaftAnnMG_IvfFlat<uint8_t, int64_t>;
template class RaftAnnMG_IvfFlat<int8_t, int64_t>;
template class RaftAnnMG_IvfFlat<float, int64_t>;
template class RaftAnnMG_IvfFlat<uint8_t, int64_t>;
template class RaftAnnMG_IvfFlat<int8_t, int64_t>;

} // namespace raft::bench::ann
Loading

0 comments on commit 11d30da

Please sign in to comment.