diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 772ccb67d2..96cd5f11c5 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -17,6 +17,8 @@ #pragma once #include +#include +#include #include #include @@ -70,7 +72,7 @@ namespace raft::neighbors::brute_force { * @param[in] n_samples number of rows in each partition * @param[in] translations optional vector of starting global id mappings for each local partition */ -template +template inline void knn_merge_parts( const raft::handle_t& handle, raft::device_matrix_view in_keys, @@ -191,4 +193,80 @@ void knn(raft::handle_t const& handle, metric_arg.value_or(2.0f)); } +/** + * @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance. + * + * This is a specialized function for fusing the k-selection with the distance + * computation when k < 64. The value of k will be inferred from the number + * of columns in the output matrices. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::handle_t handle; + * ... + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); + * @endcode + + * @tparam value_t type of values + * @tparam idx_t type of indices + * @tparam idx_layout layout type of index matrix + * @tparam query_layout layout type of query matrix + * @param[in] handle raft handle for sharing expensive resources + * @param[in] index input index array on device (size m * d) + * @param[in] query input query array on device (size n * d) + * @param[out] out_inds output indices array on device (size n * k) + * @param[out] out_dists output dists array on device (size n * k) + * @param[in] metric type of distance computation to perform (must be a variant of L2) + */ +template +void fused_l2_knn(const raft::handle_t& handle, + raft::device_matrix_view index, + raft::device_matrix_view query, + raft::device_matrix_view out_inds, + raft::device_matrix_view out_dists, + raft::distance::DistanceType metric) +{ + int k = static_cast(out_inds.extent(1)); + + RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64"); + RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(index.extent(1) == query.extent(1), + "Number of columns in input matrices must be the same."); + + RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || + metric == distance::DistanceType::L2Unexpanded || + metric == distance::DistanceType::L2SqrtUnexpanded || + metric == distance::DistanceType::L2SqrtExpanded, + "Distance metric must be L2"); + + size_t n_index_rows = index.extent(0); + size_t n_query_rows = query.extent(0); + size_t D = index.extent(1); + + RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); + RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); + + const bool rowMajorIndex = raft::is_row_major(index); + const bool rowMajorQuery = raft::is_row_major(query); + + raft::spatial::knn::detail::fusedL2Knn(D, + out_inds.data_handle(), + out_dists.data_handle(), + index.data_handle(), + query.data_handle(), + n_index_rows, + n_query_rows, + k, + rowMajorIndex, + rowMajorQuery, + handle.get_stream(), + metric); +} + } // namespace raft::neighbors::brute_force diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index b22d10bf54..8df193d53d 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -19,10 +19,11 @@ #include #include +#include #include +#include #include #include -#include #include #if defined RAFT_NN_COMPILED @@ -131,18 +132,17 @@ class FusedL2KNNTest : public ::testing::TestWithParam { void testBruteForce() { launchFaissBfknn(); - detail::fusedL2Knn(dim, - raft_indices_.data(), - raft_distances_.data(), - database.data(), - search_queries.data(), - num_db_vecs, - num_queries, - k_, - true, - true, - stream_, - metric); + + auto index_view = + raft::make_device_matrix_view(database.data(), num_db_vecs, dim); + auto query_view = + raft::make_device_matrix_view(search_queries.data(), num_queries, dim); + auto out_indices_view = + raft::make_device_matrix_view(raft_indices_.data(), num_queries, k_); + auto out_dists_view = + raft::make_device_matrix_view(raft_distances_.data(), num_queries, k_); + raft::neighbors::brute_force::fused_l2_knn( + handle_, index_view, query_view, out_indices_view, out_dists_view, metric); // verify. devArrMatchKnnPair(faiss_indices_.data(),