diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index 93b2a83bbd..e327386d13 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -33,7 +33,7 @@ #include #include #include -#include + #include #include @@ -310,7 +310,7 @@ class sparse_knn_t { stream); // combine merge buffers only if there's more than 1 partition to combine - raft::spatial::knn::detail::knn_merge_parts( + raft::spatial::knn::knn_merge_parts( merge_buffer_dists, merge_buffer_indices, out_dists, out_indices, query_batcher.batch_rows(), 2, k, stream, trans.data()); } diff --git a/cpp/include/raft/spatial/knn/ann.hpp b/cpp/include/raft/spatial/knn/ann.hpp new file mode 100644 index 0000000000..77d7831b4a --- /dev/null +++ b/cpp/include/raft/spatial/knn/ann.hpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_common.h" +#include "detail/ann_quantized_faiss.cuh" + +#include +#include + +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +using deviceAllocator = raft::mr::device::allocator; + +/** + * @brief Flat C++ API function to build an approximate nearest neighbors index + * from an index array and a set of parameters. + * + * @param[in] handle RAFT handle + * @param[out] index index to be built + * @param[in] params parametrization of the index to be built + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + * @param[in] metricArg metric argument + * @param[in] index_array the index array to build the index with + * @param[in] n number of rows in the index array + * @param[in] D the dimensionality of the index array + */ +template +inline void approx_knn_build_index(raft::handle_t &handle, + raft::spatial::knn::knnIndex *index, + knnIndexParam *params, + raft::distance::DistanceType metric, + float metricArg, float *index_array, + value_idx n, value_idx D) { + detail::approx_knn_build_index(handle, index, params, metric, metricArg, + index_array, n, D); +} + +/** + * @brief Flat C++ API function to perform an approximate nearest neighbors + * search from previously built index and a query array + * + * @param[in] handle RAFT handle + * @param[out] distances distances of the nearest neighbors toward + * their query point + * @param[out] indices indices of the nearest neighbors + * @param[in] index index to perform a search with + * @param[in] k the number of nearest neighbors to search for + * @param[in] query_array the query to perform a search with + * @param[in] n number of rows in the query array + */ +template +inline void approx_knn_search(raft::handle_t &handle, float *distances, + int64_t *indices, + raft::spatial::knn::knnIndex *index, value_idx k, + float *query_array, value_idx n) { + detail::approx_knn_search(handle, distances, indices, index, k, query_array, + n); +} + +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/ann_common.h b/cpp/include/raft/spatial/knn/ann_common.h new file mode 100644 index 0000000000..6a6c7751c2 --- /dev/null +++ b/cpp/include/raft/spatial/knn/ann_common.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +struct knnIndex { + faiss::gpu::GpuIndex *index; + raft::distance::DistanceType metric; + float metricArg; + + faiss::gpu::StandardGpuResources *gpu_res; + int device; + ~knnIndex() { + delete index; + delete gpu_res; + } +}; + +enum QuantizerType : unsigned int { + QT_8bit, + QT_4bit, + QT_8bit_uniform, + QT_4bit_uniform, + QT_fp16, + QT_8bit_direct, + QT_6bit +}; + +struct knnIndexParam { + virtual ~knnIndexParam() {} +}; + +struct IVFParam : knnIndexParam { + int nlist; + int nprobe; +}; + +struct IVFFlatParam : IVFParam {}; + +struct IVFPQParam : IVFParam { + int M; + int n_bits; + bool usePrecomputedTables; +}; + +struct IVFSQParam : IVFParam { + QuantizerType qtype; + bool encodeResidual; +}; + +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh new file mode 100644 index 0000000000..bb37fd03db --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../ann_common.h" + +#include +#include + +#include