From 8b97c1baa1a574c8fa45cd87366e58107a19ebe8 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Fri, 2 Sep 2022 14:06:30 +0200 Subject: [PATCH] Minor follow-up fixes for ivf-flat (#796) Small fixes to ivf-flat and its dependency warp_sort: - fix the template parameter for the call of `calc_smem_size_for_block_wide` (may reduce the use of shared memory); - force initialize warp_sort internal value buffers to avoid uninitialized output in case of very small input data size; - small readability fixes. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/796 --- cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh | 9 ++++----- .../raft/spatial/knn/detail/topk/warpsort_topk.cuh | 6 +++++- cpp/test/spatial/ann_ivf_flat.cu | 7 +++++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index fab845396c..d8219a48f9 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -696,8 +696,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); __syncthreads(); - topk::block_sort queue( - k, interleaved_scan_kernel_smem + query_smem_elems * sizeof(T)); + using block_sort_t = topk::block_sort; + block_sort_t queue(k, interleaved_scan_kernel_smem + query_smem_elems * sizeof(T)); { using align_warp = Pow2; @@ -766,8 +766,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } // Enqueue one element per thread - constexpr float kDummy = Ascending ? upper_bound() : lower_bound(); - const float val = valid ? static_cast(dist) : kDummy; + const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; const size_t idx = valid ? static_cast(list_indices[list_offset + vec_id]) : 0; queue.add(val, idx); } @@ -826,7 +825,7 @@ void launch_kernel(Lambda lambda, std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); int smem_size = query_smem_elems * sizeof(T); constexpr int kSubwarpSize = std::min(Capacity, WarpSize); - smem_size += raft::spatial::knn::detail::topk::calc_smem_size_for_block_wide( + smem_size += raft::spatial::knn::detail::topk::calc_smem_size_for_block_wide( kThreadsPerBlock / kSubwarpSize, k); // power-of-two less than cuda limit (for better addr alignment) diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh index 017678afbb..23448b6dc4 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh @@ -135,6 +135,7 @@ constexpr auto calc_capacity(int k) -> int template class warp_sort { static_assert(isPo2(Capacity)); + static_assert(std::is_default_constructible_v); public: /** @@ -158,6 +159,7 @@ class warp_sort { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { val_arr_[i] = kDummy; + idx_arr_[i] = IdxT{}; } } @@ -280,6 +282,7 @@ class warp_sort_filtered : public warp_sort { #pragma unroll for (int i = 0; i < kMaxBufLen; i++) { val_buf_[i] = kDummy; + idx_buf_[i] = IdxT{}; } } @@ -371,6 +374,7 @@ class warp_sort_immediate : public warp_sort { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { val_buf_[i] = kDummy; + idx_buf_[i] = IdxT{}; } } @@ -429,9 +433,9 @@ template