diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh index d91c4c328d..31d4f7a009 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh @@ -36,34 +36,6 @@ namespace raft::spatial::knn::detail::ivf_flat { -inline size_t calc_aligned_size(const std::vector& sizes) -{ - const size_t ALIGN_BYTES = 256; - const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); - size_t total = 0; - for (auto sz : sizes) { - total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; - } - return total + ALIGN_BYTES - 1; -} - -inline std::vector calc_aligned_pointers(const void* p, const std::vector& sizes) -{ - const size_t ALIGN_BYTES = 256; - const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); - - char* ptr = reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); - - std::vector aligned_pointers; - aligned_pointers.reserve(sizes.size()); - for (auto sz : sizes) { - aligned_pointers.push_back(ptr); - ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; - } - - return aligned_pointers; -} - constexpr int BLOCK_DIM = 512; constexpr int ITEM_PER_THREAD = 32; @@ -438,125 +410,6 @@ __global__ void radix_kernel(const T* in_buf, } } -template -__global__ void final_filter(const T* in, - const IdxT len, - const IdxT k, - Counter* counters, - T* out, - IdxT* out_idx, - bool greater) -{ - const int batch_id = blockIdx.y; - const T kth_value = counters[batch_id].kth_value; - const IdxT needed_num_of_kth = counters[batch_id].k; - IdxT& out_cnt = counters[batch_id].out_cnt; - IdxT& out_back_cnt = counters[batch_id].out_back_cnt; - - in = in + batch_id * len; - out = out + batch_id * k; - out_idx = out_idx + batch_id * k; - - auto f = [k, greater, kth_value, needed_num_of_kth, &out_cnt, &out_back_cnt, out, out_idx]( - T val, IdxT i) { - if ((greater && val > kth_value) || (!greater && val < kth_value)) { - IdxT pos = atomicAdd(&out_cnt, 1); - out[pos] = val; - out_idx[pos] = i; - } else if (val == kth_value) { - IdxT back_pos = atomicAdd(&out_back_cnt, 1); - if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; - out[pos] = val; - out_idx[pos] = i; - } - } - }; - vectorized_process(in, len, f); -} - -template -void radix_select_topk(void* buf, - size_t& buf_size, - const T* in, - IdxT batch_size, - IdxT len, - IdxT k, - T* out, - IdxT* out_idx, - bool greater, - cudaStream_t stream) -{ - // TODO: is it possible to relax this restriction? - static_assert(calc_num_passes() > 1); - constexpr int num_buckets = calc_num_buckets(); - - Counter* counters = nullptr; - IdxT* histograms = nullptr; - T* buf1 = nullptr; - T* buf2 = nullptr; - { - std::vector sizes = {sizeof(*counters) * batch_size, - sizeof(*histograms) * num_buckets * batch_size, - sizeof(*buf1) * len * batch_size, - sizeof(*buf2) * len * batch_size}; - size_t total_size = calc_aligned_size(sizes); - if (!buf) { - buf_size = total_size; - return; - } - - std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); - counters = static_cast(aligned_pointers[0]); - histograms = static_cast(aligned_pointers[1]); - buf1 = static_cast(aligned_pointers[2]); - buf2 = static_cast(aligned_pointers[3]); - - RAFT_CUDA_TRY(cudaMemsetAsync( - buf, - 0, - static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), - stream)); - } - - const T* in_buf = nullptr; - T* out_buf = nullptr; - - dim3 blocks((len - 1) / (NUM_THREAD * ITEM_PER_THREAD) + 1, batch_size); - - constexpr int num_passes = calc_num_passes(); - for (int pass = 0; pass < num_passes; ++pass) { - if (pass == 0) { - in_buf = in; - out_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - out_buf = buf1; - } else { - in_buf = (pass % 2 == 0) ? buf1 : buf2; - out_buf = (pass % 2 == 0) ? buf2 : buf1; - } - radix_kernel<<>>(in_buf, - nullptr, - out_buf, - nullptr, - nullptr, - nullptr, - counters, - histograms, - len, - k, - greater, - pass); - } - - constexpr int FILTER_BLOCK_DIM = 256; - constexpr int FILTER_ITEM_PER_THREAD = 32; - dim3 filter_blocks((len - 1) / (FILTER_BLOCK_DIM * FILTER_ITEM_PER_THREAD) + 1, batch_size); - final_filter<<>>( - in, len, k, counters, out, out_idx, greater); -} - template void radix_topk(const T* in, const IdxT* in_idx,