diff --git a/cpp/include/raft/spatial/knn/ann_common.h b/cpp/include/raft/spatial/knn/ann_common.h index a0d79a1b77..0e9e323b84 100644 --- a/cpp/include/raft/spatial/knn/ann_common.h +++ b/cpp/include/raft/spatial/knn/ann_common.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,12 +22,10 @@ #include "detail/processing.hpp" #include "ivf_flat_types.hpp" +#include #include -#include -#include - namespace raft { namespace spatial { namespace knn { @@ -36,13 +34,14 @@ struct knnIndex { raft::distance::DistanceType metric; float metricArg; int nprobe; - std::unique_ptr index; std::unique_ptr> metric_processor; + std::unique_ptr> ivf_flat_float_; std::unique_ptr> ivf_flat_uint8_t_; std::unique_ptr> ivf_flat_int8_t_; - std::unique_ptr gpu_res; + std::unique_ptr> ivf_pq; + int device; template @@ -70,16 +69,6 @@ inline auto knnIndex::ivf_flat() return ivf_flat_int8_t_; } -enum QuantizerType : unsigned int { - QT_8bit, - QT_4bit, - QT_8bit_uniform, - QT_4bit_uniform, - QT_fp16, - QT_8bit_direct, - QT_6bit -}; - struct knnIndexParam { virtual ~knnIndexParam() {} }; @@ -98,11 +87,6 @@ struct IVFPQParam : IVFParam { bool usePrecomputedTables; }; -struct IVFSQParam : IVFParam { - QuantizerType qtype; - bool encodeResidual; -}; - inline auto from_legacy_index_params(const IVFFlatParam& legacy, raft::distance::DistanceType metric, float metric_arg) diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 975f1a0f89..f651e943e3 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -18,9 +18,7 @@ #include "../ann_common.h" #include "../ivf_flat.cuh" -#include "knn_brute_force_faiss.cuh" -#include "common_faiss.h" #include "processing.cuh" #include #include @@ -29,83 +27,14 @@ #include #include #include -#include +#include #include -#include -#include -#include -#include - #include namespace raft::spatial::knn::detail { -inline faiss::ScalarQuantizer::QuantizerType build_faiss_qtype(QuantizerType qtype) -{ - switch (qtype) { - case QuantizerType::QT_8bit: return faiss::ScalarQuantizer::QuantizerType::QT_8bit; - case QuantizerType::QT_8bit_uniform: - return faiss::ScalarQuantizer::QuantizerType::QT_8bit_uniform; - case QuantizerType::QT_4bit_uniform: - return faiss::ScalarQuantizer::QuantizerType::QT_4bit_uniform; - case QuantizerType::QT_fp16: return faiss::ScalarQuantizer::QuantizerType::QT_fp16; - case QuantizerType::QT_8bit_direct: - return faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct; - case QuantizerType::QT_6bit: return faiss::ScalarQuantizer::QuantizerType::QT_6bit; - default: return (faiss::ScalarQuantizer::QuantizerType)qtype; - } -} - -template -void approx_knn_ivfflat_build_index(knnIndex* index, - const IVFFlatParam& params, - IntType n, - IntType D) -{ - faiss::gpu::GpuIndexIVFFlatConfig config; - config.device = index->device; - faiss::MetricType faiss_metric = build_faiss_metric(index->metric); - index->index.reset( - new faiss::gpu::GpuIndexIVFFlat(index->gpu_res.get(), D, params.nlist, faiss_metric, config)); -} - -template -void approx_knn_ivfpq_build_index(knnIndex* index, const IVFPQParam& params, IntType n, IntType D) -{ - faiss::gpu::GpuIndexIVFPQConfig config; - config.device = index->device; - config.usePrecomputedTables = params.usePrecomputedTables; - config.interleavedLayout = params.n_bits != 8; - faiss::MetricType faiss_metric = build_faiss_metric(index->metric); - index->index.reset(new faiss::gpu::GpuIndexIVFPQ( - index->gpu_res.get(), D, params.nlist, params.M, params.n_bits, faiss_metric, config)); -} - -template -void approx_knn_ivfsq_build_index(knnIndex* index, const IVFSQParam& params, IntType n, IntType D) -{ - faiss::gpu::GpuIndexIVFScalarQuantizerConfig config; - config.device = index->device; - faiss::MetricType faiss_metric = build_faiss_metric(index->metric); - faiss::ScalarQuantizer::QuantizerType faiss_qtype = build_faiss_qtype(params.qtype); - index->index.reset(new faiss::gpu::GpuIndexIVFScalarQuantizer( - index->gpu_res.get(), D, params.nlist, faiss_qtype, faiss_metric, params.encodeResidual)); -} - -inline bool ivf_flat_supported_metric(raft::distance::DistanceType metric) -{ - switch (metric) { - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - case raft::distance::DistanceType::InnerProduct: return true; - default: return false; - } -} - template void approx_knn_build_index(const handle_t& handle, knnIndex* index, @@ -117,7 +46,6 @@ void approx_knn_build_index(const handle_t& handle, IntType D) { auto stream = handle.get_stream(); - index->index = nullptr; index->metric = metric; index->metricArg = metricArg; if (dynamic_cast(params)) { @@ -125,37 +53,35 @@ void approx_knn_build_index(const handle_t& handle, } auto ivf_ft_pams = dynamic_cast(params); auto ivf_pq_pams = dynamic_cast(params); - auto ivf_sq_pams = dynamic_cast(params); if constexpr (std::is_same_v) { index->metric_processor = create_processor(metric, n, D, 0, false, stream); + // For cosine/correlation distance, the metric processor translates distance + // to inner product via pre/post processing - pass the translated metric to + // ANN index + if (metric == raft::distance::DistanceType::CosineExpanded || + metric == raft::distance::DistanceType::CorrelationExpanded) { + metric = index->metric = raft::distance::DistanceType::InnerProduct; + } } if constexpr (std::is_same_v) { index->metric_processor->preprocess(index_array); } - if (ivf_ft_pams && ivf_flat_supported_metric(metric)) { + if (ivf_ft_pams) { auto new_params = from_legacy_index_params(*ivf_ft_pams, metric, metricArg); index->ivf_flat() = std::make_unique>( ivf_flat::build(handle, new_params, index_array, int64_t(n), D)); + } else if (ivf_pq_pams) { + neighbors::ivf_pq::index_params params; + params.metric = metric; + params.metric_arg = metricArg; + params.n_lists = ivf_pq_pams->nlist; + params.pq_bits = ivf_pq_pams->n_bits; + params.pq_dim = ivf_pq_pams->M; + // TODO: handle ivf_pq_pams.usePrecomputedTables ? + index->ivf_pq = std::make_unique>( + neighbors::ivf_pq::build(handle, params, index_array, int64_t(n), D)); } else { - RAFT_CUDA_TRY(cudaGetDevice(&(index->device))); - index->gpu_res.reset(new raft::spatial::knn::RmmGpuResources()); - index->gpu_res->noTempMemory(); - index->gpu_res->setDefaultStream(index->device, stream); - if (ivf_ft_pams) { - approx_knn_ivfflat_build_index(index, *ivf_ft_pams, n, D); - } else if (ivf_pq_pams) { - approx_knn_ivfpq_build_index(index, *ivf_pq_pams, n, D); - } else if (ivf_sq_pams) { - approx_knn_ivfsq_build_index(index, *ivf_sq_pams, n, D); - } else { - RAFT_FAIL("Unrecognized index type."); - } - if constexpr (std::is_same_v) { - index->index->train(n, index_array); - index->index->add(n, index_array); - } else { - RAFT_FAIL("FAISS-based index supports only float data."); - } + RAFT_FAIL("Unrecognized index type."); } if constexpr (std::is_same_v) { index->metric_processor->revert(index_array); } @@ -170,26 +96,22 @@ void approx_knn_search(const handle_t& handle, T* query_array, IntType n) { - auto faiss_ivf = dynamic_cast(index->index.get()); - if (faiss_ivf) { faiss_ivf->setNumProbes(index->nprobe); } - if constexpr (std::is_same_v) { index->metric_processor->preprocess(query_array); index->metric_processor->set_num_queries(k); } // search - if (faiss_ivf) { - if constexpr (std::is_same_v) { - faiss_ivf->search(n, query_array, k, distances, indices); - } else { - RAFT_FAIL("FAISS-based index supports only float data."); - } - } else if (index->ivf_flat()) { + if (index->ivf_flat()) { ivf_flat::search_params params; params.n_probes = index->nprobe; ivf_flat::search( handle, params, *(index->ivf_flat()), query_array, n, k, indices, distances); + } else if (index->ivf_pq) { + neighbors::ivf_pq::search_params params; + params.n_probes = index->nprobe; + neighbors::ivf_pq::search( + handle, params, *index->ivf_pq, query_array, n, k, indices, distances); } else { RAFT_FAIL("The model is not trained"); } diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 86a62bb487..080e7551fa 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -107,8 +107,6 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { ivfParams.nprobe = ps.nprobe; ivfParams.nlist = ps.nlist; raft::spatial::knn::knnIndex index; - index.index = nullptr; - index.gpu_res = nullptr; approx_knn_build_index(handle_, &index,