From c0f3bf7653b5f306b916a56b8b56d62a2c3ba57f Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 14 May 2024 15:19:04 +0200 Subject: [PATCH 1/2] Introduce AnnBase::index_type for the output neighbors indices --- cpp/bench/ann/src/common/ann_types.hpp | 9 +- cpp/bench/ann/src/common/benchmark.hpp | 16 ++-- cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h | 9 +- cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h | 9 +- cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh | 16 ++-- cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h | 16 ++-- .../ann/src/raft/raft_cagra_hnswlib_wrapper.h | 11 +-- cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 87 +++++++++++-------- .../ann/src/raft/raft_ivf_flat_wrapper.h | 21 +++-- cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 27 ++++-- cpp/bench/ann/src/raft/raft_wrapper.h | 16 ++-- 11 files changed, 146 insertions(+), 91 deletions(-) diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp index c6213059dc..ee45480b36 100644 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ b/cpp/bench/ann/src/common/ann_types.hpp @@ -73,6 +73,8 @@ struct AlgoProperty { class AnnBase { public: + using index_type = size_t; + inline AnnBase(Metric metric, int dim) : metric_(metric), dim_(dim) {} virtual ~AnnBase() noexcept = default; @@ -118,8 +120,11 @@ class ANN : public AnnBase { virtual void set_search_param(const AnnSearchParam& param) = 0; // TODO: this assumes that an algorithm can always return k results. // This is not always possible. - virtual void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const = 0; + virtual void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const = 0; virtual void save(const std::string& file) const = 0; virtual void load(const std::string& file) = 0; diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index d7bcd17a00..cfd36f1469 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -282,8 +282,8 @@ void bench_search(::benchmark::State& state, */ std::shared_ptr> distances = std::make_shared>(current_algo_props->query_memory_type, k * query_set_size); - std::shared_ptr> neighbors = - std::make_shared>(current_algo_props->query_memory_type, k * query_set_size); + std::shared_ptr> neighbors = std::make_shared>( + current_algo_props->query_memory_type, k * query_set_size); { nvtx_case nvtx{state.name()}; @@ -338,12 +338,12 @@ void bench_search(::benchmark::State& state, // Each thread calculates recall on their partition of queries. // evaluate recall if (dataset->max_k() >= k) { - const std::int32_t* gt = dataset->gt_set(); - const std::uint32_t max_k = dataset->max_k(); - buf neighbors_host = neighbors->move(MemoryType::Host); - std::size_t rows = std::min(queries_processed, query_set_size); - std::size_t match_count = 0; - std::size_t total_count = rows * static_cast(k); + const std::int32_t* gt = dataset->gt_set(); + const std::uint32_t max_k = dataset->max_k(); + buf neighbors_host = neighbors->move(MemoryType::Host); + std::size_t rows = std::min(queries_processed, query_set_size); + std::size_t match_count = 0; + std::size_t total_count = rows * static_cast(k); // We go through the groundtruth with same stride as the benchmark loop. size_t out_offset = 0; diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h index 407f7148df..3caca15b7f 100644 --- a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h @@ -88,8 +88,11 @@ class FaissCpu : public ANN { // TODO: if the number of results is less than k, the remaining elements of 'neighbors' // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const final; AlgoProperty get_preference() const override { @@ -169,7 +172,7 @@ void FaissCpu::set_search_param(const AnnSearchParam& param) template void FaissCpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(faiss::idx_t), "sizes of size_t and faiss::idx_t are different"); diff --git a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h index 633098fd1d..2effe631e5 100644 --- a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h @@ -111,8 +111,11 @@ class FaissGpu : public ANN, public AnnGPU { // TODO: if the number of results is less than k, the remaining elements of 'neighbors' // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const final; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -196,7 +199,7 @@ void FaissGpu::build(const T* dataset, size_t nrow) template void FaissGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(faiss::idx_t), "sizes of size_t and faiss::idx_t are different"); diff --git a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh index c89f02d974..59cf3df806 100644 --- a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh +++ b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh @@ -58,8 +58,11 @@ class Ggnn : public ANN, public AnnGPU { void build(const T* dataset, size_t nrow) override { impl_->build(dataset, nrow); } void set_search_param(const AnnSearchParam& param) override { impl_->set_search_param(param); } - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override { impl_->search(queries, batch_size, k, neighbors, distances); } @@ -123,8 +126,11 @@ class GgnnImpl : public ANN, public AnnGPU { void build(const T* dataset, size_t nrow) override; void set_search_param(const AnnSearchParam& param) override; - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { return stream_; } void save(const std::string& file) const override; @@ -243,7 +249,7 @@ void GgnnImpl::set_search_param(const AnnSearc template void GgnnImpl::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(int64_t), "sizes of size_t and GGNN's KeyT are different"); if (k != KQuery) { diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index a8f7dd824f..5743632bf4 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -79,8 +79,11 @@ class HnswLib : public ANN { void build(const T* dataset, size_t nrow) override; void set_search_param(const AnnSearchParam& param) override; - void search( - const T* query, int batch_size, int k, size_t* indices, float* distances) const override; + void search(const T* query, + int batch_size, + int k, + AnnBase::index_type* indices, + float* distances) const override; void save(const std::string& path_to_index) const override; void load(const std::string& path_to_index) override; @@ -97,7 +100,10 @@ class HnswLib : public ANN { void set_base_layer_only() { appr_alg_->base_layer_only = true; } private: - void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const; + void get_search_knn_results_(const T* query, + int k, + AnnBase::index_type* indices, + float* distances) const; std::shared_ptr::type>> appr_alg_; std::shared_ptr::type>> space_; @@ -176,7 +182,7 @@ void HnswLib::set_search_param(const AnnSearchParam& param_) template void HnswLib::search( - const T* query, int batch_size, int k, size_t* indices, float* distances) const + const T* query, int batch_size, int k, AnnBase::index_type* indices, float* distances) const { auto f = [&](int i) { // hnsw can only handle a single vector at a time. @@ -217,7 +223,7 @@ void HnswLib::load(const std::string& path_to_index) template void HnswLib::get_search_knn_results_(const T* query, int k, - size_t* indices, + AnnBase::index_type* indices, float* distances) const { auto result = appr_alg_->searchKnn(query, k); diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h index ed9c120ed4..1c4b847d1a 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h @@ -41,10 +41,11 @@ class RaftCagraHnswlib : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -99,7 +100,7 @@ void RaftCagraHnswlib::load(const std::string& file) template void RaftCagraHnswlib::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { hnswlib_search_.search(queries, batch_size, k, neighbors, distances); } diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 46da8c52e6..fc0f30c2d1 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -96,12 +96,16 @@ class RaftCagra : public ANN, public AnnGPU { void set_search_dataset(const T* dataset, size_t nrow) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; - void search_base( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; + void search_base(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -272,15 +276,18 @@ std::unique_ptr> RaftCagra::copy() template void RaftCagra::search_base( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + IdxT* neighbors_IdxT; - rmm::device_uvector neighbors_storage(0, resource::get_cuda_stream(handle_)); - if constexpr (std::is_same_v) { - neighbors_IdxT = neighbors; + std::optional> neighbors_storage{std::nullopt}; + if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { + neighbors_IdxT = reinterpret_cast(neighbors); } else { - neighbors_storage.resize(batch_size * k, resource::get_cuda_stream(handle_)); - neighbors_IdxT = neighbors_storage.data(); + neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage->data(); } auto queries_view = @@ -291,18 +298,18 @@ void RaftCagra::search_base( raft::neighbors::cagra::search( handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); - if constexpr (!std::is_same_v) { + if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { raft::linalg::unaryOp(neighbors, neighbors_IdxT, batch_size * k, - raft::cast_op(), + raft::cast_op(), raft::resource::get_cuda_stream(handle_)); } } template void RaftCagra::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { auto k0 = static_cast(refine_ratio_ * k); const bool disable_refinement = k0 <= static_cast(k); @@ -312,21 +319,24 @@ void RaftCagra::search( if (disable_refinement) { search_base(queries, batch_size, k, neighbors, distances); } else { - auto candidate_ixs = raft::make_device_matrix(res, batch_size, k0); - auto candidate_dists = raft::make_device_matrix(res, batch_size, k0); + auto candidate_ixs = + raft::make_device_matrix(res, batch_size, k0); + auto candidate_dists = + raft::make_device_matrix(res, batch_size, k0); search_base(queries, batch_size, k0, - reinterpret_cast(candidate_ixs.data_handle()), + reinterpret_cast(candidate_ixs.data_handle()), candidate_dists.data_handle()); if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) { - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, dimension_); - auto neighours_v = raft::make_device_matrix_view( - reinterpret_cast(neighbors), batch_size, k); - auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); - raft::neighbors::refine( + auto queries_v = raft::make_device_matrix_view( + queries, batch_size, dimension_); + auto neighours_v = raft::make_device_matrix_view( + reinterpret_cast(neighbors), batch_size, k); + auto distances_v = + raft::make_device_matrix_view(distances, batch_size, k); + raft::neighbors::refine( res, *input_dataset_v_, queries_v, @@ -335,28 +345,31 @@ void RaftCagra::search( distances_v, index_->metric()); } else { - auto dataset_host = raft::make_host_matrix_view( + auto dataset_host = raft::make_host_matrix_view( input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1)); - auto queries_host = raft::make_host_matrix(batch_size, dimension_); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); + auto queries_host = raft::make_host_matrix(batch_size, dimension_); + auto candidates_host = + raft::make_host_matrix(batch_size, k0); + auto neighbors_host = + raft::make_host_matrix(batch_size, k); + auto distances_host = raft::make_host_matrix(batch_size, k); raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); raft::copy( candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream); raft::resource::sync_stream(res); // wait for the queries and candidates - raft::neighbors::refine(res, - dataset_host, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - index_->metric()); + raft::neighbors::refine( + res, + dataset_host, + queries_host.view(), + candidates_host.view(), + neighbors_host.view(), + distances_host.view(), + index_->metric()); raft::copy(neighbors, - reinterpret_cast(neighbors_host.data_handle()), + reinterpret_cast(neighbors_host.data_handle()), neighbors_host.size(), stream); raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 48d2b9de80..a42cd3223a 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -61,10 +61,11 @@ class RaftIvfFlatGpu : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -131,16 +132,22 @@ std::unique_ptr> RaftIvfFlatGpu::copy() template void RaftIvfFlatGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { - static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + static_assert(sizeof(AnnBase::index_type) == sizeof(IdxT), + "IdxT is incompatible with the index_type"); + // Assuming the returned and the required index types have the same size, we can just coerce the + // pointers to avoid extra mapping pass over the results. + // TODO: add a linalg::map() over the result indices if the type representations do not match. raft::neighbors::ivf_flat::search(handle_, search_params_, *index_, queries, batch_size, k, - (IdxT*)neighbors, + reinterpret_cast(neighbors), distances, resource::get_workspace_resource(handle_)); } diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 1d73bd2e51..412d432592 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -61,10 +61,11 @@ class RaftIvfPQ : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; void set_search_dataset(const T* dataset, size_t nrow) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -139,8 +140,12 @@ void RaftIvfPQ::set_search_dataset(const T* dataset, size_t nrow) template void RaftIvfPQ::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + static_assert(sizeof(AnnBase::index_type) == sizeof(IdxT), + "IdxT is incompatible with the index_type"); if (refine_ratio_ > 1.0f) { uint32_t k0 = static_cast(refine_ratio_ * k); auto queries_v = @@ -154,7 +159,8 @@ void RaftIvfPQ::search( if (raft::get_device_for_address(dataset_.data_handle()) >= 0) { auto queries_v = raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto neighbors_v = raft::make_device_matrix_view((IdxT*)neighbors, batch_size, k); + auto neighbors_v = raft::make_device_matrix_view( + reinterpret_cast(neighbors), batch_size, k); auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); raft::neighbors::refine(handle_, @@ -187,14 +193,17 @@ void RaftIvfPQ::search( distances_host.view(), index_->metric()); - raft::copy(neighbors, (size_t*)neighbors_host.data_handle(), neighbors_host.size(), stream); + raft::copy(reinterpret_cast(neighbors), + neighbors_host.data_handle(), + neighbors_host.size(), + stream); raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); } } else { auto queries_v = raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto neighbors_v = - raft::make_device_matrix_view((IdxT*)neighbors, batch_size, k); + auto neighbors_v = raft::make_device_matrix_view( + reinterpret_cast(neighbors), batch_size, k); auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); raft::neighbors::ivf_pq::search( diff --git a/cpp/bench/ann/src/raft/raft_wrapper.h b/cpp/bench/ann/src/raft/raft_wrapper.h index 586b81ae06..2c996058b2 100644 --- a/cpp/bench/ann/src/raft/raft_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_wrapper.h @@ -56,10 +56,11 @@ class RaftGpu : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const final; // to enable dataset access from GPU memory AlgoProperty get_preference() const override @@ -133,15 +134,16 @@ void RaftGpu::load(const std::string& file) template void RaftGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { auto queries_view = raft::make_device_matrix_view(queries, batch_size, this->dim_); - auto neighbors_view = raft::make_device_matrix_view(neighbors, batch_size, k); + auto neighbors_view = + raft::make_device_matrix_view(neighbors, batch_size, k); auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - raft::neighbors::brute_force::search( + raft::neighbors::brute_force::search( handle_, *index_, queries_view, neighbors_view, distances_view); } From dd66b5c509e529a3ed9831e73d67d66945601922 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 15 May 2024 15:52:30 +0200 Subject: [PATCH 2/2] Adapt raft algos to work when compiled with AnnBase::index_type = uint32_t --- cpp/bench/ann/src/raft/raft_ann_bench_utils.h | 72 +++++++++++ cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 58 +-------- .../ann/src/raft/raft_ivf_flat_wrapper.h | 23 +++- cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 116 ++++++++---------- 4 files changed, 148 insertions(+), 121 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h index 6cadb26736..ffe8f8717b 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h @@ -19,9 +19,12 @@ #include #include +#include +#include #include #include #include +#include #include #include @@ -166,4 +169,73 @@ inline configured_raft_resources::configured_raft_resources(configured_raft_reso inline configured_raft_resources& configured_raft_resources::operator=( configured_raft_resources&&) = default; +/** A helper to refine the neighbors when the data is on device or on host. */ +template +void refine_helper(const raft::resources& res, + DatasetT dataset, + QueriesT queries, + CandidatesT candidates, + int k, + AnnBase::index_type* neighbors, + float* distances, + raft::distance::DistanceType metric) +{ + using data_type = typename DatasetT::value_type; + using index_type = AnnBase::index_type; + using extents_type = index_type; // device-side refine requires this + + static_assert(std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + extents_type batch_size = queries.extent(0); + extents_type dim = queries.extent(1); + extents_type k0 = candidates.extent(1); + + if (raft::get_device_for_address(dataset.data_handle()) >= 0) { + auto dataset_device = raft::make_device_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + auto queries_device = raft::make_device_matrix_view( + queries.data_handle(), batch_size, dim); + auto candidates_device = raft::make_device_matrix_view( + candidates.data_handle(), batch_size, k0); + auto neighbors_device = + raft::make_device_matrix_view(neighbors, batch_size, k); + auto distances_device = + raft::make_device_matrix_view(distances, batch_size, k); + + raft::neighbors::refine(res, + dataset_device, + queries_device, + candidates_device, + neighbors_device, + distances_device, + metric); + } else { + auto dataset_host = raft::make_host_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + auto queries_host = raft::make_host_matrix(batch_size, dim); + auto candidates_host = raft::make_host_matrix(batch_size, k0); + auto neighbors_host = raft::make_host_matrix(batch_size, k); + auto distances_host = raft::make_host_matrix(batch_size, k); + + auto stream = resource::get_cuda_stream(res); + raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream); + raft::copy( + candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); + + raft::resource::sync_stream(res); // wait for the queries and candidates + raft::neighbors::refine(res, + dataset_host, + queries_host.view(), + candidates_host.view(), + neighbors_host.view(), + distances_host.view(), + metric); + + raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream); + raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); + } +} + } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index fc0f30c2d1..0b892dec35 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -314,66 +314,20 @@ void RaftCagra::search( auto k0 = static_cast(refine_ratio_ * k); const bool disable_refinement = k0 <= static_cast(k); const raft::resources& res = handle_; - auto stream = resource::get_cuda_stream(res); if (disable_refinement) { search_base(queries, batch_size, k, neighbors, distances); } else { + auto queries_v = + raft::make_device_matrix_view(queries, batch_size, dimension_); auto candidate_ixs = raft::make_device_matrix(res, batch_size, k0); auto candidate_dists = raft::make_device_matrix(res, batch_size, k0); - search_base(queries, - batch_size, - k0, - reinterpret_cast(candidate_ixs.data_handle()), - candidate_dists.data_handle()); - - if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) { - auto queries_v = raft::make_device_matrix_view( - queries, batch_size, dimension_); - auto neighours_v = raft::make_device_matrix_view( - reinterpret_cast(neighbors), batch_size, k); - auto distances_v = - raft::make_device_matrix_view(distances, batch_size, k); - raft::neighbors::refine( - res, - *input_dataset_v_, - queries_v, - raft::make_const_mdspan(candidate_ixs.view()), - neighours_v, - distances_v, - index_->metric()); - } else { - auto dataset_host = raft::make_host_matrix_view( - input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1)); - auto queries_host = raft::make_host_matrix(batch_size, dimension_); - auto candidates_host = - raft::make_host_matrix(batch_size, k0); - auto neighbors_host = - raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); - - raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); - raft::copy( - candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream); - - raft::resource::sync_stream(res); // wait for the queries and candidates - raft::neighbors::refine( - res, - dataset_host, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - index_->metric()); - - raft::copy(neighbors, - reinterpret_cast(neighbors_host.data_handle()), - neighbors_host.size(), - stream); - raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); - } + search_base( + queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); + refine_helper( + res, *input_dataset_v_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); } } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index a42cd3223a..83a3a63aba 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -136,19 +136,30 @@ void RaftIvfFlatGpu::search( { static_assert(std::is_integral_v); static_assert(std::is_integral_v); - static_assert(sizeof(AnnBase::index_type) == sizeof(IdxT), - "IdxT is incompatible with the index_type"); - // Assuming the returned and the required index types have the same size, we can just coerce the - // pointers to avoid extra mapping pass over the results. - // TODO: add a linalg::map() over the result indices if the type representations do not match. + + IdxT* neighbors_IdxT; + std::optional> neighbors_storage{std::nullopt}; + if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { + neighbors_IdxT = reinterpret_cast(neighbors); + } else { + neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage->data(); + } raft::neighbors::ivf_flat::search(handle_, search_params_, *index_, queries, batch_size, k, - reinterpret_cast(neighbors), + neighbors_IdxT, distances, resource::get_workspace_resource(handle_)); + if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { + raft::linalg::unaryOp(neighbors, + neighbors_IdxT, + batch_size * k, + raft::cast_op(), + raft::resource::get_cuda_stream(handle_)); + } } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 412d432592..7201467969 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -66,6 +66,11 @@ class RaftIvfPQ : public ANN, public AnnGPU { int k, AnnBase::index_type* neighbors, float* distances) const override; + void search_base(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -139,75 +144,60 @@ void RaftIvfPQ::set_search_dataset(const T* dataset, size_t nrow) } template -void RaftIvfPQ::search( +void RaftIvfPQ::search_base( const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(std::is_integral_v); static_assert(std::is_integral_v); - static_assert(sizeof(AnnBase::index_type) == sizeof(IdxT), - "IdxT is incompatible with the index_type"); - if (refine_ratio_ > 1.0f) { - uint32_t k0 = static_cast(refine_ratio_ * k); - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto distances_tmp = raft::make_device_matrix(handle_, batch_size, k0); - auto candidates = raft::make_device_matrix(handle_, batch_size, k0); - - raft::neighbors::ivf_pq::search( - handle_, search_params_, *index_, queries_v, candidates.view(), distances_tmp.view()); - - if (raft::get_device_for_address(dataset_.data_handle()) >= 0) { - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto neighbors_v = raft::make_device_matrix_view( - reinterpret_cast(neighbors), batch_size, k); - auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); - - raft::neighbors::refine(handle_, - dataset_, - queries_v, - candidates.view(), - neighbors_v, - distances_v, - index_->metric()); - } else { - auto queries_host = raft::make_host_matrix(batch_size, index_->dim()); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); - - auto stream = resource::get_cuda_stream(handle_); - raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); - raft::copy( - candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); - - auto dataset_v = raft::make_host_matrix_view( - dataset_.data_handle(), dataset_.extent(0), dataset_.extent(1)); - - raft::resource::sync_stream(handle_); // wait for the queries and candidates - raft::neighbors::refine(handle_, - dataset_v, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - index_->metric()); - - raft::copy(reinterpret_cast(neighbors), - neighbors_host.data_handle(), - neighbors_host.size(), - stream); - raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); - } + + IdxT* neighbors_IdxT; + std::optional> neighbors_storage{std::nullopt}; + if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { + neighbors_IdxT = reinterpret_cast(neighbors); } else { - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto neighbors_v = raft::make_device_matrix_view( - reinterpret_cast(neighbors), batch_size, k); - auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); + neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage->data(); + } + + auto queries_view = + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto neighbors_view = + raft::make_device_matrix_view(neighbors_IdxT, batch_size, k); + auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); + + raft::neighbors::ivf_pq::search( + handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); + + if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { + raft::linalg::unaryOp(neighbors, + neighbors_IdxT, + batch_size * k, + raft::cast_op(), + raft::resource::get_cuda_stream(handle_)); + } +} - raft::neighbors::ivf_pq::search( - handle_, search_params_, *index_, queries_v, neighbors_v, distances_v); +template +void RaftIvfPQ::search( + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const +{ + auto k0 = static_cast(refine_ratio_ * k); + const bool disable_refinement = k0 <= static_cast(k); + const raft::resources& res = handle_; + + if (disable_refinement) { + search_base(queries, batch_size, k, neighbors, distances); + } else { + auto queries_v = + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto candidate_ixs = + raft::make_device_matrix(res, batch_size, k0); + auto candidate_dists = + raft::make_device_matrix(res, batch_size, k0); + search_base( + queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); + refine_helper( + res, dataset_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); } } } // namespace raft::bench::ann