Skip to content

Commit

Permalink
Remove faiss ANN code from knnIndex (#1121)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1121
  • Loading branch information
benfred authored Jan 20, 2023
1 parent 7215c8a commit f2bc24d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 127 deletions.
26 changes: 5 additions & 21 deletions cpp/include/raft/spatial/knn/ann_common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,12 +22,10 @@

#include "detail/processing.hpp"
#include "ivf_flat_types.hpp"
#include <raft/neighbors/ivf_pq_types.hpp>

#include <raft/distance/distance_types.hpp>

#include <faiss/gpu/GpuIndex.h>
#include <raft/spatial/knn/faiss_mr.hpp>

namespace raft {
namespace spatial {
namespace knn {
Expand All @@ -36,13 +34,14 @@ struct knnIndex {
raft::distance::DistanceType metric;
float metricArg;
int nprobe;
std::unique_ptr<faiss::gpu::GpuIndex> index;
std::unique_ptr<MetricProcessor<float>> metric_processor;

std::unique_ptr<const ivf_flat::index<float, int64_t>> ivf_flat_float_;
std::unique_ptr<const ivf_flat::index<uint8_t, int64_t>> ivf_flat_uint8_t_;
std::unique_ptr<const ivf_flat::index<int8_t, int64_t>> ivf_flat_int8_t_;

std::unique_ptr<raft::spatial::knn::RmmGpuResources> gpu_res;
std::unique_ptr<const raft::neighbors::ivf_pq::index<int64_t>> ivf_pq;

int device;

template <typename T, typename IdxT>
Expand Down Expand Up @@ -70,16 +69,6 @@ inline auto knnIndex::ivf_flat<int8_t, int64_t>()
return ivf_flat_int8_t_;
}

enum QuantizerType : unsigned int {
QT_8bit,
QT_4bit,
QT_8bit_uniform,
QT_4bit_uniform,
QT_fp16,
QT_8bit_direct,
QT_6bit
};

struct knnIndexParam {
virtual ~knnIndexParam() {}
};
Expand All @@ -98,11 +87,6 @@ struct IVFPQParam : IVFParam {
bool usePrecomputedTables;
};

struct IVFSQParam : IVFParam {
QuantizerType qtype;
bool encodeResidual;
};

inline auto from_legacy_index_params(const IVFFlatParam& legacy,
raft::distance::DistanceType metric,
float metric_arg)
Expand Down
130 changes: 26 additions & 104 deletions cpp/include/raft/spatial/knn/detail/ann_quantized.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@

#include "../ann_common.h"
#include "../ivf_flat.cuh"
#include "knn_brute_force_faiss.cuh"

#include "common_faiss.h"
#include "processing.cuh"
#include <raft/core/operators.hpp>
#include <raft/util/cuda_utils.cuh>
Expand All @@ -29,83 +27,14 @@
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/label/classlabels.cuh>
#include <raft/spatial/knn/faiss_mr.hpp>
#include <raft/neighbors/ivf_pq.cuh>

#include <rmm/cuda_stream_view.hpp>

#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/gpu/GpuIndexIVFPQ.h>
#include <faiss/gpu/GpuIndexIVFScalarQuantizer.h>

#include <thrust/iterator/transform_iterator.h>

namespace raft::spatial::knn::detail {

inline faiss::ScalarQuantizer::QuantizerType build_faiss_qtype(QuantizerType qtype)
{
switch (qtype) {
case QuantizerType::QT_8bit: return faiss::ScalarQuantizer::QuantizerType::QT_8bit;
case QuantizerType::QT_8bit_uniform:
return faiss::ScalarQuantizer::QuantizerType::QT_8bit_uniform;
case QuantizerType::QT_4bit_uniform:
return faiss::ScalarQuantizer::QuantizerType::QT_4bit_uniform;
case QuantizerType::QT_fp16: return faiss::ScalarQuantizer::QuantizerType::QT_fp16;
case QuantizerType::QT_8bit_direct:
return faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct;
case QuantizerType::QT_6bit: return faiss::ScalarQuantizer::QuantizerType::QT_6bit;
default: return (faiss::ScalarQuantizer::QuantizerType)qtype;
}
}

template <typename IntType = int>
void approx_knn_ivfflat_build_index(knnIndex* index,
const IVFFlatParam& params,
IntType n,
IntType D)
{
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = index->device;
faiss::MetricType faiss_metric = build_faiss_metric(index->metric);
index->index.reset(
new faiss::gpu::GpuIndexIVFFlat(index->gpu_res.get(), D, params.nlist, faiss_metric, config));
}

template <typename IntType = int>
void approx_knn_ivfpq_build_index(knnIndex* index, const IVFPQParam& params, IntType n, IntType D)
{
faiss::gpu::GpuIndexIVFPQConfig config;
config.device = index->device;
config.usePrecomputedTables = params.usePrecomputedTables;
config.interleavedLayout = params.n_bits != 8;
faiss::MetricType faiss_metric = build_faiss_metric(index->metric);
index->index.reset(new faiss::gpu::GpuIndexIVFPQ(
index->gpu_res.get(), D, params.nlist, params.M, params.n_bits, faiss_metric, config));
}

template <typename IntType = int>
void approx_knn_ivfsq_build_index(knnIndex* index, const IVFSQParam& params, IntType n, IntType D)
{
faiss::gpu::GpuIndexIVFScalarQuantizerConfig config;
config.device = index->device;
faiss::MetricType faiss_metric = build_faiss_metric(index->metric);
faiss::ScalarQuantizer::QuantizerType faiss_qtype = build_faiss_qtype(params.qtype);
index->index.reset(new faiss::gpu::GpuIndexIVFScalarQuantizer(
index->gpu_res.get(), D, params.nlist, faiss_qtype, faiss_metric, params.encodeResidual));
}

inline bool ivf_flat_supported_metric(raft::distance::DistanceType metric)
{
switch (metric) {
case raft::distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
case raft::distance::DistanceType::InnerProduct: return true;
default: return false;
}
}

template <typename T = float, typename IntType = int>
void approx_knn_build_index(const handle_t& handle,
knnIndex* index,
Expand All @@ -117,45 +46,42 @@ void approx_knn_build_index(const handle_t& handle,
IntType D)
{
auto stream = handle.get_stream();
index->index = nullptr;
index->metric = metric;
index->metricArg = metricArg;
if (dynamic_cast<const IVFParam*>(params)) {
index->nprobe = dynamic_cast<const IVFParam*>(params)->nprobe;
}
auto ivf_ft_pams = dynamic_cast<IVFFlatParam*>(params);
auto ivf_pq_pams = dynamic_cast<IVFPQParam*>(params);
auto ivf_sq_pams = dynamic_cast<IVFSQParam*>(params);

if constexpr (std::is_same_v<T, float>) {
index->metric_processor = create_processor<float>(metric, n, D, 0, false, stream);
// For cosine/correlation distance, the metric processor translates distance
// to inner product via pre/post processing - pass the translated metric to
// ANN index
if (metric == raft::distance::DistanceType::CosineExpanded ||
metric == raft::distance::DistanceType::CorrelationExpanded) {
metric = index->metric = raft::distance::DistanceType::InnerProduct;
}
}
if constexpr (std::is_same_v<T, float>) { index->metric_processor->preprocess(index_array); }

if (ivf_ft_pams && ivf_flat_supported_metric(metric)) {
if (ivf_ft_pams) {
auto new_params = from_legacy_index_params(*ivf_ft_pams, metric, metricArg);
index->ivf_flat<T, int64_t>() = std::make_unique<const ivf_flat::index<T, int64_t>>(
ivf_flat::build(handle, new_params, index_array, int64_t(n), D));
} else if (ivf_pq_pams) {
neighbors::ivf_pq::index_params params;
params.metric = metric;
params.metric_arg = metricArg;
params.n_lists = ivf_pq_pams->nlist;
params.pq_bits = ivf_pq_pams->n_bits;
params.pq_dim = ivf_pq_pams->M;
// TODO: handle ivf_pq_pams.usePrecomputedTables ?
index->ivf_pq = std::make_unique<const neighbors::ivf_pq::index<int64_t>>(
neighbors::ivf_pq::build(handle, params, index_array, int64_t(n), D));
} else {
RAFT_CUDA_TRY(cudaGetDevice(&(index->device)));
index->gpu_res.reset(new raft::spatial::knn::RmmGpuResources());
index->gpu_res->noTempMemory();
index->gpu_res->setDefaultStream(index->device, stream);
if (ivf_ft_pams) {
approx_knn_ivfflat_build_index(index, *ivf_ft_pams, n, D);
} else if (ivf_pq_pams) {
approx_knn_ivfpq_build_index(index, *ivf_pq_pams, n, D);
} else if (ivf_sq_pams) {
approx_knn_ivfsq_build_index(index, *ivf_sq_pams, n, D);
} else {
RAFT_FAIL("Unrecognized index type.");
}
if constexpr (std::is_same_v<T, float>) {
index->index->train(n, index_array);
index->index->add(n, index_array);
} else {
RAFT_FAIL("FAISS-based index supports only float data.");
}
RAFT_FAIL("Unrecognized index type.");
}

if constexpr (std::is_same_v<T, float>) { index->metric_processor->revert(index_array); }
Expand All @@ -170,26 +96,22 @@ void approx_knn_search(const handle_t& handle,
T* query_array,
IntType n)
{
auto faiss_ivf = dynamic_cast<GpuIndexIVF*>(index->index.get());
if (faiss_ivf) { faiss_ivf->setNumProbes(index->nprobe); }

if constexpr (std::is_same_v<T, float>) {
index->metric_processor->preprocess(query_array);
index->metric_processor->set_num_queries(k);
}

// search
if (faiss_ivf) {
if constexpr (std::is_same_v<T, float>) {
faiss_ivf->search(n, query_array, k, distances, indices);
} else {
RAFT_FAIL("FAISS-based index supports only float data.");
}
} else if (index->ivf_flat<T, int64_t>()) {
if (index->ivf_flat<T, int64_t>()) {
ivf_flat::search_params params;
params.n_probes = index->nprobe;
ivf_flat::search(
handle, params, *(index->ivf_flat<T, int64_t>()), query_array, n, k, indices, distances);
} else if (index->ivf_pq) {
neighbors::ivf_pq::search_params params;
params.n_probes = index->nprobe;
neighbors::ivf_pq::search(
handle, params, *index->ivf_pq, query_array, n, k, indices, distances);
} else {
RAFT_FAIL("The model is not trained");
}
Expand Down
2 changes: 0 additions & 2 deletions cpp/test/neighbors/ann_ivf_flat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {
ivfParams.nprobe = ps.nprobe;
ivfParams.nlist = ps.nlist;
raft::spatial::knn::knnIndex index;
index.index = nullptr;
index.gpu_res = nullptr;

approx_knn_build_index(handle_,
&index,
Expand Down

0 comments on commit f2bc24d

Please sign in to comment.