From 35ab60d603afb996cbcb0dd00115db807117d07f Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 13 May 2022 10:51:37 +0200 Subject: [PATCH 001/118] inital commit and formatting cleanup --- .../raft/common/device_loads_stores.cuh | 376 ++++- cpp/include/raft/spatial/knn/ann.cuh | 24 +- cpp/include/raft/spatial/knn/ann.hpp | 76 +- cpp/include/raft/spatial/knn/ann_common.h | 2 + .../raft/spatial/knn/detail/ann_ivf_flat.cuh | 1364 ++++++++++++++++ .../knn/detail/ann_ivf_flat_kernel.cuh | 1411 +++++++++++++++++ .../knn/detail/ann_kmeans_balanced.cuh | 414 +++++ .../knn/detail/ann_quantized_faiss.cuh | 266 +++- .../raft/spatial/knn/detail/ann_utils.cuh | 586 +++++++ .../spatial/knn/detail/topk/radix_topk.cuh | 204 ++- cpp/test/CMakeLists.txt | 1 + cpp/test/spatial/ann_base_kernel.cuh | 88 + cpp/test/spatial/ann_ivf_flat.cu | 291 ++++ 13 files changed, 4944 insertions(+), 159 deletions(-) create mode 100644 cpp/include/raft/spatial/knn/detail/ann_ivf_flat.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/ann_ivf_flat_kernel.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/ann_utils.cuh create mode 100644 cpp/test/spatial/ann_base_kernel.cuh create mode 100644 cpp/test/spatial/ann_ivf_flat.cu diff --git a/cpp/include/raft/common/device_loads_stores.cuh b/cpp/include/raft/common/device_loads_stores.cuh index 41dc9cab08..0c4750aa69 100644 --- a/cpp/include/raft/common/device_loads_stores.cuh +++ b/cpp/include/raft/common/device_loads_stores.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,121 @@ namespace raft { * @param[out] addr shared memory address (should be aligned to vector size) * @param[in] x data to be stored at this address */ +DI void sts(uint8_t* addr, const uint8_t& x) +{ + uint32_t x_int; + x_int = x; + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.u8 [%0], {%1};" : : "l"(s1), "r"(x_int)); +} +DI void sts(uint8_t* addr, const uint8_t (&x)[1]) +{ + uint32_t x_int[1]; + x_int[0] = x[0]; + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.u8 [%0], {%1};" : : "l"(s1), "r"(x_int[0])); +} +DI void sts(uint8_t* addr, const uint8_t (&x)[2]) +{ + uint32_t x_int[2]; + x_int[0] = x[0]; + x_int[1] = x[1]; + auto s2 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.v2.u8 [%0], {%1, %2};" : : "l"(s2), "r"(x_int[0]), "r"(x_int[1])); +} +DI void sts(uint8_t* addr, const uint8_t (&x)[4]) +{ + uint32_t x_int[4]; + x_int[0] = x[0]; + x_int[1] = x[1]; + x_int[2] = x[2]; + x_int[3] = x[3]; + auto s4 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.v4.u8 [%0], {%1, %2, %3, %4};" + : + : "l"(s4), "r"(x_int[0]), "r"(x_int[1]), "r"(x_int[2]), "r"(x_int[3])); +} + +DI void sts(int8_t* addr, const int8_t& x) +{ + int32_t x_int = x; + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.s8 [%0], {%1};" : : "l"(s1), "r"(x_int)); +} +DI void sts(int8_t* addr, const int8_t (&x)[1]) +{ + int32_t x_int[1]; + x_int[0] = x[0]; + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.s8 [%0], {%1};" : : "l"(s1), "r"(x_int[0])); +} +DI void sts(int8_t* addr, const int8_t (&x)[2]) +{ + int32_t x_int[2]; + x_int[0] = x[0]; + x_int[1] = x[1]; + auto s2 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.v2.s8 [%0], {%1, %2};" : : "l"(s2), "r"(x_int[0]), "r"(x_int[1])); +} +DI void sts(int8_t* addr, const int8_t (&x)[4]) +{ + int32_t x_int[4]; + x_int[0] = x[0]; + x_int[1] = x[1]; + x_int[2] = x[2]; + x_int[3] = x[3]; + auto s4 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.v4.s8 [%0], {%1, %2, %3, %4};" + : + : "l"(s4), "r"(x_int[0]), "r"(x_int[1]), "r"(x_int[2]), "r"(x_int[3])); +} + +DI void sts(uint32_t* addr, const uint32_t& x) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.u32 [%0], {%1};" : : "l"(s1), "r"(x)); +} +DI void sts(uint32_t* addr, const uint32_t (&x)[1]) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.u32 [%0], {%1};" : : "l"(s1), "r"(x[0])); +} +DI void sts(uint32_t* addr, const uint32_t (&x)[2]) +{ + auto s2 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" : : "l"(s2), "r"(x[0]), "r"(x[1])); +} +DI void sts(uint32_t* addr, const uint32_t (&x)[4]) +{ + auto s4 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" + : + : "l"(s4), "r"(x[0]), "r"(x[1]), "r"(x[2]), "r"(x[3])); +} + +DI void sts(int32_t* addr, const int32_t& x) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.u32 [%0], {%1};" : : "l"(s1), "r"(x)); +} +DI void sts(int32_t* addr, const int32_t (&x)[1]) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.u32 [%0], {%1};" : : "l"(s1), "r"(x[0])); +} +DI void sts(int32_t* addr, const int32_t (&x)[2]) +{ + auto s2 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" : : "l"(s2), "r"(x[0]), "r"(x[1])); +} +DI void sts(int32_t* addr, const int32_t (&x)[4]) +{ + auto s4 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" + : + : "l"(s4), "r"(x[0]), "r"(x[1]), "r"(x[2]), "r"(x[3])); +} + DI void sts(float* addr, const float& x) { auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); @@ -83,6 +198,152 @@ DI void sts(double* addr, const double (&x)[2]) * @param[in] addr shared memory address from where to load * (should be aligned to vector size) */ + +DI void lds(uint8_t& x, const uint8_t* addr) +{ + uint32_t x_int; + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.u8 {%0}, [%1];" : "=r"(x_int) : "l"(s1)); + x = x_int; +} +DI void lds(uint8_t (&x)[1], const uint8_t* addr) +{ + uint32_t x_int[1]; + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.u8 {%0}, [%1];" : "=r"(x_int[0]) : "l"(s1)); + x[0] = x_int[0]; +} +DI void lds(uint8_t (&x)[2], const uint8_t* addr) +{ + uint32_t x_int[2]; + auto s2 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v2.u8 {%0, %1}, [%2];" : "=r"(x_int[0]), "=r"(x_int[1]) : "l"(s2)); + x[0] = x_int[0]; + x[1] = x_int[1]; +} +DI void lds(uint8_t (&x)[4], const uint8_t* addr) +{ + uint32_t x_int[4]; + auto s4 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v4.u8 {%0, %1, %2, %3}, [%4];" + : "=r"(x_int[0]), "=r"(x_int[1]), "=r"(x_int[2]), "=r"(x_int[3]) + : "l"(s4)); + x[0] = x_int[0]; + x[1] = x_int[1]; + x[2] = x_int[2]; + x[3] = x_int[3]; +} + +DI void lds(int8_t& x, const int8_t* addr) +{ + int32_t x_int; + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.s8 {%0}, [%1];" : "=r"(x_int) : "l"(s1)); + x = x_int; +} +DI void lds(int8_t (&x)[1], const int8_t* addr) +{ + int32_t x_int[1]; + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.s8 {%0}, [%1];" : "=r"(x_int[0]) : "l"(s1)); + x[0] = x_int[0]; +} +DI void lds(int8_t (&x)[2], const int8_t* addr) +{ + int32_t x_int[2]; + auto s2 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v2.s8 {%0, %1}, [%2];" : "=r"(x_int[0]), "=r"(x_int[1]) : "l"(s2)); + x[0] = x_int[0]; + x[1] = x_int[1]; +} +DI void lds(int8_t (&x)[4], const int8_t* addr) +{ + int32_t x_int[4]; + auto s4 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v4.s8 {%0, %1, %2, %3}, [%4];" + : "=r"(x_int[0]), "=r"(x_int[1]), "=r"(x_int[2]), "=r"(x_int[3]) + : "l"(s4)); + x[0] = x_int[0]; + x[1] = x_int[1]; + x[2] = x_int[2]; + x[3] = x_int[3]; +} + +DI void lds(uint32_t (&x)[4], const uint32_t* addr) +{ + auto s4 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(x[0]), "=r"(x[1]), "=r"(x[2]), "=r"(x[3]) + : "l"(s4)); +} + +DI void lds(uint32_t (&x)[2], const uint32_t* addr) +{ + auto s2 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];" : "=r"(x[0]), "=r"(x[1]) : "l"(s2)); +} + +DI void lds(uint32_t (&x)[1], const uint32_t* addr) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x[0]) : "l"(s1)); +} + +DI void lds(uint32_t& x, const uint32_t* addr) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x) : "l"(s1)); +} + +DI void lds(int32_t (&x)[4], const int32_t* addr) +{ + auto s4 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(x[0]), "=r"(x[1]), "=r"(x[2]), "=r"(x[3]) + : "l"(s4)); +} + +DI void lds(int32_t (&x)[2], const int32_t* addr) +{ + auto s2 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];" : "=r"(x[0]), "=r"(x[1]) : "l"(s2)); +} + +DI void lds(int32_t (&x)[1], const int32_t* addr) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x[0]) : "l"(s1)); +} + +DI void lds(int32_t& x, const int32_t* addr) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x) : "l"(s1)); +} + +DI void lds(float& x, const float* addr) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.f32 {%0}, [%1];" : "=f"(x) : "l"(s1)); +} +DI void lds(float (&x)[1], const float* addr) +{ + auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.f32 {%0}, [%1];" : "=f"(x[0]) : "l"(s1)); +} +DI void lds(float (&x)[2], const float* addr) +{ + auto s2 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(x[0]), "=f"(x[1]) : "l"(s2)); +} +DI void lds(float (&x)[4], const float* addr) +{ + auto s4 = __cvta_generic_to_shared(reinterpret_cast(addr)); + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" + : "=f"(x[0]), "=f"(x[1]), "=f"(x[2]), "=f"(x[3]) + : "l"(s4)); +} + DI void lds(float& x, float* addr) { auto s1 = __cvta_generic_to_shared(reinterpret_cast(addr)); @@ -159,6 +420,119 @@ DI void ldg(double (&x)[2], const double* addr) { asm volatile("ld.global.cg.v2.f64 {%0, %1}, [%2];" : "=d"(x[0]), "=d"(x[1]) : "l"(addr)); } + +DI void ldg(uint32_t (&x)[4], const uint32_t* const& addr) +{ + asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(x[0]), "=r"(x[1]), "=r"(x[2]), "=r"(x[3]) + : "l"(addr)); +} + +DI void ldg(uint32_t (&x)[2], const uint32_t* const& addr) +{ + asm volatile("ld.global.cg.v2.u32 {%0, %1}, [%2];" : "=r"(x[0]), "=r"(x[1]) : "l"(addr)); +} + +DI void ldg(uint32_t (&x)[1], const uint32_t* const& addr) +{ + asm volatile("ld.global.cg.u32 %0, [%1];" : "=r"(x[0]) : "l"(addr)); +} + +DI void ldg(uint32_t& x, const uint32_t* const& addr) +{ + asm volatile("ld.global.cg.u32 %0, [%1];" : "=r"(x) : "l"(addr)); +} + +DI void ldg(int32_t (&x)[4], const int32_t* const& addr) +{ + asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(x[0]), "=r"(x[1]), "=r"(x[2]), "=r"(x[3]) + : "l"(addr)); +} + +DI void ldg(int32_t (&x)[2], const int32_t* const& addr) +{ + asm volatile("ld.global.cg.v2.u32 {%0, %1}, [%2];" : "=r"(x[0]), "=r"(x[1]) : "l"(addr)); +} + +DI void ldg(int32_t (&x)[1], const int32_t* const& addr) +{ + asm volatile("ld.global.cg.u32 %0, [%1];" : "=r"(x[0]) : "l"(addr)); +} + +DI void ldg(int32_t& x, const int32_t* const& addr) +{ + asm volatile("ld.global.cg.u32 %0, [%1];" : "=r"(x) : "l"(addr)); +} + +DI void ldg(uint8_t (&x)[4], const uint8_t* const& addr) +{ + uint32_t x_int[4]; + asm volatile("ld.global.cg.v4.u8 {%0, %1, %2, %3}, [%4];" + : "=r"(x_int[0]), "=r"(x_int[1]), "=r"(x_int[2]), "=r"(x_int[3]) + : "l"(addr)); + x[0] = x_int[0]; + x[1] = x_int[1]; + x[2] = x_int[2]; + x[3] = x_int[3]; +} + +DI void ldg(uint8_t (&x)[2], const uint8_t* const& addr) +{ + uint32_t x_int[2]; + asm volatile("ld.global.cg.v2.u8 {%0, %1}, [%2];" : "=r"(x_int[0]), "=r"(x_int[1]) : "l"(addr)); + x[0] = x_int[0]; + x[1] = x_int[1]; +} + +DI void ldg(uint8_t (&x)[1], const uint8_t* const& addr) +{ + uint32_t x_int; + asm volatile("ld.global.cg.u8 %0, [%1];" : "=r"(x_int) : "l"(addr)); + x[0] = x_int; +} + +DI void ldg(uint8_t& x, const uint8_t* const& addr) +{ + uint32_t x_int; + asm volatile("ld.global.cg.u8 %0, [%1];" : "=r"(x_int) : "l"(addr)); + x = x_int; +} + +DI void ldg(int8_t (&x)[4], const int8_t* const& addr) +{ + int x_int[4]; + asm volatile("ld.global.cg.v4.s8 {%0, %1, %2, %3}, [%4];" + : "=r"(x_int[0]), "=r"(x_int[1]), "=r"(x_int[2]), "=r"(x_int[3]) + : "l"(addr)); + x[0] = x_int[0]; + x[1] = x_int[1]; + x[2] = x_int[2]; + x[3] = x_int[3]; +} + +DI void ldg(int8_t (&x)[2], const int8_t* const& addr) +{ + int x_int[2]; + asm volatile("ld.global.cg.v2.s8 {%0, %1}, [%2];" : "=r"(x_int[0]), "=r"(x_int[1]) : "l"(addr)); + x[0] = x_int[0]; + x[1] = x_int[1]; +} + +DI void ldg(int8_t& x, const int8_t* const& addr) +{ + int x_int; + asm volatile("ld.global.cg.s8 %0, [%1];" : "=r"(x_int) : "l"(addr)); + x = x_int; +} + +DI void ldg(int8_t (&x)[1], const int8_t* const& addr) +{ + int x_int; + asm volatile("ld.global.cg.s8 %0, [%1];" : "=r"(x_int) : "l"(addr)); + x[0] = x_int; +} + /** @} */ } // namespace raft diff --git a/cpp/include/raft/spatial/knn/ann.cuh b/cpp/include/raft/spatial/knn/ann.cuh index 2ef2ae0fa4..4dfb1b6d89 100644 --- a/cpp/include/raft/spatial/knn/ann.cuh +++ b/cpp/include/raft/spatial/knn/ann.cuh @@ -14,9 +14,6 @@ * limitations under the License. */ -#ifndef __ANN_H -#define __ANN_H - #pragma once #include "ann_common.h" @@ -25,9 +22,7 @@ #include #include -namespace raft { -namespace spatial { -namespace knn { +namespace raft::spatial::knn { /** * @brief Flat C++ API function to build an approximate nearest neighbors index @@ -42,13 +37,13 @@ namespace knn { * @param[in] n number of rows in the index array * @param[in] D the dimensionality of the index array */ -template +template inline void approx_knn_build_index(raft::handle_t& handle, raft::spatial::knn::knnIndex* index, knnIndexParam* params, raft::distance::DistanceType metric, float metricArg, - float* index_array, + T* index_array, value_idx n, value_idx D) { @@ -68,20 +63,17 @@ inline void approx_knn_build_index(raft::handle_t& handle, * @param[in] query_array the query to perform a search with * @param[in] n number of rows in the query array */ -template +template inline void approx_knn_search(raft::handle_t& handle, float* distances, int64_t* indices, raft::spatial::knn::knnIndex* index, + knnIndexParam* params, value_idx k, - float* query_array, + T* query_array, value_idx n) { - detail::approx_knn_search(handle, distances, indices, index, k, query_array, n); + detail::approx_knn_search(handle, distances, indices, index, params, k, query_array, n); } -} // namespace knn -} // namespace spatial -} // namespace raft - -#endif \ No newline at end of file +} // namespace raft::spatial::knn diff --git a/cpp/include/raft/spatial/knn/ann.hpp b/cpp/include/raft/spatial/knn/ann.hpp index b6d3ca2976..516435271d 100644 --- a/cpp/include/raft/spatial/knn/ann.hpp +++ b/cpp/include/raft/spatial/knn/ann.hpp @@ -13,79 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/** - * This file is deprecated and will be removed in release 22.06. - * Please use the cuh version instead. - */ - -#ifndef __ANN_H -#define __ANN_H #pragma once -#include "ann_common.h" -#include "detail/ann_quantized_faiss.cuh" - -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -/** - * @brief Flat C++ API function to build an approximate nearest neighbors index - * from an index array and a set of parameters. - * - * @param[in] handle RAFT handle - * @param[out] index index to be built - * @param[in] params parametrization of the index to be built - * @param[in] metric distance metric to use. Euclidean (L2) is used by default - * @param[in] metricArg metric argument - * @param[in] index_array the index array to build the index with - * @param[in] n number of rows in the index array - * @param[in] D the dimensionality of the index array - */ -template -inline void approx_knn_build_index(raft::handle_t& handle, - raft::spatial::knn::knnIndex* index, - knnIndexParam* params, - raft::distance::DistanceType metric, - float metricArg, - float* index_array, - value_idx n, - value_idx D) -{ - detail::approx_knn_build_index(handle, index, params, metric, metricArg, index_array, n, D); -} - -/** - * @brief Flat C++ API function to perform an approximate nearest neighbors - * search from previously built index and a query array - * - * @param[in] handle RAFT handle - * @param[out] distances distances of the nearest neighbors toward - * their query point - * @param[out] indices indices of the nearest neighbors - * @param[in] index index to perform a search with - * @param[in] k the number of nearest neighbors to search for - * @param[in] query_array the query to perform a search with - * @param[in] n number of rows in the query array - */ -template -inline void approx_knn_search(raft::handle_t& handle, - float* distances, - int64_t* indices, - raft::spatial::knn::knnIndex* index, - value_idx k, - float* query_array, - value_idx n) -{ - detail::approx_knn_search(handle, distances, indices, index, k, query_array, n); -} - -} // namespace knn -} // namespace spatial -} // namespace raft +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the cuh version instead.") -#endif \ No newline at end of file +#include "ann.cuh" diff --git a/cpp/include/raft/spatial/knn/ann_common.h b/cpp/include/raft/spatial/knn/ann_common.h index 5cdd6b1141..cfbde4bf21 100644 --- a/cpp/include/raft/spatial/knn/ann_common.h +++ b/cpp/include/raft/spatial/knn/ann_common.h @@ -18,6 +18,7 @@ #include +#include "detail/ann_ivf_flat.cuh" #include #include @@ -29,6 +30,7 @@ struct knnIndex { faiss::gpu::GpuIndex* index; raft::distance::DistanceType metric; float metricArg; + std::unique_ptr handle_; raft::spatial::knn::RmmGpuResources* gpu_res; int device; diff --git a/cpp/include/raft/spatial/knn/detail/ann_ivf_flat.cuh b/cpp/include/raft/spatial/knn/detail/ann_ivf_flat.cuh new file mode 100644 index 0000000000..1768bf1a1d --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ann_ivf_flat.cuh @@ -0,0 +1,1364 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_kmeans_balanced.cuh" +#include "ann_utils.cuh" +#include "knn_brute_force_faiss.cuh" +#include +#include +#include +//#include "ann_ivf_flat.cuh" +#include "ann_ivf_flat_kernel.cuh" +#include "topk/radix_topk.cuh" + +#include "common_faiss.h" +#include "processing.hpp" + +#include "processing.hpp" +#include +#include + +//#include