From 2179ed61000412d50b86c08fa673276725b71dbe Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 4 Mar 2024 22:50:44 +0000 Subject: [PATCH 01/10] optimize rbc eps-NN query --- cpp/include/raft/neighbors/ball_cover-inl.cuh | 4 +- .../raft/neighbors/ball_cover_types.hpp | 12 + .../raft/spatial/knn/detail/ball_cover.cuh | 42 +- .../knn/detail/ball_cover/registers-ext.cuh | 2 +- .../knn/detail/ball_cover/registers-inl.cuh | 690 ++++++++++++++---- .../knn/detail/ball_cover/registers_types.cuh | 17 +- .../ball_cover/registers_00_generate.py | 4 +- .../registers_eps_pass_euclidean.cu | 2 +- 8 files changed, 583 insertions(+), 190 deletions(-) diff --git a/cpp/include/raft/neighbors/ball_cover-inl.cuh b/cpp/include/raft/neighbors/ball_cover-inl.cuh index cdf7c30e89..0c202865ca 100644 --- a/cpp/include/raft/neighbors/ball_cover-inl.cuh +++ b/cpp/include/raft/neighbors/ball_cover-inl.cuh @@ -332,7 +332,7 @@ void eps_nn(raft::resources const& handle, query.extent(0), adj.data_handle(), vd.data_handle(), - spatial::knn::detail::EuclideanFunc()); + spatial::knn::detail::EuclideanSqFunc()); } /** @@ -391,7 +391,7 @@ void eps_nn(raft::resources const& handle, adj_ia.data_handle(), adj_ja.data_handle(), vd.data_handle(), - spatial::knn::detail::EuclideanFunc()); + spatial::knn::detail::EuclideanSqFunc()); } /** diff --git a/cpp/include/raft/neighbors/ball_cover_types.hpp b/cpp/include/raft/neighbors/ball_cover_types.hpp index dc96f0d45b..0b3d6ec51c 100644 --- a/cpp/include/raft/neighbors/ball_cover_types.hpp +++ b/cpp/include/raft/neighbors/ball_cover_types.hpp @@ -67,6 +67,7 @@ class BallCoverIndex { R_1nn_dists(raft::make_device_vector(handle, m_)), R_closest_landmark_dists(raft::make_device_vector(handle, m_)), R(raft::make_device_matrix(handle, sqrt(m_), n_)), + X_reordered(raft::make_device_matrix(handle, m_, n_)), R_radius(raft::make_device_vector(handle, sqrt(m_))), index_trained(false) { @@ -91,6 +92,8 @@ class BallCoverIndex { R_1nn_dists(raft::make_device_vector(handle, X_.extent(0))), R_closest_landmark_dists(raft::make_device_vector(handle, X_.extent(0))), R(raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1))), + X_reordered( + raft::make_device_matrix(handle, X_.extent(0), X_.extent(1))), R_radius(raft::make_device_vector(handle, sqrt(X_.extent(0)))), index_trained(false) { @@ -120,6 +123,10 @@ class BallCoverIndex { { return R_closest_landmark_dists.view(); } + auto get_X_reordered() const -> raft::device_matrix_view + { + return X_reordered.view(); + } raft::device_vector_view get_R_indptr() { return R_indptr.view(); } raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } @@ -130,6 +137,10 @@ class BallCoverIndex { { return R_closest_landmark_dists.view(); } + raft::device_matrix_view get_X_reordered() + { + return X_reordered.view(); + } raft::device_matrix_view get_X() const { return X; } raft::distance::DistanceType get_metric() const { return metric; } @@ -160,6 +171,7 @@ class BallCoverIndex { raft::device_vector R_radius; raft::device_matrix R; + raft::device_matrix X_reordered; protected: bool index_trained; diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 879f54fd81..b8a622bf0d 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -167,6 +167,10 @@ void construct_landmark_1nn(raft::resources const& handle, index.get_R_indptr().data_handle(), index.n_landmarks + 1, resource::get_cuda_stream(handle)); + + // reorder X to allow aligned access + raft::matrix::copy_rows( + handle, index.get_X(), index.get_X_reordered(), index.get_R_1nn_cols()); } /** @@ -339,12 +343,6 @@ void perform_rbc_query(raft::resources const& handle, /** * Perform eps-select * - * a. Map 1 row to each warp/block - * b. Add closest k R points to heap - * c. Iterate through batches of R, having each thread in the warp load a set - * of distances y from R (only if d(q, r) < 3 * distance to closest r) and - * marking the distance to be computed between x, y only - * if knn[k].distance >= d(x_i, R_k) + d(R_k, y) */ template ( - handle, index, query, n_query_pts, eps, landmark_dists, dfunc, adj, vd); + handle, index, query, n_query_pts, eps, landmarks, dfunc, adj, vd); resource::sync_stream(handle); } @@ -386,14 +384,14 @@ void perform_rbc_eps_nn_query( value_int n_query_pts, value_t eps, value_int* max_k, - const value_t* landmark_dists, + const value_t* landmarks, dist_func dfunc, value_idx* adj_ia, value_idx* adj_ja, value_idx* vd) { rbc_eps_pass( - handle, index, query, n_query_pts, eps, max_k, landmark_dists, dfunc, adj_ia, adj_ja, vd); + handle, index, query, n_query_pts, eps, max_k, landmarks, dfunc, adj_ia, adj_ja, vd); resource::sync_stream(handle); } @@ -418,6 +416,14 @@ void rbc_build_index(raft::resources const& handle, BallCoverIndex& index, distance_func dfunc) { + { + /** flush the L2 cache - Hopper at 50MB */ + size_t l2_cache_size = 50 * 1024 * 1024; + auto scratch_buf_ = rmm::device_buffer(l2_cache_size * 3, resource::get_cuda_stream(handle)); + RAFT_CUDA_TRY(cudaMemsetAsync( + scratch_buf_.data(), 0, scratch_buf_.size(), resource::get_cuda_stream(handle))); + } + ASSERT(!index.is_index_trained(), "index cannot be previously trained"); rmm::device_uvector R_knn_inds(index.m, resource::get_cuda_stream(handle)); @@ -666,15 +672,9 @@ void rbc_eps_nn_query(raft::resources const& handle, { ASSERT(index.is_index_trained(), "index must be previously trained"); - auto R_dists = - raft::make_device_matrix(handle, index.n_landmarks, n_query_pts); - - // find all landmarks that might have points in range - compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle()); - // query all points and write to adj perform_rbc_eps_nn_query( - handle, index, query, n_query_pts, eps, R_dists.data_handle(), dfunc, adj, vd); + handle, index, query, n_query_pts, eps, index.get_R().data_handle(), dfunc, adj, vd); } template (handle, index.n_landmarks, n_query_pts); - - // find all landmarks that might have points in range - compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle()); - // query all points and write to adj perform_rbc_eps_nn_query(handle, index, @@ -708,7 +702,7 @@ void rbc_eps_nn_query(raft::resources const& handle, n_query_pts, eps, max_k, - R_dists.data_handle(), + index.get_R().data_handle(), dfunc, adj_ia, adj_ja, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh index 2ed6ee3284..bee983c7b2 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh @@ -188,7 +188,7 @@ instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::DistFunc); instantiate_raft_spatial_knn_detail_rbc_eps_pass( - std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanSqFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 8b4e8f287e..3a9dc11c05 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -461,10 +461,11 @@ template -RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, const value_t* X, + const value_int n_queries, const value_int n_cols, - const value_t* R_dists, + const value_t* R, const value_int m, const value_t eps, const value_int n_landmarks, @@ -476,70 +477,112 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_index, bool* adj, value_idx* vd) { - __shared__ int column_count_smem; + constexpr int num_warps = tpb / WarpSize; - // initialize - if (vd != nullptr) { - if (threadIdx.x == 0) { column_count_smem = 0; } - __syncthreads(); - } + // process 1 query per warp + const uint32_t lid = raft::laneId(); - const value_t* x_ptr = X + (n_cols * blockIdx.x); + // this should help the compiler to prevent branches + const int warp_id = raft::shfl(threadIdx.x / WarpSize, 0); + const int query_id = raft::shfl(blockIdx.x * num_warps + warp_id, 0); - for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { - // TODO: this might also be worth computing in-place here - value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; + // this is an early out for a full warp + if (query_id >= n_queries) return; - // prune all R's that can't be within eps - if (cur_R_dist - R_radius[cur_k] > eps) continue; + unsigned long long int column_count = 0; - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_k]; - value_idx R_stop_offset = R_indptr[cur_k + 1]; - - value_idx R_size = R_stop_offset - R_start_offset; + const value_t* x_ptr = X + (n_cols * query_id); - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - adj[blockIdx.x * m + cur_candidate_ind] = true; - if (vd != nullptr) atomicAdd(&column_count_smem, 1); +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + // prune all R's that can't be within eps + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist = lane_k < n_landmarks + ? raft::sqrt(dfunc(x_ptr, R + lane_k * n_cols, n_cols)) + : std::numeric_limits::max(); + const int lane_check = + lane_k < n_landmarks ? static_cast(lane_R_dist - R_radius[lane_k] <= eps) : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + uint32_t k_offset = __ffs(lane_mask) - 1; + do { + const uint32_t cur_k = cur_k0 + k_offset; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= -(1 << k_offset + 1); + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::shfl(lane_R_dist, k_offset); + + const uint32_t limit = Pow2::roundDown(R_size); + int i = limit + lid; + + // look ahead for next k_offset + k_offset = lane_mask != 0 ? __ffs(lane_mask) - 1 : WarpSize; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + bool skip_following = + i < R_size ? (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps) : false; + { + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + const value_t dist = + (i >= R_size) ? std::numeric_limits::max() : dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i]; + column_count++; + adj[query_id * m + index] = true; + } } - } - - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - adj[blockIdx.x * m + cur_candidate_ind] = true; - if (vd != nullptr) atomicAdd(&column_count_smem, 1); + skip_following = raft::any(skip_following); + if (skip_following) continue; + + i -= WarpSize; + for (; i >= 0 && !skip_following; i -= WarpSize) { + skip_following = (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i]; + column_count++; + adj[query_id * m + index] = true; + } + skip_following = raft::any(skip_following); } - } + } while (k_offset < WarpSize); } if (vd != nullptr) { - __syncthreads(); - if (threadIdx.x == 0) { vd[blockIdx.x] = column_count_smem; } + value_idx row_sum = raft::warpReduce(column_count); + if (lid == 0) vd[query_id] = row_sum; } } template -RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, const value_t* X, + const value_int n_queries, const value_int n_cols, - const value_t* R_dists, + const value_t* R, const value_int m, const value_t eps, const value_int n_landmarks, @@ -551,58 +594,256 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index, value_idx* adj_ia, value_idx* adj_ja) { - const value_t* x_ptr = X + (n_cols * blockIdx.x); + constexpr int num_warps = tpb / WarpSize; - __shared__ unsigned long long int column_index_smem; + // process 1 query per warp + const uint32_t lid = raft::laneId(); + const uint32_t lid_mask = (1 << lid) - 1; - bool pass2 = adj_ja != nullptr; + // this should help the compiler to prevent branches + const int warp_id = raft::shfl(threadIdx.x / WarpSize, 0); + const int query_id = raft::shfl(blockIdx.x * num_warps + warp_id, 0); - // initialize - if (threadIdx.x == 0) { column_index_smem = pass2 ? adj_ia[blockIdx.x] : 0; } + // this is an early out for a full warp + if (query_id >= n_queries) return; - __syncthreads(); + unsigned long long int column_index_offset = write_pass ? adj_ia[query_id] : 0; + + // we have no neighbors to fill for this query + if (write_pass && adj_ia[query_id + 1] == column_index_offset) return; + + const value_t* x_ptr = X + (n_cols * query_id); - for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { - // TODO: this might also be worth computing in-place here - value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations // prune all R's that can't be within eps - if (cur_R_dist - R_radius[cur_k] > eps) continue; + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist = lane_k < n_landmarks + ? raft::sqrt(dfunc(x_ptr, R + lane_k * n_cols, n_cols)) + : std::numeric_limits::max(); + const int lane_check = + lane_k < n_landmarks ? static_cast(lane_R_dist - R_radius[lane_k] <= eps) : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + uint32_t k_offset = __ffs(lane_mask) - 1; + do { + const uint32_t cur_k = cur_k0 + k_offset; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= -(1 << k_offset + 1); + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::shfl(lane_R_dist, k_offset); + + const uint32_t limit = Pow2::roundDown(R_size); + int i = limit + lid; + + // look ahead for next k_offset + k_offset = lane_mask != 0 ? __ffs(lane_mask) - 1 : WarpSize; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + bool skip_following = + i < R_size ? (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps) : false; + { + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + const value_t dist = + (i >= R_size) ? std::numeric_limits::max() : dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + if (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i]; + auto row_pos = column_index_offset + __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + column_index_offset += __popc(mask); + } else { + column_index_offset += (in_range); + } + } - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_k]; - value_idx R_stop_offset = R_indptr[cur_k + 1]; + skip_following = raft::any(skip_following); + if (skip_following) continue; + + i -= WarpSize; + for (; i >= 0 && !skip_following; i -= WarpSize) { + skip_following = (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + if (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i]; + auto row_pos = column_index_offset + __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + column_index_offset += __popc(mask); + } else { + column_index_offset += (in_range); + } - value_idx R_size = R_stop_offset - R_start_offset; + skip_following = raft::any(skip_following); + } + } while (k_offset < WarpSize); + } - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + if (!write_pass) { + value_idx row_sum = raft::warpReduce(column_index_offset); + if (lid == 0) adj_ia[query_id] = row_sum; + } +} + +template +RAFT_KERNEL block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered, + const value_t* X, + const value_int n_queries, + const value_int n_cols, + const value_t* R, + const value_int m, + const value_t eps, + const value_int n_landmarks, + const value_idx* R_indptr, + const value_idx* R_1nn_cols, + const value_t* R_1nn_dists, + const value_t* R_radius, + distance_func dfunc, + value_idx* adj_ia, + value_idx* adj_ja) +{ + constexpr int num_warps = tpb / WarpSize; + + // process 1 query per warp + const uint32_t lid = raft::laneId(); + const uint32_t lid_mask = (1 << lid) - 1; + + // this should help the compiler to prevent branches + const int warp_id = raft::shfl(threadIdx.x / WarpSize, 0); + const int query_id = raft::shfl(blockIdx.x * num_warps + warp_id, 0); - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - auto row_pos = atomicAdd(&column_index_smem, 1); - if (pass2) adj_ja[row_pos] = cur_candidate_ind; + // this is an early out for a full warp + if (query_id >= n_queries) return; + + unsigned long long int column_index_offset = write_pass ? adj_ia[query_id] : 0; + + // we have no neighbors to fill for this query + if (write_pass && adj_ia[query_id + 1] == column_index_offset) return; + + const value_t* x_ptr = X + (dim * query_id); + value_t local_x_ptr[dim]; +#pragma unroll + for (uint32_t i = 0; i < dim; ++i) { + local_x_ptr[i] = x_ptr[i]; + } + + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; + +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + // prune all R's that can't be within eps + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist = lane_k < n_landmarks + ? raft::sqrt(dfunc(local_x_ptr, R + lane_k * dim, dim)) + : std::numeric_limits::max(); + const int lane_check = + lane_k < n_landmarks ? static_cast(lane_R_dist - R_radius[lane_k] <= eps) : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + uint32_t k_offset = __ffs(lane_mask) - 1; + do { + const uint32_t cur_k = cur_k0 + k_offset; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= -(1 << k_offset + 1); + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::shfl(lane_R_dist, k_offset); + + const uint32_t limit = Pow2::roundDown(R_size); + int i = limit + lid; + + // look ahead for next k_offset + k_offset = lane_mask != 0 ? __ffs(lane_mask) - 1 : WarpSize; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + bool skip_following = + i < R_size ? (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps) : false; + { + const value_t* y_ptr = X_reordered + (dim * (R_start_offset + i)); + const value_t dist = + (i >= R_size) ? std::numeric_limits::max() : dfunc(local_x_ptr, y_ptr, dim); + const bool in_range = (dist <= eps2); + if (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i]; + auto row_pos = column_index_offset + __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + column_index_offset += __popc(mask); + } else { + column_index_offset += (in_range); + } } - } - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + skip_following = raft::any(skip_following); + if (skip_following) continue; + + i -= WarpSize; + for (; i >= 0 && !skip_following; i -= WarpSize) { + skip_following = (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps); + const value_t* y_ptr = X_reordered + (dim * (R_start_offset + i)); + const value_t dist = dfunc(local_x_ptr, y_ptr, dim); + const bool in_range = (dist <= eps2); + if (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i]; + auto row_pos = column_index_offset + __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + column_index_offset += __popc(mask); + } else { + column_index_offset += (in_range); + } - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - auto row_pos = atomicAdd(&column_index_smem, 1); - if (pass2) adj_ja[row_pos] = cur_candidate_ind; + skip_following = raft::any(skip_following); } - } + } while (k_offset < WarpSize); } - __syncthreads(); - if (threadIdx.x == 0 && !pass2) { adj_ia[blockIdx.x] = (value_idx)column_index_smem; } + if (!write_pass) { + value_idx row_sum = raft::warpReduce(column_index_offset); + if (lid == 0) adj_ia[query_id] = row_sum; + } } template -RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, const value_t* X, + const value_int n_queries, const value_int n_cols, - const value_t* R_dists, + const value_t* R, const value_int m, const value_t eps, const value_int n_landmarks, @@ -626,59 +868,107 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_index, const value_int max_k, value_idx* tmp) { - const value_t* x_ptr = X + (n_cols * blockIdx.x); + constexpr int num_warps = tpb / WarpSize; - __shared__ int column_count_smem; + // process 1 query per warp + const uint32_t lid = raft::laneId(); + const uint32_t lid_mask = (1 << lid) - 1; - // initialize - if (threadIdx.x == 0) { column_count_smem = 0; } + // this should help the compiler to prevent branches + const int warp_id = raft::shfl(threadIdx.x / WarpSize, 0); + const int query_id = raft::shfl(blockIdx.x * num_warps + warp_id, 0); - __syncthreads(); - - // we store all column indices in dense tmp store [blockDim.x * max_k] - value_int offset = blockIdx.x * max_k; + // this is an early out for a full warp + if (query_id >= n_queries) return; - for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { - // TODO: this might also be worth computing in-place here - value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; + unsigned long long int column_count = 0; - // prune all R's that can't be within eps - if (cur_R_dist - R_radius[cur_k] > eps) continue; + const value_t* x_ptr = X + (n_cols * query_id); - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_k]; - value_idx R_stop_offset = R_indptr[cur_k + 1]; + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; - value_idx R_size = R_stop_offset - R_start_offset; - - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - int row_pos = atomicAdd(&column_count_smem, 1); - if (row_pos < max_k) tmp[row_pos + offset] = cur_candidate_ind; +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + // prune all R's that can't be within eps + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist = lane_k < n_landmarks + ? raft::sqrt(dfunc(x_ptr, R + lane_k * n_cols, n_cols)) + : std::numeric_limits::max(); + const int lane_check = + lane_k < n_landmarks ? static_cast(lane_R_dist - R_radius[lane_k] <= eps) : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + uint32_t k_offset = __ffs(lane_mask) - 1; + do { + const uint32_t cur_k = cur_k0 + k_offset; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= -(1 << k_offset + 1); + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::shfl(lane_R_dist, k_offset); + + const uint32_t limit = Pow2::roundDown(R_size); + int i = limit + lid; + + // look ahead for next k_offset + k_offset = lane_mask != 0 ? __ffs(lane_mask) - 1 : WarpSize; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + bool skip_following = + i < R_size ? (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps) : false; + { + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + const value_t dist = + (i >= R_size) ? std::numeric_limits::max() : dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + const int mask = raft::ballot(in_range); + if (in_range) { + auto row_pos = column_count + __popc(mask & lid_mask); + // we still continue to look for more hits to return valid vd + if (row_pos < max_k) { + auto index = R_1nn_cols[R_start_offset + i]; + tmp[query_id * max_k + row_pos] = index; + } + } + column_count += __popc(mask); } - } - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - int row_pos = atomicAdd(&column_count_smem, 1); - if (row_pos < max_k) tmp[row_pos + offset] = cur_candidate_ind; + skip_following = raft::any(skip_following); + if (skip_following) continue; + + i -= WarpSize; + for (; i >= 0 && !skip_following; i -= WarpSize) { + skip_following = (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + const int mask = raft::ballot(in_range); + if (in_range) { + auto row_pos = column_count + __popc(mask & lid_mask); + // we still continue to look for more hits to return valid vd + if (row_pos < max_k) { + auto index = R_1nn_cols[R_start_offset + i]; + tmp[query_id * max_k + row_pos] = index; + } + } + column_count += __popc(mask); + skip_following = raft::any(skip_following); } - } + } while (k_offset < WarpSize); } - __syncthreads(); - if (threadIdx.x == 0) { vd[blockIdx.x] = column_count_smem; } + if (lid == 0) vd[query_id] = column_count; } template @@ -1047,17 +1337,18 @@ void rbc_eps_pass(raft::resources const& handle, const value_t* query, const value_int n_query_rows, value_t eps, - const value_t* R_dists, + const value_t* R, dist_func& dfunc, bool* adj, value_idx* vd) { block_rbc_kernel_eps_dense <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, + n_query_rows, index.n, - R_dists, + R, index.m, eps, index.n_landmarks, @@ -1093,7 +1384,7 @@ void rbc_eps_pass(raft::resources const& handle, const value_int n_query_rows, value_t eps, value_int* max_k, - const value_t* R_dists, + const value_t* R, dist_func& dfunc, value_idx* adj_ia, value_idx* adj_ja, @@ -1104,22 +1395,61 @@ void rbc_eps_pass(raft::resources const& handle, if (adj_ja == nullptr) { // pass 1 -> only compute adj_ia / vd value_idx* vd_ptr = (vd != nullptr) ? vd : adj_ia; - block_rbc_kernel_eps_csr_pass - <<>>( - index.get_X().data_handle(), - query, - index.n, - R_dists, - index.m, - eps, - index.n_landmarks, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - index.get_R_radius().data_handle(), - dfunc, - vd_ptr, - nullptr); + if (index.n == 2) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + } else if (index.n == 3) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + } else { + block_rbc_kernel_eps_csr_pass + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + } thrust::exclusive_scan(resource::get_thrust_policy(handle), vd_ptr, @@ -1129,22 +1459,61 @@ void rbc_eps_pass(raft::resources const& handle, } else { // pass 2 -> fill in adj_ja - block_rbc_kernel_eps_csr_pass - <<>>( - index.get_X().data_handle(), - query, - index.n, - R_dists, - index.m, - eps, - index.n_landmarks, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - index.get_R_radius().data_handle(), - dfunc, - adj_ia, - adj_ja); + if (index.n == 2) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } else if (index.n == 3) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } else { + block_rbc_kernel_eps_csr_pass + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } } } else { value_int max_k_in = *max_k; @@ -1153,11 +1522,12 @@ void rbc_eps_pass(raft::resources const& handle, rmm::device_uvector tmp(n_query_rows * max_k_in, resource::get_cuda_stream(handle)); block_rbc_kernel_eps_max_k - <<>>( - index.get_X().data_handle(), + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), query, + n_query_rows, index.n, - R_dists, + R, index.m, eps, index.n_landmarks, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh index 7f4268d2dc..5d317529d1 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,6 +60,21 @@ struct EuclideanFunc : public DistFunc { } }; +template +struct EuclideanSqFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + value_t sum_sq = 0; + for (value_int i = 0; i < n_dims; ++i) { + value_t diff = a[i] - b[i]; + sum_sq += diff * diff; + } + return sum_sq; + } +}; + }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py index dff2e015a4..10d9c95ece 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py @@ -121,6 +121,8 @@ dist="raft::spatial::knn::detail::DistFunc", ) +euclideanSq="raft::spatial::knn::detail::EuclideanSqFunc", + types = dict( int64_float=("std::int64_t", "float"), #int64_double=("std::int64_t", "double"), @@ -156,7 +158,7 @@ f.write(macro_pass_eps) for type_path, (int_t, data_t) in types.items(): f.write(f"instantiate_raft_spatial_knn_detail_rbc_eps_pass(\n") - f.write(f" {int_t}, {data_t}, {int_t}, {int_t}, {distances['euclidean']});\n") + f.write(f" {int_t}, {data_t}, {int_t}, {int_t}, {euclideanSq});\n") f.write("#undef instantiate_raft_spatial_knn_detail_rbc_eps_pass\n") print(f"src/spatial/knn/detail/ball_cover/{path}") diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu index 0d09f88b65..710291b09c 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu @@ -55,5 +55,5 @@ Mvalue_idx* vd) instantiate_raft_spatial_knn_detail_rbc_eps_pass( - std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanSqFunc); #undef instantiate_raft_spatial_knn_detail_rbc_eps_pass From 1782558dc61f321aeb722e6b5811b5f03b35747b Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 5 Mar 2024 09:25:48 +0000 Subject: [PATCH 02/10] remove L2 cache flush --- cpp/include/raft/spatial/knn/detail/ball_cover.cuh | 8 -------- 1 file changed, 8 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index b8a622bf0d..38d2e102d5 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -416,14 +416,6 @@ void rbc_build_index(raft::resources const& handle, BallCoverIndex& index, distance_func dfunc) { - { - /** flush the L2 cache - Hopper at 50MB */ - size_t l2_cache_size = 50 * 1024 * 1024; - auto scratch_buf_ = rmm::device_buffer(l2_cache_size * 3, resource::get_cuda_stream(handle)); - RAFT_CUDA_TRY(cudaMemsetAsync( - scratch_buf_.data(), 0, scratch_buf_.size(), resource::get_cuda_stream(handle))); - } - ASSERT(!index.is_index_trained(), "index cannot be previously trained"); rmm::device_uvector R_knn_inds(index.m, resource::get_cuda_stream(handle)); From b3e998a6b2537c912edece84c14bd825687c351f Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 5 Mar 2024 14:24:24 +0000 Subject: [PATCH 03/10] review suggestion constexpr --- .../spatial/knn/detail/ball_cover/registers-inl.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 3a9dc11c05..0337f6891f 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -661,7 +661,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, const value_t dist = (i >= R_size) ? std::numeric_limits::max() : dfunc(x_ptr, y_ptr, n_cols); const bool in_range = (dist <= eps2); - if (write_pass) { + if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { auto index = R_1nn_cols[R_start_offset + i]; @@ -683,7 +683,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); const value_t dist = dfunc(x_ptr, y_ptr, n_cols); const bool in_range = (dist <= eps2); - if (write_pass) { + if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { auto index = R_1nn_cols[R_start_offset + i]; @@ -700,7 +700,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, } while (k_offset < WarpSize); } - if (!write_pass) { + if constexpr (!write_pass) { value_idx row_sum = raft::warpReduce(column_index_offset); if (lid == 0) adj_ia[query_id] = row_sum; } @@ -801,7 +801,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered, const value_t dist = (i >= R_size) ? std::numeric_limits::max() : dfunc(local_x_ptr, y_ptr, dim); const bool in_range = (dist <= eps2); - if (write_pass) { + if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { auto index = R_1nn_cols[R_start_offset + i]; @@ -823,7 +823,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered, const value_t* y_ptr = X_reordered + (dim * (R_start_offset + i)); const value_t dist = dfunc(local_x_ptr, y_ptr, dim); const bool in_range = (dist <= eps2); - if (write_pass) { + if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { auto index = R_1nn_cols[R_start_offset + i]; @@ -840,7 +840,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered, } while (k_offset < WarpSize); } - if (!write_pass) { + if constexpr (!write_pass) { value_idx row_sum = raft::warpReduce(column_index_offset); if (lid == 0) adj_ia[query_id] = row_sum; } From 0129748a48b4637078e88c14bbff3a76c6088fa3 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 5 Mar 2024 15:17:49 +0000 Subject: [PATCH 04/10] utilize reordered index for rbc knn --- .../knn/detail/ball_cover/registers-inl.cuh | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 0337f6891f..71f9f7c492 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -157,7 +157,7 @@ template -RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, +RAFT_KERNEL compute_final_dists_registers(const value_t* X_reordered, const value_t* X, const value_int n_cols, bitset_type* bitset, @@ -238,7 +238,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -267,7 +267,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -313,7 +313,7 @@ template -RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_reordered, const value_t* X, value_int n_cols, // n_cols should be 2 or 3 dims const value_idx* R_knn_inds, @@ -408,7 +408,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -433,7 +433,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -1013,7 +1013,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, if (k <= 32) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1033,7 +1033,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 64) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1052,7 +1052,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 128) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1072,7 +1072,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 256) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1092,7 +1092,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 512) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1112,7 +1112,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 1024) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1182,7 +1182,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1208,7 +1208,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1234,7 +1234,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1260,7 +1260,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1285,7 +1285,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 8, 64, dims><<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1310,7 +1310,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 8, 64, dims><<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), From fde845be5a7600d38795e0fde74fc08b243a08d2 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 7 Mar 2024 15:35:12 +0000 Subject: [PATCH 05/10] minor optimizations / suggestions --- .../knn/detail/ball_cover/registers-inl.cuh | 117 ++++++++++-------- 1 file changed, 62 insertions(+), 55 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 71f9f7c492..8aaf5af1f6 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -706,6 +706,12 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, } } +template +__device__ value_t squared(const value_t& a) +{ + return a * a; +} + template -RAFT_KERNEL block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered, - const value_t* X, - const value_int n_queries, - const value_int n_cols, - const value_t* R, - const value_int m, - const value_t eps, - const value_int n_landmarks, - const value_idx* R_indptr, - const value_idx* R_1nn_cols, - const value_t* R_1nn_dists, - const value_t* R_radius, - distance_func dfunc, - value_idx* adj_ia, - value_idx* adj_ja) +RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered, + const value_t* X, + const value_int n_queries, + const value_int n_cols, + const value_t* R, + const value_int m, + const value_t eps, + const value_int n_landmarks, + const value_idx* R_indptr, + const value_idx* R_1nn_cols, + const value_t* R_1nn_dists, + const value_t* R_radius, + distance_func dfunc, + value_idx* adj_ia, + value_idx* adj_ja) { constexpr int num_warps = tpb / WarpSize; + constexpr int max_lid = WarpSize - 1; // process 1 query per warp const uint32_t lid = raft::laneId(); const uint32_t lid_mask = (1 << lid) - 1; // this should help the compiler to prevent branches - const int warp_id = raft::shfl(threadIdx.x / WarpSize, 0); - const int query_id = raft::shfl(blockIdx.x * num_warps + warp_id, 0); + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); // this is an early out for a full warp if (query_id >= n_queries) return; - unsigned long long int column_index_offset = write_pass ? adj_ia[query_id] : 0; + value_idx column_index_offset = write_pass ? adj_ia[query_id] : 0; // we have no neighbors to fill for this query if (write_pass && adj_ia[query_id + 1] == column_index_offset) return; @@ -760,84 +766,85 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered, #pragma nounroll for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { // Pre-compute landmark_dist & triangularization checks for 32 iterations - // prune all R's that can't be within eps - const uint32_t lane_k = cur_k0 + lid; - const value_t lane_R_dist = lane_k < n_landmarks - ? raft::sqrt(dfunc(local_x_ptr, R + lane_k * dim, dim)) - : std::numeric_limits::max(); - const int lane_check = - lane_k < n_landmarks ? static_cast(lane_R_dist - R_radius[lane_k] <= eps) : 0; + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(local_x_ptr, R + lane_k * dim, dim) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; int lane_mask = raft::ballot(lane_check); if (lane_mask == 0) continue; - uint32_t k_offset = __ffs(lane_mask) - 1; + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + uint32_t k_offset = __clz(lane_mask); do { const uint32_t cur_k = cur_k0 + k_offset; - // update lane_mask for next iteration - erase bits up to k_offset - lane_mask &= -(1 << k_offset + 1); - // The whole warp should iterate through the elements in the current R const value_idx R_start_offset = R_indptr[cur_k]; - const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (1 << max_lid - k_offset) - 1; + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; // we have precomputed the query<->landmark distance - const value_t cur_R_dist = raft::shfl(lane_R_dist, k_offset); + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); const uint32_t limit = Pow2::roundDown(R_size); - int i = limit + lid; + uint32_t i = limit + lid; // look ahead for next k_offset - k_offset = lane_mask != 0 ? __ffs(lane_mask) - 1 : WarpSize; + k_offset = __clz(lane_mask); // R_1nn_dists are sorted ascendingly for each landmark // Iterating backwards, after pruning the first point w.r.t. triangle // inequality all subsequent points can be pruned as well - bool skip_following = - i < R_size ? (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps) : false; + const value_t* y_ptr = X_reordered + (dim * (R_start_offset + i)); { - const value_t* y_ptr = X_reordered + (dim * (R_start_offset + i)); + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; const value_t dist = - (i >= R_size) ? std::numeric_limits::max() : dfunc(local_x_ptr, y_ptr, dim); + (i < R_size) ? dfunc(local_x_ptr, y_ptr, dim) : std::numeric_limits::max(); const bool in_range = (dist <= eps2); if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - auto index = R_1nn_cols[R_start_offset + i]; - auto row_pos = column_index_offset + __popc(mask & lid_mask); - adj_ja[row_pos] = index; + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const value_idx row_pos = column_index_offset + __popc(mask & lid_mask); + adj_ja[row_pos] = index; } column_index_offset += __popc(mask); } else { column_index_offset += (in_range); } + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - skip_following = raft::any(skip_following); - if (skip_following) continue; - - i -= WarpSize; - for (; i >= 0 && !skip_following; i -= WarpSize) { - skip_following = (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps); - const value_t* y_ptr = X_reordered + (dim * (R_start_offset + i)); - const value_t dist = dfunc(local_x_ptr, y_ptr, dim); - const bool in_range = (dist <= eps2); + while (i >= WarpSize) { + y_ptr -= WarpSize * dim; + i -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + const value_t dist = dfunc(local_x_ptr, y_ptr, dim); + const bool in_range = (dist <= eps2); if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - auto index = R_1nn_cols[R_start_offset + i]; - auto row_pos = column_index_offset + __popc(mask & lid_mask); - adj_ja[row_pos] = index; + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const value_idx row_pos = column_index_offset + __popc(mask & lid_mask); + adj_ja[row_pos] = index; } column_index_offset += __popc(mask); } else { column_index_offset += (in_range); } - - skip_following = raft::any(skip_following); + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - } while (k_offset < WarpSize); + } while (lane_mask); } if constexpr (!write_pass) { From 663dabf6bf6f0fd66a8a7d51415a5dac788be302 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 7 Mar 2024 15:56:24 +0000 Subject: [PATCH 06/10] add restrict, free 1 register --- .../knn/detail/ball_cover/registers-inl.cuh | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 8aaf5af1f6..2118cf4810 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -719,21 +719,22 @@ template -RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered, - const value_t* X, - const value_int n_queries, - const value_int n_cols, - const value_t* R, - const value_int m, - const value_t eps, - const value_int n_landmarks, - const value_idx* R_indptr, - const value_idx* R_1nn_cols, - const value_t* R_1nn_dists, - const value_t* R_radius, - distance_func dfunc, - value_idx* adj_ia, - value_idx* adj_ja) +RAFT_KERNEL __launch_bounds__(tpb) + block_rbc_kernel_eps_csr_pass_xd(const value_t* __restrict__ X_reordered, + const value_t* __restrict__ X, + const value_int n_queries, + const value_int n_cols, + const value_t* __restrict__ R, + const value_int m, + const value_t eps, + const value_int n_landmarks, + const value_idx* __restrict__ R_indptr, + const value_idx* __restrict__ R_1nn_cols, + const value_t* __restrict__ R_1nn_dists, + const value_t* __restrict__ R_radius, + distance_func dfunc, + value_idx* __restrict__ adj_ia, + value_idx* adj_ja) { constexpr int num_warps = tpb / WarpSize; constexpr int max_lid = WarpSize - 1; @@ -748,10 +749,14 @@ RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_ // this is an early out for a full warp if (query_id >= n_queries) return; - value_idx column_index_offset = write_pass ? adj_ia[query_id] : 0; + uint32_t column_index_offset = 0; - // we have no neighbors to fill for this query - if (write_pass && adj_ia[query_id + 1] == column_index_offset) return; + if constexpr (write_pass) { + value_idx offset = adj_ia[query_id]; + // we have no neighbors to fill for this query + if (offset == adj_ia[query_id + 1]) return; + adj_ja += offset; + } const value_t* x_ptr = X + (dim * query_id); value_t local_x_ptr[dim]; @@ -812,11 +817,11 @@ RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_ if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - const uint32_t index = R_1nn_cols[R_start_offset + i]; - const value_idx row_pos = column_index_offset + __popc(mask & lid_mask); - adj_ja[row_pos] = index; + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; } - column_index_offset += __popc(mask); + adj_ja += __popc(mask); } else { column_index_offset += (in_range); } @@ -833,11 +838,11 @@ RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_ if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - const uint32_t index = R_1nn_cols[R_start_offset + i]; - const value_idx row_pos = column_index_offset + __popc(mask & lid_mask); - adj_ja[row_pos] = index; + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; } - column_index_offset += __popc(mask); + adj_ja += __popc(mask); } else { column_index_offset += (in_range); } From 2be4c179c8b8e50e7b1142063734a585ae1b0241 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 7 Mar 2024 16:49:46 +0000 Subject: [PATCH 07/10] apply modifications to all kernel variants --- .../knn/detail/ball_cover/registers-inl.cuh | 257 +++++++++--------- 1 file changed, 134 insertions(+), 123 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 9969724576..3454f0d02c 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -456,6 +456,12 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_reordered, } } +template +__device__ value_t squared(const value_t& a) +{ + return a * a; +} + template = n_queries) return; - unsigned long long int column_count = 0; + value_idx column_count = 0; const value_t* x_ptr = X + (n_cols * query_id); + adj += query_id * m; // we omit the sqrt() in the inner distance compute const value_t eps2 = eps * eps; @@ -499,71 +506,73 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, #pragma nounroll for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { // Pre-compute landmark_dist & triangularization checks for 32 iterations - // prune all R's that can't be within eps - const uint32_t lane_k = cur_k0 + lid; - const value_t lane_R_dist = lane_k < n_landmarks - ? raft::sqrt(dfunc(x_ptr, R + lane_k * n_cols, n_cols)) - : std::numeric_limits::max(); - const int lane_check = - lane_k < n_landmarks ? static_cast(lane_R_dist - R_radius[lane_k] <= eps) : 0; + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(x_ptr, R + lane_k * n_cols, n_cols) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; int lane_mask = raft::ballot(lane_check); if (lane_mask == 0) continue; - uint32_t k_offset = __ffs(lane_mask) - 1; + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + uint32_t k_offset = __clz(lane_mask); do { const uint32_t cur_k = cur_k0 + k_offset; - // update lane_mask for next iteration - erase bits up to k_offset - lane_mask &= -(1 << k_offset + 1); - // The whole warp should iterate through the elements in the current R const value_idx R_start_offset = R_indptr[cur_k]; - const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (1 << max_lid - k_offset) - 1; + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; // we have precomputed the query<->landmark distance - const value_t cur_R_dist = raft::shfl(lane_R_dist, k_offset); + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); const uint32_t limit = Pow2::roundDown(R_size); - int i = limit + lid; + uint32_t i = limit + lid; // look ahead for next k_offset - k_offset = lane_mask != 0 ? __ffs(lane_mask) - 1 : WarpSize; + k_offset = __clz(lane_mask); // R_1nn_dists are sorted ascendingly for each landmark // Iterating backwards, after pruning the first point w.r.t. triangle // inequality all subsequent points can be pruned as well - bool skip_following = - i < R_size ? (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps) : false; + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); { - const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; const value_t dist = - (i >= R_size) ? std::numeric_limits::max() : dfunc(x_ptr, y_ptr, n_cols); + (i < R_size) ? dfunc(x_ptr, y_ptr, n_cols) : std::numeric_limits::max(); const bool in_range = (dist <= eps2); if (in_range) { auto index = R_1nn_cols[R_start_offset + i]; column_count++; - adj[query_id * m + index] = true; + adj[index] = true; } + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - skip_following = raft::any(skip_following); - if (skip_following) continue; - - i -= WarpSize; - for (; i >= 0 && !skip_following; i -= WarpSize) { - skip_following = (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps); - const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); - const value_t dist = dfunc(x_ptr, y_ptr, n_cols); - const bool in_range = (dist <= eps2); + while (i >= WarpSize) { + y_ptr -= WarpSize * n_cols; + i -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); if (in_range) { auto index = R_1nn_cols[R_start_offset + i]; column_count++; - adj[query_id * m + index] = true; + adj[index] = true; } - skip_following = raft::any(skip_following); + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - } while (k_offset < WarpSize); + } while (lane_mask); } if (vd != nullptr) { @@ -595,22 +604,26 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, value_idx* adj_ja) { constexpr int num_warps = tpb / WarpSize; + constexpr int max_lid = WarpSize - 1; // process 1 query per warp const uint32_t lid = raft::laneId(); const uint32_t lid_mask = (1 << lid) - 1; // this should help the compiler to prevent branches - const int warp_id = raft::shfl(threadIdx.x / WarpSize, 0); - const int query_id = raft::shfl(blockIdx.x * num_warps + warp_id, 0); + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); // this is an early out for a full warp if (query_id >= n_queries) return; - unsigned long long int column_index_offset = write_pass ? adj_ia[query_id] : 0; + uint32_t column_index_offset = 0; - // we have no neighbors to fill for this query - if (write_pass && adj_ia[query_id + 1] == column_index_offset) return; + if constexpr (write_pass) { + value_idx offset = adj_ia[query_id]; + // we have no neighbors to fill for this query + if (offset == adj_ia[query_id + 1]) return; + adj_ja += offset; + } const value_t* x_ptr = X + (n_cols * query_id); @@ -620,84 +633,85 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, #pragma nounroll for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { // Pre-compute landmark_dist & triangularization checks for 32 iterations - // prune all R's that can't be within eps - const uint32_t lane_k = cur_k0 + lid; - const value_t lane_R_dist = lane_k < n_landmarks - ? raft::sqrt(dfunc(x_ptr, R + lane_k * n_cols, n_cols)) - : std::numeric_limits::max(); - const int lane_check = - lane_k < n_landmarks ? static_cast(lane_R_dist - R_radius[lane_k] <= eps) : 0; + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(x_ptr, R + lane_k * n_cols, n_cols) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; int lane_mask = raft::ballot(lane_check); if (lane_mask == 0) continue; - uint32_t k_offset = __ffs(lane_mask) - 1; + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + uint32_t k_offset = __clz(lane_mask); do { const uint32_t cur_k = cur_k0 + k_offset; - // update lane_mask for next iteration - erase bits up to k_offset - lane_mask &= -(1 << k_offset + 1); - // The whole warp should iterate through the elements in the current R const value_idx R_start_offset = R_indptr[cur_k]; - const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (1 << max_lid - k_offset) - 1; + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; // we have precomputed the query<->landmark distance - const value_t cur_R_dist = raft::shfl(lane_R_dist, k_offset); + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); const uint32_t limit = Pow2::roundDown(R_size); - int i = limit + lid; + uint32_t i = limit + lid; // look ahead for next k_offset - k_offset = lane_mask != 0 ? __ffs(lane_mask) - 1 : WarpSize; + k_offset = __clz(lane_mask); // R_1nn_dists are sorted ascendingly for each landmark // Iterating backwards, after pruning the first point w.r.t. triangle // inequality all subsequent points can be pruned as well - bool skip_following = - i < R_size ? (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps) : false; + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); { - const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; const value_t dist = - (i >= R_size) ? std::numeric_limits::max() : dfunc(x_ptr, y_ptr, n_cols); + (i < R_size) ? dfunc(x_ptr, y_ptr, n_cols) : std::numeric_limits::max(); const bool in_range = (dist <= eps2); if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - auto index = R_1nn_cols[R_start_offset + i]; - auto row_pos = column_index_offset + __popc(mask & lid_mask); - adj_ja[row_pos] = index; + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; } - column_index_offset += __popc(mask); + adj_ja += __popc(mask); } else { column_index_offset += (in_range); } + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - skip_following = raft::any(skip_following); - if (skip_following) continue; - - i -= WarpSize; - for (; i >= 0 && !skip_following; i -= WarpSize) { - skip_following = (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps); - const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); - const value_t dist = dfunc(x_ptr, y_ptr, n_cols); - const bool in_range = (dist <= eps2); + while (i >= WarpSize) { + y_ptr -= WarpSize * n_cols; + i -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - auto index = R_1nn_cols[R_start_offset + i]; - auto row_pos = column_index_offset + __popc(mask & lid_mask); - adj_ja[row_pos] = index; + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; } - column_index_offset += __popc(mask); + adj_ja += __popc(mask); } else { column_index_offset += (in_range); } - - skip_following = raft::any(skip_following); + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - } while (k_offset < WarpSize); + } while (lane_mask); } if constexpr (!write_pass) { @@ -706,12 +720,6 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, } } -template -__device__ value_t squared(const value_t& a) -{ - return a * a; -} - template = n_queries) return; - unsigned long long int column_count = 0; + value_idx column_count = 0; const value_t* x_ptr = X + (n_cols * query_id); + tmp += query_id * max_k; // we omit the sqrt() in the inner distance compute const value_t eps2 = eps * eps; @@ -903,81 +912,83 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, #pragma nounroll for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { // Pre-compute landmark_dist & triangularization checks for 32 iterations - // prune all R's that can't be within eps - const uint32_t lane_k = cur_k0 + lid; - const value_t lane_R_dist = lane_k < n_landmarks - ? raft::sqrt(dfunc(x_ptr, R + lane_k * n_cols, n_cols)) - : std::numeric_limits::max(); - const int lane_check = - lane_k < n_landmarks ? static_cast(lane_R_dist - R_radius[lane_k] <= eps) : 0; + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(x_ptr, R + lane_k * n_cols, n_cols) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; int lane_mask = raft::ballot(lane_check); if (lane_mask == 0) continue; - uint32_t k_offset = __ffs(lane_mask) - 1; + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + uint32_t k_offset = __clz(lane_mask); do { const uint32_t cur_k = cur_k0 + k_offset; - // update lane_mask for next iteration - erase bits up to k_offset - lane_mask &= -(1 << k_offset + 1); - // The whole warp should iterate through the elements in the current R const value_idx R_start_offset = R_indptr[cur_k]; - const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (1 << max_lid - k_offset) - 1; + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; // we have precomputed the query<->landmark distance - const value_t cur_R_dist = raft::shfl(lane_R_dist, k_offset); + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); const uint32_t limit = Pow2::roundDown(R_size); - int i = limit + lid; + uint32_t i = limit + lid; // look ahead for next k_offset - k_offset = lane_mask != 0 ? __ffs(lane_mask) - 1 : WarpSize; + k_offset = __clz(lane_mask); // R_1nn_dists are sorted ascendingly for each landmark // Iterating backwards, after pruning the first point w.r.t. triangle // inequality all subsequent points can be pruned as well - bool skip_following = - i < R_size ? (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps) : false; + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); { - const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; const value_t dist = - (i >= R_size) ? std::numeric_limits::max() : dfunc(x_ptr, y_ptr, n_cols); + (i < R_size) ? dfunc(x_ptr, y_ptr, n_cols) : std::numeric_limits::max(); const bool in_range = (dist <= eps2); const int mask = raft::ballot(in_range); if (in_range) { auto row_pos = column_count + __popc(mask & lid_mask); // we still continue to look for more hits to return valid vd if (row_pos < max_k) { - auto index = R_1nn_cols[R_start_offset + i]; - tmp[query_id * max_k + row_pos] = index; + auto index = R_1nn_cols[R_start_offset + i]; + tmp[row_pos] = index; } } column_count += __popc(mask); + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - skip_following = raft::any(skip_following); - if (skip_following) continue; - - i -= WarpSize; - for (; i >= 0 && !skip_following; i -= WarpSize) { - skip_following = (cur_R_dist - R_1nn_dists[R_start_offset + i] > eps); - const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); - const value_t dist = dfunc(x_ptr, y_ptr, n_cols); - const bool in_range = (dist <= eps2); - const int mask = raft::ballot(in_range); + while (i >= WarpSize) { + y_ptr -= WarpSize * n_cols; + i -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + const int mask = raft::ballot(in_range); if (in_range) { auto row_pos = column_count + __popc(mask & lid_mask); // we still continue to look for more hits to return valid vd if (row_pos < max_k) { - auto index = R_1nn_cols[R_start_offset + i]; - tmp[query_id * max_k + row_pos] = index; + auto index = R_1nn_cols[R_start_offset + i]; + tmp[row_pos] = index; } } column_count += __popc(mask); - skip_following = raft::any(skip_following); + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - } while (k_offset < WarpSize); + } while (lane_mask); } if (lid == 0) vd[query_id] = column_count; From 768eb48788e043e7290a169edae9ca4f000d40dc Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 7 Mar 2024 19:33:24 +0000 Subject: [PATCH 08/10] remove one more FLO per landmark visited --- .../knn/detail/ball_cover/registers-inl.cuh | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 3454f0d02c..85771e9c26 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -517,9 +517,11 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, if (lane_mask == 0) continue; // reverse to use __clz instead of __ffs - lane_mask = __brev(lane_mask); - uint32_t k_offset = __clz(lane_mask); + lane_mask = __brev(lane_mask); do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + const uint32_t cur_k = cur_k0 + k_offset; // The whole warp should iterate through the elements in the current R @@ -536,9 +538,6 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, const uint32_t limit = Pow2::roundDown(R_size); uint32_t i = limit + lid; - // look ahead for next k_offset - k_offset = __clz(lane_mask); - // R_1nn_dists are sorted ascendingly for each landmark // Iterating backwards, after pruning the first point w.r.t. triangle // inequality all subsequent points can be pruned as well @@ -644,9 +643,11 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, if (lane_mask == 0) continue; // reverse to use __clz instead of __ffs - lane_mask = __brev(lane_mask); - uint32_t k_offset = __clz(lane_mask); + lane_mask = __brev(lane_mask); do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + const uint32_t cur_k = cur_k0 + k_offset; // The whole warp should iterate through the elements in the current R @@ -663,9 +664,6 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, const uint32_t limit = Pow2::roundDown(R_size); uint32_t i = limit + lid; - // look ahead for next k_offset - k_offset = __clz(lane_mask); - // R_1nn_dists are sorted ascendingly for each landmark // Iterating backwards, after pruning the first point w.r.t. triangle // inequality all subsequent points can be pruned as well @@ -790,9 +788,11 @@ RAFT_KERNEL __launch_bounds__(tpb) if (lane_mask == 0) continue; // reverse to use __clz instead of __ffs - lane_mask = __brev(lane_mask); - uint32_t k_offset = __clz(lane_mask); + lane_mask = __brev(lane_mask); do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + const uint32_t cur_k = cur_k0 + k_offset; // The whole warp should iterate through the elements in the current R @@ -809,9 +809,6 @@ RAFT_KERNEL __launch_bounds__(tpb) const uint32_t limit = Pow2::roundDown(R_size); uint32_t i = limit + lid; - // look ahead for next k_offset - k_offset = __clz(lane_mask); - // R_1nn_dists are sorted ascendingly for each landmark // Iterating backwards, after pruning the first point w.r.t. triangle // inequality all subsequent points can be pruned as well @@ -923,9 +920,11 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, if (lane_mask == 0) continue; // reverse to use __clz instead of __ffs - lane_mask = __brev(lane_mask); - uint32_t k_offset = __clz(lane_mask); + lane_mask = __brev(lane_mask); do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + const uint32_t cur_k = cur_k0 + k_offset; // The whole warp should iterate through the elements in the current R @@ -942,9 +941,6 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, const uint32_t limit = Pow2::roundDown(R_size); uint32_t i = limit + lid; - // look ahead for next k_offset - k_offset = __clz(lane_mask); - // R_1nn_dists are sorted ascendingly for each landmark // Iterating backwards, after pruning the first point w.r.t. triangle // inequality all subsequent points can be pruned as well From 2315e2f18caa59dc4f6e73c442ace82b9b7c3e02 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 8 Mar 2024 14:19:22 +0000 Subject: [PATCH 09/10] change mask update --- .../spatial/knn/detail/ball_cover/registers-inl.cuh | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 85771e9c26..2062e6e421 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -484,7 +484,6 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, value_idx* vd) { constexpr int num_warps = tpb / WarpSize; - constexpr int max_lid = WarpSize - 1; // process 1 query per warp const uint32_t lid = raft::laneId(); @@ -528,7 +527,7 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, const value_idx R_start_offset = R_indptr[cur_k]; // update lane_mask for next iteration - erase bits up to k_offset - lane_mask &= (1 << max_lid - k_offset) - 1; + lane_mask &= (0x7fffffff >> k_offset); const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; @@ -603,7 +602,6 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, value_idx* adj_ja) { constexpr int num_warps = tpb / WarpSize; - constexpr int max_lid = WarpSize - 1; // process 1 query per warp const uint32_t lid = raft::laneId(); @@ -654,7 +652,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, const value_idx R_start_offset = R_indptr[cur_k]; // update lane_mask for next iteration - erase bits up to k_offset - lane_mask &= (1 << max_lid - k_offset) - 1; + lane_mask &= (0x7fffffff >> k_offset); const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; @@ -743,7 +741,6 @@ RAFT_KERNEL __launch_bounds__(tpb) value_idx* adj_ja) { constexpr int num_warps = tpb / WarpSize; - constexpr int max_lid = WarpSize - 1; // process 1 query per warp const uint32_t lid = raft::laneId(); @@ -799,7 +796,7 @@ RAFT_KERNEL __launch_bounds__(tpb) const value_idx R_start_offset = R_indptr[cur_k]; // update lane_mask for next iteration - erase bits up to k_offset - lane_mask &= (1 << max_lid - k_offset) - 1; + lane_mask &= (0x7fffffff >> k_offset); const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; @@ -886,7 +883,6 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, value_idx* tmp) { constexpr int num_warps = tpb / WarpSize; - constexpr int max_lid = WarpSize - 1; // process 1 query per warp const uint32_t lid = raft::laneId(); @@ -931,7 +927,7 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, const value_idx R_start_offset = R_indptr[cur_k]; // update lane_mask for next iteration - erase bits up to k_offset - lane_mask &= (1 << max_lid - k_offset) - 1; + lane_mask &= (0x7fffffff >> k_offset); const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; From 2d06d0a77e8d318659d73b2f3cda93ae71054d37 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 8 Mar 2024 16:58:53 +0000 Subject: [PATCH 10/10] remove warp divergence --- .../knn/detail/ball_cover/registers-inl.cuh | 48 +++++++++++-------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 2062e6e421..eda6d33293 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -556,19 +556,21 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, i *= (cur_R_dist - min_warp_dist <= eps); } - while (i >= WarpSize) { + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { y_ptr -= WarpSize * n_cols; - i -= WarpSize; - const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; const value_t dist = dfunc(x_ptr, y_ptr, n_cols); const bool in_range = (dist <= eps2); if (in_range) { - auto index = R_1nn_cols[R_start_offset + i]; + auto index = R_1nn_cols[R_start_offset + i0 + lid]; column_count++; adj[index] = true; } // abort in case subsequent points cannot possibly be in reach - i *= (cur_R_dist - min_warp_dist <= eps); + i0 *= (cur_R_dist - min_warp_dist <= eps); } } while (lane_mask); } @@ -687,16 +689,18 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, i *= (cur_R_dist - min_warp_dist <= eps); } - while (i >= WarpSize) { + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { y_ptr -= WarpSize * n_cols; - i -= WarpSize; - const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; const value_t dist = dfunc(x_ptr, y_ptr, n_cols); const bool in_range = (dist <= eps2); if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid]; const uint32_t row_pos = __popc(mask & lid_mask); adj_ja[row_pos] = index; } @@ -705,7 +709,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, column_index_offset += (in_range); } // abort in case subsequent points cannot possibly be in reach - i *= (cur_R_dist - min_warp_dist <= eps); + i0 *= (cur_R_dist - min_warp_dist <= eps); } } while (lane_mask); } @@ -831,16 +835,18 @@ RAFT_KERNEL __launch_bounds__(tpb) i *= (cur_R_dist - min_warp_dist <= eps); } - while (i >= WarpSize) { + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { y_ptr -= WarpSize * dim; - i -= WarpSize; - const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; const value_t dist = dfunc(local_x_ptr, y_ptr, dim); const bool in_range = (dist <= eps2); if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid]; const uint32_t row_pos = __popc(mask & lid_mask); adj_ja[row_pos] = index; } @@ -849,7 +855,7 @@ RAFT_KERNEL __launch_bounds__(tpb) column_index_offset += (in_range); } // abort in case subsequent points cannot possibly be in reach - i *= (cur_R_dist - min_warp_dist <= eps); + i0 *= (cur_R_dist - min_warp_dist <= eps); } } while (lane_mask); } @@ -961,10 +967,12 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, i *= (cur_R_dist - min_warp_dist <= eps); } - while (i >= WarpSize) { + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { y_ptr -= WarpSize * n_cols; - i -= WarpSize; - const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; const value_t dist = dfunc(x_ptr, y_ptr, n_cols); const bool in_range = (dist <= eps2); const int mask = raft::ballot(in_range); @@ -972,13 +980,13 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, auto row_pos = column_count + __popc(mask & lid_mask); // we still continue to look for more hits to return valid vd if (row_pos < max_k) { - auto index = R_1nn_cols[R_start_offset + i]; + auto index = R_1nn_cols[R_start_offset + i0 + lid]; tmp[row_pos] = index; } } column_count += __popc(mask); // abort in case subsequent points cannot possibly be in reach - i *= (cur_R_dist - min_warp_dist <= eps); + i0 *= (cur_R_dist - min_warp_dist <= eps); } } while (lane_mask); }