Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ANN_BENCH: common AnnBase::index_type #2315

Merged
merged 4 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cpp/bench/ann/src/common/ann_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ struct AlgoProperty {

class AnnBase {
public:
using index_type = size_t;

inline AnnBase(Metric metric, int dim) : metric_(metric), dim_(dim) {}
virtual ~AnnBase() noexcept = default;

Expand Down Expand Up @@ -127,8 +129,11 @@ class ANN : public AnnBase {
virtual void set_search_param(const AnnSearchParam& param) = 0;
// TODO: this assumes that an algorithm can always return k results.
// This is not always possible.
virtual void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const = 0;
virtual void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const = 0;

virtual void save(const std::string& file) const = 0;
virtual void load(const std::string& file) = 0;
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ void bench_search(::benchmark::State& state,
/**
* Each thread will manage its own outputs
*/
using index_type = size_t;
using index_type = AnnBase::index_type;
constexpr size_t kAlignResultBuf = 64;
size_t result_elem_count = k * query_set_size;
result_elem_count =
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ class FaissCpu : public ANN<T> {

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const final;

AlgoProperty get_preference() const override
{
Expand Down Expand Up @@ -169,7 +172,7 @@ void FaissCpu<T>::set_search_param(const AnnSearchParam& param)

template <typename T>
void FaissCpu<T>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ class FaissGpu : public ANN<T>, public AnnGPU {

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const final;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand Down Expand Up @@ -196,7 +199,7 @@ void FaissGpu<T>::build(const T* dataset, size_t nrow)

template <typename T>
void FaissGpu<T>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");
Expand Down
16 changes: 11 additions & 5 deletions cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ class Ggnn : public ANN<T>, public AnnGPU {
void build(const T* dataset, size_t nrow) override { impl_->build(dataset, nrow); }

void set_search_param(const AnnSearchParam& param) override { impl_->set_search_param(param); }
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override
{
impl_->search(queries, batch_size, k, neighbors, distances);
}
Expand Down Expand Up @@ -123,8 +126,11 @@ class GgnnImpl : public ANN<T>, public AnnGPU {
void build(const T* dataset, size_t nrow) override;

void set_search_param(const AnnSearchParam& param) override;
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override;
[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { return stream_; }

void save(const std::string& file) const override;
Expand Down Expand Up @@ -243,7 +249,7 @@ void GgnnImpl<T, measure, D, KBuild, KQuery, S>::set_search_param(const AnnSearc

template <typename T, DistanceMeasure measure, int D, int KBuild, int KQuery, int S>
void GgnnImpl<T, measure, D, KBuild, KQuery, S>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(int64_t), "sizes of size_t and GGNN's KeyT are different");
if (k != KQuery) {
Expand Down
16 changes: 11 additions & 5 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ class HnswLib : public ANN<T> {
void build(const T* dataset, size_t nrow) override;

void set_search_param(const AnnSearchParam& param) override;
void search(
const T* query, int batch_size, int k, size_t* indices, float* distances) const override;
void search(const T* query,
int batch_size,
int k,
AnnBase::index_type* indices,
float* distances) const override;

void save(const std::string& path_to_index) const override;
void load(const std::string& path_to_index) override;
Expand All @@ -97,7 +100,10 @@ class HnswLib : public ANN<T> {
void set_base_layer_only() { appr_alg_->base_layer_only = true; }

private:
void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const;
void get_search_knn_results_(const T* query,
int k,
AnnBase::index_type* indices,
float* distances) const;

std::shared_ptr<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>> appr_alg_;
std::shared_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_;
Expand Down Expand Up @@ -176,7 +182,7 @@ void HnswLib<T>::set_search_param(const AnnSearchParam& param_)

template <typename T>
void HnswLib<T>::search(
const T* query, int batch_size, int k, size_t* indices, float* distances) const
const T* query, int batch_size, int k, AnnBase::index_type* indices, float* distances) const
{
auto f = [&](int i) {
// hnsw can only handle a single vector at a time.
Expand Down Expand Up @@ -217,7 +223,7 @@ void HnswLib<T>::load(const std::string& path_to_index)
template <typename T>
void HnswLib<T>::get_search_knn_results_(const T* query,
int k,
size_t* indices,
AnnBase::index_type* indices,
float* distances) const
{
auto result = appr_alg_->searchKnn(query, k);
Expand Down
72 changes: 72 additions & 0 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/operators.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/refine.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -166,4 +169,73 @@ inline configured_raft_resources::configured_raft_resources(configured_raft_reso
inline configured_raft_resources& configured_raft_resources::operator=(
configured_raft_resources&&) = default;

/** A helper to refine the neighbors when the data is on device or on host. */
template <typename DatasetT, typename QueriesT, typename CandidatesT>
void refine_helper(const raft::resources& res,
DatasetT dataset,
QueriesT queries,
CandidatesT candidates,
int k,
AnnBase::index_type* neighbors,
float* distances,
raft::distance::DistanceType metric)
{
using data_type = typename DatasetT::value_type;
using index_type = AnnBase::index_type;
using extents_type = index_type; // device-side refine requires this

static_assert(std::is_same_v<data_type, typename QueriesT::value_type>);
static_assert(std::is_same_v<data_type, typename DatasetT::value_type>);
static_assert(std::is_same_v<index_type, typename CandidatesT::value_type>);

extents_type batch_size = queries.extent(0);
extents_type dim = queries.extent(1);
extents_type k0 = candidates.extent(1);

if (raft::get_device_for_address(dataset.data_handle()) >= 0) {
auto dataset_device = raft::make_device_matrix_view<const data_type, extents_type>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1));
auto queries_device = raft::make_device_matrix_view<const data_type, extents_type>(
queries.data_handle(), batch_size, dim);
auto candidates_device = raft::make_device_matrix_view<const index_type, extents_type>(
candidates.data_handle(), batch_size, k0);
auto neighbors_device =
raft::make_device_matrix_view<index_type, extents_type>(neighbors, batch_size, k);
auto distances_device =
raft::make_device_matrix_view<float, extents_type>(distances, batch_size, k);

raft::neighbors::refine<index_type, data_type, float, extents_type>(res,
dataset_device,
queries_device,
candidates_device,
neighbors_device,
distances_device,
metric);
} else {
auto dataset_host = raft::make_host_matrix_view<const data_type, extents_type>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1));
auto queries_host = raft::make_host_matrix<data_type, extents_type>(batch_size, dim);
auto candidates_host = raft::make_host_matrix<index_type, extents_type>(batch_size, k0);
auto neighbors_host = raft::make_host_matrix<index_type, extents_type>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, extents_type>(batch_size, k);

auto stream = resource::get_cuda_stream(res);
raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream);
raft::copy(
candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream);

raft::resource::sync_stream(res); // wait for the queries and candidates
raft::neighbors::refine<index_type, data_type, float, extents_type>(res,
dataset_host,
queries_host.view(),
candidates_host.view(),
neighbors_host.view(),
distances_host.view(),
metric);

raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream);
raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream);
}
}

} // namespace raft::bench::ann
11 changes: 6 additions & 5 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ class RaftCagraHnswlib : public ANN<T>, public AnnGPU {

void set_search_param(const AnnSearchParam& param) override;

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand Down Expand Up @@ -99,7 +100,7 @@ void RaftCagraHnswlib<T, IdxT>::load(const std::string& file)

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
hnswlib_search_.search(queries, batch_size, k, neighbors, distances);
}
Expand Down
Loading
Loading