Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ann-bench: miscellaneous improvements #1808

Merged
16 changes: 12 additions & 4 deletions cpp/bench/ann/src/faiss/faiss_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,27 @@

namespace raft::bench::ann {

template <typename T>
void parse_base_build_param(const nlohmann::json& conf,
typename raft::bench::ann::FaissGpu<T>::BuildParam& param)
{
param.nlist = conf.at("nlist");
if (conf.contains("ratio")) { param.ratio = conf.at("ratio"); }
}

template <typename T>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::FaissGpuIVFFlat<T>::BuildParam& param)
{
param.nlist = conf.at("nlist");
parse_base_build_param<T>(conf, param);
}

template <typename T>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::FaissGpuIVFPQ<T>::BuildParam& param)
{
param.nlist = conf.at("nlist");
param.M = conf.at("M");
parse_base_build_param<T>(conf, param);
param.M = conf.at("M");
if (conf.contains("usePrecomputed")) {
param.usePrecomputed = conf.at("usePrecomputed");
} else {
Expand All @@ -59,7 +67,7 @@ template <typename T>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::FaissGpuIVFSQ<T>::BuildParam& param)
{
param.nlist = conf.at("nlist");
parse_base_build_param<T>(conf, param);
param.quantizer_type = conf.at("quantizer_type");
}

Expand Down
75 changes: 49 additions & 26 deletions cpp/bench/ann/src/faiss/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "../common/ann_types.hpp"

#include <raft/core/logger.hpp>
#include <raft/util/cudart_utils.hpp>

#include <faiss/IndexFlat.h>
Expand Down Expand Up @@ -85,7 +86,23 @@ class FaissGpu : public ANN<T> {
float refine_ratio = 1.0;
};

FaissGpu(Metric metric, int dim, int nlist);
struct BuildParam {
int nlist = 1;
int ratio = 2;
};

FaissGpu(Metric metric, int dim, const BuildParam& param)
: ANN<T>(metric, dim),
metric_type_(parse_metric_type(metric)),
nlist_{param.nlist},
training_sample_fraction_{1.0 / double(param.ratio)}
{
static_assert(std::is_same_v<T, float>, "faiss support only float type");
RAFT_CUDA_TRY(cudaGetDevice(&device_));
RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming));
faiss_default_stream_ = gpu_resource_.getDefaultStream(device_);
}

virtual ~FaissGpu() noexcept { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); }

void build(const T* dataset, size_t nrow, cudaStream_t stream = 0) final;
Expand Down Expand Up @@ -131,23 +148,35 @@ class FaissGpu : public ANN<T> {
int device_;
cudaEvent_t sync_{nullptr};
cudaStream_t faiss_default_stream_{nullptr};
double training_sample_fraction_;
};

template <typename T>
FaissGpu<T>::FaissGpu(Metric metric, int dim, int nlist)
: ANN<T>(metric, dim), metric_type_(parse_metric_type(metric)), nlist_(nlist)
{
static_assert(std::is_same_v<T, float>, "faiss support only float type");
RAFT_CUDA_TRY(cudaGetDevice(&device_));
RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming));
faiss_default_stream_ = gpu_resource_.getDefaultStream(device_);
}

template <typename T>
void FaissGpu<T>::build(const T* dataset, size_t nrow, cudaStream_t stream)
{
OmpSingleThreadScope omp_single_thread;

auto index_ivf = dynamic_cast<faiss::gpu::GpuIndexIVF*>(index_.get());
if (index_ivf != nullptr) {
// set the min/max training size for clustering to use the whole provided training set.
double trainset_size = training_sample_fraction_ * static_cast<double>(nrow);
double points_per_centroid = trainset_size / static_cast<double>(nlist_);
int max_ppc = std::ceil(points_per_centroid);
int min_ppc = std::floor(points_per_centroid);
if (min_ppc < index_ivf->cp.min_points_per_centroid) {
RAFT_LOG_WARN(
"The suggested training set size %zu (data size %zu, training sample ratio %f) yields %d "
"points per cluster (n_lists = %d). This is smaller than the FAISS default "
"min_points_per_centroid = %d.",
static_cast<size_t>(trainset_size),
nrow,
training_sample_fraction_,
min_ppc,
nlist_,
index_ivf->cp.min_points_per_centroid);
}
index_ivf->cp.max_points_per_centroid = max_ppc;
index_ivf->cp.min_points_per_centroid = min_ppc;
}
index_->train(nrow, dataset); // faiss::gpu::GpuIndexFlat::train() will do nothing
assert(index_->is_trained);
index_->add(nrow, dataset);
Expand Down Expand Up @@ -208,12 +237,9 @@ void FaissGpu<T>::load_(const std::string& file)
template <typename T>
class FaissGpuIVFFlat : public FaissGpu<T> {
public:
struct BuildParam {
int nlist;
};
using typename FaissGpu<T>::BuildParam;

FaissGpuIVFFlat(Metric metric, int dim, const BuildParam& param)
: FaissGpu<T>(metric, dim, param.nlist)
FaissGpuIVFFlat(Metric metric, int dim, const BuildParam& param) : FaissGpu<T>(metric, dim, param)
{
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = this->device_;
Expand All @@ -234,15 +260,13 @@ class FaissGpuIVFFlat : public FaissGpu<T> {
template <typename T>
class FaissGpuIVFPQ : public FaissGpu<T> {
public:
struct BuildParam {
int nlist;
struct BuildParam : public FaissGpu<T>::BuildParam {
int M;
bool useFloat16;
bool usePrecomputed;
};

FaissGpuIVFPQ(Metric metric, int dim, const BuildParam& param)
: FaissGpu<T>(metric, dim, param.nlist)
FaissGpuIVFPQ(Metric metric, int dim, const BuildParam& param) : FaissGpu<T>(metric, dim, param)
{
faiss::gpu::GpuIndexIVFPQConfig config;
config.useFloat16LookupTables = param.useFloat16;
Expand Down Expand Up @@ -271,13 +295,11 @@ class FaissGpuIVFPQ : public FaissGpu<T> {
template <typename T>
class FaissGpuIVFSQ : public FaissGpu<T> {
public:
struct BuildParam {
int nlist;
struct BuildParam : public FaissGpu<T>::BuildParam {
std::string quantizer_type;
};

FaissGpuIVFSQ(Metric metric, int dim, const BuildParam& param)
: FaissGpu<T>(metric, dim, param.nlist)
FaissGpuIVFSQ(Metric metric, int dim, const BuildParam& param) : FaissGpu<T>(metric, dim, param)
{
faiss::ScalarQuantizer::QuantizerType qtype;
if (param.quantizer_type == "fp16") {
Expand Down Expand Up @@ -310,7 +332,8 @@ class FaissGpuIVFSQ : public FaissGpu<T> {
template <typename T>
class FaissGpuFlat : public FaissGpu<T> {
public:
FaissGpuFlat(Metric metric, int dim) : FaissGpu<T>(metric, dim, 0)
FaissGpuFlat(Metric metric, int dim)
: FaissGpu<T>(metric, dim, typename FaissGpu<T>::BuildParam{})
{
faiss::gpu::GpuIndexFlatConfig config;
config.device = this->device_;
Expand Down
16 changes: 12 additions & 4 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ void parse_build_param(const nlohmann::json& conf,
{
param.n_lists = conf.at("nlist");
if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); }
if (conf.contains("ratio")) {
param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio");
std::cout << "kmeans_trainset_fraction " << param.kmeans_trainset_fraction;
}
if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); }
}

template <typename T, typename IdxT>
Expand All @@ -82,6 +79,17 @@ void parse_build_param(const nlohmann::json& conf,
if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); }
if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); }
if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); }
if (conf.contains("codebook_kind")) {
std::string kind = conf.at("codebook_kind");
achirkin marked this conversation as resolved.
Show resolved Hide resolved
if (kind == "cluster") {
param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER;
} else if (kind == "subspace") {
param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE;
} else {
throw std::runtime_error("codebook_kind: '" + kind +
"', should be either 'cluster' or 'subspace'");
}
}
}

template <typename T, typename IdxT>
Expand Down
41 changes: 24 additions & 17 deletions cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,14 @@ class RaftIvfPQ : public ANN<T> {
rmm::mr::set_current_device_resource(&mr_);
index_params_.metric = parse_metric_type(metric);
RAFT_CUDA_TRY(cudaGetDevice(&device_));
RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming));
}

~RaftIvfPQ() noexcept { rmm::mr::set_current_device_resource(mr_.get_upstream()); }
~RaftIvfPQ() noexcept
{
RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_));
rmm::mr::set_current_device_resource(mr_.get_upstream());
}

void build(const T* dataset, size_t nrow, cudaStream_t stream) final;

Expand Down Expand Up @@ -96,13 +101,20 @@ class RaftIvfPQ : public ANN<T> {
// `mr_` must go first to make sure it dies last
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> mr_;
raft::device_resources handle_;
cudaEvent_t sync_{nullptr};
BuildParam index_params_;
raft::neighbors::ivf_pq::search_params search_params_;
std::optional<raft::neighbors::ivf_pq::index<IdxT>> index_;
int device_;
int dimension_;
float refine_ratio_ = 1.0;
raft::device_matrix_view<const T, IdxT> dataset_;

void stream_wait(cudaStream_t stream) const
{
RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(handle_)));
RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_));
}
};

template <typename T, typename IdxT>
Expand All @@ -121,12 +133,12 @@ void RaftIvfPQ<T, IdxT>::load(const std::string& file)
}

template <typename T, typename IdxT>
void RaftIvfPQ<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
void RaftIvfPQ<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t stream)
{
auto dataset_v = raft::make_device_matrix_view<const T, IdxT>(dataset, IdxT(nrow), dim_);

index_.emplace(raft::runtime::neighbors::ivf_pq::build(handle_, index_params_, dataset_v));
return;
stream_wait(stream);
}

template <typename T, typename IdxT>
Expand Down Expand Up @@ -176,16 +188,14 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,
neighbors_v,
distances_v,
index_->metric());
stream_wait(stream); // RAFT stream -> bench stream
} else {
auto queries_host = raft::make_host_matrix<T, IdxT>(batch_size, index_->dim());
auto candidates_host = raft::make_host_matrix<IdxT, IdxT>(batch_size, k0);
auto neighbors_host = raft::make_host_matrix<IdxT, IdxT>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, IdxT>(batch_size, k);

raft::copy(queries_host.data_handle(),
queries,
queries_host.size(),
resource::get_cuda_stream(handle_));
raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream);
raft::copy(candidates_host.data_handle(),
candidates.data_handle(),
candidates_host.size(),
Expand All @@ -194,6 +204,10 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,
auto dataset_v = raft::make_host_matrix_view<const T, IdxT>(
dataset_.data_handle(), dataset_.extent(0), dataset_.extent(1));

// wait for the queries to copy to host in 'stream` and for IVF-PQ::search to finish
RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(handle_)));
RAFT_CUDA_TRY(cudaEventRecord(sync_, stream));
RAFT_CUDA_TRY(cudaEventSynchronize(sync_));
raft::runtime::neighbors::refine(handle_,
dataset_v,
queries_host.view(),
Expand All @@ -202,14 +216,8 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,
distances_host.view(),
index_->metric());

raft::copy(neighbors,
(size_t*)neighbors_host.data_handle(),
neighbors_host.size(),
resource::get_cuda_stream(handle_));
raft::copy(distances,
distances_host.data_handle(),
distances_host.size(),
resource::get_cuda_stream(handle_));
raft::copy(neighbors, (size_t*)neighbors_host.data_handle(), neighbors_host.size(), stream);
raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream);
}
} else {
auto queries_v =
Expand All @@ -219,8 +227,7 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,

raft::runtime::neighbors::ivf_pq::search(
handle_, search_params_, *index_, queries_v, neighbors_v, distances_v);
stream_wait(stream); // RAFT stream -> bench stream
}
resource::sync_stream(handle_);
return;
}
} // namespace raft::bench::ann
Loading