diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp index 776d29a906..b010063dee 100644 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ b/cpp/bench/ann/src/common/ann_types.hpp @@ -73,6 +73,8 @@ struct AlgoProperty { class AnnBase { public: + using index_type = size_t; + inline AnnBase(Metric metric, int dim) : metric_(metric), dim_(dim) {} virtual ~AnnBase() noexcept = default; @@ -127,8 +129,11 @@ class ANN : public AnnBase { virtual void set_search_param(const AnnSearchParam& param) = 0; // TODO: this assumes that an algorithm can always return k results. // This is not always possible. - virtual void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const = 0; + virtual void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const = 0; virtual void save(const std::string& file) const = 0; virtual void load(const std::string& file) = 0; diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 1f27c9d6a4..8762ccd1fe 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -280,7 +280,7 @@ void bench_search(::benchmark::State& state, /** * Each thread will manage its own outputs */ - using index_type = size_t; + using index_type = AnnBase::index_type; constexpr size_t kAlignResultBuf = 64; size_t result_elem_count = k * query_set_size; result_elem_count = diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h index 407f7148df..3caca15b7f 100644 --- a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h @@ -88,8 +88,11 @@ class FaissCpu : public ANN { // TODO: if the number of results is less than k, the remaining elements of 'neighbors' // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const final; AlgoProperty get_preference() const override { @@ -169,7 +172,7 @@ void FaissCpu::set_search_param(const AnnSearchParam& param) template void FaissCpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(faiss::idx_t), "sizes of size_t and faiss::idx_t are different"); diff --git a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h index 633098fd1d..2effe631e5 100644 --- a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h @@ -111,8 +111,11 @@ class FaissGpu : public ANN, public AnnGPU { // TODO: if the number of results is less than k, the remaining elements of 'neighbors' // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const final; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -196,7 +199,7 @@ void FaissGpu::build(const T* dataset, size_t nrow) template void FaissGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(faiss::idx_t), "sizes of size_t and faiss::idx_t are different"); diff --git a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh index c89f02d974..59cf3df806 100644 --- a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh +++ b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh @@ -58,8 +58,11 @@ class Ggnn : public ANN, public AnnGPU { void build(const T* dataset, size_t nrow) override { impl_->build(dataset, nrow); } void set_search_param(const AnnSearchParam& param) override { impl_->set_search_param(param); } - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override { impl_->search(queries, batch_size, k, neighbors, distances); } @@ -123,8 +126,11 @@ class GgnnImpl : public ANN, public AnnGPU { void build(const T* dataset, size_t nrow) override; void set_search_param(const AnnSearchParam& param) override; - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { return stream_; } void save(const std::string& file) const override; @@ -243,7 +249,7 @@ void GgnnImpl::set_search_param(const AnnSearc template void GgnnImpl::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(int64_t), "sizes of size_t and GGNN's KeyT are different"); if (k != KQuery) { diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index a8f7dd824f..5743632bf4 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -79,8 +79,11 @@ class HnswLib : public ANN { void build(const T* dataset, size_t nrow) override; void set_search_param(const AnnSearchParam& param) override; - void search( - const T* query, int batch_size, int k, size_t* indices, float* distances) const override; + void search(const T* query, + int batch_size, + int k, + AnnBase::index_type* indices, + float* distances) const override; void save(const std::string& path_to_index) const override; void load(const std::string& path_to_index) override; @@ -97,7 +100,10 @@ class HnswLib : public ANN { void set_base_layer_only() { appr_alg_->base_layer_only = true; } private: - void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const; + void get_search_knn_results_(const T* query, + int k, + AnnBase::index_type* indices, + float* distances) const; std::shared_ptr::type>> appr_alg_; std::shared_ptr::type>> space_; @@ -176,7 +182,7 @@ void HnswLib::set_search_param(const AnnSearchParam& param_) template void HnswLib::search( - const T* query, int batch_size, int k, size_t* indices, float* distances) const + const T* query, int batch_size, int k, AnnBase::index_type* indices, float* distances) const { auto f = [&](int i) { // hnsw can only handle a single vector at a time. @@ -217,7 +223,7 @@ void HnswLib::load(const std::string& path_to_index) template void HnswLib::get_search_knn_results_(const T* query, int k, - size_t* indices, + AnnBase::index_type* indices, float* distances) const { auto result = appr_alg_->searchKnn(query, k); diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h index 6cadb26736..ffe8f8717b 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h @@ -19,9 +19,12 @@ #include #include +#include +#include #include #include #include +#include #include #include @@ -166,4 +169,73 @@ inline configured_raft_resources::configured_raft_resources(configured_raft_reso inline configured_raft_resources& configured_raft_resources::operator=( configured_raft_resources&&) = default; +/** A helper to refine the neighbors when the data is on device or on host. */ +template +void refine_helper(const raft::resources& res, + DatasetT dataset, + QueriesT queries, + CandidatesT candidates, + int k, + AnnBase::index_type* neighbors, + float* distances, + raft::distance::DistanceType metric) +{ + using data_type = typename DatasetT::value_type; + using index_type = AnnBase::index_type; + using extents_type = index_type; // device-side refine requires this + + static_assert(std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + extents_type batch_size = queries.extent(0); + extents_type dim = queries.extent(1); + extents_type k0 = candidates.extent(1); + + if (raft::get_device_for_address(dataset.data_handle()) >= 0) { + auto dataset_device = raft::make_device_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + auto queries_device = raft::make_device_matrix_view( + queries.data_handle(), batch_size, dim); + auto candidates_device = raft::make_device_matrix_view( + candidates.data_handle(), batch_size, k0); + auto neighbors_device = + raft::make_device_matrix_view(neighbors, batch_size, k); + auto distances_device = + raft::make_device_matrix_view(distances, batch_size, k); + + raft::neighbors::refine(res, + dataset_device, + queries_device, + candidates_device, + neighbors_device, + distances_device, + metric); + } else { + auto dataset_host = raft::make_host_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + auto queries_host = raft::make_host_matrix(batch_size, dim); + auto candidates_host = raft::make_host_matrix(batch_size, k0); + auto neighbors_host = raft::make_host_matrix(batch_size, k); + auto distances_host = raft::make_host_matrix(batch_size, k); + + auto stream = resource::get_cuda_stream(res); + raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream); + raft::copy( + candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); + + raft::resource::sync_stream(res); // wait for the queries and candidates + raft::neighbors::refine(res, + dataset_host, + queries_host.view(), + candidates_host.view(), + neighbors_host.view(), + distances_host.view(), + metric); + + raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream); + raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); + } +} + } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h index ed9c120ed4..1c4b847d1a 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h @@ -41,10 +41,11 @@ class RaftCagraHnswlib : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -99,7 +100,7 @@ void RaftCagraHnswlib::load(const std::string& file) template void RaftCagraHnswlib::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { hnswlib_search_.search(queries, batch_size, k, neighbors, distances); } diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 46da8c52e6..0b892dec35 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -96,12 +96,16 @@ class RaftCagra : public ANN, public AnnGPU { void set_search_dataset(const T* dataset, size_t nrow) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; - void search_base( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; + void search_base(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -272,15 +276,18 @@ std::unique_ptr> RaftCagra::copy() template void RaftCagra::search_base( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + IdxT* neighbors_IdxT; - rmm::device_uvector neighbors_storage(0, resource::get_cuda_stream(handle_)); - if constexpr (std::is_same_v) { - neighbors_IdxT = neighbors; + std::optional> neighbors_storage{std::nullopt}; + if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { + neighbors_IdxT = reinterpret_cast(neighbors); } else { - neighbors_storage.resize(batch_size * k, resource::get_cuda_stream(handle_)); - neighbors_IdxT = neighbors_storage.data(); + neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage->data(); } auto queries_view = @@ -291,76 +298,36 @@ void RaftCagra::search_base( raft::neighbors::cagra::search( handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); - if constexpr (!std::is_same_v) { + if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { raft::linalg::unaryOp(neighbors, neighbors_IdxT, batch_size * k, - raft::cast_op(), + raft::cast_op(), raft::resource::get_cuda_stream(handle_)); } } template void RaftCagra::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { auto k0 = static_cast(refine_ratio_ * k); const bool disable_refinement = k0 <= static_cast(k); const raft::resources& res = handle_; - auto stream = resource::get_cuda_stream(res); if (disable_refinement) { search_base(queries, batch_size, k, neighbors, distances); } else { - auto candidate_ixs = raft::make_device_matrix(res, batch_size, k0); - auto candidate_dists = raft::make_device_matrix(res, batch_size, k0); - search_base(queries, - batch_size, - k0, - reinterpret_cast(candidate_ixs.data_handle()), - candidate_dists.data_handle()); - - if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) { - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, dimension_); - auto neighours_v = raft::make_device_matrix_view( - reinterpret_cast(neighbors), batch_size, k); - auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); - raft::neighbors::refine( - res, - *input_dataset_v_, - queries_v, - raft::make_const_mdspan(candidate_ixs.view()), - neighours_v, - distances_v, - index_->metric()); - } else { - auto dataset_host = raft::make_host_matrix_view( - input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1)); - auto queries_host = raft::make_host_matrix(batch_size, dimension_); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); - - raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); - raft::copy( - candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream); - - raft::resource::sync_stream(res); // wait for the queries and candidates - raft::neighbors::refine(res, - dataset_host, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - index_->metric()); - - raft::copy(neighbors, - reinterpret_cast(neighbors_host.data_handle()), - neighbors_host.size(), - stream); - raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); - } + auto queries_v = + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto candidate_ixs = + raft::make_device_matrix(res, batch_size, k0); + auto candidate_dists = + raft::make_device_matrix(res, batch_size, k0); + search_base( + queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); + refine_helper( + res, *input_dataset_v_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); } } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 48d2b9de80..83a3a63aba 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -61,10 +61,11 @@ class RaftIvfFlatGpu : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -131,17 +132,34 @@ std::unique_ptr> RaftIvfFlatGpu::copy() template void RaftIvfFlatGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { - static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + + IdxT* neighbors_IdxT; + std::optional> neighbors_storage{std::nullopt}; + if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { + neighbors_IdxT = reinterpret_cast(neighbors); + } else { + neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage->data(); + } raft::neighbors::ivf_flat::search(handle_, search_params_, *index_, queries, batch_size, k, - (IdxT*)neighbors, + neighbors_IdxT, distances, resource::get_workspace_resource(handle_)); + if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { + raft::linalg::unaryOp(neighbors, + neighbors_IdxT, + batch_size * k, + raft::cast_op(), + raft::resource::get_cuda_stream(handle_)); + } } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 1d73bd2e51..7201467969 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -61,10 +61,16 @@ class RaftIvfPQ : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; void set_search_dataset(const T* dataset, size_t nrow) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const override; + void search_base(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -137,68 +143,61 @@ void RaftIvfPQ::set_search_dataset(const T* dataset, size_t nrow) dataset_ = raft::make_device_matrix_view(dataset, nrow, index_->dim()); } +template +void RaftIvfPQ::search_base( + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const +{ + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + + IdxT* neighbors_IdxT; + std::optional> neighbors_storage{std::nullopt}; + if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { + neighbors_IdxT = reinterpret_cast(neighbors); + } else { + neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); + neighbors_IdxT = neighbors_storage->data(); + } + + auto queries_view = + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto neighbors_view = + raft::make_device_matrix_view(neighbors_IdxT, batch_size, k); + auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); + + raft::neighbors::ivf_pq::search( + handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); + + if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { + raft::linalg::unaryOp(neighbors, + neighbors_IdxT, + batch_size * k, + raft::cast_op(), + raft::resource::get_cuda_stream(handle_)); + } +} + template void RaftIvfPQ::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { - if (refine_ratio_ > 1.0f) { - uint32_t k0 = static_cast(refine_ratio_ * k); - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto distances_tmp = raft::make_device_matrix(handle_, batch_size, k0); - auto candidates = raft::make_device_matrix(handle_, batch_size, k0); - - raft::neighbors::ivf_pq::search( - handle_, search_params_, *index_, queries_v, candidates.view(), distances_tmp.view()); - - if (raft::get_device_for_address(dataset_.data_handle()) >= 0) { - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto neighbors_v = raft::make_device_matrix_view((IdxT*)neighbors, batch_size, k); - auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); - - raft::neighbors::refine(handle_, - dataset_, - queries_v, - candidates.view(), - neighbors_v, - distances_v, - index_->metric()); - } else { - auto queries_host = raft::make_host_matrix(batch_size, index_->dim()); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); - - auto stream = resource::get_cuda_stream(handle_); - raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); - raft::copy( - candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); - - auto dataset_v = raft::make_host_matrix_view( - dataset_.data_handle(), dataset_.extent(0), dataset_.extent(1)); - - raft::resource::sync_stream(handle_); // wait for the queries and candidates - raft::neighbors::refine(handle_, - dataset_v, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - index_->metric()); - - raft::copy(neighbors, (size_t*)neighbors_host.data_handle(), neighbors_host.size(), stream); - raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); - } + auto k0 = static_cast(refine_ratio_ * k); + const bool disable_refinement = k0 <= static_cast(k); + const raft::resources& res = handle_; + + if (disable_refinement) { + search_base(queries, batch_size, k, neighbors, distances); } else { auto queries_v = - raft::make_device_matrix_view(queries, batch_size, index_->dim()); - auto neighbors_v = - raft::make_device_matrix_view((IdxT*)neighbors, batch_size, k); - auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); - - raft::neighbors::ivf_pq::search( - handle_, search_params_, *index_, queries_v, neighbors_v, distances_v); + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto candidate_ixs = + raft::make_device_matrix(res, batch_size, k0); + auto candidate_dists = + raft::make_device_matrix(res, batch_size, k0); + search_base( + queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); + refine_helper( + res, dataset_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); } } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_wrapper.h b/cpp/bench/ann/src/raft/raft_wrapper.h index 586b81ae06..2c996058b2 100644 --- a/cpp/bench/ann/src/raft/raft_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_wrapper.h @@ -56,10 +56,11 @@ class RaftGpu : public ANN, public AnnGPU { void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final; + void search(const T* queries, + int batch_size, + int k, + AnnBase::index_type* neighbors, + float* distances) const final; // to enable dataset access from GPU memory AlgoProperty get_preference() const override @@ -133,15 +134,16 @@ void RaftGpu::load(const std::string& file) template void RaftGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const + const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const { auto queries_view = raft::make_device_matrix_view(queries, batch_size, this->dim_); - auto neighbors_view = raft::make_device_matrix_view(neighbors, batch_size, k); + auto neighbors_view = + raft::make_device_matrix_view(neighbors, batch_size, k); auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - raft::neighbors::brute_force::search( + raft::neighbors::brute_force::search( handle_, *index_, queries_view, neighbors_view, distances_view); }