Skip to content

Commit

Permalink
use specializations in RBC code
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Feb 14, 2023
1 parent d97ddb8 commit 5905b2d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 34 deletions.
22 changes: 7 additions & 15 deletions cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh>

#include <raft/matrix/matrix.cuh>
#include <raft/neighbors/detail/knn_brute_force.cuh>
#include <raft/neighbors/brute_force.cuh>
#include <raft/random/rng.cuh>
#include <raft/sparse/convert/csr.cuh>

Expand Down Expand Up @@ -178,23 +178,15 @@ void k_closest_landmarks(raft::device_resources const& handle,
value_idx* R_knn_inds,
value_t* R_knn_dists)
{
// TODO: Add const to the brute-force knn inputs
std::vector<value_t*> input = {const_cast<value_t*>(index.get_R().data_handle())};
std::vector<value_int> sizes = {index.n_landmarks};
std::vector<raft::device_matrix_view<const value_t, value_int>> inputs = {index.get_R()};

raft::neighbors::detail::brute_force_knn_impl<value_int, value_idx>(
raft::neighbors::brute_force::knn<value_idx, value_t, value_int>(
handle,
input,
sizes,
index.n,
const_cast<value_t*>(query_pts),
n_query_pts,
R_knn_inds,
R_knn_dists,
inputs,
make_device_matrix_view(query_pts, n_query_pts, inputs[0].extent(1)),
make_device_matrix_view(R_knn_inds, n_query_pts, k),
make_device_matrix_view(R_knn_dists, n_query_pts, k),
k,
true,
true,
nullptr,
index.get_metric());
}

Expand Down
30 changes: 11 additions & 19 deletions cpp/test/neighbors/ball_cover.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/ball_cover.cuh>
#include <raft/neighbors/detail/knn_brute_force.cuh>
#include <raft/neighbors/brute_force.cuh>
#include <raft/random/make_blobs.cuh>
#include <raft/util/cudart_utils.hpp>
#if defined RAFT_NN_COMPILED
Expand Down Expand Up @@ -112,24 +112,16 @@ void compute_bfknn(const raft::device_resources& handle,
value_t* dists,
int64_t* inds)
{
std::vector<value_t*> input_vec = {const_cast<value_t*>(X1)};
std::vector<uint32_t> sizes_vec = {n_rows};

std::vector<int64_t>* translations = nullptr;

raft::neighbors::detail::brute_force_knn_impl<uint32_t, int64_t>(handle,
input_vec,
sizes_vec,
d,
const_cast<value_t*>(X2),
n_query_rows,
inds,
dists,
k,
true,
true,
translations,
metric);
std::vector<raft::device_matrix_view<const value_t, uint32_t>> input_vec = {
make_device_matrix_view(X1, n_rows, d)};

raft::neighbors::brute_force::knn(handle,
input_vec,
make_device_matrix_view(X2, n_query_rows, d),
make_device_matrix_view(inds, n_query_rows, k),
make_device_matrix_view(dists, n_query_rows, k),
k,
metric);
}

struct ToRadians {
Expand Down

0 comments on commit 5905b2d

Please sign in to comment.