diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index c784df6d49..abef8830d4 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -30,7 +30,6 @@ function(find_and_configure_raft) SOURCE_SUBDIR cpp OPTIONS "BUILD_TESTS OFF" - ) message(VERBOSE "CUML: Using RAFT located in ${raft_SOURCE_DIR}") diff --git a/cpp/include/cuml/neighbors/knn.hpp b/cpp/include/cuml/neighbors/knn.hpp index 4bdfa1a713..80e807ac80 100644 --- a/cpp/include/cuml/neighbors/knn.hpp +++ b/cpp/include/cuml/neighbors/knn.hpp @@ -16,59 +16,14 @@ #pragma once -#include -#include #include +#include namespace raft { class handle_t; } namespace ML { -struct knnIndex { - faiss::gpu::GpuIndex *index; - raft::distance::DistanceType metric; - float metricArg; - - faiss::gpu::StandardGpuResources *gpu_res; - int device; - ~knnIndex() { - delete index; - delete gpu_res; - } -}; - -typedef enum { - QT_8bit, - QT_4bit, - QT_8bit_uniform, - QT_4bit_uniform, - QT_fp16, - QT_8bit_direct, - QT_6bit -} QuantizerType; - -struct knnIndexParam { - virtual ~knnIndexParam() {} -}; - -struct IVFParam : knnIndexParam { - int nlist; - int nprobe; -}; - -struct IVFFlatParam : IVFParam {}; - -struct IVFPQParam : IVFParam { - int M; - int n_bits; - bool usePrecomputedTables; -}; - -struct IVFSQParam : IVFParam { - QuantizerType qtype; - bool encodeResidual; -}; /** * @brief Flat C++ API function to perform a brute force knn on @@ -112,8 +67,9 @@ void brute_force_knn(const raft::handle_t &handle, std::vector &input, * @param[in] n number of rows in the index array * @param[in] D the dimensionality of the index array */ -void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index, - ML::knnIndexParam *params, +void approx_knn_build_index(raft::handle_t &handle, + raft::spatial::knn::knnIndex *index, + raft::spatial::knn::knnIndexParam *params, raft::distance::DistanceType metric, float metricArg, float *index_array, int n, int D); @@ -131,8 +87,8 @@ void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index, * @param[in] n number of rows in the query array */ void approx_knn_search(raft::handle_t &handle, float *distances, - int64_t *indices, ML::knnIndex *index, int k, - float *query_array, int n); + int64_t *indices, raft::spatial::knn::knnIndex *index, + int k, float *query_array, int n); /** * @brief Flat C++ API function to perform a knn classification using a diff --git a/cpp/src/knn/knn.cu b/cpp/src/knn/knn.cu index adc0591518..15a5201d14 100644 --- a/cpp/src/knn/knn.cu +++ b/cpp/src/knn/knn.cu @@ -20,6 +20,7 @@ #include #include