Skip to content

Commit

Permalink
Exposing fused l2 knn to public APIs (#959)
Browse files Browse the repository at this point in the history
This is needed for the FAISS support. Since we wrap the fused l2 knn inside our brute-force call, which depends on FAISS currently.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #959
  • Loading branch information
cjnolet authored Oct 28, 2022
1 parent d199a9f commit 37c1b3d
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 14 deletions.
80 changes: 79 additions & 1 deletion cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/spatial/knn/detail/knn_brute_force_faiss.cuh>
#include <raft/spatial/knn/detail/selection_faiss.cuh>

Expand Down Expand Up @@ -70,7 +72,7 @@ namespace raft::neighbors::brute_force {
* @param[in] n_samples number of rows in each partition
* @param[in] translations optional vector of starting global id mappings for each local partition
*/
template <typename idx_t, typename value_t>
template <typename value_t, typename idx_t>
inline void knn_merge_parts(
const raft::handle_t& handle,
raft::device_matrix_view<const value_t, idx_t, row_major> in_keys,
Expand Down Expand Up @@ -191,4 +193,80 @@ void knn(raft::handle_t const& handle,
metric_arg.value_or(2.0f));
}

/**
* @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance.
*
* This is a specialized function for fusing the k-selection with the distance
* computation when k < 64. The value of k will be inferred from the number
* of columns in the output matrices.
*
* Usage example:
* @code{.cpp}
* #include <raft/core/handle.hpp>
* #include <raft/neighbors/brute_force.cuh>
* #include <raft/distance/distance_types.hpp>
* using namespace raft::neighbors;
*
* raft::handle_t handle;
* ...
* auto metric = raft::distance::DistanceType::L2SqrtExpanded;
* brute_force::fused_l2_knn(handle, index, search, indices, distances, metric);
* @endcode
* @tparam value_t type of values
* @tparam idx_t type of indices
* @tparam idx_layout layout type of index matrix
* @tparam query_layout layout type of query matrix
* @param[in] handle raft handle for sharing expensive resources
* @param[in] index input index array on device (size m * d)
* @param[in] query input query array on device (size n * d)
* @param[out] out_inds output indices array on device (size n * k)
* @param[out] out_dists output dists array on device (size n * k)
* @param[in] metric type of distance computation to perform (must be a variant of L2)
*/
template <typename value_t, typename idx_t, typename idx_layout, typename query_layout>
void fused_l2_knn(const raft::handle_t& handle,
raft::device_matrix_view<const value_t, idx_t, idx_layout> index,
raft::device_matrix_view<const value_t, idx_t, query_layout> query,
raft::device_matrix_view<idx_t, idx_t, row_major> out_inds,
raft::device_matrix_view<value_t, idx_t, row_major> out_dists,
raft::distance::DistanceType metric)
{
int k = static_cast<int>(out_inds.extent(1));

RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64");
RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs");
RAFT_EXPECTS(index.extent(1) == query.extent(1),
"Number of columns in input matrices must be the same.");

RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded ||
metric == distance::DistanceType::L2Unexpanded ||
metric == distance::DistanceType::L2SqrtUnexpanded ||
metric == distance::DistanceType::L2SqrtExpanded,
"Distance metric must be L2");

size_t n_index_rows = index.extent(0);
size_t n_query_rows = query.extent(0);
size_t D = index.extent(1);

RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout");
RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout");

const bool rowMajorIndex = raft::is_row_major(index);
const bool rowMajorQuery = raft::is_row_major(query);

raft::spatial::knn::detail::fusedL2Knn(D,
out_inds.data_handle(),
out_dists.data_handle(),
index.data_handle(),
query.data_handle(),
n_index_rows,
n_query_rows,
k,
rowMajorIndex,
rowMajorQuery,
handle.get_stream(),
metric);
}

} // namespace raft::neighbors::brute_force
26 changes: 13 additions & 13 deletions cpp/test/neighbors/fused_l2_knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/StandardGpuResources.h>

#include <raft/core/device_mdspan.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/brute_force.cuh>
#include <raft/random/rng.cuh>
#include <raft/spatial/knn/detail/common_faiss.h>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/spatial/knn/knn.cuh>

#if defined RAFT_NN_COMPILED
Expand Down Expand Up @@ -131,18 +132,17 @@ class FusedL2KNNTest : public ::testing::TestWithParam<FusedL2KNNInputs> {
void testBruteForce()
{
launchFaissBfknn();
detail::fusedL2Knn(dim,
raft_indices_.data(),
raft_distances_.data(),
database.data(),
search_queries.data(),
num_db_vecs,
num_queries,
k_,
true,
true,
stream_,
metric);

auto index_view =
raft::make_device_matrix_view<const T, int64_t>(database.data(), num_db_vecs, dim);
auto query_view =
raft::make_device_matrix_view<const T, int64_t>(search_queries.data(), num_queries, dim);
auto out_indices_view =
raft::make_device_matrix_view<int64_t, int64_t>(raft_indices_.data(), num_queries, k_);
auto out_dists_view =
raft::make_device_matrix_view<T, int64_t>(raft_distances_.data(), num_queries, k_);
raft::neighbors::brute_force::fused_l2_knn(
handle_, index_view, query_view, out_indices_view, out_dists_view, metric);

// verify.
devArrMatchKnnPair(faiss_indices_.data(),
Expand Down

0 comments on commit 37c1b3d

Please sign in to comment.