From a1ace038e62b7857bdd70ce68e9428aec763f22e Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Mon, 16 May 2022 18:11:13 +0200 Subject: [PATCH] Improve performance of select-top-k WARP_SORT implementation (#606) A few simplifications and tricks to improve the performance of the kernel: - Promote some constants to static constexpr - Allow `capacity < WarpSize` - Reduce the frequency of `sort` operations for `filtered` version - Remove `warp_sort::load` to simplify the api and implementation Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/606 --- .../spatial/knn/detail/topk/warpsort_topk.cuh | 573 ++++++++---------- 1 file changed, 263 insertions(+), 310 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh index f5ea8ba879..a599e8367e 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh @@ -40,33 +40,22 @@ although the API is not typical. class warp_sort_filtered and warp_sort_immediate can be used to instantiate block_sort. - It uses dynamic shared memory as intermediate buffer. + It uses dynamic shared memory as an intermediate buffer. So the required shared memory size should be calculated using calc_smem_size_for_block_wide() and passed as the 3rd kernel launch parameter. - Two overload functions can be used to add items to the queue. - One is load(const T* in, IdxT start, IdxT end) and it adds a range of items, - namely [start, end) of in. The idx is inferred from start. - This function should be called only once to add all items, and should not be - used together with the add(). - The second one is add(T val, IdxT idx), and it adds only one item pair. - Note that the range [start, end) is for the whole block of threads, that is, - each thread in the same block should get the same start/end. - In contrast, the parameters of the second form are for only one thread, - so each thread must get different val/idx. + To add elements to the queue, use add(T val, IdxT idx) with unique values per-thread. + Use WarpSortClass<...>::kDummy constant for the threads outside of input bounds. - After adding is finished, function done() should be called. And finally, - store() is used to get the top-k result. + After adding is finished, function done() should be called. And finally, store() is used to get + the top-k result. Example: __global__ void kernel() { block_sort queue(...); - // way 1, [0, len) is same for the whole block - queue.load(in, 0, len); - // way 2, each thread gets its own val/idx pair for (IdxT i = threadIdx.x; i < len, i += blockDim.x) { - queue.add(in[i], idx[i]); + queue.add(in[i], in_idx[i]); } queue.done(); @@ -79,10 +68,7 @@ 3. class warp_sort_filtered and class warp_sort_immediate These two classes can be regarded as fixed size priority queue for a warp. - Usage is similar to class block_sort. - Two types of add() functions are provided, and also note that [start, end) is - for a whole warp, while val/idx is for a thread. - No shared memory is needed. + Usage is similar to class block_sort. No shared memory is needed. The host function (warp_sort_topk) uses a heuristic to choose between these two classes for sorting, warp_sort_immediate being chosen when the number of inputs per warp is somewhat small @@ -94,16 +80,13 @@ int warp_id = threadIdx.x / WarpSize; int lane_id = threadIdx.x % WarpSize; - // way 1, [0, len) is same for the whole warp - queue.load(in, 0, len); - // way 2, each thread gets its own val/idx pair for (IdxT i = lane_id; i < len, i += WarpSize) { queue.add(in[i], idx[i]); } queue.done(); // each warp outputs to a different offset - queue.store(out+ warp_id * k, out_idx+ warp_id * k); + queue.store(out + warp_id * k, out_idx + warp_id * k); } */ @@ -124,7 +107,6 @@ __device__ __forceinline__ auto is_ordered(T left, T right) -> bool constexpr auto calc_capacity(int k) -> int { int capacity = isPo2(k) ? k : (1 << (log2(k) + 1)); - if (capacity < WarpSize) { capacity = WarpSize; } // TODO: remove this to allow small sizes. return capacity; } @@ -151,60 +133,92 @@ class warp_sort { static_assert(isPo2(Capacity)); public: + /** + * The `empty` value for the choosen binary operation, + * i.e. `Ascending ? upper_bound() : lower_bound()`. + */ + static constexpr T kDummy = Ascending ? upper_bound() : lower_bound(); + /** Width of the subwarp. */ + static constexpr int kWarpWidth = std::min(Capacity, WarpSize); + /** The number of elements to select. */ + const int k; + /** * Construct the warp_sort empty queue. * * @param k * number of elements to select. - * @param dummy - * the `empty` value for the choosen binary operation, - * i.e. `Ascending ? upper_bound() : lower_bound()`. - * */ - __device__ warp_sort(IdxT k, T dummy) : k_(k), dummy_(dummy) + __device__ warp_sort(int k) : k(k) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { - val_arr_[i] = dummy_; + val_arr_[i] = kDummy; } } /** * Load k values from the pointers at the given position, and merge them in the storage. + * + * When it actually loads the values, it always performs some collective warp operations in the + * end, thus enforcing warp sync. This means, it's safe to call `store` with the same arguments + * after `load_sorted` without extra sync. Note, however, that this is not neccesarily true for + * the reverse order, because the access patterns of `store` and `load_sorted` are different. + * + * @param[in] in + * a device pointer to a contiguous array, unique per-subwarp + * (length: k <= kWarpWidth * kMaxArrLen). + * @param[in] in_idx + * a device pointer to a contiguous array, unique per-subwarp + * (length: k <= kWarpWidth * kMaxArrLen). + * @param[in] do_merge + * must be the same for all threads within a subwarp of size `kWarpWidth`. + * It serves as a conditional; when `false` the function does nothing. + * We need it to ensure threads within a full warp don't diverge calling `bitonic::merge()`. */ - __device__ void load_sorted(const T* in, const IdxT* in_idx) + __device__ void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) { - IdxT idx = kWarpWidth - 1 - Pow2::mod(laneId()); + if (do_merge) { + int idx = Pow2::mod(laneId()) ^ Pow2::Mask; #pragma unroll - for (int i = kMaxArrLen - 1; i >= 0; --i, idx += kWarpWidth) { - if (idx < k_) { - T t = in[idx]; - if (is_ordered(t, val_arr_[i])) { - val_arr_[i] = t; - idx_arr_[i] = in_idx[idx]; + for (int i = kMaxArrLen - 1; i >= 0; --i, idx += kWarpWidth) { + if (idx < k) { + T t = in[idx]; + if (is_ordered(t, val_arr_[i])) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } } } } - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + if (kWarpWidth < WarpSize || do_merge) { + topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + } } - /** Save the content by the pointer location. */ + /** + * Save the content by the pointer location. + * + * @param[out] out + * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` + * (length: k <= kWarpWidth * kMaxArrLen). + * @param[out] out_idx + * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` + * (length: k <= kWarpWidth * kMaxArrLen). + */ __device__ void store(T* out, IdxT* out_idx) const { - IdxT idx = Pow2::mod(laneId()); + int idx = Pow2::mod(laneId()); #pragma unroll kMaxArrLen - for (int i = 0; i < kMaxArrLen && idx < k_; i++, idx += kWarpWidth) { + for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) { out[idx] = val_arr_[i]; out_idx[idx] = idx_arr_[i]; } } protected: - static constexpr int kWarpWidth = std::min(Capacity, WarpSize); static constexpr int kMaxArrLen = Capacity / kWarpWidth; - const IdxT k_; - const T dummy_; T val_arr_[kMaxArrLen]; IdxT idx_arr_[kMaxArrLen]; @@ -237,7 +251,7 @@ class warp_sort { idx_arr_[kMaxArrLen - i] = ids_in[PerThreadSizeIn - i]; } } - topk::bitonic(Ascending).merge(val_arr_, idx_arr_); + topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } }; @@ -251,25 +265,17 @@ class warp_sort { */ template class warp_sort_filtered : public warp_sort { - static_assert(Capacity >= WarpSize); - public: - __device__ warp_sort_filtered(int k, T dummy) - : warp_sort(k, dummy), buf_len_(0), k_th_(dummy) + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + + __device__ warp_sort_filtered(int k) + : warp_sort(k), buf_len_(0), k_th_(kDummy) { #pragma unroll for (int i = 0; i < kMaxBufLen; i++) { - val_buf_[i] = dummy_; - } - } - - __device__ void load(const T* in, const IdxT* in_idx, IdxT start, IdxT end) - { - const IdxT end_for_fullwarp = Pow2::roundUp(end - start) + start; - for (IdxT i = start + laneId(); i < end_for_fullwarp; i += WarpSize) { - T val = (i < end) ? in[i] : dummy_; - IdxT idx = (i < end) ? in_idx[i] : std::numeric_limits::max(); - add(val, idx); + val_buf_[i] = kDummy; } } @@ -277,19 +283,18 @@ class warp_sort_filtered : public warp_sort { { // comparing for k_th should reduce the total amount of updates: // `false` means the input value is surely not in the top-k values. - if (is_ordered(val, k_th_)) { - // NB: the loop is used here to ensure the constant indexing, - // to not force the buffers spill into the local memory. -#pragma unroll - for (int i = 0; i < kMaxBufLen; i++) { - if (i == buf_len_) { - val_buf_[i] = val; - idx_buf_[i] = idx; - } + bool do_add = is_ordered(val, k_th_); + // merge the buf if it's full and we cannot add an element anymore. + if (any(buf_len_ + do_add > kMaxBufLen)) { + // still, add an element before merging if possible for this thread + if (do_add && buf_len_ < kMaxBufLen) { + add_to_buf_(val, idx); + do_add = false; } - ++buf_len_; + merge_buf_(); } - if (any(buf_len_ == kMaxBufLen)) { merge_buf_(); } + // add an element if necessary and haven't already. + if (do_add) { add_to_buf_(val, idx); } } __device__ void done() @@ -298,30 +303,42 @@ class warp_sort_filtered : public warp_sort { } private: - __device__ void set_k_th_() + __device__ __forceinline__ void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. - k_th_ = shfl(val_arr_[kMaxArrLen - 1], k_ - 1); + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ void merge_buf_() + __device__ __forceinline__ void merge_buf_() { - topk::bitonic(!Ascending).sort(val_buf_, idx_buf_); + topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); buf_len_ = 0; set_k_th_(); // contains warp sync #pragma unroll for (int i = 0; i < kMaxBufLen; i++) { - val_buf_[i] = dummy_; + val_buf_[i] = kDummy; } } + __device__ __forceinline__ void add_to_buf_(T val, IdxT idx) + { + // NB: the loop is used here to ensure the constant indexing, + // to not force the buffers spill into the local memory. +#pragma unroll + for (int i = 0; i < kMaxBufLen; i++) { + if (i == buf_len_) { + val_buf_[i] = val; + idx_buf_[i] = idx; + } + } + buf_len_++; + } + using warp_sort::kMaxArrLen; using warp_sort::val_arr_; using warp_sort::idx_arr_; - using warp_sort::k_; - using warp_sort::dummy_; static constexpr int kMaxBufLen = (Capacity <= 64) ? 2 : 4; @@ -340,26 +357,16 @@ class warp_sort_filtered : public warp_sort { */ template class warp_sort_immediate : public warp_sort { - static_assert(Capacity >= WarpSize); - public: - __device__ warp_sort_immediate(int k, T dummy) - : warp_sort(k, dummy), buf_len_(0) + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + + __device__ warp_sort_immediate(int k) : warp_sort(k), buf_len_(0) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { - val_buf_[i] = dummy_; - } - } - - __device__ void load(const T* in, const IdxT* in_idx, IdxT start, IdxT end) - { - add_first_(in, in_idx, start, end); - start += Capacity; - while (start < end) { - add_extra_(in, in_idx, start, end); - this->merge_in(val_buf_, idx_buf_); - start += Capacity; + val_buf_[i] = kDummy; } } @@ -377,11 +384,11 @@ class warp_sort_immediate : public warp_sort { ++buf_len_; if (buf_len_ == kMaxArrLen) { - topk::bitonic(!Ascending).sort(val_buf_, idx_buf_); + topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { - val_buf_[i] = dummy_; + val_buf_[i] = kDummy; } buf_len_ = 0; } @@ -390,82 +397,26 @@ class warp_sort_immediate : public warp_sort { __device__ void done() { if (buf_len_ != 0) { - topk::bitonic(!Ascending).sort(val_buf_, idx_buf_); + topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); } } private: - /** Fill in the primary val_arr_/idx_arr_ */ - __device__ void add_first_(const T* in, const IdxT* in_idx, IdxT start, IdxT end) - { - IdxT idx = start + laneId(); - for (int i = 0; i < kMaxArrLen; ++i, idx += WarpSize) { - if (idx < end) { - val_arr_[i] = in[idx]; - idx_arr_[i] = in_idx[idx]; - } - } - topk::bitonic(Ascending).sort(val_arr_, idx_arr_); - } - - /** Fill in the secondary val_buf_/idx_buf_ */ - __device__ void add_extra_(const T* in, const IdxT* in_idx, IdxT start, IdxT end) - { - IdxT idx = start + laneId(); - for (int i = 0; i < kMaxArrLen; ++i, idx += WarpSize) { - val_buf_[i] = (idx < end) ? in[idx] : dummy_; - idx_buf_[i] = (idx < end) ? in_idx[idx] : std::numeric_limits::max(); - } - topk::bitonic(!Ascending).sort(val_buf_, idx_buf_); - } - using warp_sort::kMaxArrLen; using warp_sort::val_arr_; using warp_sort::idx_arr_; - using warp_sort::k_; - using warp_sort::dummy_; T val_buf_[kMaxArrLen]; IdxT idx_buf_[kMaxArrLen]; int buf_len_; }; -/** - * This one is used for the second pass only: - * if the first pass happens in multiple blocks, the output consists of a series - * of sorted arrays, length `k` each. - * Under this assumption, we can use load_sorted to just do the merging, rather than - * the full sort. - */ -template -class warp_merge : public warp_sort { - public: - __device__ warp_merge(int k, T dummy) : warp_sort(k, dummy) {} - - // NB: the input is already sorted, because it's the second pass. - __device__ void load(const T* in, const IdxT* in_idx, IdxT start, IdxT end) - { - for (; start < end; start += k_) { - load_sorted(in + start, in_idx + start); - } - } - - __device__ void done() {} - - private: - using warp_sort::kWarpWidth; - using warp_sort::kMaxArrLen; - using warp_sort::val_arr_; - using warp_sort::idx_arr_; - using warp_sort::k_; - using warp_sort::dummy_; -}; - template -int calc_smem_size_for_block_wide(int num_of_warp, IdxT k) +auto calc_smem_size_for_block_wide(int num_of_warp, int k) -> int { - return Pow2<256>::roundUp(num_of_warp / 2 * sizeof(T) * k) + num_of_warp / 2 * sizeof(IdxT) * k; + return Pow2<256>::roundUp(ceildiv(num_of_warp, 2) * sizeof(T) * k) + + ceildiv(num_of_warp, 2) * sizeof(IdxT) * k; } template