Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposing fused l2 knn to public APIs #959

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/spatial/knn/detail/knn_brute_force_faiss.cuh>
#include <raft/spatial/knn/detail/selection_faiss.cuh>

Expand Down Expand Up @@ -70,7 +72,7 @@ namespace raft::neighbors::brute_force {
* @param[in] n_samples number of rows in each partition
* @param[in] translations optional vector of starting global id mappings for each local partition
*/
template <typename idx_t, typename value_t>
template <typename value_t, typename idx_t>
inline void knn_merge_parts(
const raft::handle_t& handle,
raft::device_matrix_view<const value_t, idx_t, row_major> in_keys,
Expand Down Expand Up @@ -191,4 +193,81 @@ void knn(raft::handle_t const& handle,
metric_arg.value_or(2.0f));
}

/**
* @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance.
*
* This is a specialized function for fusing the k-selection with the distance
* computation when k < 64. The value of k will be inferred from the number
* of columns in the output matrices.
*
* Usage example:
* @code{.cpp}
* #include <raft/core/handle.hpp>
* #include <raft/neighbors/brute_force.cuh>
* #include <raft/distance/distance_types.hpp>
* using namespace raft::neighbors;
*
* raft::handle_t handle;
* ...
* int k = 10;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
* auto metric = raft::distance::DistanceType::L2SqrtExpanded;
* brute_force::fused_l2_knn(handle, index, search, indices, distances, metric);
* @endcode

* @tparam value_t type of values
* @tparam idx_t type of indices
* @tparam idx_layout layout type of index matrix
* @tparam query_layout layout type of query matrix
* @param[in] handle raft handle for sharing expensive resources
* @param[in] index input index array on device (size m * d)
* @param[in] query input query array on device (size n * d)
* @param[out] out_inds output indices array on device (size n * k)
* @param[out] out_dists output dists array on device (size n * k)
* @param[in] metric type of distance computation to perform (must be a variant of L2)
*/
template <typename value_t, typename idx_t, typename idx_layout, typename query_layout>
void fused_l2_knn(const raft::handle_t& handle,
raft::device_matrix_view<const value_t, idx_t, idx_layout> index,
raft::device_matrix_view<const value_t, idx_t, query_layout> query,
raft::device_matrix_view<idx_t, idx_t, row_major> out_inds,
raft::device_matrix_view<value_t, idx_t, row_major> out_dists,
raft::distance::DistanceType metric)
{
int k = out_inds.extent(1);
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64");
RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs");
RAFT_EXPECTS(index.extent(1) == query.extent(1),
"Number of columns in input matrices must be the same.");

RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded ||
metric == distance::DistanceType::L2Unexpanded ||
metric == distance::DistanceType::L2SqrtUnexpanded ||
metric == distance::DistanceType::L2SqrtExpanded,
"Distance metric must be L2");

size_t n_index_rows = index.extent(0);
size_t n_query_rows = query.extent(0);
size_t D = index.extent(1);

RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout");
RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout");

bool rowMajorIndex = raft::is_row_major(index);
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
bool rowMajorQuery = raft::is_row_major(query);
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

raft::spatial::knn::detail::fusedL2Knn(D,
out_inds.data_handle(),
out_dists.data_handle(),
index.data_handle(),
query.data_handle(),
n_index_rows,
n_query_rows,
k,
rowMajorIndex,
rowMajorQuery,
handle.get_stream(),
metric);
}

} // namespace raft::neighbors::brute_force
26 changes: 13 additions & 13 deletions cpp/test/neighbors/fused_l2_knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/StandardGpuResources.h>

#include <raft/core/device_mdspan.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/brute_force.cuh>
#include <raft/random/rng.cuh>
#include <raft/spatial/knn/detail/common_faiss.h>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/spatial/knn/knn.cuh>

#if defined RAFT_NN_COMPILED
Expand Down Expand Up @@ -131,18 +132,17 @@ class FusedL2KNNTest : public ::testing::TestWithParam<FusedL2KNNInputs> {
void testBruteForce()
{
launchFaissBfknn();
detail::fusedL2Knn(dim,
raft_indices_.data(),
raft_distances_.data(),
database.data(),
search_queries.data(),
num_db_vecs,
num_queries,
k_,
true,
true,
stream_,
metric);

auto index_view =
raft::make_device_matrix_view<const T, int64_t>(database.data(), num_db_vecs, dim);
auto query_view =
raft::make_device_matrix_view<const T, int64_t>(search_queries.data(), num_queries, dim);
auto out_indices_view =
raft::make_device_matrix_view<int64_t, int64_t>(raft_indices_.data(), num_queries, k_);
auto out_dists_view =
raft::make_device_matrix_view<T, int64_t>(raft_distances_.data(), num_queries, k_);
raft::neighbors::brute_force::fused_l2_knn(
handle_, index_view, query_view, out_indices_view, out_dists_view, metric);

// verify.
devArrMatchKnnPair(faiss_indices_.data(),
Expand Down