diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index fab845396c..d8219a48f9 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -696,8 +696,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); __syncthreads(); - topk::block_sort queue( - k, interleaved_scan_kernel_smem + query_smem_elems * sizeof(T)); + using block_sort_t = topk::block_sort; + block_sort_t queue(k, interleaved_scan_kernel_smem + query_smem_elems * sizeof(T)); { using align_warp = Pow2; @@ -766,8 +766,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } // Enqueue one element per thread - constexpr float kDummy = Ascending ? upper_bound() : lower_bound(); - const float val = valid ? static_cast(dist) : kDummy; + const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; const size_t idx = valid ? static_cast(list_indices[list_offset + vec_id]) : 0; queue.add(val, idx); } @@ -826,7 +825,7 @@ void launch_kernel(Lambda lambda, std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); int smem_size = query_smem_elems * sizeof(T); constexpr int kSubwarpSize = std::min(Capacity, WarpSize); - smem_size += raft::spatial::knn::detail::topk::calc_smem_size_for_block_wide( + smem_size += raft::spatial::knn::detail::topk::calc_smem_size_for_block_wide( kThreadsPerBlock / kSubwarpSize, k); // power-of-two less than cuda limit (for better addr alignment) diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh index 017678afbb..23448b6dc4 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh @@ -135,6 +135,7 @@ constexpr auto calc_capacity(int k) -> int template class warp_sort { static_assert(isPo2(Capacity)); + static_assert(std::is_default_constructible_v); public: /** @@ -158,6 +159,7 @@ class warp_sort { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { val_arr_[i] = kDummy; + idx_arr_[i] = IdxT{}; } } @@ -280,6 +282,7 @@ class warp_sort_filtered : public warp_sort { #pragma unroll for (int i = 0; i < kMaxBufLen; i++) { val_buf_[i] = kDummy; + idx_buf_[i] = IdxT{}; } } @@ -371,6 +374,7 @@ class warp_sort_immediate : public warp_sort { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { val_buf_[i] = kDummy; + idx_buf_[i] = IdxT{}; } } @@ -429,9 +433,9 @@ template