Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RBC eps-neighborhood query performance #2211

Merged
merged 13 commits into from
Mar 11, 2024
4 changes: 2 additions & 2 deletions cpp/include/raft/neighbors/ball_cover-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ void eps_nn(raft::resources const& handle,
query.extent(0),
adj.data_handle(),
vd.data_handle(),
spatial::knn::detail::EuclideanFunc<value_t, int_t>());
spatial::knn::detail::EuclideanSqFunc<value_t, int_t>());
}

/**
Expand Down Expand Up @@ -391,7 +391,7 @@ void eps_nn(raft::resources const& handle,
adj_ia.data_handle(),
adj_ja.data_handle(),
vd.data_handle(),
spatial::knn::detail::EuclideanFunc<value_t, int_t>());
spatial::knn::detail::EuclideanSqFunc<value_t, int_t>());
}

/**
Expand Down
12 changes: 12 additions & 0 deletions cpp/include/raft/neighbors/ball_cover_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class BallCoverIndex {
R_1nn_dists(raft::make_device_vector<value_t, matrix_idx>(handle, m_)),
R_closest_landmark_dists(raft::make_device_vector<value_t, matrix_idx>(handle, m_)),
R(raft::make_device_matrix<value_t, matrix_idx>(handle, sqrt(m_), n_)),
X_reordered(raft::make_device_matrix<value_t, matrix_idx>(handle, m_, n_)),
R_radius(raft::make_device_vector<value_t, matrix_idx>(handle, sqrt(m_))),
index_trained(false)
{
Expand All @@ -91,6 +92,8 @@ class BallCoverIndex {
R_1nn_dists(raft::make_device_vector<value_t, matrix_idx>(handle, X_.extent(0))),
R_closest_landmark_dists(raft::make_device_vector<value_t, matrix_idx>(handle, X_.extent(0))),
R(raft::make_device_matrix<value_t, matrix_idx>(handle, sqrt(X_.extent(0)), X_.extent(1))),
X_reordered(
raft::make_device_matrix<value_t, matrix_idx>(handle, X_.extent(0), X_.extent(1))),
R_radius(raft::make_device_vector<value_t, matrix_idx>(handle, sqrt(X_.extent(0)))),
index_trained(false)
{
Expand Down Expand Up @@ -120,6 +123,10 @@ class BallCoverIndex {
{
return R_closest_landmark_dists.view();
}
auto get_X_reordered() const -> raft::device_matrix_view<const value_t, matrix_idx, row_major>
{
return X_reordered.view();
}

raft::device_vector_view<value_idx, matrix_idx> get_R_indptr() { return R_indptr.view(); }
raft::device_vector_view<value_idx, matrix_idx> get_R_1nn_cols() { return R_1nn_cols.view(); }
Expand All @@ -130,6 +137,10 @@ class BallCoverIndex {
{
return R_closest_landmark_dists.view();
}
raft::device_matrix_view<value_t, matrix_idx, row_major> get_X_reordered()
{
return X_reordered.view();
}
raft::device_matrix_view<const value_t, matrix_idx, row_major> get_X() const { return X; }

raft::distance::DistanceType get_metric() const { return metric; }
Expand Down Expand Up @@ -160,6 +171,7 @@ class BallCoverIndex {
raft::device_vector<value_t, matrix_idx> R_radius;

raft::device_matrix<value_t, matrix_idx, row_major> R;
raft::device_matrix<value_t, matrix_idx, row_major> X_reordered;

protected:
bool index_trained;
Expand Down
34 changes: 10 additions & 24 deletions cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ void construct_landmark_1nn(raft::resources const& handle,
index.get_R_indptr().data_handle(),
index.n_landmarks + 1,
resource::get_cuda_stream(handle));

// reorder X to allow aligned access
raft::matrix::copy_rows<value_t, value_idx>(
handle, index.get_X(), index.get_X_reordered(), index.get_R_1nn_cols());
}

/**
Expand Down Expand Up @@ -339,12 +343,6 @@ void perform_rbc_query(raft::resources const& handle,
/**
* Perform eps-select
*
* a. Map 1 row to each warp/block
* b. Add closest k R points to heap
* c. Iterate through batches of R, having each thread in the warp load a set
* of distances y from R (only if d(q, r) < 3 * distance to closest r) and
* marking the distance to be computed between x, y only
* if knn[k].distance >= d(x_i, R_k) + d(R_k, y)
*/
template <typename value_idx,
typename value_t,
Expand All @@ -357,7 +355,7 @@ void perform_rbc_eps_nn_query(
const value_t* query,
value_int n_query_pts,
value_t eps,
const value_t* landmark_dists,
const value_t* landmarks,
dist_func dfunc,
bool* adj,
value_idx* vd)
Expand All @@ -369,7 +367,7 @@ void perform_rbc_eps_nn_query(
resource::sync_stream(handle);

rbc_eps_pass<value_idx, value_t, value_int, matrix_idx>(
handle, index, query, n_query_pts, eps, landmark_dists, dfunc, adj, vd);
handle, index, query, n_query_pts, eps, landmarks, dfunc, adj, vd);

resource::sync_stream(handle);
}
Expand All @@ -386,14 +384,14 @@ void perform_rbc_eps_nn_query(
value_int n_query_pts,
value_t eps,
value_int* max_k,
const value_t* landmark_dists,
const value_t* landmarks,
dist_func dfunc,
value_idx* adj_ia,
value_idx* adj_ja,
value_idx* vd)
{
rbc_eps_pass<value_idx, value_t, value_int, matrix_idx>(
handle, index, query, n_query_pts, eps, max_k, landmark_dists, dfunc, adj_ia, adj_ja, vd);
handle, index, query, n_query_pts, eps, max_k, landmarks, dfunc, adj_ia, adj_ja, vd);

resource::sync_stream(handle);
}
Expand Down Expand Up @@ -666,15 +664,9 @@ void rbc_eps_nn_query(raft::resources const& handle,
{
ASSERT(index.is_index_trained(), "index must be previously trained");

auto R_dists =
raft::make_device_matrix<value_t, matrix_idx>(handle, index.n_landmarks, n_query_pts);

// find all landmarks that might have points in range
compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle());

// query all points and write to adj
perform_rbc_eps_nn_query(
handle, index, query, n_query_pts, eps, R_dists.data_handle(), dfunc, adj, vd);
handle, index, query, n_query_pts, eps, index.get_R().data_handle(), dfunc, adj, vd);
}

template <typename value_idx = std::int64_t,
Expand All @@ -695,20 +687,14 @@ void rbc_eps_nn_query(raft::resources const& handle,
{
ASSERT(index.is_index_trained(), "index must be previously trained");

auto R_dists =
raft::make_device_matrix<value_t, matrix_idx>(handle, index.n_landmarks, n_query_pts);

// find all landmarks that might have points in range
compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle());

// query all points and write to adj
perform_rbc_eps_nn_query(handle,
index,
query,
n_query_pts,
eps,
max_k,
R_dists.data_handle(),
index.get_R().data_handle(),
dfunc,
adj_ia,
adj_ja,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(
std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::DistFunc);

instantiate_raft_spatial_knn_detail_rbc_eps_pass(
std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc);
std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanSqFunc);

#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two
#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one
Expand Down
Loading
Loading