Skip to content

Commit

Permalink
use matrix::select_k in bfknn call
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Feb 14, 2023
1 parent 8eaba84 commit fe728e9
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
#include <raft/core/device_resources.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/matrix/select_k.cuh>
#include <raft/spatial/knn/detail/faiss_select/DistanceUtils.h>
#include <raft/spatial/knn/detail/faiss_select/Select.cuh>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/spatial/knn/detail/haversine_distance.cuh>
#include <raft/spatial/knn/detail/processing.cuh>
#include <raft/spatial/knn/detail/selection_faiss.cuh>
#include <set>
#include <thrust/iterator/transform_iterator.h>

Expand Down Expand Up @@ -214,15 +213,16 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
true,
metric_arg);

detail::select_k<IndexType, ElementType>(temp_distances.data(),
nullptr,
current_query_size,
current_centroid_size,
distances + i * k,
indices + i * k,
true,
current_k,
stream);
matrix::select_k<ElementType, IndexType>(
handle,
raft::make_device_matrix_view<const ElementType, size_t, row_major>(
temp_distances.data(), current_query_size, current_centroid_size),
std::nullopt,
raft::make_device_matrix_view<ElementType, size_t, row_major>(
distances + i * k, current_query_size, k),
raft::make_device_matrix_view<IndexType, size_t, row_major>(
indices + i * k, current_query_size, k),
true);

// if we're tiling over columns, we need to do a couple things to fix up
// the output of select_k
Expand Down Expand Up @@ -254,15 +254,17 @@ void tiled_brute_force_knn(const raft::device_resources& handle,

if (tile_cols != n) {
// select the actual top-k items here from the temporary output
detail::select_k<IndexType, ElementType>(temp_out_distances.data(),
temp_out_indices.data(),
current_query_size,
temp_out_cols,
distances + i * k,
indices + i * k,
true,
k,
stream);
matrix::select_k<ElementType, IndexType>(
handle,
raft::make_device_matrix_view<const ElementType, size_t, row_major>(
temp_out_distances.data(), current_query_size, temp_out_cols),
raft::make_device_matrix_view<const IndexType, size_t, row_major>(
temp_out_indices.data(), current_query_size, temp_out_cols),
raft::make_device_matrix_view<ElementType, size_t, row_major>(
distances + i * k, current_query_size, k),
raft::make_device_matrix_view<IndexType, size_t, row_major>(
indices + i * k, current_query_size, k),
true);
}
}
}
Expand Down

0 comments on commit fe728e9

Please sign in to comment.