diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 9f0a6096d9..5214047571 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -19,6 +19,7 @@ set(RAFT_CPP_BENCH_TARGET "bench_raft") # (please keep the filenames in alphabetical order) add_executable(${RAFT_CPP_BENCH_TARGET} bench/linalg/reduce.cu + bench/spatial/selection.cu bench/main.cpp ) diff --git a/cpp/bench/spatial/selection.cu b/cpp/bench/spatial/selection.cu new file mode 100644 index 0000000000..09d02940a5 --- /dev/null +++ b/cpp/bench/spatial/selection.cu @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include +#include + +namespace raft::bench::spatial { + +struct params { + int n_inputs; + int input_len; + int k; + int select_min; +}; + +template +struct selection : public fixture { + explicit selection(const params& p) + : params_(p), + in_dists_(p.n_inputs * p.input_len, stream), + in_ids_(p.n_inputs * p.input_len, stream), + out_dists_(p.n_inputs * p.k, stream), + out_ids_(p.n_inputs * p.k, stream) + { + raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream); + raft::random::Rng(42).uniform( + in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0), stream); + } + + void run_benchmark(::benchmark::State& state) override + { + using_pool_memory_res res; + try { + std::ostringstream label_stream; + label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k; + state.SetLabel(label_stream.str()); + loop_on_state(state, [this]() { + raft::spatial::knn::select_k(in_dists_.data(), + in_ids_.data(), + params_.n_inputs, + params_.input_len, + out_dists_.data(), + out_ids_.data(), + params_.select_min, + params_.k, + stream, + Algo); + }); + } catch (raft::exception& e) { + state.SkipWithError(e.what()); + } + } + + private: + const params params_; + rmm::device_uvector in_dists_, out_dists_; + rmm::device_uvector in_ids_, out_ids_; +}; + +const std::vector kInputs{ + {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, + {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, + {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, + + {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, + {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, + {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, + + {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, + {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, + {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, + + {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, + {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, + {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, +}; + +#define SELECTION_REGISTER(KeyT, IdxT, Algo) \ + namespace BENCHMARK_PRIVATE_NAME(selection) \ + { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \ + } + +SELECTION_REGISTER(float, int, FAISS); +SELECTION_REGISTER(float, int, RADIX_8_BITS); +SELECTION_REGISTER(float, int, RADIX_11_BITS); +SELECTION_REGISTER(float, int, WARP_SORT); + +SELECTION_REGISTER(double, int, FAISS); +SELECTION_REGISTER(double, int, RADIX_8_BITS); +SELECTION_REGISTER(double, int, RADIX_11_BITS); +SELECTION_REGISTER(double, int, WARP_SORT); + +SELECTION_REGISTER(double, size_t, FAISS); +SELECTION_REGISTER(double, size_t, RADIX_8_BITS); +SELECTION_REGISTER(double, size_t, RADIX_11_BITS); +SELECTION_REGISTER(double, size_t, WARP_SORT); + +} // namespace raft::bench::spatial diff --git a/cpp/include/raft/cudart_utils.h b/cpp/include/raft/cudart_utils.h index 4ba1e18768..05fce6c0c4 100644 --- a/cpp/include/raft/cudart_utils.h +++ b/cpp/include/raft/cudart_utils.h @@ -404,6 +404,22 @@ IntType gcd(IntType a, IntType b) return a; } +template +constexpr T lower_bound() +{ + if constexpr (std::numeric_limits::has_infinity && std::numeric_limits::is_signed) { + return -std::numeric_limits::infinity(); + } + return std::numeric_limits::lowest(); +} + +template +constexpr T upper_bound() +{ + if constexpr (std::numeric_limits::has_infinity) { return std::numeric_limits::infinity(); } + return std::numeric_limits::max(); +} + } // namespace raft #endif diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index 03a4eabaac..2d2fabd9d6 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,24 +31,30 @@ namespace spatial { namespace knn { namespace detail { -template -__global__ void select_k_kernel(K* inK, - IndexType* inV, +template +constexpr int kFaissMaxK() +{ + return (sizeof(key_t) + sizeof(payload_t) > 8) ? 512 : 1024; +} + +template +__global__ void select_k_kernel(key_t* inK, + payload_t* inV, size_t n_rows, size_t n_cols, - K* outK, - IndexType* outV, - K initK, - IndexType initV, + key_t* outK, + payload_t* outV, + key_t initK, + payload_t initV, int k) { constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; - __shared__ K smemK[kNumWarps * warp_q]; - __shared__ IndexType smemV[kNumWarps * warp_q]; + __shared__ key_t smemK[kNumWarps * warp_q]; + __shared__ payload_t smemV[kNumWarps * warp_q]; faiss::gpu:: - BlockSelect, warp_q, thread_q, tpb> + BlockSelect, warp_q, thread_q, tpb> heap(initK, initV, smemK, smemV, k); // Grid is exactly sized to rows available @@ -56,8 +62,8 @@ __global__ void select_k_kernel(K* inK, int i = threadIdx.x; int idx = row * n_cols; - K* inKStart = inK + idx + i; - IndexType* inVStart = inV + idx + i; + key_t* inKStart = inK + idx + i; + payload_t* inVStart = inV + idx + i; // Whole warps must participate in the selection int limit = faiss::gpu::utils::roundDown(n_cols, faiss::gpu::kWarpSize); @@ -84,13 +90,13 @@ __global__ void select_k_kernel(K* inK, } } -template -inline void select_k_impl(value_t* inK, - value_idx* inV, +template +inline void select_k_impl(key_t* inK, + payload_t* inV, size_t n_rows, size_t n_cols, - value_t* outK, - value_idx* outV, + key_t* outK, + payload_t* outV, bool select_min, int k, cudaStream_t stream) @@ -100,14 +106,13 @@ inline void select_k_impl(value_t* inK, constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; auto block = dim3(n_threads); - auto kInit = - select_min ? faiss::gpu::Limits::getMax() : faiss::gpu::Limits::getMin(); + auto kInit = select_min ? upper_bound() : lower_bound(); auto vInit = -1; if (select_min) { - select_k_kernel + select_k_kernel <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); } else { - select_k_kernel + select_k_kernel <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); } RAFT_CUDA_TRY(cudaGetLastError()); @@ -127,38 +132,41 @@ inline void select_k_impl(value_t* inK, * @param[in] k number of neighbors per partition (also number of merged neighbors) * @param[in] stream CUDA stream to use */ -template -inline void select_k(value_t* inK, - value_idx* inV, +template +inline void select_k(key_t* inK, + payload_t* inV, size_t n_rows, size_t n_cols, - value_t* outK, - value_idx* outV, + key_t* outK, + payload_t* outV, bool select_min, int k, cudaStream_t stream) { + constexpr int max_k = kFaissMaxK(); if (k == 1) - select_k_impl( + select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); else if (k <= 32) - select_k_impl( + select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); else if (k <= 64) - select_k_impl( + select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); else if (k <= 128) - select_k_impl( + select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); else if (k <= 256) - select_k_impl( + select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); else if (k <= 512) - select_k_impl( + select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 1024) - select_k_impl( + else if (k <= 1024 && k <= max_k) + select_k_impl( inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else + ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); } }; // namespace detail diff --git a/cpp/include/raft/spatial/knn/detail/topk/bitonic_sort.cuh b/cpp/include/raft/spatial/knn/detail/topk/bitonic_sort.cuh new file mode 100644 index 0000000000..44ffe6bc50 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/topk/bitonic_sort.cuh @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::spatial::knn::detail::topk { + +namespace helpers { + +template +__device__ __forceinline__ void swap(T& x, T& y) +{ + T t = x; + x = y; + y = t; +} + +template +__device__ __forceinline__ void conditional_assign(bool cond, T& ptr, T x) +{ + if (cond) { ptr = x; } +} + +} // namespace helpers + +/** + * Warp-wide bitonic merge and sort. + * The data is strided among `warp_width` threads, + * e.g. calling `bitonic<4>(ascending=true).sort(arr)` takes a unique 4-element array as input of + * each thread in a warp and sorts them, such that for a fixed i, arr[i] are sorted within the + * threads in a warp, and for any i < j, arr[j] in any thread is not smaller than arr[i] in any + * other thread. + * When `warp_width < WarpSize`, the data is sorted within all subwarps of the warp independently. + * + * As an example, assuming `Size = 4`, `warp_width = 16`, and `WarpSize = 32`, sorting a permutation + * of numbers 0-63 in each subwarp yield the following result: + * ` + * arr_i \ laneId() + * 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ... + * subwarp_1 subwarp_2 + * 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 0 1 2 ... + * 1 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 16 17 18 ... + * 2 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 32 33 34 ... + * 3 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 48 49 50 ... + * ` + * + * @tparam Size + * number of elements processed in each thread; + * i.e. the total data size is `Size * warp_width`. + * Must be power-of-two. + * + */ +template +class bitonic { + static_assert(isPo2(Size)); + + public: + /** + * Initialize bitonic sort config. + * + * @param ascending + * the resulting order (true: ascending, false: descending). + * @param warp_width + * the number of threads participating in the warp-level primitives; + * the total size of the sorted data is `Size * warp_width`. + * Must be power-of-two, not larger than the WarpSize. + */ + __device__ __forceinline__ explicit bitonic(bool ascending, int warp_width = WarpSize) + : ascending_(ascending), warp_width_(warp_width) + { + } + + bitonic(bitonic const&) = delete; + bitonic(bitonic&&) = delete; + auto operator=(bitonic const&) -> bitonic& = delete; + auto operator=(bitonic&&) -> bitonic& = delete; + + /** + * You can think of this function in two ways: + * + * 1) Sort any bitonic sequence. + * 2) Merge two halfs of the input data assuming they're already sorted, and their order is + * opposite (i.e. either ascending, descending or vice-versa). + * + * The input pointers are unique per-thread. + * See the class description for the description of the data layout. + * + * @param keys + * is a device pointer to a contiguous array of keys, unique per thread; must be at least `Size` + * elements long. + * @param payloads + * are zero or more associated arrays of the same size as keys, which are sorted together with + * the keys; must be at least `Size` elements long. + */ + template + __device__ __forceinline__ void merge(KeyT* __restrict__ keys, + PayloadTs* __restrict__... payloads) const + { + return bitonic::merge_(ascending_, warp_width_, keys, payloads...); + } + + /** + * Sort the data. + * The input pointers are unique per-thread. + * See the class description for the description of the data layout. + * + * @param keys + * is a device pointer to a contiguous array of keys, unique per thread; must be at least `Size` + * elements long. + * @param payloads + * are zero or more associated arrays of the same size as keys, which are sorted together with + * the keys; must be at least `Size` elements long. + */ + template + __device__ __forceinline__ void sort(KeyT* __restrict__ keys, + PayloadTs* __restrict__... payloads) const + { + return bitonic::sort_(ascending_, warp_width_, keys, payloads...); + } + + /** + * @brief `merge` variant for the case of one element per thread + * (pass input by a reference instead of a pointer). + * + * @param key + * @param payload + */ + template + __device__ __forceinline__ auto merge(KeyT& __restrict__ key, + PayloadTs& __restrict__... payload) const + -> std::enable_if_t // SFINAE to enable this for Size == 1 only + { + static_assert(S == Size); + return merge(&key, &payload...); + } + + /** + * @brief `sort` variant for the case of one element per thread + * (pass input by a reference instead of a pointer). + * + * @param key + * @param payload + */ + template + __device__ __forceinline__ auto sort(KeyT& __restrict__ key, + PayloadTs& __restrict__... payload) const + -> std::enable_if_t // SFINAE to enable this for Size == 1 only + { + static_assert(S == Size); + return sort(&key, &payload...); + } + + private: + const int warp_width_; + const bool ascending_; + + template + friend class bitonic; + + template + static __device__ __forceinline__ void merge_(bool ascending, + int warp_width, + KeyT* __restrict__ keys, + PayloadTs* __restrict__... payloads) + { +#pragma unroll + for (int size = Size; size > 1; size >>= 1) { + const int stride = size >> 1; +#pragma unroll + for (int offset = 0; offset < Size; offset += size) { +#pragma unroll + for (int i = offset + stride - 1; i >= offset; i--) { + const int other_i = i + stride; + KeyT& key = keys[i]; + KeyT& other = keys[other_i]; + if (ascending ? key > other : key < other) { + helpers::swap(key, other); + (helpers::swap(payloads[i], payloads[other_i]), ...); + } + } + } + } + const int lane = laneId(); +#pragma unroll + for (int i = 0; i < Size; i++) { + KeyT& key = keys[i]; + for (int stride = (warp_width >> 1); stride > 0; stride >>= 1) { + const bool is_second = lane & stride; + const KeyT other = shfl_xor(key, stride, warp_width); + const bool do_assign = (ascending != is_second) ? key > other : key < other; + + helpers::conditional_assign(do_assign, key, other); + // NB: don't put shfl_xor in a conditional; it must be called by all threads in a warp. + (helpers::conditional_assign( + do_assign, payloads[i], shfl_xor(payloads[i], stride, warp_width)), + ...); + } + } + } + + template + static __device__ __forceinline__ void sort_(bool ascending, + int warp_width, + KeyT* __restrict__ keys, + PayloadTs* __restrict__... payloads) + { + if constexpr (Size == 1) { + const int lane = laneId(); + for (int width = 2; width < warp_width; width <<= 1) { + bitonic<1>::merge_(lane & width, width, keys, payloads...); + } + } else { + constexpr int kSize2 = Size / 2; + bitonic::sort_(false, warp_width, keys, payloads...); + bitonic::sort_(true, warp_width, keys + kSize2, (payloads + kSize2)...); + } + bitonic::merge_(ascending, warp_width, keys, payloads...); + } +}; + +} // namespace raft::spatial::knn::detail::topk diff --git a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh b/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh new file mode 100644 index 0000000000..21e6ea026c --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh @@ -0,0 +1,608 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace raft::spatial::knn::detail::topk { + +constexpr int ITEM_PER_THREAD = 32; +constexpr int VECTORIZED_READ_SIZE = 16; + +template +__host__ __device__ constexpr int calc_num_buckets() +{ + return 1 << BitsPerPass; +} + +template +__host__ __device__ constexpr int calc_num_passes() +{ + return ceildiv(sizeof(T) * 8, BitsPerPass); +} + +/** + * Bit 0 is the least significant (rightmost); + * this implementation processes input from the most to the least significant bit. + * This way, we can skip some passes in the end at the cost of having an unsorted output. + * + * NB: Use pass=-1 for calc_mask(). + */ +template +__device__ constexpr int calc_start_bit(int pass) +{ + int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; + if (start_bit < 0) { start_bit = 0; } + return start_bit; +} + +template +__device__ constexpr unsigned calc_mask(int pass) +{ + static_assert(BitsPerPass <= 31); + int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); + return (1 << num_bits) - 1; +} + +/** + * Use cub to twiddle bits - so that we can correctly compare bits of floating-point values as well + * as of integers. + */ +template +__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +{ + auto bits = reinterpret_cast::UnsignedBits&>(key); + bits = cub::Traits::TwiddleIn(bits); + if (greater) { bits = ~bits; } + return bits; +} + +template +__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +{ + static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int + return (twiddle_in(x, greater) >> start_bit) & mask; +} + +/** + * Map a Func over the input data, using vectorized load instructions if possible. + * + * NB: in future, we should move this to cpp/include/raft/linalg/detail/unary_op.cuh, which + * currently does not support the second lambda argument (index of an element) + * + * @tparam T element type + * @tparam IdxT indexing type + * @tparam Func void (T x, IdxT idx) + * + * @param in the input data + * @param len the number of elements to read + * @param f the lambda taking two arguments (T x, IdxT idx) + */ +template +__device__ void vectorized_process(const T* in, IdxT len, Func f) +{ + const IdxT stride = blockDim.x * gridDim.x; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { + for (IdxT i = tid; i < len; i += stride) { + f(in[i], i); + } + } else { + using wide_t = TxN_t; + using align_bytes = Pow2<(size_t)VECTORIZED_READ_SIZE>; + using align_elems = Pow2; + wide_t wide; + + // how many elements to skip in order to do aligned vectorized load + const IdxT skip_cnt_left = std::min((IdxT)(align_bytes::roundUp(in) - in), len); + + // The main loop: process all aligned data + for (IdxT i = tid * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; + i += stride * wide_t::Ratio) { + wide.load(in, i); +#pragma unroll + for (int j = 0; j < wide_t::Ratio; ++j) { + f(wide.val.data[j], i + j); + } + } + + static_assert(WarpSize >= wide_t::Ratio); + // Processes the skipped elements on the left + if (tid < skip_cnt_left) { f(in[tid], tid); } + // Processes the skipped elements on the right + const IdxT skip_cnt_right = align_elems::mod(len - skip_cnt_left); + const IdxT remain_i = len - skip_cnt_right + tid; + if (remain_i < len) { f(in[remain_i], remain_i); } + } +} + +template +struct Counter { + IdxT k; + IdxT len; + IdxT previous_len; + int bucket; + + IdxT filter_cnt; + unsigned int finished_block_cnt; + IdxT out_cnt; + IdxT out_back_cnt; +}; + +/** + * Fused filtering of the current phase and building histogram for the next phase + * (see steps 4-1 in `radix_kernel` description). + */ +template +__device__ void filter_and_histogram(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + Counter* counter, + IdxT* histogram, + bool greater, + int pass, + int k) +{ + constexpr int num_buckets = calc_num_buckets(); + __shared__ IdxT histogram_smem[num_buckets]; + for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram_smem[i] = 0; + } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + + if (pass == 0) { + // Passed to vectorized_process, this function executes in all blocks in parallel, + // i.e. the work is split along the input (both, in batches and chunks of a single row). + // Later, the histograms are merged using atomicAdd. + auto f = [greater, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, IdxT(1)); + }; + vectorized_process(in_buf, len, f); + } else { + const IdxT previous_len = counter->previous_len; + const int want_bucket = counter->bucket; + IdxT& filter_cnt = counter->filter_cnt; + IdxT& out_cnt = counter->out_cnt; + const IdxT counter_len = counter->len; + const int previous_start_bit = calc_start_bit(pass - 1); + const unsigned previous_mask = calc_mask(pass - 1); + + // See the remark above on the distributed execution of `f` using vectorized_process. + auto f = [in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + greater, + k, + start_bit, + mask, + previous_start_bit, + previous_mask, + want_bucket, + &filter_cnt, + &out_cnt, + counter_len](T value, IdxT i) { + int prev_bucket = + calc_bucket(value, previous_start_bit, previous_mask, greater); + if (prev_bucket == want_bucket) { + IdxT pos = atomicAdd(&filter_cnt, IdxT(1)); + out_buf[pos] = value; + if (out_idx_buf) { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, IdxT(1)); + + if (counter_len == 1) { + out[k - 1] = value; + out_idx[k - 1] = in_idx_buf ? in_idx_buf[i] : i; + } + } else if (prev_bucket < want_bucket) { + IdxT pos = atomicAdd(&out_cnt, IdxT(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + }; + + vectorized_process(in_buf, previous_len, f); + } + __syncthreads(); + + // merge histograms produced by individual blocks + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } + } +} + +/** + * Replace a part of the histogram with its own prefix sum, starting from the `start` and adding + * `current` to each entry of the result. + * (step 2 in `radix_kernel` description) + */ +template +__device__ void scan(volatile IdxT* histogram, + const int start, + const int num_buckets, + const IdxT current) +{ + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + IdxT thread_data = 0; + int index = start + threadIdx.x; + if (index < num_buckets) { thread_data = histogram[index]; } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + if (index < num_buckets) { histogram[index] = thread_data + current; } + __syncthreads(); // This sync is necessary, as the content of histogram needs + // to be read after +} + +/** + * Calculate in which bucket the k-th value will fall + * (steps 2-3 in `radix_kernel` description) + */ +template +__device__ void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +{ + constexpr int num_buckets = calc_num_buckets(); + int index = threadIdx.x; + IdxT last_prefix_sum = 0; + int num_pass = 1; + if constexpr (num_buckets >= BlockSize) { + static_assert(num_buckets % BlockSize == 0); + num_pass = num_buckets / BlockSize; + } + + for (int i = 0; i < num_pass && (last_prefix_sum < k); i++) { + // Turn the i-th chunk of the histogram into its prefix sum. + scan(histogram, i * BlockSize, num_buckets, last_prefix_sum); + if (index < num_buckets) { + // Number of values in the previous `index-1` buckets (see the `scan` op above) + IdxT prev = (index == 0) ? 0 : histogram[index - 1]; + // Number of values in `index` buckets + IdxT cur = histogram[index]; + + // one and only one thread will satisfy this condition, so only write once + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->previous_len = counter->len; + counter->len = cur - prev; // number of values in `index` bucket + counter->bucket = index; + } + } + index += BlockSize; + // this will break the loop when the counter is set (cur >= k), because last_prefix_sum >= cur + last_prefix_sum = histogram[(i + 1) * BlockSize - 1]; + } +} + +/** + * + * It is expected to call this kernel multiple times (passes), in each pass we process a radix, + * going from the most significant towards the least significant bits (MSD). + * + * Conceptually, each pass consists of 4 steps: + * + * 1. Calculate histogram + * First, transform bits into a digit, the value of which is in the range + * [0, 2^{BITS_PER_PASS}-1]. Then count the frequency of each digit value and the result is a + * histogram. That is, histogram[i] contains the count of inputs having value i. + * + * 2. Scan the histogram + * Inclusive prefix sum is computed for the histogram. After this step, histogram[i] contains + * the count of inputs having value <= i. + * + * 3. Find the bucket j of the histogram that the k-th value falls into + * + * 4. Filtering + * Input elements whose digit value +__global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const int k, + const bool greater, + const int pass) +{ + __shared__ bool isLastBlockDone; + + constexpr int num_buckets = calc_num_buckets(); + constexpr int num_passes = calc_num_passes(); + const int batch_id = blockIdx.y; + in_buf += batch_id * len; + out_buf += batch_id * len; + out += batch_id * k; + out_idx += batch_id * k; + if (in_idx_buf) { in_idx_buf += batch_id * len; } + if (out_idx_buf) { out_idx_buf += batch_id * len; } + + auto counter = counters + batch_id; + auto histogram = histograms + batch_id * num_buckets; + + filter_and_histogram(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + len, + counter, + histogram, + greater, + pass, + k); + __threadfence(); + + if (threadIdx.x == 0) { + unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); + isLastBlockDone = (finished == (gridDim.x - 1)); + } + + // Synchronize to make sure that each thread reads the correct value of + // isLastBlockDone. + __syncthreads(); + if (isLastBlockDone) { + if (counter->len == 1 && threadIdx.x == 0) { + counter->previous_len = 0; + counter->len = 0; + } + // init counter, other members of counter is initialized with 0 by + // cudaMemset() + if (pass == 0 && threadIdx.x == 0) { + counter->k = k; + counter->len = len; + counter->out_back_cnt = 0; + } + __syncthreads(); + + IdxT ori_k = counter->k; + + if (counter->len > 0) { + choose_bucket(counter, histogram, ori_k); + } + + __syncthreads(); + if (pass == num_passes - 1) { + const IdxT previous_len = counter->previous_len; + const int want_bucket = counter->bucket; + int start_bit = calc_start_bit(pass); + unsigned mask = calc_mask(pass); + + // radix topk + IdxT& out_cnt = counter->out_cnt; + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = out_buf[i]; + int bucket = calc_bucket(value, start_bit, mask, greater); + if (bucket < want_bucket) { + IdxT pos = atomicAdd(&out_cnt, IdxT(1)); + out[pos] = value; + out_idx[pos] = out_idx_buf[i]; + } else if (bucket == want_bucket) { + IdxT needed_num_of_kth = counter->k; + IdxT back_pos = atomicAdd(&(counter->out_back_cnt), IdxT(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = out_idx_buf[i]; + } + } + } + __syncthreads(); + } else { + // reset for next pass + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + if (threadIdx.x == 0) { counter->filter_cnt = 0; } + } + } +} + +/** + * Calculate the minimal batch size, such that GPU is still fully occupied. + */ +template +inline uint16_t get_optimal_batch_size(size_t req_batch_size, size_t blocks_per_row) +{ + int dev_id, sm_count, occupancy, max_grid_dim_y; + RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&max_grid_dim_y, cudaDevAttrMaxGridDimY, dev_id)); + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, radix_kernel, BlockSize, 0)); + + // fully occupy GPU + size_t opt_batch_size = ceildiv(sm_count * occupancy, blocks_per_row); + // round it up to the closest pow-of-two for better data alignment + opt_batch_size = isPo2(opt_batch_size) ? opt_batch_size : (1 << (log2(opt_batch_size) + 1)); + // Take a max possible pow-of-two grid_dim_y + max_grid_dim_y = isPo2(max_grid_dim_y) ? max_grid_dim_y : (1 << log2(max_grid_dim_y)); + // If the optimal batch size is very small compared to the requested batch size, we know + // the extra required memory is not significant and we can increase the batch size for + // better occupancy when the grid size is not multiple of the SM count. + // Also don't split the batch size when there is not much work overall. + const size_t safe_enlarge_factor = 9; + const size_t min_grid_size = 1024; + while ((opt_batch_size << safe_enlarge_factor) < req_batch_size || + blocks_per_row * opt_batch_size < min_grid_size) { + opt_batch_size <<= 1; + } + + // Do not exceed the max grid size. + opt_batch_size = std::min(opt_batch_size, size_t(max_grid_dim_y)); + + // Don't do more work than needed + return uint16_t(std::min(opt_batch_size, req_batch_size)); +} + +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_keys` as a row-major matrix with len columns and + * batch_size rows, then this function selects k smallest/largest values in each row and fills + * in the row-major matrix `out` of size (batch_size, k). + * + * Note, the output is NOT sorted within the groups of `k` selected elements. + * + * @tparam T + * the type of the keys (what is being compared). + * @tparam IdxT + * the index type (what is being selected together with the keys). + * @tparam BitsPerPass + * The size of the radix; + * it affects the number of passes and number of buckets. + * @tparam BlockSize + * Number of threads in a kernel thread block. + * + * @param[in] in + * contiguous device array of inputs of size (len * batch_size); + * these are compared and selected. + * @param[in] in_idx + * contiguous device array of inputs of size (len * batch_size); + * typically, these are indices of the corresponding in_keys. + * @param[in] batch_size + * number of input rows, i.e. the batch size. + * @param[in] len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: len >= k. + * @param[in] k + * the number of outputs to select in each input row. + * @param[out] out + * contiguous device array of outputs of size (k * batch_size); + * the k smallest/largest values from each row of the `in_keys`. + * @param[out] out_idx + * contiguous device array of outputs of size (k * batch_size); + * the payload selected together with `out`. + * @param[in] select_min + * whether to select k smallest (true) or largest (false) keys. + * @param[in] stream + */ +template +void radix_topk(const T* in, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream) +{ + // TODO: is it possible to relax this restriction? + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + size_t blocks_per_row = ceildiv(len, BlockSize * ITEM_PER_THREAD); + uint16_t max_chunk_size = + get_optimal_batch_size(batch_size, blocks_per_row); + + rmm::device_uvector> counters(max_chunk_size, stream); + rmm::device_uvector histograms(num_buckets * max_chunk_size, stream); + rmm::device_uvector buf1(len * max_chunk_size, stream); + rmm::device_uvector idx_buf1(len * max_chunk_size, stream); + rmm::device_uvector buf2(len * max_chunk_size, stream); + rmm::device_uvector idx_buf2(len * max_chunk_size, stream); + + for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { + auto chunk_size = uint16_t(std::min(max_chunk_size, batch_size - offset)); + + RAFT_CUDA_TRY( + cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + dim3 blocks(blocks_per_row, chunk_size); + + constexpr int num_passes = calc_num_passes(); + + for (int pass = 0; pass < num_passes; ++pass) { + if (pass == 0) { + in_buf = in + offset * len; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in + offset * len; + in_idx_buf = in_idx ? in_idx + offset * len : nullptr; + out_buf = buf1.data(); + out_idx_buf = idx_buf1.data(); + } else if (pass % 2 == 0) { + in_buf = buf1.data(); + in_idx_buf = idx_buf1.data(); + out_buf = buf2.data(); + out_idx_buf = idx_buf2.data(); + } else { + in_buf = buf2.data(); + in_idx_buf = idx_buf2.data(); + out_buf = buf1.data(); + out_idx_buf = idx_buf1.data(); + } + + radix_kernel + <<>>(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out + offset * k, + out_idx + offset * k, + counters.data(), + histograms.data(), + len, + k, + !select_min, + pass); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } +} + +} // namespace raft::spatial::knn::detail::topk diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh new file mode 100644 index 0000000000..f5ea8ba879 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh @@ -0,0 +1,881 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "bitonic_sort.cuh" + +#include +#include + +#include +#include +#include + +/* + Three APIs of different scopes are provided: + 1. host function: warp_sort_topk() + 2. block-wide API: class block_sort + 3. warp-wide API: class warp_sort_filtered and class warp_sort_immediate + + + 1. warp_sort_topk() + (see the docstring) + + 2. class block_sort + It can be regarded as a fixed size priority queue for a thread block, + although the API is not typical. + class warp_sort_filtered and warp_sort_immediate can be used to instantiate block_sort. + + It uses dynamic shared memory as intermediate buffer. + So the required shared memory size should be calculated using + calc_smem_size_for_block_wide() and passed as the 3rd kernel launch parameter. + + Two overload functions can be used to add items to the queue. + One is load(const T* in, IdxT start, IdxT end) and it adds a range of items, + namely [start, end) of in. The idx is inferred from start. + This function should be called only once to add all items, and should not be + used together with the add(). + The second one is add(T val, IdxT idx), and it adds only one item pair. + Note that the range [start, end) is for the whole block of threads, that is, + each thread in the same block should get the same start/end. + In contrast, the parameters of the second form are for only one thread, + so each thread must get different val/idx. + + After adding is finished, function done() should be called. And finally, + store() is used to get the top-k result. + + Example: + __global__ void kernel() { + block_sort queue(...); + + // way 1, [0, len) is same for the whole block + queue.load(in, 0, len); + // way 2, each thread gets its own val/idx pair + for (IdxT i = threadIdx.x; i < len, i += blockDim.x) { + queue.add(in[i], idx[i]); + } + + queue.done(); + queue.store(out, out_idx); + } + + int smem_size = calc_smem_size_for_block_wide(...); + kernel<<>>(); + + + 3. class warp_sort_filtered and class warp_sort_immediate + These two classes can be regarded as fixed size priority queue for a warp. + Usage is similar to class block_sort. + Two types of add() functions are provided, and also note that [start, end) is + for a whole warp, while val/idx is for a thread. + No shared memory is needed. + + The host function (warp_sort_topk) uses a heuristic to choose between these two classes for + sorting, warp_sort_immediate being chosen when the number of inputs per warp is somewhat small + (see the usage of LaunchThreshold::len_factor_for_choosing). + + Example: + __global__ void kernel() { + warp_sort_immediate<...> queue(...); + int warp_id = threadIdx.x / WarpSize; + int lane_id = threadIdx.x % WarpSize; + + // way 1, [0, len) is same for the whole warp + queue.load(in, 0, len); + // way 2, each thread gets its own val/idx pair + for (IdxT i = lane_id; i < len, i += WarpSize) { + queue.add(in[i], idx[i]); + } + + queue.done(); + // each warp outputs to a different offset + queue.store(out+ warp_id * k, out_idx+ warp_id * k); + } + */ + +namespace raft::spatial::knn::detail::topk { + +static constexpr int kMaxCapacity = 256; + +namespace { + +/** Whether 'left` should indeed be on the left w.r.t. `right`. */ +template +__device__ __forceinline__ auto is_ordered(T left, T right) -> bool +{ + if constexpr (Ascending) { return left < right; } + if constexpr (!Ascending) { return left > right; } +} + +constexpr auto calc_capacity(int k) -> int +{ + int capacity = isPo2(k) ? k : (1 << (log2(k) + 1)); + if (capacity < WarpSize) { capacity = WarpSize; } // TODO: remove this to allow small sizes. + return capacity; +} + +} // namespace + +/** + * A fixed-size warp-level priority queue. + * By feeding the data through this queue, you get the `k <= Capacity` + * smallest/greatest values in the data. + * + * @tparam Capacity + * maximum number of elements in the queue. + * @tparam Ascending + * which comparison to use: `true` means `<`, collect the smallest elements, + * `false` means `>`, collect the greatest elements. + * @tparam T + * the type of keys (what is being compared) + * @tparam IdxT + * the type of payload (normally, indices of elements), i.e. + * the content sorted alongside the keys. + */ +template +class warp_sort { + static_assert(isPo2(Capacity)); + + public: + /** + * Construct the warp_sort empty queue. + * + * @param k + * number of elements to select. + * @param dummy + * the `empty` value for the choosen binary operation, + * i.e. `Ascending ? upper_bound() : lower_bound()`. + * + */ + __device__ warp_sort(IdxT k, T dummy) : k_(k), dummy_(dummy) + { +#pragma unroll + for (int i = 0; i < kMaxArrLen; i++) { + val_arr_[i] = dummy_; + } + } + + /** + * Load k values from the pointers at the given position, and merge them in the storage. + */ + __device__ void load_sorted(const T* in, const IdxT* in_idx) + { + IdxT idx = kWarpWidth - 1 - Pow2::mod(laneId()); +#pragma unroll + for (int i = kMaxArrLen - 1; i >= 0; --i, idx += kWarpWidth) { + if (idx < k_) { + T t = in[idx]; + if (is_ordered(t, val_arr_[i])) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } + } + } + topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + } + + /** Save the content by the pointer location. */ + __device__ void store(T* out, IdxT* out_idx) const + { + IdxT idx = Pow2::mod(laneId()); +#pragma unroll kMaxArrLen + for (int i = 0; i < kMaxArrLen && idx < k_; i++, idx += kWarpWidth) { + out[idx] = val_arr_[i]; + out_idx[idx] = idx_arr_[i]; + } + } + + protected: + static constexpr int kWarpWidth = std::min(Capacity, WarpSize); + static constexpr int kMaxArrLen = Capacity / kWarpWidth; + + const IdxT k_; + const T dummy_; + T val_arr_[kMaxArrLen]; + IdxT idx_arr_[kMaxArrLen]; + + /** + * Merge another array (sorted in the opposite direction) in the queue. + * Thanks to the other array being sorted in the opposite direction, + * it's enough to call bitonic.merge once to maintain the valid state + * of the queue. + * + * @tparam PerThreadSizeIn + * the size of the other array per-thread (compared to `kMaxArrLen`). + * + * @param keys_in + * the values to be merged in. Pointers are unique per-thread. The values + * must already be sorted in the opposite direction. + * The layout of `keys_in` must be the same as the layout of `val_arr_`. + * @param ids_in + * the associated indices of the elements in the same format as `keys_in`. + */ + template + __device__ __forceinline__ void merge_in(const T* __restrict__ keys_in, + const IdxT* __restrict__ ids_in) + { +#pragma unroll + for (int i = std::min(kMaxArrLen, PerThreadSizeIn); i > 0; i--) { + T& key = val_arr_[kMaxArrLen - i]; + T other = keys_in[PerThreadSizeIn - i]; + if (is_ordered(other, key)) { + key = other; + idx_arr_[kMaxArrLen - i] = ids_in[PerThreadSizeIn - i]; + } + } + topk::bitonic(Ascending).merge(val_arr_, idx_arr_); + } +}; + +/** + * This version of warp_sort compares each input element against the current + * estimate of k-th value before adding it to the intermediate sorting buffer. + * This makes the algorithm do less sorting steps for long input sequences + * at the cost of extra checks on each step. + * + * This implementation is preferred for large len values. + */ +template +class warp_sort_filtered : public warp_sort { + static_assert(Capacity >= WarpSize); + + public: + __device__ warp_sort_filtered(int k, T dummy) + : warp_sort(k, dummy), buf_len_(0), k_th_(dummy) + { +#pragma unroll + for (int i = 0; i < kMaxBufLen; i++) { + val_buf_[i] = dummy_; + } + } + + __device__ void load(const T* in, const IdxT* in_idx, IdxT start, IdxT end) + { + const IdxT end_for_fullwarp = Pow2::roundUp(end - start) + start; + for (IdxT i = start + laneId(); i < end_for_fullwarp; i += WarpSize) { + T val = (i < end) ? in[i] : dummy_; + IdxT idx = (i < end) ? in_idx[i] : std::numeric_limits::max(); + add(val, idx); + } + } + + __device__ void add(T val, IdxT idx) + { + // comparing for k_th should reduce the total amount of updates: + // `false` means the input value is surely not in the top-k values. + if (is_ordered(val, k_th_)) { + // NB: the loop is used here to ensure the constant indexing, + // to not force the buffers spill into the local memory. +#pragma unroll + for (int i = 0; i < kMaxBufLen; i++) { + if (i == buf_len_) { + val_buf_[i] = val; + idx_buf_[i] = idx; + } + } + ++buf_len_; + } + if (any(buf_len_ == kMaxBufLen)) { merge_buf_(); } + } + + __device__ void done() + { + if (any(buf_len_ != 0)) { merge_buf_(); } + } + + private: + __device__ void set_k_th_() + { + // NB on using srcLane: it's ok if it is outside the warp size / width; + // the modulo op will be done inside the __shfl_sync. + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k_ - 1); + } + + __device__ void merge_buf_() + { + topk::bitonic(!Ascending).sort(val_buf_, idx_buf_); + this->merge_in(val_buf_, idx_buf_); + buf_len_ = 0; + set_k_th_(); // contains warp sync +#pragma unroll + for (int i = 0; i < kMaxBufLen; i++) { + val_buf_[i] = dummy_; + } + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + using warp_sort::k_; + using warp_sort::dummy_; + + static constexpr int kMaxBufLen = (Capacity <= 64) ? 2 : 4; + + T val_buf_[kMaxBufLen]; + IdxT idx_buf_[kMaxBufLen]; + int buf_len_; + + T k_th_; +}; + +/** + * This version of warp_sort adds every input element into the intermediate sorting + * buffer, and thus does the sorting step every `Capacity` input elements. + * + * This implementation is preferred for very small len values. + */ +template +class warp_sort_immediate : public warp_sort { + static_assert(Capacity >= WarpSize); + + public: + __device__ warp_sort_immediate(int k, T dummy) + : warp_sort(k, dummy), buf_len_(0) + { +#pragma unroll + for (int i = 0; i < kMaxArrLen; i++) { + val_buf_[i] = dummy_; + } + } + + __device__ void load(const T* in, const IdxT* in_idx, IdxT start, IdxT end) + { + add_first_(in, in_idx, start, end); + start += Capacity; + while (start < end) { + add_extra_(in, in_idx, start, end); + this->merge_in(val_buf_, idx_buf_); + start += Capacity; + } + } + + __device__ void add(T val, IdxT idx) + { + // NB: the loop is used here to ensure the constant indexing, + // to not force the buffers spill into the local memory. +#pragma unroll + for (int i = 0; i < kMaxArrLen; ++i) { + if (i == buf_len_) { + val_buf_[i] = val; + idx_buf_[i] = idx; + } + } + + ++buf_len_; + if (buf_len_ == kMaxArrLen) { + topk::bitonic(!Ascending).sort(val_buf_, idx_buf_); + this->merge_in(val_buf_, idx_buf_); +#pragma unroll + for (int i = 0; i < kMaxArrLen; i++) { + val_buf_[i] = dummy_; + } + buf_len_ = 0; + } + } + + __device__ void done() + { + if (buf_len_ != 0) { + topk::bitonic(!Ascending).sort(val_buf_, idx_buf_); + this->merge_in(val_buf_, idx_buf_); + } + } + + private: + /** Fill in the primary val_arr_/idx_arr_ */ + __device__ void add_first_(const T* in, const IdxT* in_idx, IdxT start, IdxT end) + { + IdxT idx = start + laneId(); + for (int i = 0; i < kMaxArrLen; ++i, idx += WarpSize) { + if (idx < end) { + val_arr_[i] = in[idx]; + idx_arr_[i] = in_idx[idx]; + } + } + topk::bitonic(Ascending).sort(val_arr_, idx_arr_); + } + + /** Fill in the secondary val_buf_/idx_buf_ */ + __device__ void add_extra_(const T* in, const IdxT* in_idx, IdxT start, IdxT end) + { + IdxT idx = start + laneId(); + for (int i = 0; i < kMaxArrLen; ++i, idx += WarpSize) { + val_buf_[i] = (idx < end) ? in[idx] : dummy_; + idx_buf_[i] = (idx < end) ? in_idx[idx] : std::numeric_limits::max(); + } + topk::bitonic(!Ascending).sort(val_buf_, idx_buf_); + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + using warp_sort::k_; + using warp_sort::dummy_; + + T val_buf_[kMaxArrLen]; + IdxT idx_buf_[kMaxArrLen]; + int buf_len_; +}; + +/** + * This one is used for the second pass only: + * if the first pass happens in multiple blocks, the output consists of a series + * of sorted arrays, length `k` each. + * Under this assumption, we can use load_sorted to just do the merging, rather than + * the full sort. + */ +template +class warp_merge : public warp_sort { + public: + __device__ warp_merge(int k, T dummy) : warp_sort(k, dummy) {} + + // NB: the input is already sorted, because it's the second pass. + __device__ void load(const T* in, const IdxT* in_idx, IdxT start, IdxT end) + { + for (; start < end; start += k_) { + load_sorted(in + start, in_idx + start); + } + } + + __device__ void done() {} + + private: + using warp_sort::kWarpWidth; + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + using warp_sort::k_; + using warp_sort::dummy_; +}; + +template +int calc_smem_size_for_block_wide(int num_of_warp, IdxT k) +{ + return Pow2<256>::roundUp(num_of_warp / 2 * sizeof(T) * k) + num_of_warp / 2 * sizeof(IdxT) * k; +} + +template