From 079420db17382d11560f5833bbfa1e1983e7a2f5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 27 Oct 2022 15:48:34 -0400 Subject: [PATCH 1/4] Adding fused l2 knn to public APIs --- cpp/include/raft/neighbors/brute_force.cuh | 64 +++++++++++++++++++++- cpp/test/neighbors/fused_l2_knn.cu | 26 ++++----- 2 files changed, 76 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 772ccb67d2..4e5d7d084d 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -17,6 +17,8 @@ #pragma once #include +#include +#include #include #include @@ -70,7 +72,7 @@ namespace raft::neighbors::brute_force { * @param[in] n_samples number of rows in each partition * @param[in] translations optional vector of starting global id mappings for each local partition */ -template +template inline void knn_merge_parts( const raft::handle_t& handle, raft::device_matrix_view in_keys, @@ -191,4 +193,64 @@ void knn(raft::handle_t const& handle, metric_arg.value_or(2.0f)); } +/** + * @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance. + * + * This is a specialized function for fusing the k-selection with the distance + * computation when k < 64. The value of k will be inferred from the number + * of columns in the output matrices. + * + * @tparam value_t + * @tparam idx_t, + * @param[in] handle raft handle for sharing expensive resources + * @param[in] index input index array on device (size m * d) + * @param[in] query input query array on device (size n * d) + * @param[out] out_inds output indices array on device (size n * k) + * @param[out] out_dists output dists array on device (size n * k) + */ +template +void fused_l2_knn(const raft::handle_t& handle, + raft::device_matrix_view index, + raft::device_matrix_view query, + raft::device_matrix_view out_inds, + raft::device_matrix_view out_dists, + raft::distance::DistanceType metric) +{ + int k = out_inds.extent(1); + + RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64"); + RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(index.extent(1) == query.extent(1), + "Number of columns in input matrices must be the same."); + + RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || + metric == distance::DistanceType::L2Unexpanded || + metric == distance::DistanceType::L2SqrtUnexpanded || + metric == distance::DistanceType::L2SqrtExpanded, + "Distance metric must be L2"); + + size_t n_index_rows = index.extent(0); + size_t n_query_rows = query.extent(0); + size_t D = index.extent(1); + + RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); + RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); + + bool rowMajorIndex = raft::is_row_major(index); + bool rowMajorQuery = raft::is_row_major(query); + + raft::spatial::knn::detail::fusedL2Knn(D, + out_inds.data_handle(), + out_dists.data_handle(), + index.data_handle(), + query.data_handle(), + n_index_rows, + n_query_rows, + k, + rowMajorIndex, + rowMajorQuery, + handle.get_stream(), + metric); +} + } // namespace raft::neighbors::brute_force diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index b22d10bf54..8df193d53d 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -19,10 +19,11 @@ #include #include +#include #include +#include #include #include -#include #include #if defined RAFT_NN_COMPILED @@ -131,18 +132,17 @@ class FusedL2KNNTest : public ::testing::TestWithParam { void testBruteForce() { launchFaissBfknn(); - detail::fusedL2Knn(dim, - raft_indices_.data(), - raft_distances_.data(), - database.data(), - search_queries.data(), - num_db_vecs, - num_queries, - k_, - true, - true, - stream_, - metric); + + auto index_view = + raft::make_device_matrix_view(database.data(), num_db_vecs, dim); + auto query_view = + raft::make_device_matrix_view(search_queries.data(), num_queries, dim); + auto out_indices_view = + raft::make_device_matrix_view(raft_indices_.data(), num_queries, k_); + auto out_dists_view = + raft::make_device_matrix_view(raft_distances_.data(), num_queries, k_); + raft::neighbors::brute_force::fused_l2_knn( + handle_, index_view, query_view, out_indices_view, out_dists_view, metric); // verify. devArrMatchKnnPair(faiss_indices_.data(), From 5004f2d502f420e90ee2badfb28fca37b1b32559 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 27 Oct 2022 22:32:35 -0400 Subject: [PATCH 2/4] Fixing docs --- cpp/include/raft/neighbors/brute_force.cuh | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 4e5d7d084d..4e673c1372 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -200,13 +200,30 @@ void knn(raft::handle_t const& handle, * computation when k < 64. The value of k will be inferred from the number * of columns in the output matrices. * - * @tparam value_t - * @tparam idx_t, + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::handle_t handle; + * ... + * int k = 10; + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); + * @endcode + + * @tparam value_t type of values + * @tparam idx_t type of indices + * @tparam idx_layout layout type of index matrix + * @tparam query_layout layout type of query matrix * @param[in] handle raft handle for sharing expensive resources * @param[in] index input index array on device (size m * d) * @param[in] query input query array on device (size n * d) * @param[out] out_inds output indices array on device (size n * k) * @param[out] out_dists output dists array on device (size n * k) + * @param[in] metric type of distance computation to perform (must be a variant of L2) */ template void fused_l2_knn(const raft::handle_t& handle, From 3726fde12eb7dbb8c2409709b73f99ed6f56d5a5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 28 Oct 2022 13:09:39 -0400 Subject: [PATCH 3/4] Consexpt --- cpp/include/raft/neighbors/brute_force.cuh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 4e673c1372..8f2e5c5ddc 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -209,7 +209,6 @@ void knn(raft::handle_t const& handle, * * raft::handle_t handle; * ... - * int k = 10; * auto metric = raft::distance::DistanceType::L2SqrtExpanded; * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); * @endcode @@ -233,7 +232,7 @@ void fused_l2_knn(const raft::handle_t& handle, raft::device_matrix_view out_dists, raft::distance::DistanceType metric) { - int k = out_inds.extent(1); + int k = static_cast(out_inds.extent(1)); RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64"); RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs"); @@ -253,8 +252,8 @@ void fused_l2_knn(const raft::handle_t& handle, RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); - bool rowMajorIndex = raft::is_row_major(index); - bool rowMajorQuery = raft::is_row_major(query); + constexpr bool rowMajorIndex = raft::is_row_major(index); + constexpr bool rowMajorQuery = raft::is_row_major(query); raft::spatial::knn::detail::fusedL2Knn(D, out_inds.data_handle(), From fb384693f3dff2b582079c11282363a4d301dbf7 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 28 Oct 2022 13:58:16 -0400 Subject: [PATCH 4/4] Fixing compile error --- cpp/include/raft/neighbors/brute_force.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 8f2e5c5ddc..96cd5f11c5 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -252,8 +252,8 @@ void fused_l2_knn(const raft::handle_t& handle, RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); - constexpr bool rowMajorIndex = raft::is_row_major(index); - constexpr bool rowMajorQuery = raft::is_row_major(query); + const bool rowMajorIndex = raft::is_row_major(index); + const bool rowMajorQuery = raft::is_row_major(query); raft::spatial::knn::detail::fusedL2Knn(D, out_inds.data_handle(),