From 765f4d4236b72a1b1e5c741e968435d6b0b79367 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Thu, 8 Dec 2022 12:35:34 +0800 Subject: [PATCH 01/29] Improve performance of radix top-k --- .../knn/detail/topk/radix_topk_updated.cuh | 1207 +++++++++++++++++ 1 file changed, 1207 insertions(+) create mode 100644 cpp/include/raft/spatial/knn/detail/topk/radix_topk_updated.cuh diff --git a/cpp/include/raft/spatial/knn/detail/topk/radix_topk_updated.cuh b/cpp/include/raft/spatial/knn/detail/topk/radix_topk_updated.cuh new file mode 100644 index 0000000000..0cc268e068 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/topk/radix_topk_updated.cuh @@ -0,0 +1,1207 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace raft::spatial::knn::detail::topk { +namespace radix_impl { + +constexpr int BLOCK_DIM = 512; +using WideT = float4; +constexpr int LAZY_WRITING_FACTOR = 4; + +template +__host__ __device__ constexpr int calc_num_buckets() +{ + return 1 << BITS_PER_PASS; +} + +template +__host__ __device__ constexpr int calc_num_passes() +{ + return (sizeof(T) * 8 - 1) / BITS_PER_PASS + 1; +} + +// bit 0 is the least significant (rightmost) bit +// this function works even when pass=-1, which is used in calc_mask() +template +__device__ constexpr int calc_start_bit(int pass) +{ + int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BITS_PER_PASS; + if (start_bit < 0) { start_bit = 0; } + return start_bit; +} + +template +__device__ constexpr unsigned calc_mask(int pass) +{ + static_assert(BITS_PER_PASS <= 31); + int num_bits = + calc_start_bit(pass - 1) - calc_start_bit(pass); + return (1 << num_bits) - 1; +} + +template +__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +{ + auto bits = reinterpret_cast::UnsignedBits&>(key); + bits = cub::Traits::TwiddleIn(bits); + if (greater) { bits = ~bits; } + return bits; +} + +template +__device__ T twiddle_out(typename cub::Traits::UnsignedBits bits, bool greater) +{ + if (greater) { bits = ~bits; } + bits = cub::Traits::TwiddleOut(bits); + return reinterpret_cast(bits); +} + +template +__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +{ + static_assert(BITS_PER_PASS <= sizeof(int) * 8 - 1, + "BITS_PER_PASS is too large that the result type could not be int"); + return (twiddle_in(x, greater) >> start_bit) & mask; +} + +template +__device__ void vectorized_process(const T* in, idxT len, Func f) +{ + const idxT stride = blockDim.x * gridDim.x; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = tid; i < len; i += stride) { + f(in[i], i); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + // TODO: it's UB + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) + : 0; + if (skip_cnt > len) { skip_cnt = len; } + const WideT* in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + + for (idxT i = tid; i < len_cast; i += stride) { + wide.scalar = in_cast[i]; + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt + // no need to use loop + if (tid < skip_cnt) { f(in[tid], tid); } + // because len_cast = (len - skip_cnt) / items_per_scalar, + // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; + // and so + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WARP_SIZE + // no need to use loop + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + if (remain_i < len) { f(in[remain_i], remain_i); } + } +} + +// sync_width should >= WARP_SIZE +template +__device__ void vectorized_process(const T* in, idxT len, Func f, int sync_width) +{ + const idxT stride = blockDim.x * gridDim.x; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = tid; i < len; i += stride) { + f(in[i], i, true); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) + : 0; + if (skip_cnt > len) { skip_cnt = len; } + const WideT* in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + + const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; + for (idxT i = tid; i < len_cast_for_sync; i += stride) { + bool valid = i < len_cast; + if (valid) { wide.scalar = in_cast[i]; } + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j, valid); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // need at most one warp for skipped and remained elements, + // and sync_width >= WARP_SIZE + if (tid < sync_width) { + bool valid = tid < skip_cnt; + T value = valid ? in[tid] : T(); + f(value, tid, valid); + + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + valid = remain_i < len; + value = valid ? in[remain_i] : T(); + f(value, remain_i, valid); + } + } +} + +template +struct alignas(128) Counter { + idxT k; + idxT len; + idxT previous_len; + typename cub::Traits::UnsignedBits kth_value_bits; + + alignas(128) idxT filter_cnt; + alignas(128) unsigned int finished_block_cnt; + alignas(128) idxT out_cnt; + alignas(128) idxT out_back_cnt; +}; + +// not actually used since the specialization for FilterAndHistogram doesn't use this +// implementation +template +class DirectStore { + public: + __device__ void store(T value, idxT index, bool valid, T* out, idxT* out_idx, idxT* p_out_cnt) + { + if (!valid) { return; } + idxT pos = atomicAdd(p_out_cnt, 1); + out[pos] = value; + out_idx[pos] = index; + } + + __device__ void flush(T*, idxT*, idxT*) {} +}; + +template +class BufferedStore { + public: + __device__ BufferedStore() + { + const int warp_id = threadIdx.x >> 5; + lane_id_ = threadIdx.x % WARP_SIZE; + + __shared__ T value_smem[NUM_THREAD]; + __shared__ idxT index_smem[NUM_THREAD]; + + value_smem_ = value_smem + (warp_id << 5); + index_smem_ = index_smem + (warp_id << 5); + warp_pos_ = 0; + } + + __device__ void store(T value, idxT index, bool valid, T* out, idxT* out_idx, idxT* p_out_cnt) + { + unsigned int valid_mask = __ballot_sync(FULL_WARP_MASK, valid); + if (valid_mask == 0) { return; } + + int pos = __popc(valid_mask & ((0x1u << lane_id_) - 1)) + warp_pos_; + if (valid && pos < WARP_SIZE) { + value_smem_[pos] = value; + index_smem_[pos] = index; + } + + warp_pos_ += __popc(valid_mask); + // Check if the buffer is full + if (warp_pos_ >= WARP_SIZE) { + idxT pos_smem; + if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, WARP_SIZE); } + pos_smem = __shfl_sync(FULL_WARP_MASK, pos_smem, 0); + + __syncwarp(); + out[pos_smem + lane_id_] = value_smem_[lane_id_]; + out_idx[pos_smem + lane_id_] = index_smem_[lane_id_]; + __syncwarp(); + // Now the buffer is clean + if (valid && pos >= WARP_SIZE) { + pos -= WARP_SIZE; + value_smem_[pos] = value; + index_smem_[pos] = index; + } + + warp_pos_ -= WARP_SIZE; + } + } + + __device__ void flush(T* out, idxT* out_idx, idxT* p_out_cnt) + { + if (warp_pos_ > 0) { + idxT pos_smem; + if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, warp_pos_); } + pos_smem = __shfl_sync(FULL_WARP_MASK, pos_smem, 0); + + __syncwarp(); + if (lane_id_ < warp_pos_) { + out[pos_smem + lane_id_] = value_smem_[lane_id_]; + out_idx[pos_smem + lane_id_] = index_smem_[lane_id_]; + } + } + } + + private: + T* value_smem_; + idxT* index_smem_; + idxT lane_id_; //@TODO: Can be const variable + int warp_pos_; +}; + +template + class Store> +class FilterAndHistogram { + public: + __device__ void operator()(const T* in_buf, + const idxT* in_idx_buf, + T* out_buf, + idxT* out_idx_buf, + T* out, + idxT* out_idx, + idxT previous_len, + Counter* counter, + idxT* histogram, + bool greater, + int pass, + bool early_stop) + { + constexpr int num_buckets = calc_num_buckets(); + __shared__ idxT histogram_smem[num_buckets]; + for (idxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram_smem[i] = 0; + } + Store store; + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + + if (pass == 0) { + auto f = [greater, start_bit, mask](T value, idxT) { + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, 1); + }; + vectorized_process(in_buf, previous_len, f); + } else { + idxT* p_filter_cnt = &counter->filter_cnt; + idxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + auto f = [in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + greater, + start_bit, + mask, + previous_start_bit, + kth_value_bits, + p_filter_cnt, + p_out_cnt, + early_stop, + &store](T value, idxT i, bool valid) { + const auto previous_bits = (twiddle_in(value, greater) >> previous_start_bit) + << previous_start_bit; + + if (valid && previous_bits == kth_value_bits) { + if (early_stop) { + idxT pos = atomicAdd(p_out_cnt, 1); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + idxT pos = atomicAdd(p_filter_cnt, 1); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, 1); + } + } + + if (out_buf || early_stop) { + store.store(value, + in_idx_buf ? in_idx_buf[i] : i, + valid && previous_bits < kth_value_bits, + out, + out_idx, + p_out_cnt); + } + }; + vectorized_process(in_buf, previous_len, f, WARP_SIZE); + store.flush(out, out_idx, p_out_cnt); + } + if (early_stop) { return; } + + __syncthreads(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } + } + } +}; + +template +class FilterAndHistogram { + public: + __device__ void operator()(const T* in_buf, + const idxT* in_idx_buf, + T* out_buf, + idxT* out_idx_buf, + T* out, + idxT* out_idx, + idxT previous_len, + Counter* counter, + idxT* histogram, + bool greater, + int pass, + bool early_stop) + { + constexpr int num_buckets = calc_num_buckets(); + __shared__ idxT histogram_smem[num_buckets]; + for (idxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram_smem[i] = 0; + } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + + if (pass == 0) { + auto f = [greater, start_bit, mask](T value, idxT) { + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, 1); + }; + vectorized_process(in_buf, previous_len, f); + } else { + idxT* p_filter_cnt = &counter->filter_cnt; + idxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + auto f = [in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + greater, + start_bit, + mask, + previous_start_bit, + kth_value_bits, + p_filter_cnt, + p_out_cnt, + early_stop](T value, idxT i) { + const auto previous_bits = (twiddle_in(value, greater) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + if (early_stop) { + idxT pos = atomicAdd(p_out_cnt, 1); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + idxT pos = atomicAdd(p_filter_cnt, 1); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, 1); + } + } + // '(out_buf || early_stop)': + // If we skip writing to 'out_buf' (when !out_buf), we should skip + // writing to 'out' too. So we won't write the same value to 'out' + // multiple times. And if we keep skipping the writing, values will be + // written in last_filter_kernel at last. But when 'early_stop' is true, + // we need to write to 'out' since it's the last chance. + else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + idxT pos = atomicAdd(p_out_cnt, 1); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + }; + vectorized_process(in_buf, previous_len, f); + } + if (early_stop) { return; } + + __syncthreads(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } + } + } +}; + +template +__device__ void scan(volatile idxT* histogram) +{ + constexpr int num_buckets = calc_num_buckets(); + if constexpr (num_buckets >= NUM_THREAD) { + static_assert(num_buckets % NUM_THREAD == 0); + constexpr int items_per_thread = num_buckets / NUM_THREAD; + typedef cub::BlockLoad BlockLoad; + typedef cub::BlockStore + BlockStore; + typedef cub::BlockScan BlockScan; + + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + idxT thread_data[items_per_thread]; + + BlockLoad(temp_storage.load).Load(histogram, thread_data); + __syncthreads(); + + BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + BlockStore(temp_storage.store).Store(histogram, thread_data); + } else { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + idxT thread_data = 0; + if (threadIdx.x < num_buckets) { thread_data = histogram[threadIdx.x]; } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + if (threadIdx.x < num_buckets) { histogram[threadIdx.x] = thread_data; } + } +} + +template +__device__ void choose_bucket(Counter* counter, + const idxT* histogram, + const idxT k, + const int pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + idxT prev = (i == 0) ? 0 : histogram[i - 1]; + idxT cur = histogram[i]; + + // one and only one thread will satisfy this condition, so only write once + if (prev < k && cur >= k) { + counter->k = k - prev; + counter->len = cur - prev; + typename cub::Traits::UnsignedBits bucket = i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } + } +} + +// For one-block version, last_filter() could be called when pass < num_passes - 1. +// So pass could not be constexpr +template +__device__ void last_filter(const T* out_buf, + const idxT* out_idx_buf, + T* out, + idxT* out_idx, + idxT current_len, + idxT k, + Counter* counter, + const bool greater, + const int pass) +{ + const auto kth_value_bits = counter->kth_value_bits; + const int start_bit = calc_start_bit(pass); + + // changed in choose_bucket(), need to reload + const idxT needed_num_of_kth = counter->k; + idxT* p_out_cnt = &counter->out_cnt; + idxT* p_out_back_cnt = &counter->out_back_cnt; + for (idxT i = threadIdx.x; i < current_len; i += blockDim.x) { + const T value = out_buf[i]; + const auto bits = (twiddle_in(value, greater) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + idxT pos = atomicAdd(p_out_cnt, 1); + out[pos] = value; + // for one-block version, 'out_idx_buf' could be nullptr at pass 0; + // and for dynamic version, 'out_idx_buf' could be nullptr if 'out_buf' is + // 'in' + out_idx[pos] = out_idx_buf ? out_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + idxT back_pos = atomicAdd(p_out_back_cnt, 1); + if (back_pos < needed_num_of_kth) { + idxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = out_idx_buf ? out_idx_buf[i] : i; + } + } + } +} + +template +__global__ void last_filter_kernel(const T* in, + const T* in_buf, + const idxT* in_idx_buf, + T* out, + idxT* out_idx, + idxT len, + idxT k, + Counter* counters, + const bool greater) +{ + const int batch_id = blockIdx.y; + + Counter* counter = counters + batch_id; + idxT previous_len = counter->previous_len; + if (previous_len == 0) { return; } + if (previous_len > len / LAZY_WRITING_FACTOR) { + in_buf = in; + in_idx_buf = nullptr; + previous_len = len; + } + + in_buf += batch_id * len; + if (in_idx_buf) { in_idx_buf += batch_id * len; } + out += batch_id * k; + out_idx += batch_id * k; + + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); + + const auto kth_value_bits = counter->kth_value_bits; + const idxT needed_num_of_kth = counter->k; + idxT* p_out_cnt = &counter->out_cnt; + idxT* p_out_back_cnt = &counter->out_back_cnt; + + auto f = [k, + greater, + kth_value_bits, + needed_num_of_kth, + p_out_cnt, + p_out_back_cnt, + in_idx_buf, + out, + out_idx](T value, idxT i) { + const auto bits = (twiddle_in(value, greater) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + idxT pos = atomicAdd(p_out_cnt, 1); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + idxT back_pos = atomicAdd(p_out_back_cnt, 1); + if (back_pos < needed_num_of_kth) { + idxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + }; + + vectorized_process(in_buf, previous_len, f); +} + +template + class Store> +__global__ void radix_kernel(const T* in, + const T* in_buf, + const idxT* in_idx_buf, + T* out_buf, + idxT* out_idx_buf, + T* out, + idxT* out_idx, + Counter* counters, + idxT* histograms, + const idxT len, + const idxT k, + const bool greater, + const int pass) +{ + __shared__ bool isLastBlock; + + const int batch_id = blockIdx.y; + auto counter = counters + batch_id; + idxT current_k; + idxT previous_len; + idxT current_len; + if (pass == 0) { + current_k = k; + previous_len = len; + // Need to do this so setting counter->previous_len for the next pass is correct. + // This value is meaningless for pass 0, but it's fine because pass 0 won't be the + // last pass in current implementation so pass 0 won't hit the "if (pass == + // num_passes - 1)" branch. + // Maybe it's better to reload counter->previous_len and use it rather than + // current_len in last_filter() + current_len = len; + } else { + current_k = counter->k; + current_len = counter->len; + previous_len = counter->previous_len; + } + if (current_len == 0) { return; } + bool early_stop = (current_len == current_k); + + constexpr int num_buckets = calc_num_buckets(); + constexpr int num_passes = calc_num_passes(); + + if constexpr (use_dynamic) { + // Figure out if the previous pass writes buffer + if (previous_len > len / LAZY_WRITING_FACTOR) { + previous_len = len; + in_buf = in; + in_idx_buf = nullptr; + } + // Figure out if this pass need to write buffer + if (current_len > len / LAZY_WRITING_FACTOR) { + out_buf = nullptr; + out_idx_buf = nullptr; + } + } + in_buf += batch_id * len; + if (in_idx_buf) { in_idx_buf += batch_id * len; } + if (out_buf) { out_buf += batch_id * len; } + if (out_idx_buf) { out_idx_buf += batch_id * len; } + if (out) { + out += batch_id * k; + out_idx += batch_id * k; + } + auto histogram = histograms + batch_id * num_buckets; + + FilterAndHistogram()(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + previous_len, + counter, + histogram, + greater, + pass, + early_stop); + __threadfence(); + + if (threadIdx.x == 0) { + unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); + isLastBlock = (finished == (gridDim.x - 1)); + } + + // Synchronize to make sure that each thread reads the correct value of isLastBlock. + __syncthreads(); + if (isLastBlock) { + if (early_stop) { + if (threadIdx.x == 0) { + // last_filter_kernel from dynamic version requires setting previous_len + counter->previous_len = 0; + counter->len = 0; + } + return; + } + + scan(histogram); + __syncthreads(); + choose_bucket(counter, histogram, current_k, pass); + __syncthreads(); + + // reset for next pass + if (pass != num_passes - 1) { + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + } + if (threadIdx.x == 0) { + // last_filter_kernel requires setting previous_len even in the last pass + counter->previous_len = current_len; + // not necessary for the last pass, but put it here anyway + counter->filter_cnt = 0; + } + + if constexpr (!use_dynamic) { + if (pass == num_passes - 1) { + last_filter( + out_buf, out_idx_buf, out, out_idx, current_len, k, counter, greater, pass); + } + } + } +} + +template + class Store> +unsigned calc_grid_dim(int batch_size, idxT len, int sm_cnt, bool use_dynamic) +{ + static_assert(sizeof(WideT) / sizeof(T) >= 1); + + int active_blocks; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &active_blocks, + use_dynamic ? radix_kernel + : radix_kernel, + NUM_THREAD, + 0)); + active_blocks *= sm_cnt; + + idxT best_num_blocks = 0; + float best_tail_wave_penalty = 1.0f; + const idxT max_num_blocks = (len - 1) / (sizeof(WideT) / sizeof(T) * NUM_THREAD) + 1; + for (int num_waves = 1;; ++num_waves) { + int num_blocks = std::min(max_num_blocks, std::max(num_waves * active_blocks / batch_size, 1)); + idxT items_per_thread = (len - 1) / (num_blocks * NUM_THREAD) + 1; + items_per_thread = (items_per_thread - 1) / (sizeof(WideT) / sizeof(T)) + 1; + items_per_thread *= sizeof(WideT) / sizeof(T); + num_blocks = (len - 1) / (items_per_thread * NUM_THREAD) + 1; + float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; + float tail_wave_penalty = + (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); + + // 0.15 is determined experimentally. It also ensures breaking the loop early, + // e.g. when num_waves > 7, tail_wave_penalty will always <0.15 + if (tail_wave_penalty < 0.15) { + best_num_blocks = num_blocks; + break; + } else if (tail_wave_penalty < best_tail_wave_penalty) { + best_num_blocks = num_blocks; + best_tail_wave_penalty = tail_wave_penalty; + } + + if (num_blocks == max_num_blocks) { break; } + } + return best_num_blocks; +} + +template + class Store> +void radix_topk(void* buf, + size_t& buf_size, + const T* in, + int batch_size, + idxT len, + idxT k, + T* out, + idxT* out_idx, + bool greater, + cudaStream_t stream, + bool use_dynamic = false) +{ + // TODO: is it possible to relax this restriction? + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + Counter* counters = nullptr; + idxT* histograms = nullptr; + T* buf1 = nullptr; + idxT* idx_buf1 = nullptr; + T* buf2 = nullptr; + idxT* idx_buf2 = nullptr; + { + std::vector sizes = {sizeof(*counters) * batch_size, + sizeof(*histograms) * num_buckets * batch_size, + sizeof(*buf1) * len * batch_size, + sizeof(*idx_buf1) * len * batch_size, + sizeof(*buf2) * len * batch_size, + sizeof(*idx_buf2) * len * batch_size}; + size_t total_size = calc_aligned_size(sizes); + if (!buf) { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + counters = static_cast(aligned_pointers[0]); + histograms = static_cast(aligned_pointers[1]); + buf1 = static_cast(aligned_pointers[2]); + idx_buf1 = static_cast(aligned_pointers[3]); + buf2 = static_cast(aligned_pointers[4]); + idx_buf2 = static_cast(aligned_pointers[5]); + + RAFT_CUDA_TRY(cudaMemsetAsync( + buf, + 0, + static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), + stream)); + } + + const T* in_buf = nullptr; + const idxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + idxT* out_idx_buf = nullptr; + + int sm_cnt; + { + int dev; + RAFT_CUDA_TRY(cudaGetDevice(&dev)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); + } + dim3 blocks( + calc_grid_dim(batch_size, len, sm_cnt, use_dynamic), + batch_size); + + constexpr int num_passes = calc_num_passes(); + + for (int pass = 0; pass < num_passes; ++pass) { + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } + + if (!use_dynamic) { + radix_kernel + <<>>(in, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + counters, + histograms, + len, + k, + greater, + pass); + } else { + radix_kernel + <<>>(in, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + counters, + histograms, + len, + k, + greater, + pass); + } + } + + if (use_dynamic) { + dim3 blocks((len / (sizeof(WideT) / sizeof(T)) - 1) / NUM_THREAD + 1, batch_size); + last_filter_kernel<<>>( + in, out_buf, out_idx_buf, out, out_idx, len, k, counters, greater); + } +} + +template +__device__ void filter_and_histogram(const T* in_buf, + const idxT* in_idx_buf, + T* out_buf, + idxT* out_idx_buf, + T* out, + idxT* out_idx, + Counter* counter, + idxT* histogram, + bool greater, + int pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + idxT* p_filter_cnt = &counter->filter_cnt; + if (threadIdx.x == 0) { *p_filter_cnt = 0; } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + const idxT previous_len = counter->previous_len; + + if (pass == 0) { + // Could not use vectorized_process() as in FilterAndHistogram because + // vectorized_process() assumes multi-block, e.g. uses gridDim.x + for (idxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + T value = in_buf[i]; + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram + bucket, 1); + } + } else { + idxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (idxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, greater) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + idxT pos = atomicAdd(p_filter_cnt, 1); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram + bucket, 1); + } else if (previous_bits < kth_value_bits) { + idxT pos = atomicAdd(p_out_cnt, 1); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + } +} + +template +__global__ void radix_topk_one_block_kernel(const T* in, + const idxT len, + const idxT k, + T* out, + idxT* out_idx, + const bool greater, + T* buf1, + idxT* idx_buf1, + T* buf2, + idxT* idx_buf2) +{ + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ idxT histogram[num_buckets]; + + if (threadIdx.x == 0) { + counter.k = k; + counter.len = len; + counter.previous_len = len; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; + } + __syncthreads(); + + in += blockIdx.x * len; + out += blockIdx.x * k; + out_idx += blockIdx.x * k; + buf1 += blockIdx.x * len; + idx_buf1 += blockIdx.x * len; + buf2 += blockIdx.x * len; + idx_buf2 += blockIdx.x * len; + const T* in_buf = nullptr; + const idxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + idxT* out_idx_buf = nullptr; + + constexpr int num_passes = calc_num_passes(); + for (int pass = 0; pass < num_passes; ++pass) { + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } + idxT current_len = counter.len; + idxT current_k = counter.k; + + filter_and_histogram( + in_buf, in_idx_buf, out_buf, out_idx_buf, out, out_idx, &counter, histogram, greater, pass); + __syncthreads(); + + scan(histogram); + __syncthreads(); + + choose_bucket(&counter, histogram, current_k, pass); + if (threadIdx.x == 0) { counter.previous_len = current_len; } + __syncthreads(); + + if (counter.len == counter.k || pass == num_passes - 1) { + last_filter(pass == 0 ? in : out_buf, + pass == 0 ? nullptr : out_idx_buf, + out, + out_idx, + current_len, + k, + &counter, + greater, + pass); + break; + } + } +} + +template +void radix_topk_one_block(void* buf, + size_t& buf_size, + const T* in, + int batch_size, + idxT len, + idxT k, + T* out, + idxT* out_idx, + bool greater, + cudaStream_t stream) +{ + static_assert(calc_num_passes() > 1); + + T* buf1 = nullptr; + idxT* idx_buf1 = nullptr; + T* buf2 = nullptr; + idxT* idx_buf2 = nullptr; + { + std::vector sizes = {sizeof(*buf1) * len * batch_size, + sizeof(*idx_buf1) * len * batch_size, + sizeof(*buf2) * len * batch_size, + sizeof(*idx_buf2) * len * batch_size}; + size_t total_size = calc_aligned_size(sizes); + if (!buf) { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + buf1 = static_cast(aligned_pointers[0]); + idx_buf1 = static_cast(aligned_pointers[1]); + buf2 = static_cast(aligned_pointers[2]); + idx_buf2 = static_cast(aligned_pointers[3]); + } + + radix_topk_one_block_kernel + <<>>( + in, len, k, out, out_idx, greater, buf1, idx_buf1, buf2, idx_buf2); +} + +} // namespace radix_impl + +template +void radix_topk_11bits(void* buf, + size_t& buf_size, + const T* in, + int batch_size, + idxT len, + idxT k, + T* out, + idxT* out_idx = nullptr, + bool greater = true, + cudaStream_t stream = 0) +{ + constexpr int items_per_thread = 32; + if (len <= radix_impl::BLOCK_DIM * items_per_thread) { + radix_impl::radix_topk_one_block( + buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); + } else if (len < 100.0 * k / batch_size + 0.01) { + radix_impl::radix_topk( + buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); + } else { + radix_impl::radix_topk( + buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); + } +} + +template +void radix_topk_11bits_dynamic(void* buf, + size_t& buf_size, + const T* in, + int batch_size, + idxT len, + idxT k, + T* out, + idxT* out_idx = nullptr, + bool greater = true, + cudaStream_t stream = 0) +{ + constexpr bool use_dynamic = true; + + constexpr int items_per_thread = 32; + if (len <= radix_impl::BLOCK_DIM * items_per_thread) { + radix_impl::radix_topk_one_block( + buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); + } else if (len < 100.0 * k / batch_size + 0.01) { + radix_impl::radix_topk( + buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); + } else { + radix_impl::radix_topk( + buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); + } +} + +} // namespace raft::spatial::knn::detail::topk From d085e587616cba74a1f85370b6476454dce702d6 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Thu, 5 Jan 2023 16:59:19 +0800 Subject: [PATCH 02/29] radix top-k: conform to RAFT code style --- .../matrix/detail/select_radix_updated.cuh | 647 +++++++++--------- 1 file changed, 323 insertions(+), 324 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 0cc268e068..1867065939 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -37,34 +37,33 @@ constexpr int BLOCK_DIM = 512; using WideT = float4; constexpr int LAZY_WRITING_FACTOR = 4; -template +template __host__ __device__ constexpr int calc_num_buckets() { - return 1 << BITS_PER_PASS; + return 1 << BitsPerPass; } -template +template __host__ __device__ constexpr int calc_num_passes() { - return (sizeof(T) * 8 - 1) / BITS_PER_PASS + 1; + return (sizeof(T) * 8 - 1) / BitsPerPass + 1; } // bit 0 is the least significant (rightmost) bit // this function works even when pass=-1, which is used in calc_mask() -template +template __device__ constexpr int calc_start_bit(int pass) { - int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BITS_PER_PASS; + int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; if (start_bit < 0) { start_bit = 0; } return start_bit; } -template +template __device__ constexpr unsigned calc_mask(int pass) { - static_assert(BITS_PER_PASS <= 31); - int num_bits = - calc_start_bit(pass - 1) - calc_start_bit(pass); + static_assert(BitsPerPass <= 31); + int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); return (1 << num_bits) - 1; } @@ -85,21 +84,21 @@ __device__ T twiddle_out(typename cub::Traits::UnsignedBits bits, bool greate return reinterpret_cast(bits); } -template +template __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) { - static_assert(BITS_PER_PASS <= sizeof(int) * 8 - 1, - "BITS_PER_PASS is too large that the result type could not be int"); + static_assert(BitsPerPass <= sizeof(int) * 8 - 1, + "BitsPerPass is too large that the result type could not be int"); return (twiddle_in(x, greater) >> start_bit) & mask; } -template -__device__ void vectorized_process(const T* in, idxT len, Func f) +template +__device__ void vectorized_process(const T* in, IdxT len, Func f) { - const idxT stride = blockDim.x * gridDim.x; + const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= sizeof(WideT)) { - for (idxT i = tid; i < len; i += stride) { + for (IdxT i = tid; i < len; i += stride) { f(in[i], i); } } else { @@ -116,39 +115,39 @@ __device__ void vectorized_process(const T* in, idxT len, Func f) : 0; if (skip_cnt > len) { skip_cnt = len; } const WideT* in_cast = reinterpret_cast(in + skip_cnt); - const idxT len_cast = (len - skip_cnt) / items_per_scalar; + const IdxT len_cast = (len - skip_cnt) / items_per_scalar; - for (idxT i = tid; i < len_cast; i += stride) { + for (IdxT i = tid; i < len_cast; i += stride) { wide.scalar = in_cast[i]; - const idxT real_i = skip_cnt + i * items_per_scalar; + const IdxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { f(wide.array[j], real_i + j); } } - static_assert(WARP_SIZE >= items_per_scalar); - // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt + static_assert(WarpSize >= items_per_scalar); + // and because items_per_scalar > skip_cnt, WarpSize > skip_cnt // no need to use loop if (tid < skip_cnt) { f(in[tid], tid); } // because len_cast = (len - skip_cnt) / items_per_scalar, // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; // and so - // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WARP_SIZE + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WarpSize // no need to use loop - const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; if (remain_i < len) { f(in[remain_i], remain_i); } } } -// sync_width should >= WARP_SIZE -template -__device__ void vectorized_process(const T* in, idxT len, Func f, int sync_width) +// sync_width should >= WarpSize +template +__device__ void vectorized_process(const T* in, IdxT len, Func f, int sync_width) { - const idxT stride = blockDim.x * gridDim.x; + const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= sizeof(WideT)) { - for (idxT i = tid; i < len; i += stride) { + for (IdxT i = tid; i < len; i += stride) { f(in[i], i, true); } } else { @@ -164,28 +163,28 @@ __device__ void vectorized_process(const T* in, idxT len, Func f, int sync_width : 0; if (skip_cnt > len) { skip_cnt = len; } const WideT* in_cast = reinterpret_cast(in + skip_cnt); - const idxT len_cast = (len - skip_cnt) / items_per_scalar; + const IdxT len_cast = (len - skip_cnt) / items_per_scalar; - const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; - for (idxT i = tid; i < len_cast_for_sync; i += stride) { + const IdxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; + for (IdxT i = tid; i < len_cast_for_sync; i += stride) { bool valid = i < len_cast; if (valid) { wide.scalar = in_cast[i]; } - const idxT real_i = skip_cnt + i * items_per_scalar; + const IdxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { f(wide.array[j], real_i + j, valid); } } - static_assert(WARP_SIZE >= items_per_scalar); + static_assert(WarpSize >= items_per_scalar); // need at most one warp for skipped and remained elements, - // and sync_width >= WARP_SIZE + // and sync_width >= WarpSize if (tid < sync_width) { bool valid = tid < skip_cnt; T value = valid ? in[tid] : T(); f(value, tid, valid); - const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; valid = remain_i < len; value = valid ? in[remain_i] : T(); f(value, remain_i, valid); @@ -193,67 +192,67 @@ __device__ void vectorized_process(const T* in, idxT len, Func f, int sync_width } } -template +template struct alignas(128) Counter { - idxT k; - idxT len; - idxT previous_len; + IdxT k; + IdxT len; + IdxT previous_len; typename cub::Traits::UnsignedBits kth_value_bits; - alignas(128) idxT filter_cnt; + alignas(128) IdxT filter_cnt; alignas(128) unsigned int finished_block_cnt; - alignas(128) idxT out_cnt; - alignas(128) idxT out_back_cnt; + alignas(128) IdxT out_cnt; + alignas(128) IdxT out_back_cnt; }; // not actually used since the specialization for FilterAndHistogram doesn't use this // implementation -template +template class DirectStore { public: - __device__ void store(T value, idxT index, bool valid, T* out, idxT* out_idx, idxT* p_out_cnt) + __device__ void store(T value, IdxT index, bool valid, T* out, IdxT* out_idx, IdxT* p_out_cnt) { if (!valid) { return; } - idxT pos = atomicAdd(p_out_cnt, 1); + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = index; } - __device__ void flush(T*, idxT*, idxT*) {} + __device__ void flush(T*, IdxT*, IdxT*) {} }; -template +template class BufferedStore { public: __device__ BufferedStore() { const int warp_id = threadIdx.x >> 5; - lane_id_ = threadIdx.x % WARP_SIZE; + lane_id_ = threadIdx.x % WarpSize; - __shared__ T value_smem[NUM_THREAD]; - __shared__ idxT index_smem[NUM_THREAD]; + __shared__ T value_smem[BlockSize]; + __shared__ IdxT index_smem[BlockSize]; value_smem_ = value_smem + (warp_id << 5); index_smem_ = index_smem + (warp_id << 5); warp_pos_ = 0; } - __device__ void store(T value, idxT index, bool valid, T* out, idxT* out_idx, idxT* p_out_cnt) + __device__ void store(T value, IdxT index, bool valid, T* out, IdxT* out_idx, IdxT* p_out_cnt) { unsigned int valid_mask = __ballot_sync(FULL_WARP_MASK, valid); if (valid_mask == 0) { return; } int pos = __popc(valid_mask & ((0x1u << lane_id_) - 1)) + warp_pos_; - if (valid && pos < WARP_SIZE) { + if (valid && pos < WarpSize) { value_smem_[pos] = value; index_smem_[pos] = index; } warp_pos_ += __popc(valid_mask); // Check if the buffer is full - if (warp_pos_ >= WARP_SIZE) { - idxT pos_smem; - if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, WARP_SIZE); } + if (warp_pos_ >= WarpSize) { + IdxT pos_smem; + if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, static_cast(WarpSize)); } pos_smem = __shfl_sync(FULL_WARP_MASK, pos_smem, 0); __syncwarp(); @@ -261,21 +260,21 @@ class BufferedStore { out_idx[pos_smem + lane_id_] = index_smem_[lane_id_]; __syncwarp(); // Now the buffer is clean - if (valid && pos >= WARP_SIZE) { - pos -= WARP_SIZE; + if (valid && pos >= WarpSize) { + pos -= WarpSize; value_smem_[pos] = value; index_smem_[pos] = index; } - warp_pos_ -= WARP_SIZE; + warp_pos_ -= WarpSize; } } - __device__ void flush(T* out, idxT* out_idx, idxT* p_out_cnt) + __device__ void flush(T* out, IdxT* out_idx, IdxT* p_out_cnt) { if (warp_pos_ > 0) { - idxT pos_smem; - if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, warp_pos_); } + IdxT pos_smem; + if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, static_cast(warp_pos_)); } pos_smem = __shfl_sync(FULL_WARP_MASK, pos_smem, 0); __syncwarp(); @@ -288,54 +287,54 @@ class BufferedStore { private: T* value_smem_; - idxT* index_smem_; - idxT lane_id_; //@TODO: Can be const variable + IdxT* index_smem_; + IdxT lane_id_; //@TODO: Can be const variable int warp_pos_; }; template class Store> class FilterAndHistogram { public: __device__ void operator()(const T* in_buf, - const idxT* in_idx_buf, + const IdxT* in_idx_buf, T* out_buf, - idxT* out_idx_buf, + IdxT* out_idx_buf, T* out, - idxT* out_idx, - idxT previous_len, - Counter* counter, - idxT* histogram, + IdxT* out_idx, + IdxT previous_len, + Counter* counter, + IdxT* histogram, bool greater, int pass, bool early_stop) { - constexpr int num_buckets = calc_num_buckets(); - __shared__ idxT histogram_smem[num_buckets]; - for (idxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + constexpr int num_buckets = calc_num_buckets(); + __shared__ IdxT histogram_smem[num_buckets]; + for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { histogram_smem[i] = 0; } - Store store; + Store store; __syncthreads(); - const int start_bit = calc_start_bit(pass); - const unsigned mask = calc_mask(pass); + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); if (pass == 0) { - auto f = [greater, start_bit, mask](T value, idxT) { - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, 1); + auto f = [greater, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, static_cast(1)); }; vectorized_process(in_buf, previous_len, f); } else { - idxT* p_filter_cnt = &counter->filter_cnt; - idxT* p_out_cnt = &counter->out_cnt; + IdxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_out_cnt = &counter->out_cnt; const auto kth_value_bits = counter->kth_value_bits; - const int previous_start_bit = calc_start_bit(pass - 1); + const int previous_start_bit = calc_start_bit(pass - 1); auto f = [in_idx_buf, out_buf, @@ -350,24 +349,24 @@ class FilterAndHistogram { p_filter_cnt, p_out_cnt, early_stop, - &store](T value, idxT i, bool valid) { + &store](T value, IdxT i, bool valid) { const auto previous_bits = (twiddle_in(value, greater) >> previous_start_bit) << previous_start_bit; if (valid && previous_bits == kth_value_bits) { if (early_stop) { - idxT pos = atomicAdd(p_out_cnt, 1); + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else { if (out_buf) { - idxT pos = atomicAdd(p_filter_cnt, 1); + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); out_buf[pos] = value; out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, 1); + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, static_cast(1)); } } @@ -380,7 +379,7 @@ class FilterAndHistogram { p_out_cnt); } }; - vectorized_process(in_buf, previous_len, f, WARP_SIZE); + vectorized_process(in_buf, previous_len, f, WarpSize); store.flush(out, out_idx, p_out_cnt); } if (early_stop) { return; } @@ -392,43 +391,43 @@ class FilterAndHistogram { } }; -template -class FilterAndHistogram { +template +class FilterAndHistogram { public: __device__ void operator()(const T* in_buf, - const idxT* in_idx_buf, + const IdxT* in_idx_buf, T* out_buf, - idxT* out_idx_buf, + IdxT* out_idx_buf, T* out, - idxT* out_idx, - idxT previous_len, - Counter* counter, - idxT* histogram, + IdxT* out_idx, + IdxT previous_len, + Counter* counter, + IdxT* histogram, bool greater, int pass, bool early_stop) { - constexpr int num_buckets = calc_num_buckets(); - __shared__ idxT histogram_smem[num_buckets]; - for (idxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + constexpr int num_buckets = calc_num_buckets(); + __shared__ IdxT histogram_smem[num_buckets]; + for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { histogram_smem[i] = 0; } __syncthreads(); - const int start_bit = calc_start_bit(pass); - const unsigned mask = calc_mask(pass); + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); if (pass == 0) { - auto f = [greater, start_bit, mask](T value, idxT) { - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, 1); + auto f = [greater, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, static_cast(1)); }; vectorized_process(in_buf, previous_len, f); } else { - idxT* p_filter_cnt = &counter->filter_cnt; - idxT* p_out_cnt = &counter->out_cnt; + IdxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_out_cnt = &counter->out_cnt; const auto kth_value_bits = counter->kth_value_bits; - const int previous_start_bit = calc_start_bit(pass - 1); + const int previous_start_bit = calc_start_bit(pass - 1); auto f = [in_idx_buf, out_buf, @@ -442,23 +441,23 @@ class FilterAndHistogram { kth_value_bits, p_filter_cnt, p_out_cnt, - early_stop](T value, idxT i) { + early_stop](T value, IdxT i) { const auto previous_bits = (twiddle_in(value, greater) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { if (early_stop) { - idxT pos = atomicAdd(p_out_cnt, 1); + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else { if (out_buf) { - idxT pos = atomicAdd(p_filter_cnt, 1); + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); out_buf[pos] = value; out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, 1); + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram_smem + bucket, static_cast(1)); } } // '(out_buf || early_stop)': @@ -468,7 +467,7 @@ class FilterAndHistogram { // written in last_filter_kernel at last. But when 'early_stop' is true, // we need to write to 'out' since it's the last chance. else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { - idxT pos = atomicAdd(p_out_cnt, 1); + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -484,24 +483,24 @@ class FilterAndHistogram { } }; -template -__device__ void scan(volatile idxT* histogram) +template +__device__ void scan(volatile IdxT* histogram) { - constexpr int num_buckets = calc_num_buckets(); - if constexpr (num_buckets >= NUM_THREAD) { - static_assert(num_buckets % NUM_THREAD == 0); - constexpr int items_per_thread = num_buckets / NUM_THREAD; - typedef cub::BlockLoad BlockLoad; - typedef cub::BlockStore + constexpr int num_buckets = calc_num_buckets(); + if constexpr (num_buckets >= BlockSize) { + static_assert(num_buckets % BlockSize == 0); + constexpr int items_per_thread = num_buckets / BlockSize; + typedef cub::BlockLoad BlockLoad; + typedef cub::BlockStore BlockStore; - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; __shared__ union { typename BlockLoad::TempStorage load; typename BlockScan::TempStorage scan; typename BlockStore::TempStorage store; } temp_storage; - idxT thread_data[items_per_thread]; + IdxT thread_data[items_per_thread]; BlockLoad(temp_storage.load).Load(histogram, thread_data); __syncthreads(); @@ -511,10 +510,10 @@ __device__ void scan(volatile idxT* histogram) BlockStore(temp_storage.store).Store(histogram, thread_data); } else { - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; - idxT thread_data = 0; + IdxT thread_data = 0; if (threadIdx.x < num_buckets) { thread_data = histogram[threadIdx.x]; } BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); @@ -524,23 +523,23 @@ __device__ void scan(volatile idxT* histogram) } } -template -__device__ void choose_bucket(Counter* counter, - const idxT* histogram, - const idxT k, +template +__device__ void choose_bucket(Counter* counter, + const IdxT* histogram, + const IdxT k, const int pass) { - constexpr int num_buckets = calc_num_buckets(); + constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - idxT prev = (i == 0) ? 0 : histogram[i - 1]; - idxT cur = histogram[i]; + IdxT prev = (i == 0) ? 0 : histogram[i - 1]; + IdxT cur = histogram[i]; // one and only one thread will satisfy this condition, so only write once if (prev < k && cur >= k) { counter->k = k - prev; counter->len = cur - prev; typename cub::Traits::UnsignedBits bucket = i; - int start_bit = calc_start_bit(pass); + int start_bit = calc_start_bit(pass); counter->kth_value_bits |= bucket << start_bit; } } @@ -548,38 +547,38 @@ __device__ void choose_bucket(Counter* counter, // For one-block version, last_filter() could be called when pass < num_passes - 1. // So pass could not be constexpr -template +template __device__ void last_filter(const T* out_buf, - const idxT* out_idx_buf, + const IdxT* out_idx_buf, T* out, - idxT* out_idx, - idxT current_len, - idxT k, - Counter* counter, + IdxT* out_idx, + IdxT current_len, + IdxT k, + Counter* counter, const bool greater, const int pass) { const auto kth_value_bits = counter->kth_value_bits; - const int start_bit = calc_start_bit(pass); + const int start_bit = calc_start_bit(pass); // changed in choose_bucket(), need to reload - const idxT needed_num_of_kth = counter->k; - idxT* p_out_cnt = &counter->out_cnt; - idxT* p_out_back_cnt = &counter->out_back_cnt; - for (idxT i = threadIdx.x; i < current_len; i += blockDim.x) { + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { const T value = out_buf[i]; const auto bits = (twiddle_in(value, greater) >> start_bit) << start_bit; if (bits < kth_value_bits) { - idxT pos = atomicAdd(p_out_cnt, 1); + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; // for one-block version, 'out_idx_buf' could be nullptr at pass 0; // and for dynamic version, 'out_idx_buf' could be nullptr if 'out_buf' is // 'in' out_idx[pos] = out_idx_buf ? out_idx_buf[i] : i; } else if (bits == kth_value_bits) { - idxT back_pos = atomicAdd(p_out_back_cnt, 1); + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); if (back_pos < needed_num_of_kth) { - idxT pos = k - 1 - back_pos; + IdxT pos = k - 1 - back_pos; out[pos] = value; out_idx[pos] = out_idx_buf ? out_idx_buf[i] : i; } @@ -587,21 +586,21 @@ __device__ void last_filter(const T* out_buf, } } -template +template __global__ void last_filter_kernel(const T* in, const T* in_buf, - const idxT* in_idx_buf, + const IdxT* in_idx_buf, T* out, - idxT* out_idx, - idxT len, - idxT k, - Counter* counters, + IdxT* out_idx, + IdxT len, + IdxT k, + Counter* counters, const bool greater) { const int batch_id = blockIdx.y; - Counter* counter = counters + batch_id; - idxT previous_len = counter->previous_len; + Counter* counter = counters + batch_id; + IdxT previous_len = counter->previous_len; if (previous_len == 0) { return; } if (previous_len > len / LAZY_WRITING_FACTOR) { in_buf = in; @@ -614,13 +613,13 @@ __global__ void last_filter_kernel(const T* in, out += batch_id * k; out_idx += batch_id * k; - constexpr int pass = calc_num_passes() - 1; - constexpr int start_bit = calc_start_bit(pass); + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); const auto kth_value_bits = counter->kth_value_bits; - const idxT needed_num_of_kth = counter->k; - idxT* p_out_cnt = &counter->out_cnt; - idxT* p_out_back_cnt = &counter->out_back_cnt; + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; auto f = [k, greater, @@ -630,16 +629,16 @@ __global__ void last_filter_kernel(const T* in, p_out_back_cnt, in_idx_buf, out, - out_idx](T value, idxT i) { + out_idx](T value, IdxT i) { const auto bits = (twiddle_in(value, greater) >> start_bit) << start_bit; if (bits < kth_value_bits) { - idxT pos = atomicAdd(p_out_cnt, 1); + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else if (bits == kth_value_bits) { - idxT back_pos = atomicAdd(p_out_back_cnt, 1); + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); if (back_pos < needed_num_of_kth) { - idxT pos = k - 1 - back_pos; + IdxT pos = k - 1 - back_pos; out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -650,23 +649,23 @@ __global__ void last_filter_kernel(const T* in, } template class Store> __global__ void radix_kernel(const T* in, const T* in_buf, - const idxT* in_idx_buf, + const IdxT* in_idx_buf, T* out_buf, - idxT* out_idx_buf, + IdxT* out_idx_buf, T* out, - idxT* out_idx, - Counter* counters, - idxT* histograms, - const idxT len, - const idxT k, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const IdxT k, const bool greater, const int pass) { @@ -674,9 +673,9 @@ __global__ void radix_kernel(const T* in, const int batch_id = blockIdx.y; auto counter = counters + batch_id; - idxT current_k; - idxT previous_len; - idxT current_len; + IdxT current_k; + IdxT previous_len; + IdxT current_len; if (pass == 0) { current_k = k; previous_len = len; @@ -695,8 +694,8 @@ __global__ void radix_kernel(const T* in, if (current_len == 0) { return; } bool early_stop = (current_len == current_k); - constexpr int num_buckets = calc_num_buckets(); - constexpr int num_passes = calc_num_passes(); + constexpr int num_buckets = calc_num_buckets(); + constexpr int num_passes = calc_num_passes(); if constexpr (use_dynamic) { // Figure out if the previous pass writes buffer @@ -721,18 +720,18 @@ __global__ void radix_kernel(const T* in, } auto histogram = histograms + batch_id * num_buckets; - FilterAndHistogram()(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - previous_len, - counter, - histogram, - greater, - pass, - early_stop); + FilterAndHistogram()(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + previous_len, + counter, + histogram, + greater, + pass, + early_stop); __threadfence(); if (threadIdx.x == 0) { @@ -752,9 +751,9 @@ __global__ void radix_kernel(const T* in, return; } - scan(histogram); + scan(histogram); __syncthreads(); - choose_bucket(counter, histogram, current_k, pass); + choose_bucket(counter, histogram, current_k, pass); __syncthreads(); // reset for next pass @@ -772,7 +771,7 @@ __global__ void radix_kernel(const T* in, if constexpr (!use_dynamic) { if (pass == num_passes - 1) { - last_filter( + last_filter( out_buf, out_idx_buf, out, out_idx, current_len, k, counter, greater, pass); } } @@ -780,33 +779,33 @@ __global__ void radix_kernel(const T* in, } template class Store> -unsigned calc_grid_dim(int batch_size, idxT len, int sm_cnt, bool use_dynamic) +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) { static_assert(sizeof(WideT) / sizeof(T) >= 1); int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &active_blocks, - use_dynamic ? radix_kernel - : radix_kernel, - NUM_THREAD, + use_dynamic ? radix_kernel + : radix_kernel, + BlockSize, 0)); active_blocks *= sm_cnt; - idxT best_num_blocks = 0; + IdxT best_num_blocks = 0; float best_tail_wave_penalty = 1.0f; - const idxT max_num_blocks = (len - 1) / (sizeof(WideT) / sizeof(T) * NUM_THREAD) + 1; + const IdxT max_num_blocks = (len - 1) / (sizeof(WideT) / sizeof(T) * BlockSize) + 1; for (int num_waves = 1;; ++num_waves) { int num_blocks = std::min(max_num_blocks, std::max(num_waves * active_blocks / batch_size, 1)); - idxT items_per_thread = (len - 1) / (num_blocks * NUM_THREAD) + 1; + IdxT items_per_thread = (len - 1) / (num_blocks * BlockSize) + 1; items_per_thread = (items_per_thread - 1) / (sizeof(WideT) / sizeof(T)) + 1; items_per_thread *= sizeof(WideT) / sizeof(T); - num_blocks = (len - 1) / (items_per_thread * NUM_THREAD) + 1; + num_blocks = (len - 1) / (items_per_thread * BlockSize) + 1; float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; float tail_wave_penalty = (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); @@ -827,33 +826,33 @@ unsigned calc_grid_dim(int batch_size, idxT len, int sm_cnt, bool use_dynamic) } template class Store> void radix_topk(void* buf, size_t& buf_size, const T* in, int batch_size, - idxT len, - idxT k, + IdxT len, + IdxT k, T* out, - idxT* out_idx, + IdxT* out_idx, bool greater, cudaStream_t stream, bool use_dynamic = false) { // TODO: is it possible to relax this restriction? - static_assert(calc_num_passes() > 1); - constexpr int num_buckets = calc_num_buckets(); + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); - Counter* counters = nullptr; - idxT* histograms = nullptr; + Counter* counters = nullptr; + IdxT* histograms = nullptr; T* buf1 = nullptr; - idxT* idx_buf1 = nullptr; + IdxT* idx_buf1 = nullptr; T* buf2 = nullptr; - idxT* idx_buf2 = nullptr; + IdxT* idx_buf2 = nullptr; { std::vector sizes = {sizeof(*counters) * batch_size, sizeof(*histograms) * num_buckets * batch_size, @@ -883,9 +882,9 @@ void radix_topk(void* buf, } const T* in_buf = nullptr; - const idxT* in_idx_buf = nullptr; + const IdxT* in_idx_buf = nullptr; T* out_buf = nullptr; - idxT* out_idx_buf = nullptr; + IdxT* out_idx_buf = nullptr; int sm_cnt; { @@ -894,10 +893,10 @@ void radix_topk(void* buf, RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); } dim3 blocks( - calc_grid_dim(batch_size, len, sm_cnt, use_dynamic), + calc_grid_dim(batch_size, len, sm_cnt, use_dynamic), batch_size); - constexpr int num_passes = calc_num_passes(); + constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { if (pass == 0) { @@ -923,95 +922,95 @@ void radix_topk(void* buf, } if (!use_dynamic) { - radix_kernel - <<>>(in, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - counters, - histograms, - len, - k, - greater, - pass); + radix_kernel + <<>>(in, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + counters, + histograms, + len, + k, + greater, + pass); } else { - radix_kernel - <<>>(in, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - counters, - histograms, - len, - k, - greater, - pass); + radix_kernel + <<>>(in, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + counters, + histograms, + len, + k, + greater, + pass); } } if (use_dynamic) { - dim3 blocks((len / (sizeof(WideT) / sizeof(T)) - 1) / NUM_THREAD + 1, batch_size); - last_filter_kernel<<>>( + dim3 blocks((len / (sizeof(WideT) / sizeof(T)) - 1) / BlockSize + 1, batch_size); + last_filter_kernel<<>>( in, out_buf, out_idx_buf, out, out_idx, len, k, counters, greater); } } -template +template __device__ void filter_and_histogram(const T* in_buf, - const idxT* in_idx_buf, + const IdxT* in_idx_buf, T* out_buf, - idxT* out_idx_buf, + IdxT* out_idx_buf, T* out, - idxT* out_idx, - Counter* counter, - idxT* histogram, + IdxT* out_idx, + Counter* counter, + IdxT* histogram, bool greater, int pass) { - constexpr int num_buckets = calc_num_buckets(); + constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { histogram[i] = 0; } - idxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_filter_cnt = &counter->filter_cnt; if (threadIdx.x == 0) { *p_filter_cnt = 0; } __syncthreads(); - const int start_bit = calc_start_bit(pass); - const unsigned mask = calc_mask(pass); - const idxT previous_len = counter->previous_len; + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + const IdxT previous_len = counter->previous_len; if (pass == 0) { // Could not use vectorized_process() as in FilterAndHistogram because // vectorized_process() assumes multi-block, e.g. uses gridDim.x - for (idxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { T value = in_buf[i]; - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram + bucket, 1); + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram + bucket, static_cast(1)); } } else { - idxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_cnt = &counter->out_cnt; const auto kth_value_bits = counter->kth_value_bits; - const int previous_start_bit = calc_start_bit(pass - 1); + const int previous_start_bit = calc_start_bit(pass - 1); - for (idxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { const T value = in_buf[i]; const auto previous_bits = (twiddle_in(value, greater) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { - idxT pos = atomicAdd(p_filter_cnt, 1); + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); out_buf[pos] = value; out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram + bucket, 1); + int bucket = calc_bucket(value, start_bit, mask, greater); + atomicAdd(histogram + bucket, static_cast(1)); } else if (previous_bits < kth_value_bits) { - idxT pos = atomicAdd(p_out_cnt, 1); + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -1019,21 +1018,21 @@ __device__ void filter_and_histogram(const T* in_buf, } } -template +template __global__ void radix_topk_one_block_kernel(const T* in, - const idxT len, - const idxT k, + const IdxT len, + const IdxT k, T* out, - idxT* out_idx, + IdxT* out_idx, const bool greater, T* buf1, - idxT* idx_buf1, + IdxT* idx_buf1, T* buf2, - idxT* idx_buf2) + IdxT* idx_buf2) { - constexpr int num_buckets = calc_num_buckets(); - __shared__ Counter counter; - __shared__ idxT histogram[num_buckets]; + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ IdxT histogram[num_buckets]; if (threadIdx.x == 0) { counter.k = k; @@ -1053,11 +1052,11 @@ __global__ void radix_topk_one_block_kernel(const T* in, buf2 += blockIdx.x * len; idx_buf2 += blockIdx.x * len; const T* in_buf = nullptr; - const idxT* in_idx_buf = nullptr; + const IdxT* in_idx_buf = nullptr; T* out_buf = nullptr; - idxT* out_idx_buf = nullptr; + IdxT* out_idx_buf = nullptr; - constexpr int num_passes = calc_num_passes(); + constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { if (pass == 0) { in_buf = in; @@ -1080,53 +1079,53 @@ __global__ void radix_topk_one_block_kernel(const T* in, out_buf = buf1; out_idx_buf = idx_buf1; } - idxT current_len = counter.len; - idxT current_k = counter.k; + IdxT current_len = counter.len; + IdxT current_k = counter.k; - filter_and_histogram( + filter_and_histogram( in_buf, in_idx_buf, out_buf, out_idx_buf, out, out_idx, &counter, histogram, greater, pass); __syncthreads(); - scan(histogram); + scan(histogram); __syncthreads(); - choose_bucket(&counter, histogram, current_k, pass); + choose_bucket(&counter, histogram, current_k, pass); if (threadIdx.x == 0) { counter.previous_len = current_len; } __syncthreads(); if (counter.len == counter.k || pass == num_passes - 1) { - last_filter(pass == 0 ? in : out_buf, - pass == 0 ? nullptr : out_idx_buf, - out, - out_idx, - current_len, - k, - &counter, - greater, - pass); + last_filter(pass == 0 ? in : out_buf, + pass == 0 ? nullptr : out_idx_buf, + out, + out_idx, + current_len, + k, + &counter, + greater, + pass); break; } } } -template +template void radix_topk_one_block(void* buf, size_t& buf_size, const T* in, int batch_size, - idxT len, - idxT k, + IdxT len, + IdxT k, T* out, - idxT* out_idx, + IdxT* out_idx, bool greater, cudaStream_t stream) { - static_assert(calc_num_passes() > 1); + static_assert(calc_num_passes() > 1); T* buf1 = nullptr; - idxT* idx_buf1 = nullptr; + IdxT* idx_buf1 = nullptr; T* buf2 = nullptr; - idxT* idx_buf2 = nullptr; + IdxT* idx_buf2 = nullptr; { std::vector sizes = {sizeof(*buf1) * len * batch_size, sizeof(*idx_buf1) * len * batch_size, @@ -1145,47 +1144,47 @@ void radix_topk_one_block(void* buf, idx_buf2 = static_cast(aligned_pointers[3]); } - radix_topk_one_block_kernel - <<>>( + radix_topk_one_block_kernel + <<>>( in, len, k, out, out_idx, greater, buf1, idx_buf1, buf2, idx_buf2); } } // namespace radix_impl -template +template void radix_topk_11bits(void* buf, size_t& buf_size, const T* in, int batch_size, - idxT len, - idxT k, + IdxT len, + IdxT k, T* out, - idxT* out_idx = nullptr, + IdxT* out_idx = nullptr, bool greater = true, cudaStream_t stream = 0) { constexpr int items_per_thread = 32; if (len <= radix_impl::BLOCK_DIM * items_per_thread) { - radix_impl::radix_topk_one_block( + radix_impl::radix_topk_one_block( buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); } else if (len < 100.0 * k / batch_size + 0.01) { - radix_impl::radix_topk( + radix_impl::radix_topk( buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); } else { - radix_impl::radix_topk( + radix_impl::radix_topk( buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); } } -template +template void radix_topk_11bits_dynamic(void* buf, size_t& buf_size, const T* in, int batch_size, - idxT len, - idxT k, + IdxT len, + IdxT k, T* out, - idxT* out_idx = nullptr, + IdxT* out_idx = nullptr, bool greater = true, cudaStream_t stream = 0) { @@ -1193,13 +1192,13 @@ void radix_topk_11bits_dynamic(void* buf, constexpr int items_per_thread = 32; if (len <= radix_impl::BLOCK_DIM * items_per_thread) { - radix_impl::radix_topk_one_block( + radix_impl::radix_topk_one_block( buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); } else if (len < 100.0 * k / batch_size + 0.01) { - radix_impl::radix_topk( + radix_impl::radix_topk( buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); } else { - radix_impl::radix_topk( + radix_impl::radix_topk( buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); } } From 5f63cbdfe24076e3c314384529528b46b85b4804 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Wed, 11 Jan 2023 12:57:48 +0800 Subject: [PATCH 03/29] radix top-k: add extra input parameter in_idx --- .../matrix/detail/select_radix_updated.cuh | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 1867065939..c92a8cdd2b 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -588,6 +588,7 @@ __device__ void last_filter(const T* out_buf, template __global__ void last_filter_kernel(const T* in, + const IdxT* in_idx, const T* in_buf, const IdxT* in_idx_buf, T* out, @@ -604,7 +605,7 @@ __global__ void last_filter_kernel(const T* in, if (previous_len == 0) { return; } if (previous_len > len / LAZY_WRITING_FACTOR) { in_buf = in; - in_idx_buf = nullptr; + in_idx_buf = in_idx; previous_len = len; } @@ -656,6 +657,7 @@ template class Store> __global__ void radix_kernel(const T* in, + const IdxT* in_idx, const T* in_buf, const IdxT* in_idx_buf, T* out_buf, @@ -702,7 +704,7 @@ __global__ void radix_kernel(const T* in, if (previous_len > len / LAZY_WRITING_FACTOR) { previous_len = len; in_buf = in; - in_idx_buf = nullptr; + in_idx_buf = in_idx; } // Figure out if this pass need to write buffer if (current_len > len / LAZY_WRITING_FACTOR) { @@ -834,6 +836,7 @@ template <<>>(in, + in_idx, in_buf, in_idx_buf, out_buf, @@ -939,6 +943,7 @@ void radix_topk(void* buf, } else { radix_kernel <<>>(in, + in_idx, in_buf, in_idx_buf, out_buf, @@ -957,7 +962,7 @@ void radix_topk(void* buf, if (use_dynamic) { dim3 blocks((len / (sizeof(WideT) / sizeof(T)) - 1) / BlockSize + 1, batch_size); last_filter_kernel<<>>( - in, out_buf, out_idx_buf, out, out_idx, len, k, counters, greater); + in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, greater); } } @@ -1020,6 +1025,7 @@ __device__ void filter_and_histogram(const T* in_buf, template __global__ void radix_topk_one_block_kernel(const T* in, + const IdxT* in_idx, const IdxT len, const IdxT k, T* out, @@ -1045,6 +1051,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, __syncthreads(); in += blockIdx.x * len; + if (in_idx) { in_idx += blockIdx.x * len; } out += blockIdx.x * k; out_idx += blockIdx.x * k; buf1 += blockIdx.x * len; @@ -1065,7 +1072,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, out_idx_buf = nullptr; } else if (pass == 1) { in_buf = in; - in_idx_buf = nullptr; + in_idx_buf = in_idx; out_buf = buf1; out_idx_buf = idx_buf1; } else if (pass % 2 == 0) { @@ -1095,7 +1102,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, if (counter.len == counter.k || pass == num_passes - 1) { last_filter(pass == 0 ? in : out_buf, - pass == 0 ? nullptr : out_idx_buf, + pass == 0 ? in_idx : out_idx_buf, out, out_idx, current_len, @@ -1112,6 +1119,7 @@ template void radix_topk_one_block(void* buf, size_t& buf_size, const T* in, + const IdxT* in_idx, int batch_size, IdxT len, IdxT k, @@ -1146,7 +1154,7 @@ void radix_topk_one_block(void* buf, radix_topk_one_block_kernel <<>>( - in, len, k, out, out_idx, greater, buf1, idx_buf1, buf2, idx_buf2); + in, in_idx, len, k, out, out_idx, greater, buf1, idx_buf1, buf2, idx_buf2); } } // namespace radix_impl @@ -1155,6 +1163,7 @@ template void radix_topk_11bits(void* buf, size_t& buf_size, const T* in, + const IdxT* in_idx, int batch_size, IdxT len, IdxT k, @@ -1166,13 +1175,13 @@ void radix_topk_11bits(void* buf, constexpr int items_per_thread = 32; if (len <= radix_impl::BLOCK_DIM * items_per_thread) { radix_impl::radix_topk_one_block( - buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream); } else if (len < 100.0 * k / batch_size + 0.01) { radix_impl::radix_topk( - buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream); } else { radix_impl::radix_topk( - buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream); } } @@ -1180,6 +1189,7 @@ template void radix_topk_11bits_dynamic(void* buf, size_t& buf_size, const T* in, + const IdxT* in_idx, int batch_size, IdxT len, IdxT k, @@ -1193,13 +1203,13 @@ void radix_topk_11bits_dynamic(void* buf, constexpr int items_per_thread = 32; if (len <= radix_impl::BLOCK_DIM * items_per_thread) { radix_impl::radix_topk_one_block( - buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream); } else if (len < 100.0 * k / batch_size + 0.01) { radix_impl::radix_topk( - buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); } else { radix_impl::radix_topk( - buf, buf_size, in, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); } } From 60864544f8c448a59d4deeee6a2cae72e6d13556 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Thu, 12 Jan 2023 22:20:02 +0800 Subject: [PATCH 04/29] radix top-k: replace greater with select_min --- .../matrix/detail/select_radix_updated.cuh | 104 ++++++++++-------- 1 file changed, 56 insertions(+), 48 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index c92a8cdd2b..5715f61192 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -68,28 +68,28 @@ __device__ constexpr unsigned calc_mask(int pass) } template -__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); - if (greater) { bits = ~bits; } + if (!select_min) { bits = ~bits; } return bits; } template -__device__ T twiddle_out(typename cub::Traits::UnsignedBits bits, bool greater) +__device__ T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) { - if (greater) { bits = ~bits; } + if (!select_min) { bits = ~bits; } bits = cub::Traits::TwiddleOut(bits); return reinterpret_cast(bits); } template -__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) { static_assert(BitsPerPass <= sizeof(int) * 8 - 1, "BitsPerPass is too large that the result type could not be int"); - return (twiddle_in(x, greater) >> start_bit) & mask; + return (twiddle_in(x, select_min) >> start_bit) & mask; } template @@ -309,7 +309,7 @@ class FilterAndHistogram { IdxT previous_len, Counter* counter, IdxT* histogram, - bool greater, + bool select_min, int pass, bool early_stop) { @@ -325,8 +325,8 @@ class FilterAndHistogram { const unsigned mask = calc_mask(pass); if (pass == 0) { - auto f = [greater, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, greater); + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); }; vectorized_process(in_buf, previous_len, f); @@ -341,7 +341,7 @@ class FilterAndHistogram { out_idx_buf, out, out_idx, - greater, + select_min, start_bit, mask, previous_start_bit, @@ -350,7 +350,7 @@ class FilterAndHistogram { p_out_cnt, early_stop, &store](T value, IdxT i, bool valid) { - const auto previous_bits = (twiddle_in(value, greater) >> previous_start_bit) + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (valid && previous_bits == kth_value_bits) { @@ -365,7 +365,7 @@ class FilterAndHistogram { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } - int bucket = calc_bucket(value, start_bit, mask, greater); + int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); } } @@ -403,7 +403,7 @@ class FilterAndHistogram { IdxT previous_len, Counter* counter, IdxT* histogram, - bool greater, + bool select_min, int pass, bool early_stop) { @@ -418,8 +418,8 @@ class FilterAndHistogram { const unsigned mask = calc_mask(pass); if (pass == 0) { - auto f = [greater, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, greater); + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); }; vectorized_process(in_buf, previous_len, f); @@ -434,7 +434,7 @@ class FilterAndHistogram { out_idx_buf, out, out_idx, - greater, + select_min, start_bit, mask, previous_start_bit, @@ -442,7 +442,7 @@ class FilterAndHistogram { p_filter_cnt, p_out_cnt, early_stop](T value, IdxT i) { - const auto previous_bits = (twiddle_in(value, greater) >> previous_start_bit) + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { if (early_stop) { @@ -456,7 +456,7 @@ class FilterAndHistogram { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } - int bucket = calc_bucket(value, start_bit, mask, greater); + int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); } } @@ -555,7 +555,7 @@ __device__ void last_filter(const T* out_buf, IdxT current_len, IdxT k, Counter* counter, - const bool greater, + const bool select_min, const int pass) { const auto kth_value_bits = counter->kth_value_bits; @@ -567,7 +567,7 @@ __device__ void last_filter(const T* out_buf, IdxT* p_out_back_cnt = &counter->out_back_cnt; for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { const T value = out_buf[i]; - const auto bits = (twiddle_in(value, greater) >> start_bit) << start_bit; + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; @@ -596,7 +596,7 @@ __global__ void last_filter_kernel(const T* in, IdxT len, IdxT k, Counter* counters, - const bool greater) + const bool select_min) { const int batch_id = blockIdx.y; @@ -623,7 +623,7 @@ __global__ void last_filter_kernel(const T* in, IdxT* p_out_back_cnt = &counter->out_back_cnt; auto f = [k, - greater, + select_min, kth_value_bits, needed_num_of_kth, p_out_cnt, @@ -631,7 +631,7 @@ __global__ void last_filter_kernel(const T* in, in_idx_buf, out, out_idx](T value, IdxT i) { - const auto bits = (twiddle_in(value, greater) >> start_bit) << start_bit; + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; @@ -668,7 +668,7 @@ __global__ void radix_kernel(const T* in, IdxT* histograms, const IdxT len, const IdxT k, - const bool greater, + const bool select_min, const int pass) { __shared__ bool isLastBlock; @@ -731,7 +731,7 @@ __global__ void radix_kernel(const T* in, previous_len, counter, histogram, - greater, + select_min, pass, early_stop); __threadfence(); @@ -774,7 +774,7 @@ __global__ void radix_kernel(const T* in, if constexpr (!use_dynamic) { if (pass == num_passes - 1) { last_filter( - out_buf, out_idx_buf, out, out_idx, current_len, k, counter, greater, pass); + out_buf, out_idx_buf, out, out_idx, current_len, k, counter, select_min, pass); } } } @@ -842,7 +842,7 @@ void radix_topk(void* buf, IdxT k, T* out, IdxT* out_idx, - bool greater, + bool select_min, cudaStream_t stream, bool use_dynamic = false) { @@ -938,7 +938,7 @@ void radix_topk(void* buf, histograms, len, k, - greater, + select_min, pass); } else { radix_kernel @@ -954,7 +954,7 @@ void radix_topk(void* buf, histograms, len, k, - greater, + select_min, pass); } } @@ -962,7 +962,7 @@ void radix_topk(void* buf, if (use_dynamic) { dim3 blocks((len / (sizeof(WideT) / sizeof(T)) - 1) / BlockSize + 1, batch_size); last_filter_kernel<<>>( - in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, greater); + in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min); } } @@ -975,7 +975,7 @@ __device__ void filter_and_histogram(const T* in_buf, IdxT* out_idx, Counter* counter, IdxT* histogram, - bool greater, + bool select_min, int pass) { constexpr int num_buckets = calc_num_buckets(); @@ -995,7 +995,7 @@ __device__ void filter_and_histogram(const T* in_buf, // vectorized_process() assumes multi-block, e.g. uses gridDim.x for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { T value = in_buf[i]; - int bucket = calc_bucket(value, start_bit, mask, greater); + int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram + bucket, static_cast(1)); } } else { @@ -1005,14 +1005,14 @@ __device__ void filter_and_histogram(const T* in_buf, for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { const T value = in_buf[i]; - const auto previous_bits = (twiddle_in(value, greater) >> previous_start_bit) + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); out_buf[pos] = value; out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; - int bucket = calc_bucket(value, start_bit, mask, greater); + int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram + bucket, static_cast(1)); } else if (previous_bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); @@ -1030,7 +1030,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, const IdxT k, T* out, IdxT* out_idx, - const bool greater, + const bool select_min, T* buf1, IdxT* idx_buf1, T* buf2, @@ -1089,8 +1089,16 @@ __global__ void radix_topk_one_block_kernel(const T* in, IdxT current_len = counter.len; IdxT current_k = counter.k; - filter_and_histogram( - in_buf, in_idx_buf, out_buf, out_idx_buf, out, out_idx, &counter, histogram, greater, pass); + filter_and_histogram(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + &counter, + histogram, + select_min, + pass); __syncthreads(); scan(histogram); @@ -1108,7 +1116,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, current_len, k, &counter, - greater, + select_min, pass); break; } @@ -1125,7 +1133,7 @@ void radix_topk_one_block(void* buf, IdxT k, T* out, IdxT* out_idx, - bool greater, + bool select_min, cudaStream_t stream) { static_assert(calc_num_passes() > 1); @@ -1154,7 +1162,7 @@ void radix_topk_one_block(void* buf, radix_topk_one_block_kernel <<>>( - in, in_idx, len, k, out, out_idx, greater, buf1, idx_buf1, buf2, idx_buf2); + in, in_idx, len, k, out, out_idx, select_min, buf1, idx_buf1, buf2, idx_buf2); } } // namespace radix_impl @@ -1169,19 +1177,19 @@ void radix_topk_11bits(void* buf, IdxT k, T* out, IdxT* out_idx = nullptr, - bool greater = true, + bool select_min = true, cudaStream_t stream = 0) { constexpr int items_per_thread = 32; if (len <= radix_impl::BLOCK_DIM * items_per_thread) { radix_impl::radix_topk_one_block( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); } else if (len < 100.0 * k / batch_size + 0.01) { radix_impl::radix_topk( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); } else { radix_impl::radix_topk( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); } } @@ -1195,7 +1203,7 @@ void radix_topk_11bits_dynamic(void* buf, IdxT k, T* out, IdxT* out_idx = nullptr, - bool greater = true, + bool select_min = true, cudaStream_t stream = 0) { constexpr bool use_dynamic = true; @@ -1203,13 +1211,13 @@ void radix_topk_11bits_dynamic(void* buf, constexpr int items_per_thread = 32; if (len <= radix_impl::BLOCK_DIM * items_per_thread) { radix_impl::radix_topk_one_block( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); } else if (len < 100.0 * k / batch_size + 0.01) { radix_impl::radix_topk( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, use_dynamic); } else { radix_impl::radix_topk( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, greater, stream, use_dynamic); + buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, use_dynamic); } } From dd63770266f9c25fd9d1ed878e3e6ee31885603d Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Wed, 11 Jan 2023 14:07:31 +0800 Subject: [PATCH 05/29] radix top-k: make it compiled --- .../matrix/detail/select_radix_updated.cuh | 191 +++++++++--------- 1 file changed, 92 insertions(+), 99 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 5715f61192..de46b75866 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -33,6 +33,7 @@ namespace raft::spatial::knn::detail::topk { namespace radix_impl { +constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr int BLOCK_DIM = 512; using WideT = float4; constexpr int LAZY_WRITING_FACTOR = 4; @@ -803,7 +804,8 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) float best_tail_wave_penalty = 1.0f; const IdxT max_num_blocks = (len - 1) / (sizeof(WideT) / sizeof(T) * BlockSize) + 1; for (int num_waves = 1;; ++num_waves) { - int num_blocks = std::min(max_num_blocks, std::max(num_waves * active_blocks / batch_size, 1)); + IdxT num_blocks = std::min( + max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); IdxT items_per_thread = (len - 1) / (num_blocks * BlockSize) + 1; items_per_thread = (items_per_thread - 1) / (sizeof(WideT) / sizeof(T)) + 1; items_per_thread *= sizeof(WideT) / sizeof(T); @@ -833,9 +835,7 @@ template class Store> -void radix_topk(void* buf, - size_t& buf_size, - const T* in, +void radix_topk(const T* in, const IdxT* in_idx, int batch_size, IdxT len, @@ -843,47 +843,38 @@ void radix_topk(void* buf, T* out, IdxT* out_idx, bool select_min, - cudaStream_t stream, - bool use_dynamic = false) + bool use_dynamic, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { // TODO: is it possible to relax this restriction? static_assert(calc_num_passes() > 1); constexpr int num_buckets = calc_num_buckets(); - Counter* counters = nullptr; - IdxT* histograms = nullptr; - T* buf1 = nullptr; - IdxT* idx_buf1 = nullptr; - T* buf2 = nullptr; - IdxT* idx_buf2 = nullptr; - { - std::vector sizes = {sizeof(*counters) * batch_size, - sizeof(*histograms) * num_buckets * batch_size, - sizeof(*buf1) * len * batch_size, - sizeof(*idx_buf1) * len * batch_size, - sizeof(*buf2) * len * batch_size, - sizeof(*idx_buf2) * len * batch_size}; - size_t total_size = calc_aligned_size(sizes); - if (!buf) { - buf_size = total_size; - return; - } - - std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); - counters = static_cast(aligned_pointers[0]); - histograms = static_cast(aligned_pointers[1]); - buf1 = static_cast(aligned_pointers[2]); - idx_buf1 = static_cast(aligned_pointers[3]); - buf2 = static_cast(aligned_pointers[4]); - idx_buf2 = static_cast(aligned_pointers[5]); - - RAFT_CUDA_TRY(cudaMemsetAsync( - buf, - 0, - static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), - stream)); + auto pool_guard = + raft::get_pool_memory_resource(mr, + batch_size * (sizeof(Counter) // counters + + sizeof(IdxT) * num_buckets // histograms + + sizeof(T) * len * 2 // T bufs + + sizeof(IdxT) * len * 2 // IdxT bufs + ) + + 256 * 6); + if (pool_guard) { + RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); } + rmm::device_uvector> counters(batch_size, stream, mr); + rmm::device_uvector histograms(num_buckets * batch_size, stream, mr); + rmm::device_uvector buf1(len * batch_size, stream, mr); + rmm::device_uvector idx_buf1(len * batch_size, stream, mr); + rmm::device_uvector buf2(len * batch_size, stream, mr); + rmm::device_uvector idx_buf2(len * batch_size, stream, mr); + + RAFT_CUDA_TRY( + cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); + const T* in_buf = nullptr; const IdxT* in_idx_buf = nullptr; T* out_buf = nullptr; @@ -910,18 +901,18 @@ void radix_topk(void* buf, } else if (pass == 1) { in_buf = in; in_idx_buf = in_idx; - out_buf = buf1; - out_idx_buf = idx_buf1; + out_buf = buf1.data(); + out_idx_buf = idx_buf1.data(); } else if (pass % 2 == 0) { - in_buf = buf1; - in_idx_buf = idx_buf1; - out_buf = buf2; - out_idx_buf = idx_buf2; + in_buf = buf1.data(); + in_idx_buf = idx_buf1.data(); + out_buf = buf2.data(); + out_idx_buf = idx_buf2.data(); } else { - in_buf = buf2; - in_idx_buf = idx_buf2; - out_buf = buf1; - out_idx_buf = idx_buf1; + in_buf = buf2.data(); + in_idx_buf = idx_buf2.data(); + out_buf = buf1.data(); + out_idx_buf = idx_buf1.data(); } if (!use_dynamic) { @@ -934,8 +925,8 @@ void radix_topk(void* buf, out_idx_buf, out, out_idx, - counters, - histograms, + counters.data(), + histograms.data(), len, k, select_min, @@ -950,8 +941,8 @@ void radix_topk(void* buf, out_idx_buf, out, out_idx, - counters, - histograms, + counters.data(), + histograms.data(), len, k, select_min, @@ -962,7 +953,7 @@ void radix_topk(void* buf, if (use_dynamic) { dim3 blocks((len / (sizeof(WideT) / sizeof(T)) - 1) / BlockSize + 1, batch_size); last_filter_kernel<<>>( - in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min); + in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters.data(), select_min); } } @@ -1124,9 +1115,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, } template -void radix_topk_one_block(void* buf, - size_t& buf_size, - const T* in, +void radix_topk_one_block(const T* in, const IdxT* in_idx, int batch_size, IdxT len, @@ -1134,90 +1123,94 @@ void radix_topk_one_block(void* buf, T* out, IdxT* out_idx, bool select_min, - cudaStream_t stream) + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { static_assert(calc_num_passes() > 1); - T* buf1 = nullptr; - IdxT* idx_buf1 = nullptr; - T* buf2 = nullptr; - IdxT* idx_buf2 = nullptr; - { - std::vector sizes = {sizeof(*buf1) * len * batch_size, - sizeof(*idx_buf1) * len * batch_size, - sizeof(*buf2) * len * batch_size, - sizeof(*idx_buf2) * len * batch_size}; - size_t total_size = calc_aligned_size(sizes); - if (!buf) { - buf_size = total_size; - return; - } - - std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); - buf1 = static_cast(aligned_pointers[0]); - idx_buf1 = static_cast(aligned_pointers[1]); - buf2 = static_cast(aligned_pointers[2]); - idx_buf2 = static_cast(aligned_pointers[3]); + auto pool_guard = + raft::get_pool_memory_resource(mr, + batch_size * (sizeof(T) * len * 2 // T bufs + + sizeof(IdxT) * len * 2 // IdxT bufs + ) + + 256 * 4); + if (pool_guard) { + RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); } + rmm::device_uvector buf1(len * batch_size, stream, mr); + rmm::device_uvector idx_buf1(len * batch_size, stream, mr); + rmm::device_uvector buf2(len * batch_size, stream, mr); + rmm::device_uvector idx_buf2(len * batch_size, stream, mr); + radix_topk_one_block_kernel - <<>>( - in, in_idx, len, k, out, out_idx, select_min, buf1, idx_buf1, buf2, idx_buf2); + <<>>(in, + in_idx, + len, + k, + out, + out_idx, + select_min, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data()); } } // namespace radix_impl template -void radix_topk_11bits(void* buf, - size_t& buf_size, - const T* in, +void radix_topk_11bits(const T* in, const IdxT* in_idx, int batch_size, IdxT len, IdxT k, T* out, - IdxT* out_idx = nullptr, - bool select_min = true, - cudaStream_t stream = 0) + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { + constexpr bool use_dynamic = false; constexpr int items_per_thread = 32; + if (len <= radix_impl::BLOCK_DIM * items_per_thread) { radix_impl::radix_topk_one_block( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); } else if (len < 100.0 * k / batch_size + 0.01) { radix_impl::radix_topk( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); } else { radix_impl::radix_topk( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); } } template -void radix_topk_11bits_dynamic(void* buf, - size_t& buf_size, - const T* in, +void radix_topk_11bits_dynamic(const T* in, const IdxT* in_idx, int batch_size, IdxT len, IdxT k, T* out, - IdxT* out_idx = nullptr, - bool select_min = true, - cudaStream_t stream = 0) + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { - constexpr bool use_dynamic = true; - + constexpr bool use_dynamic = true; constexpr int items_per_thread = 32; + if (len <= radix_impl::BLOCK_DIM * items_per_thread) { radix_impl::radix_topk_one_block( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); } else if (len < 100.0 * k / batch_size + 0.01) { radix_impl::radix_topk( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, use_dynamic); + in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); } else { radix_impl::radix_topk( - buf, buf_size, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, use_dynamic); + in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); } } From aea2bd08aa8ea10e1a77962a247f1830ca3f1655 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Wed, 11 Jan 2023 14:08:40 +0800 Subject: [PATCH 06/29] radix top-k: polish style --- .../matrix/detail/select_radix_updated.cuh | 223 +++++++++++++----- 1 file changed, 165 insertions(+), 58 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index de46b75866..8f6521ec0c 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,10 +33,9 @@ namespace raft::spatial::knn::detail::topk { namespace radix_impl { -constexpr unsigned FULL_WARP_MASK = 0xffffffff; -constexpr int BLOCK_DIM = 512; -using WideT = float4; -constexpr int LAZY_WRITING_FACTOR = 4; +constexpr int BLOCK_DIM = 512; +constexpr int VECTORIZED_READ_SIZE = 16; +constexpr int LAZY_WRITING_FACTOR = 4; template __host__ __device__ constexpr int calc_num_buckets() @@ -47,11 +46,16 @@ __host__ __device__ constexpr int calc_num_buckets() template __host__ __device__ constexpr int calc_num_passes() { - return (sizeof(T) * 8 - 1) / BitsPerPass + 1; + return ceildiv(sizeof(T) * 8, BitsPerPass); } -// bit 0 is the least significant (rightmost) bit -// this function works even when pass=-1, which is used in calc_mask() +/** + * Bit 0 is the least significant (rightmost); + * this implementation processes input from the most to the least significant bit. + * This way, we can skip some passes in the end at the cost of having an unsorted output. + * + * NB: Use pass=-1 for calc_mask(). + */ template __device__ constexpr int calc_start_bit(int pass) { @@ -68,6 +72,10 @@ __device__ constexpr unsigned calc_mask(int pass) return (1 << num_bits) - 1; } +/** + * Use cub to twiddle bits - so that we can correctly compare bits of floating-point values as well + * as of integers. + */ template __device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { @@ -93,50 +101,54 @@ __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) return (twiddle_in(x, select_min) >> start_bit) & mask; } +/** + * Map a Func over the input data, using vectorized load instructions if possible. + * + * NB: in future, we should move this to cpp/include/raft/linalg/detail/unary_op.cuh, which + * currently does not support the second lambda argument (index of an element) + * + * @tparam T element type + * @tparam IdxT indexing type + * @tparam Func void (T x, IdxT idx) + * + * @param in the input data + * @param len the number of elements to read + * @param f the lambda taking two arguments (T x, IdxT idx) + */ template __device__ void vectorized_process(const T* in, IdxT len, Func f) { const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if constexpr (sizeof(T) >= sizeof(WideT)) { + if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { for (IdxT i = tid; i < len; i += stride) { f(in[i], i); } } else { - static_assert(sizeof(WideT) % sizeof(T) == 0); - constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); - // TODO: it's UB - union { - WideT scalar; - T array[items_per_scalar]; - } wide; - - int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) - ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) - : 0; - if (skip_cnt > len) { skip_cnt = len; } - const WideT* in_cast = reinterpret_cast(in + skip_cnt); - const IdxT len_cast = (len - skip_cnt) / items_per_scalar; - - for (IdxT i = tid; i < len_cast; i += stride) { - wide.scalar = in_cast[i]; - const IdxT real_i = skip_cnt + i * items_per_scalar; + using wide_t = TxN_t; + using align_bytes = Pow2<(size_t)VECTORIZED_READ_SIZE>; + using align_elems = Pow2; + wide_t wide; + + // how many elements to skip in order to do aligned vectorized load + const IdxT skip_cnt_left = std::min((IdxT)(align_bytes::roundUp(in) - in), len); + + // The main loop: process all aligned data + for (IdxT i = tid * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; + i += stride * wide_t::Ratio) { + wide.load(in, i); #pragma unroll - for (int j = 0; j < items_per_scalar; ++j) { - f(wide.array[j], real_i + j); + for (int j = 0; j < wide_t::Ratio; ++j) { + f(wide.val.data[j], i + j); } } - static_assert(WarpSize >= items_per_scalar); - // and because items_per_scalar > skip_cnt, WarpSize > skip_cnt - // no need to use loop - if (tid < skip_cnt) { f(in[tid], tid); } - // because len_cast = (len - skip_cnt) / items_per_scalar, - // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; - // and so - // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WarpSize - // no need to use loop - const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + static_assert(WarpSize >= wide_t::Ratio); + // Processes the skipped elements on the left + if (tid < skip_cnt_left) { f(in[tid], tid); } + // Processes the skipped elements on the right + const IdxT skip_cnt_right = align_elems::mod(len - skip_cnt_left); + const IdxT remain_i = len - skip_cnt_right + tid; if (remain_i < len) { f(in[remain_i], remain_i); } } } @@ -145,6 +157,7 @@ __device__ void vectorized_process(const T* in, IdxT len, Func f) template __device__ void vectorized_process(const T* in, IdxT len, Func f, int sync_width) { + using WideT = float4; const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= sizeof(WideT)) { @@ -240,7 +253,7 @@ class BufferedStore { __device__ void store(T value, IdxT index, bool valid, T* out, IdxT* out_idx, IdxT* p_out_cnt) { - unsigned int valid_mask = __ballot_sync(FULL_WARP_MASK, valid); + unsigned int valid_mask = __ballot_sync(FULL_WARP_MASK_, valid); if (valid_mask == 0) { return; } int pos = __popc(valid_mask & ((0x1u << lane_id_) - 1)) + warp_pos_; @@ -254,7 +267,7 @@ class BufferedStore { if (warp_pos_ >= WarpSize) { IdxT pos_smem; if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, static_cast(WarpSize)); } - pos_smem = __shfl_sync(FULL_WARP_MASK, pos_smem, 0); + pos_smem = __shfl_sync(FULL_WARP_MASK_, pos_smem, 0); __syncwarp(); out[pos_smem + lane_id_] = value_smem_[lane_id_]; @@ -276,7 +289,7 @@ class BufferedStore { if (warp_pos_ > 0) { IdxT pos_smem; if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, static_cast(warp_pos_)); } - pos_smem = __shfl_sync(FULL_WARP_MASK, pos_smem, 0); + pos_smem = __shfl_sync(FULL_WARP_MASK_, pos_smem, 0); __syncwarp(); if (lane_id_ < warp_pos_) { @@ -287,6 +300,7 @@ class BufferedStore { } private: + const unsigned FULL_WARP_MASK_{0xffffffff}; T* value_smem_; IdxT* index_smem_; IdxT lane_id_; //@TODO: Can be const variable @@ -392,6 +406,10 @@ class FilterAndHistogram { } }; +/** + * Fused filtering of the current phase and building histogram for the next phase + * (see steps 4-1 in `radix_kernel` description). + */ template class FilterAndHistogram { public: @@ -419,6 +437,9 @@ class FilterAndHistogram { const unsigned mask = calc_mask(pass); if (pass == 0) { + // Passed to vectorized_process, this function executes in all blocks in parallel, + // i.e. the work is split along the input (both, in batches and chunks of a single row). + // Later, the histograms are merged using atomicAdd. auto f = [select_min, start_bit, mask](T value, IdxT) { int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); @@ -430,6 +451,7 @@ class FilterAndHistogram { const auto kth_value_bits = counter->kth_value_bits; const int previous_start_bit = calc_start_bit(pass - 1); + // See the remark above on the distributed execution of `f` using vectorized_process. auto f = [in_idx_buf, out_buf, out_idx_buf, @@ -461,12 +483,11 @@ class FilterAndHistogram { atomicAdd(histogram_smem + bucket, static_cast(1)); } } - // '(out_buf || early_stop)': - // If we skip writing to 'out_buf' (when !out_buf), we should skip - // writing to 'out' too. So we won't write the same value to 'out' - // multiple times. And if we keep skipping the writing, values will be - // written in last_filter_kernel at last. But when 'early_stop' is true, - // we need to write to 'out' since it's the last chance. + // '(out_buf || early_stop)' is a little tricky: + // If we skip writing to 'out_buf' (when 'out_buf' is false), we should skip writing to + // 'out' too. So we won't write the same value to 'out' multiple times. And if we keep + // skipping the writing, values will be written in last_filter_kernel at last. + // But when 'early_stop' is true, we need to write to 'out' since it's the last chance. else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; @@ -476,14 +497,20 @@ class FilterAndHistogram { vectorized_process(in_buf, previous_len, f); } if (early_stop) { return; } - __syncthreads(); + + // merge histograms produced by individual blocks for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } } } }; +/** + * Replace a part of the histogram with its own prefix sum, starting from the `start` and adding + * `current` to each entry of the result. + * (step 2 in `radix_kernel` description) + */ template __device__ void scan(volatile IdxT* histogram) { @@ -524,6 +551,10 @@ __device__ void scan(volatile IdxT* histogram) } } +/** + * Calculate in which bucket the k-th value will fall + * (steps 3 in `radix_kernel` description) + */ template __device__ void choose_bucket(Counter* counter, const IdxT* histogram, @@ -537,8 +568,8 @@ __device__ void choose_bucket(Counter* counter, // one and only one thread will satisfy this condition, so only write once if (prev < k && cur >= k) { - counter->k = k - prev; - counter->len = cur - prev; + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in `index` bucket typename cub::Traits::UnsignedBits bucket = i; int start_bit = calc_start_bit(pass); counter->kth_value_bits |= bucket << start_bit; @@ -650,6 +681,35 @@ __global__ void last_filter_kernel(const T* in, vectorized_process(in_buf, previous_len, f); } +/** + * + * It is expected to call this kernel multiple times (passes), in each pass we process a radix, + * going from the most significant towards the least significant bits (MSD). + * + * Conceptually, each pass consists of 4 steps: + * + * 1. Calculate histogram + * First, transform bits into a digit, the value of which is in the range + * [0, 2^{BITS_PER_PASS}-1]. Then count the frequency of each digit value and the result is a + * histogram. That is, histogram[i] contains the count of inputs having value i. + * + * 2. Scan the histogram + * Inclusive prefix sum is computed for the histogram. After this step, histogram[i] contains + * the count of inputs having value <= i. + * + * 3. Find the bucket j of the histogram that the k-th value falls into + * + * 4. Filtering + * Input elements whose digit value unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) { - static_assert(sizeof(WideT) / sizeof(T) >= 1); + static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( @@ -802,13 +862,13 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) IdxT best_num_blocks = 0; float best_tail_wave_penalty = 1.0f; - const IdxT max_num_blocks = (len - 1) / (sizeof(WideT) / sizeof(T) * BlockSize) + 1; + const IdxT max_num_blocks = (len - 1) / (VECTORIZED_READ_SIZE / sizeof(T) * BlockSize) + 1; for (int num_waves = 1;; ++num_waves) { IdxT num_blocks = std::min( max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); IdxT items_per_thread = (len - 1) / (num_blocks * BlockSize) + 1; - items_per_thread = (items_per_thread - 1) / (sizeof(WideT) / sizeof(T)) + 1; - items_per_thread *= sizeof(WideT) / sizeof(T); + items_per_thread = (items_per_thread - 1) / (VECTORIZED_READ_SIZE / sizeof(T)) + 1; + items_per_thread *= VECTORIZED_READ_SIZE / sizeof(T); num_blocks = (len - 1) / (items_per_thread * BlockSize) + 1; float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; float tail_wave_penalty = @@ -829,6 +889,53 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) return best_num_blocks; } +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_keys` as a row-major matrix with len columns and + * batch_size rows, then this function selects k smallest/largest values in each row and fills + * in the row-major matrix `out` of size (batch_size, k). + * + * Note, the output is NOT sorted within the groups of `k` selected elements. + * + * @tparam T + * the type of the keys (what is being compared). + * @tparam IdxT + * the index type (what is being selected together with the keys). + * @tparam BitsPerPass + * The size of the radix; + * it affects the number of passes and number of buckets. + * @tparam BlockSize + * Number of threads in a kernel thread block. + * + * @param[in] in + * contiguous device array of inputs of size (len * batch_size); + * these are compared and selected. + * @param[in] in_idx + * contiguous device array of inputs of size (len * batch_size); + * typically, these are indices of the corresponding in_keys. + * @param batch_size + * number of input rows, i.e. the batch size. + * @param len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: len >= k. + * @param k + * the number of outputs to select in each input row. + * @param[out] out + * contiguous device array of outputs of size (k * batch_size); + * the k smallest/largest values from each row of the `in_keys`. + * @param[out] out_idx + * contiguous device array of outputs of size (k * batch_size); + * the payload selected together with `out`. + * @param select_min + * whether to select k smallest (true) or largest (false) keys. + * @param use_dynamic + * whether to use the dynamic implementation, which is favorable if the leading bits of input data + * are almost the same. + * @param stream + * @param mr an optional memory resource to use across the calls (you can provide a large enough + * memory pool here to avoid memory allocations within the call). + */ template pool_size()); @@ -951,7 +1058,7 @@ void radix_topk(const T* in, } if (use_dynamic) { - dim3 blocks((len / (sizeof(WideT) / sizeof(T)) - 1) / BlockSize + 1, batch_size); + dim3 blocks((len / (VECTORIZED_READ_SIZE / sizeof(T)) - 1) / BlockSize + 1, batch_size); last_filter_kernel<<>>( in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters.data(), select_min); } @@ -1133,7 +1240,7 @@ void radix_topk_one_block(const T* in, batch_size * (sizeof(T) * len * 2 // T bufs + sizeof(IdxT) * len * 2 // IdxT bufs ) + - 256 * 4); + 256 * 4); // might need extra memory for alignment if (pool_guard) { RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); From 99924bd69a72776ab517ebbc174ae05d2c3748ba Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Fri, 13 Jan 2023 19:55:24 +0800 Subject: [PATCH 07/29] radix top-k: polish code --- .../matrix/detail/select_radix_updated.cuh | 167 +++++++----------- 1 file changed, 68 insertions(+), 99 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 8f6521ec0c..3dcd9664c1 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -27,13 +27,12 @@ #include #include -#include +#include #include namespace raft::spatial::knn::detail::topk { namespace radix_impl { -constexpr int BLOCK_DIM = 512; constexpr int VECTORIZED_READ_SIZE = 16; constexpr int LAZY_WRITING_FACTOR = 4; @@ -732,8 +731,6 @@ __global__ void radix_kernel(const T* in, const bool select_min, const int pass) { - __shared__ bool isLastBlock; - const int batch_id = blockIdx.y; auto counter = counters + batch_id; IdxT current_k; @@ -797,14 +794,13 @@ __global__ void radix_kernel(const T* in, early_stop); __threadfence(); + bool isLastBlock = false; if (threadIdx.x == 0) { unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); isLastBlock = (finished == (gridDim.x - 1)); } - // Synchronize to make sure that each thread reads the correct value of isLastBlock. - __syncthreads(); - if (isLastBlock) { + if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { // last_filter_kernel from dynamic version requires setting previous_len @@ -854,8 +850,8 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &active_blocks, - use_dynamic ? radix_kernel - : radix_kernel, + use_dynamic ? radix_kernel + : radix_kernel, BlockSize, 0)); active_blocks *= sm_cnt; @@ -889,53 +885,6 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) return best_num_blocks; } -/** - * Select k smallest or largest key/values from each row in the input data. - * - * If you think of the input data `in_keys` as a row-major matrix with len columns and - * batch_size rows, then this function selects k smallest/largest values in each row and fills - * in the row-major matrix `out` of size (batch_size, k). - * - * Note, the output is NOT sorted within the groups of `k` selected elements. - * - * @tparam T - * the type of the keys (what is being compared). - * @tparam IdxT - * the index type (what is being selected together with the keys). - * @tparam BitsPerPass - * The size of the radix; - * it affects the number of passes and number of buckets. - * @tparam BlockSize - * Number of threads in a kernel thread block. - * - * @param[in] in - * contiguous device array of inputs of size (len * batch_size); - * these are compared and selected. - * @param[in] in_idx - * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_keys. - * @param batch_size - * number of input rows, i.e. the batch size. - * @param len - * length of a single input array (row); also sometimes referred as n_cols. - * Invariant: len >= k. - * @param k - * the number of outputs to select in each input row. - * @param[out] out - * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_keys`. - * @param[out] out_idx - * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out`. - * @param select_min - * whether to select k smallest (true) or largest (false) keys. - * @param use_dynamic - * whether to use the dynamic implementation, which is favorable if the leading bits of input data - * are almost the same. - * @param stream - * @param mr an optional memory resource to use across the calls (you can provide a large enough - * memory pool here to avoid memory allocations within the call). - */ template -void radix_topk_11bits(const T* in, - const IdxT* in_idx, - int batch_size, - IdxT len, - IdxT k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) -{ - constexpr bool use_dynamic = false; - constexpr int items_per_thread = 32; - - if (len <= radix_impl::BLOCK_DIM * items_per_thread) { - radix_impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); - } else if (len < 100.0 * k / batch_size + 0.01) { - radix_impl::radix_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); - } else { - radix_impl::radix_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); - } -} - -template -void radix_topk_11bits_dynamic(const T* in, - const IdxT* in_idx, - int batch_size, - IdxT len, - IdxT k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_keys` as a row-major matrix with len columns and + * batch_size rows, then this function selects k smallest/largest values in each row and fills + * in the row-major matrix `out` of size (batch_size, k). + * + * Note, the output is NOT sorted within the groups of `k` selected elements. + * + * @tparam T + * the type of the keys (what is being compared). + * @tparam IdxT + * the index type (what is being selected together with the keys). + * @tparam BitsPerPass + * The size of the radix; + * it affects the number of passes and number of buckets. + * @tparam BlockSize + * Number of threads in a kernel thread block. + * + * @param[in] in + * contiguous device array of inputs of size (len * batch_size); + * these are compared and selected. + * @param[in] in_idx + * contiguous device array of inputs of size (len * batch_size); + * typically, these are indices of the corresponding in_keys. + * @param batch_size + * number of input rows, i.e. the batch size. + * @param len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: len >= k. + * @param k + * the number of outputs to select in each input row. + * @param[out] out + * contiguous device array of outputs of size (k * batch_size); + * the k smallest/largest values from each row of the `in_keys`. + * @param[out] out_idx + * contiguous device array of outputs of size (k * batch_size); + * the payload selected together with `out`. + * @param select_min + * whether to select k smallest (true) or largest (false) keys. + * @param use_dynamic + * whether to use the dynamic implementation, which is favorable if the leading bits of input data + * are almost the same. + * @param stream + * @param mr an optional memory resource to use across the calls (you can provide a large enough + * memory pool here to avoid memory allocations within the call). + */ +template +void radix_topk_updated(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + bool use_dynamic, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { - constexpr bool use_dynamic = true; constexpr int items_per_thread = 32; - if (len <= radix_impl::BLOCK_DIM * items_per_thread) { - radix_impl::radix_topk_one_block( + if (len <= BlockSize * items_per_thread) { + radix_impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); } else if (len < 100.0 * k / batch_size + 0.01) { - radix_impl::radix_topk( + radix_impl::radix_topk( in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); } else { - radix_impl::radix_topk( + radix_impl::radix_topk( in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); } } From 2746173df7e61e8394a382230d7f6dedfcb0d7e1 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Wed, 18 Jan 2023 18:28:34 +0800 Subject: [PATCH 08/29] radix top-k: remove Store classes --- .../matrix/detail/select_radix_updated.cuh | 477 ++++-------------- 1 file changed, 107 insertions(+), 370 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 3dcd9664c1..039a5c9c79 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -152,59 +152,6 @@ __device__ void vectorized_process(const T* in, IdxT len, Func f) } } -// sync_width should >= WarpSize -template -__device__ void vectorized_process(const T* in, IdxT len, Func f, int sync_width) -{ - using WideT = float4; - const IdxT stride = blockDim.x * gridDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if constexpr (sizeof(T) >= sizeof(WideT)) { - for (IdxT i = tid; i < len; i += stride) { - f(in[i], i, true); - } - } else { - static_assert(sizeof(WideT) % sizeof(T) == 0); - constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); - union { - WideT scalar; - T array[items_per_scalar]; - } wide; - - int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) - ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) - : 0; - if (skip_cnt > len) { skip_cnt = len; } - const WideT* in_cast = reinterpret_cast(in + skip_cnt); - const IdxT len_cast = (len - skip_cnt) / items_per_scalar; - - const IdxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; - for (IdxT i = tid; i < len_cast_for_sync; i += stride) { - bool valid = i < len_cast; - if (valid) { wide.scalar = in_cast[i]; } - const IdxT real_i = skip_cnt + i * items_per_scalar; -#pragma unroll - for (int j = 0; j < items_per_scalar; ++j) { - f(wide.array[j], real_i + j, valid); - } - } - - static_assert(WarpSize >= items_per_scalar); - // need at most one warp for skipped and remained elements, - // and sync_width >= WarpSize - if (tid < sync_width) { - bool valid = tid < skip_cnt; - T value = valid ? in[tid] : T(); - f(value, tid, valid); - - const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; - valid = remain_i < len; - value = valid ? in[remain_i] : T(); - f(value, remain_i, valid); - } - } -} - template struct alignas(128) Counter { IdxT k; @@ -218,292 +165,102 @@ struct alignas(128) Counter { alignas(128) IdxT out_back_cnt; }; -// not actually used since the specialization for FilterAndHistogram doesn't use this -// implementation -template -class DirectStore { - public: - __device__ void store(T value, IdxT index, bool valid, T* out, IdxT* out_idx, IdxT* p_out_cnt) - { - if (!valid) { return; } - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - out[pos] = value; - out_idx[pos] = index; - } - - __device__ void flush(T*, IdxT*, IdxT*) {} -}; - -template -class BufferedStore { - public: - __device__ BufferedStore() - { - const int warp_id = threadIdx.x >> 5; - lane_id_ = threadIdx.x % WarpSize; - - __shared__ T value_smem[BlockSize]; - __shared__ IdxT index_smem[BlockSize]; - - value_smem_ = value_smem + (warp_id << 5); - index_smem_ = index_smem + (warp_id << 5); - warp_pos_ = 0; - } - - __device__ void store(T value, IdxT index, bool valid, T* out, IdxT* out_idx, IdxT* p_out_cnt) - { - unsigned int valid_mask = __ballot_sync(FULL_WARP_MASK_, valid); - if (valid_mask == 0) { return; } - - int pos = __popc(valid_mask & ((0x1u << lane_id_) - 1)) + warp_pos_; - if (valid && pos < WarpSize) { - value_smem_[pos] = value; - index_smem_[pos] = index; - } - - warp_pos_ += __popc(valid_mask); - // Check if the buffer is full - if (warp_pos_ >= WarpSize) { - IdxT pos_smem; - if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, static_cast(WarpSize)); } - pos_smem = __shfl_sync(FULL_WARP_MASK_, pos_smem, 0); - - __syncwarp(); - out[pos_smem + lane_id_] = value_smem_[lane_id_]; - out_idx[pos_smem + lane_id_] = index_smem_[lane_id_]; - __syncwarp(); - // Now the buffer is clean - if (valid && pos >= WarpSize) { - pos -= WarpSize; - value_smem_[pos] = value; - index_smem_[pos] = index; - } - - warp_pos_ -= WarpSize; - } - } - - __device__ void flush(T* out, IdxT* out_idx, IdxT* p_out_cnt) - { - if (warp_pos_ > 0) { - IdxT pos_smem; - if (lane_id_ == 0) { pos_smem = atomicAdd(p_out_cnt, static_cast(warp_pos_)); } - pos_smem = __shfl_sync(FULL_WARP_MASK_, pos_smem, 0); - - __syncwarp(); - if (lane_id_ < warp_pos_) { - out[pos_smem + lane_id_] = value_smem_[lane_id_]; - out_idx[pos_smem + lane_id_] = index_smem_[lane_id_]; - } - } - } - - private: - const unsigned FULL_WARP_MASK_{0xffffffff}; - T* value_smem_; - IdxT* index_smem_; - IdxT lane_id_; //@TODO: Can be const variable - int warp_pos_; -}; - -template - class Store> -class FilterAndHistogram { - public: - __device__ void operator()(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT previous_len, - Counter* counter, - IdxT* histogram, - bool select_min, - int pass, - bool early_stop) - { - constexpr int num_buckets = calc_num_buckets(); - __shared__ IdxT histogram_smem[num_buckets]; - for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { - histogram_smem[i] = 0; - } - Store store; - __syncthreads(); - - const int start_bit = calc_start_bit(pass); - const unsigned mask = calc_mask(pass); - - if (pass == 0) { - auto f = [select_min, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram_smem + bucket, static_cast(1)); - }; - vectorized_process(in_buf, previous_len, f); - } else { - IdxT* p_filter_cnt = &counter->filter_cnt; - IdxT* p_out_cnt = &counter->out_cnt; - const auto kth_value_bits = counter->kth_value_bits; - const int previous_start_bit = calc_start_bit(pass - 1); - - auto f = [in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - select_min, - start_bit, - mask, - previous_start_bit, - kth_value_bits, - p_filter_cnt, - p_out_cnt, - early_stop, - &store](T value, IdxT i, bool valid) { - const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) - << previous_start_bit; - - if (valid && previous_bits == kth_value_bits) { - if (early_stop) { - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - out[pos] = value; - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; - } else { - if (out_buf) { - IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); - out_buf[pos] = value; - out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; - } - - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram_smem + bucket, static_cast(1)); - } - } - - if (out_buf || early_stop) { - store.store(value, - in_idx_buf ? in_idx_buf[i] : i, - valid && previous_bits < kth_value_bits, - out, - out_idx, - p_out_cnt); - } - }; - vectorized_process(in_buf, previous_len, f, WarpSize); - store.flush(out, out_idx, p_out_cnt); - } - if (early_stop) { return; } - - __syncthreads(); - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } - } - } -}; - /** * Fused filtering of the current phase and building histogram for the next phase * (see steps 4-1 in `radix_kernel` description). */ -template -class FilterAndHistogram { - public: - __device__ void operator()(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT previous_len, - Counter* counter, - IdxT* histogram, - bool select_min, - int pass, - bool early_stop) - { - constexpr int num_buckets = calc_num_buckets(); - __shared__ IdxT histogram_smem[num_buckets]; - for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { - histogram_smem[i] = 0; - } - __syncthreads(); +template +__device__ void filter_and_histogram(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT previous_len, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass, + bool early_stop) +{ + constexpr int num_buckets = calc_num_buckets(); + __shared__ IdxT histogram_smem[num_buckets]; + for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram_smem[i] = 0; + } + __syncthreads(); - const int start_bit = calc_start_bit(pass); - const unsigned mask = calc_mask(pass); + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); - if (pass == 0) { - // Passed to vectorized_process, this function executes in all blocks in parallel, - // i.e. the work is split along the input (both, in batches and chunks of a single row). - // Later, the histograms are merged using atomicAdd. - auto f = [select_min, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram_smem + bucket, static_cast(1)); - }; - vectorized_process(in_buf, previous_len, f); - } else { - IdxT* p_filter_cnt = &counter->filter_cnt; - IdxT* p_out_cnt = &counter->out_cnt; - const auto kth_value_bits = counter->kth_value_bits; - const int previous_start_bit = calc_start_bit(pass - 1); - - // See the remark above on the distributed execution of `f` using vectorized_process. - auto f = [in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - select_min, - start_bit, - mask, - previous_start_bit, - kth_value_bits, - p_filter_cnt, - p_out_cnt, - early_stop](T value, IdxT i) { - const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) - << previous_start_bit; - if (previous_bits == kth_value_bits) { - if (early_stop) { - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - out[pos] = value; - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; - } else { - if (out_buf) { - IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); - out_buf[pos] = value; - out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; - } - - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram_smem + bucket, static_cast(1)); - } - } - // '(out_buf || early_stop)' is a little tricky: - // If we skip writing to 'out_buf' (when 'out_buf' is false), we should skip writing to - // 'out' too. So we won't write the same value to 'out' multiple times. And if we keep - // skipping the writing, values will be written in last_filter_kernel at last. - // But when 'early_stop' is true, we need to write to 'out' since it's the last chance. - else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + if (pass == 0) { + // Passed to vectorized_process, this function executes in all blocks in parallel, + // i.e. the work is split along the input (both, in batches and chunks of a single row). + // Later, the histograms are merged using atomicAdd. + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); + }; + vectorized_process(in_buf, previous_len, f); + } else { + IdxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + // See the remark above on the distributed execution of `f` using vectorized_process. + auto f = [in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + select_min, + start_bit, + mask, + previous_start_bit, + kth_value_bits, + p_filter_cnt, + p_out_cnt, + early_stop](T value, IdxT i) { + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + if (early_stop) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); } - }; - vectorized_process(in_buf, previous_len, f); - } - if (early_stop) { return; } - __syncthreads(); + } + // '(out_buf || early_stop)' is a little tricky: + // If we skip writing to 'out_buf' (when 'out_buf' is false), we should skip writing to + // 'out' too. So we won't write the same value to 'out' multiple times. And if we keep + // skipping the writing, values will be written in last_filter_kernel at last. + // But when 'early_stop' is true, we need to write to 'out' since it's the last chance. + else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + }; + vectorized_process(in_buf, previous_len, f); + } + if (early_stop) { return; } + __syncthreads(); - // merge histograms produced by individual blocks - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } - } + // merge histograms produced by individual blocks + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } } -}; +} /** * Replace a part of the histogram with its own prefix sum, starting from the `start` and adding @@ -709,13 +466,7 @@ __global__ void last_filter_kernel(const T* in, * In the implementation, the filtering step is delayed to the next pass so the filtering and * histogram computation are fused. In this way, inputs are read once rather than twice. */ -template - class Store> +template __global__ void radix_kernel(const T* in, const IdxT* in_idx, const T* in_buf, @@ -780,18 +531,18 @@ __global__ void radix_kernel(const T* in, } auto histogram = histograms + batch_id * num_buckets; - FilterAndHistogram()(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - previous_len, - counter, - histogram, - select_min, - pass, - early_stop); + filter_and_histogram(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + previous_len, + counter, + histogram, + select_min, + pass, + early_stop); __threadfence(); bool isLastBlock = false; @@ -837,12 +588,7 @@ __global__ void radix_kernel(const T* in, } } -template - class Store> +template unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) { static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); @@ -850,8 +596,8 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &active_blocks, - use_dynamic ? radix_kernel - : radix_kernel, + use_dynamic ? radix_kernel + : radix_kernel, BlockSize, 0)); active_blocks *= sm_cnt; @@ -885,12 +631,7 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) return best_num_blocks; } -template - class Store> +template void radix_topk(const T* in, const IdxT* in_idx, int batch_size, @@ -942,9 +683,8 @@ void radix_topk(const T* in, RAFT_CUDA_TRY(cudaGetDevice(&dev)); RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); } - dim3 blocks( - calc_grid_dim(batch_size, len, sm_cnt, use_dynamic), - batch_size); + dim3 blocks(calc_grid_dim(batch_size, len, sm_cnt, use_dynamic), + batch_size); constexpr int num_passes = calc_num_passes(); @@ -972,7 +712,7 @@ void radix_topk(const T* in, } if (!use_dynamic) { - radix_kernel + radix_kernel <<>>(in, in_idx, in_buf, @@ -988,7 +728,7 @@ void radix_topk(const T* in, select_min, pass); } else { - radix_kernel + radix_kernel <<>>(in, in_idx, in_buf, @@ -1038,7 +778,7 @@ __device__ void filter_and_histogram(const T* in_buf, const IdxT previous_len = counter->previous_len; if (pass == 0) { - // Could not use vectorized_process() as in FilterAndHistogram because + // Could not use vectorized_process() as in filter_and_histogram because // vectorized_process() assumes multi-block, e.g. uses gridDim.x for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { T value = in_buf[i]; @@ -1281,11 +1021,8 @@ void radix_topk_updated(const T* in, if (len <= BlockSize * items_per_thread) { radix_impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); - } else if (len < 100.0 * k / batch_size + 0.01) { - radix_impl::radix_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); } else { - radix_impl::radix_topk( + radix_impl::radix_topk( in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); } } From 381c0759ce8f60bd244e42359733d724215a4448 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Thu, 19 Jan 2023 11:25:07 +0800 Subject: [PATCH 09/29] radix top-k: polish code comments --- .../matrix/detail/select_radix_updated.cuh | 106 ++++++++++-------- 1 file changed, 60 insertions(+), 46 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 039a5c9c79..fb3e0da08c 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -34,7 +34,6 @@ namespace raft::spatial::knn::detail::topk { namespace radix_impl { constexpr int VECTORIZED_READ_SIZE = 16; -constexpr int LAZY_WRITING_FACTOR = 4; template __host__ __device__ constexpr int calc_num_buckets() @@ -72,7 +71,7 @@ __device__ constexpr unsigned calc_mask(int pass) } /** - * Use cub to twiddle bits - so that we can correctly compare bits of floating-point values as well + * Use CUB to twiddle bits - so that we can correctly compare bits of floating-point values as well * as of integers. */ template @@ -100,6 +99,16 @@ __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) return (twiddle_in(x, select_min) >> start_bit) & mask; } +template +__device__ bool use_lazy_writing(IdxT original_len, IdxT len) +{ + // When using lazy writing, only read `in`(type T). + // When not using it, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) and + // `out_idx_buf`(IdxT). + constexpr float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); + return len * ratio > original_len; +} + /** * Map a Func over the input data, using vectorized load instructions if possible. * @@ -166,8 +175,8 @@ struct alignas(128) Counter { }; /** - * Fused filtering of the current phase and building histogram for the next phase - * (see steps 4-1 in `radix_kernel` description). + * Fused filtering of the current pass and building histogram for the next pass + * (see steps 4 & 1 in `radix_kernel` description). */ template __device__ void filter_and_histogram(const T* in_buf, @@ -240,11 +249,12 @@ __device__ void filter_and_histogram(const T* in_buf, atomicAdd(histogram_smem + bucket, static_cast(1)); } } - // '(out_buf || early_stop)' is a little tricky: - // If we skip writing to 'out_buf' (when 'out_buf' is false), we should skip writing to - // 'out' too. So we won't write the same value to 'out' multiple times. And if we keep - // skipping the writing, values will be written in last_filter_kernel at last. - // But when 'early_stop' is true, we need to write to 'out' since it's the last chance. + // the condition `(out_buf || early_stop)` is a little tricky: + // If we skip writing to `out_buf` (when `out_buf` is false), we should skip writing to + // `out` too. So we won't write the same value to `out` multiple times in different passes. + // And if we keep skipping the writing, values will be written in `last_filter_kernel` at + // last. + // But when `early_stop` is true, we need to write to `out` since it's the last chance. else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; @@ -263,8 +273,7 @@ __device__ void filter_and_histogram(const T* in_buf, } /** - * Replace a part of the histogram with its own prefix sum, starting from the `start` and adding - * `current` to each entry of the result. + * Replace histogram with its own prefix sum * (step 2 in `radix_kernel` description) */ template @@ -322,10 +331,10 @@ __device__ void choose_bucket(Counter* counter, IdxT prev = (i == 0) ? 0 : histogram[i - 1]; IdxT cur = histogram[i]; - // one and only one thread will satisfy this condition, so only write once + // one and only one thread will satisfy this condition, so counter is written by only one thread if (prev < k && cur >= k) { counter->k = k - prev; // how many values still are there to find - counter->len = cur - prev; // number of values in `index` bucket + counter->len = cur - prev; // number of values in next pass typename cub::Traits::UnsignedBits bucket = i; int start_bit = calc_start_bit(pass); counter->kth_value_bits |= bucket << start_bit; @@ -334,7 +343,7 @@ __device__ void choose_bucket(Counter* counter, } // For one-block version, last_filter() could be called when pass < num_passes - 1. -// So pass could not be constexpr +// So `pass` could not be constexpr template __device__ void last_filter(const T* out_buf, const IdxT* out_idx_buf, @@ -349,7 +358,7 @@ __device__ void last_filter(const T* out_buf, const auto kth_value_bits = counter->kth_value_bits; const int start_bit = calc_start_bit(pass); - // changed in choose_bucket(), need to reload + // changed in choose_bucket(); need to reload const IdxT needed_num_of_kth = counter->k; IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; @@ -359,9 +368,8 @@ __device__ void last_filter(const T* out_buf, if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; - // for one-block version, 'out_idx_buf' could be nullptr at pass 0; - // and for dynamic version, 'out_idx_buf' could be nullptr if 'out_buf' is - // 'in' + // For one-block version, `out_idx_buf` could be nullptr at pass 0. + // And for dynamic version, `out_idx_buf` could be nullptr if `out_buf` is `in` out_idx[pos] = out_idx_buf ? out_idx_buf[i] : i; } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); @@ -374,6 +382,7 @@ __device__ void last_filter(const T* out_buf, } } +// used only for dynamic version template __global__ void last_filter_kernel(const T* in, const IdxT* in_idx, @@ -391,7 +400,7 @@ __global__ void last_filter_kernel(const T* in, Counter* counter = counters + batch_id; IdxT previous_len = counter->previous_len; if (previous_len == 0) { return; } - if (previous_len > len / LAZY_WRITING_FACTOR) { + if (use_lazy_writing(len, previous_len)) { in_buf = in; in_idx_buf = in_idx; previous_len = len; @@ -492,7 +501,7 @@ __global__ void radix_kernel(const T* in, previous_len = len; // Need to do this so setting counter->previous_len for the next pass is correct. // This value is meaningless for pass 0, but it's fine because pass 0 won't be the - // last pass in current implementation so pass 0 won't hit the "if (pass == + // last pass in this implementation so pass 0 won't hit the "if (pass == // num_passes - 1)" branch. // Maybe it's better to reload counter->previous_len and use it rather than // current_len in last_filter() @@ -510,13 +519,13 @@ __global__ void radix_kernel(const T* in, if constexpr (use_dynamic) { // Figure out if the previous pass writes buffer - if (previous_len > len / LAZY_WRITING_FACTOR) { + if (use_lazy_writing(len, previous_len)) { previous_len = len; in_buf = in; in_idx_buf = in_idx; } // Figure out if this pass need to write buffer - if (current_len > len / LAZY_WRITING_FACTOR) { + if (use_lazy_writing(len, current_len)) { out_buf = nullptr; out_idx_buf = nullptr; } @@ -554,7 +563,7 @@ __global__ void radix_kernel(const T* in, if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { - // last_filter_kernel from dynamic version requires setting previous_len + // last_filter_kernel from the dynamic version requires setting previous_len counter->previous_len = 0; counter->len = 0; } @@ -579,6 +588,8 @@ __global__ void radix_kernel(const T* in, counter->filter_cnt = 0; } + // For non-dynamic version, we do the last filtering using the last thread block. + // For dynamic version, we'll use a multi-block kernel (last_filter_kernel). if constexpr (!use_dynamic) { if (pass == num_passes - 1) { last_filter( @@ -753,17 +764,20 @@ void radix_topk(const T* in, } } +// The following a few functions are for the one-block version, which uses single thread block for +// each row of a batch. It's used when len is relatively small, so intermediate data, like counters +// and histograms, can be kept in shared memory and cheap sync operations can be used. template -__device__ void filter_and_histogram(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counter, - IdxT* histogram, - bool select_min, - int pass) +__device__ void filter_and_histogram_for_one_block(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass) { constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { @@ -778,7 +792,7 @@ __device__ void filter_and_histogram(const T* in_buf, const IdxT previous_len = counter->previous_len; if (pass == 0) { - // Could not use vectorized_process() as in filter_and_histogram because + // Could not use vectorized_process() as in filter_and_histogram() because // vectorized_process() assumes multi-block, e.g. uses gridDim.x for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { T value = in_buf[i]; @@ -876,16 +890,16 @@ __global__ void radix_topk_one_block_kernel(const T* in, IdxT current_len = counter.len; IdxT current_k = counter.k; - filter_and_histogram(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - &counter, - histogram, - select_min, - pass); + filter_and_histogram_for_one_block(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + &counter, + histogram, + select_min, + pass); __syncthreads(); scan(histogram); @@ -997,8 +1011,8 @@ void radix_topk_one_block(const T* in, * @param select_min * whether to select k smallest (true) or largest (false) keys. * @param use_dynamic - * whether to use the dynamic implementation, which is favorable if the leading bits of input data - * are almost the same. + * whether to use the dynamic implementation, which is favorable if the most significant bits of + * input data are almost the same. That is, when the value range of input data is narrow. * @param stream * @param mr an optional memory resource to use across the calls (you can provide a large enough * memory pool here to avoid memory allocations within the call). From d20a480b0f2c44733f0c4e0a353657960510cd2d Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Thu, 19 Jan 2023 11:56:20 +0800 Subject: [PATCH 10/29] radix top-k: change dynamic to adaptive --- .../matrix/detail/select_radix_updated.cuh | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index fb3e0da08c..a56fcc20ce 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -369,7 +369,7 @@ __device__ void last_filter(const T* out_buf, IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; // For one-block version, `out_idx_buf` could be nullptr at pass 0. - // And for dynamic version, `out_idx_buf` could be nullptr if `out_buf` is `in` + // And for adaptive version, `out_idx_buf` could be nullptr if `out_buf` is `in` out_idx[pos] = out_idx_buf ? out_idx_buf[i] : i; } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); @@ -382,7 +382,7 @@ __device__ void last_filter(const T* out_buf, } } -// used only for dynamic version +// used only for adaptive version template __global__ void last_filter_kernel(const T* in, const IdxT* in_idx, @@ -475,7 +475,7 @@ __global__ void last_filter_kernel(const T* in, * In the implementation, the filtering step is delayed to the next pass so the filtering and * histogram computation are fused. In this way, inputs are read once rather than twice. */ -template +template __global__ void radix_kernel(const T* in, const IdxT* in_idx, const T* in_buf, @@ -517,7 +517,7 @@ __global__ void radix_kernel(const T* in, constexpr int num_buckets = calc_num_buckets(); constexpr int num_passes = calc_num_passes(); - if constexpr (use_dynamic) { + if constexpr (adaptive) { // Figure out if the previous pass writes buffer if (use_lazy_writing(len, previous_len)) { previous_len = len; @@ -563,7 +563,7 @@ __global__ void radix_kernel(const T* in, if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { - // last_filter_kernel from the dynamic version requires setting previous_len + // last_filter_kernel from the adaptive version requires setting previous_len counter->previous_len = 0; counter->len = 0; } @@ -588,9 +588,9 @@ __global__ void radix_kernel(const T* in, counter->filter_cnt = 0; } - // For non-dynamic version, we do the last filtering using the last thread block. - // For dynamic version, we'll use a multi-block kernel (last_filter_kernel). - if constexpr (!use_dynamic) { + // For non-adaptive version, we do the last filtering using the last thread block. + // For adaptive version, we'll use a multi-block kernel (last_filter_kernel). + if constexpr (!adaptive) { if (pass == num_passes - 1) { last_filter( out_buf, out_idx_buf, out, out_idx, current_len, k, counter, select_min, pass); @@ -600,15 +600,15 @@ __global__ void radix_kernel(const T* in, } template -unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool use_dynamic) +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool adaptive) { static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &active_blocks, - use_dynamic ? radix_kernel - : radix_kernel, + adaptive ? radix_kernel + : radix_kernel, BlockSize, 0)); active_blocks *= sm_cnt; @@ -651,7 +651,7 @@ void radix_topk(const T* in, T* out, IdxT* out_idx, bool select_min, - bool use_dynamic, + bool adaptive, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -694,7 +694,7 @@ void radix_topk(const T* in, RAFT_CUDA_TRY(cudaGetDevice(&dev)); RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); } - dim3 blocks(calc_grid_dim(batch_size, len, sm_cnt, use_dynamic), + dim3 blocks(calc_grid_dim(batch_size, len, sm_cnt, adaptive), batch_size); constexpr int num_passes = calc_num_passes(); @@ -722,7 +722,7 @@ void radix_topk(const T* in, out_idx_buf = idx_buf1.data(); } - if (!use_dynamic) { + if (!adaptive) { radix_kernel <<>>(in, in_idx, @@ -757,7 +757,7 @@ void radix_topk(const T* in, } } - if (use_dynamic) { + if (adaptive) { dim3 blocks((len / (VECTORIZED_READ_SIZE / sizeof(T)) - 1) / BlockSize + 1, batch_size); last_filter_kernel<<>>( in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters.data(), select_min); @@ -1010,9 +1010,9 @@ void radix_topk_one_block(const T* in, * the payload selected together with `out`. * @param select_min * whether to select k smallest (true) or largest (false) keys. - * @param use_dynamic - * whether to use the dynamic implementation, which is favorable if the most significant bits of - * input data are almost the same. That is, when the value range of input data is narrow. + * @param adaptive + * whether to use the adaptive implementation, which is preferable when the most significant bits + * of input data are almost the same. That is, when the value range of input data is narrow. * @param stream * @param mr an optional memory resource to use across the calls (you can provide a large enough * memory pool here to avoid memory allocations within the call). @@ -1026,7 +1026,7 @@ void radix_topk_updated(const T* in, T* out, IdxT* out_idx, bool select_min, - bool use_dynamic, + bool adaptive, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = nullptr) { @@ -1037,7 +1037,7 @@ void radix_topk_updated(const T* in, in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); } else { radix_impl::radix_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, use_dynamic, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, adaptive, stream, mr); } } From c4082450d1691f751ca1ec868d6e9e1ddda8fd29 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Mon, 23 Jan 2023 18:58:16 +0800 Subject: [PATCH 11/29] modify radix top-k so that it conforms the latest select_k code --- .../matrix/detail/select_radix_updated.cuh | 163 +++++++++--------- 1 file changed, 86 insertions(+), 77 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index a56fcc20ce..6fb2cf9975 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -30,19 +31,19 @@ #include #include -namespace raft::spatial::knn::detail::topk { -namespace radix_impl { +namespace raft::matrix::detail::select::radix { +namespace impl { constexpr int VECTORIZED_READ_SIZE = 16; template -__host__ __device__ constexpr int calc_num_buckets() +_RAFT_HOST_DEVICE constexpr int calc_num_buckets() { return 1 << BitsPerPass; } template -__host__ __device__ constexpr int calc_num_passes() +_RAFT_HOST_DEVICE constexpr int calc_num_passes() { return ceildiv(sizeof(T) * 8, BitsPerPass); } @@ -55,7 +56,7 @@ __host__ __device__ constexpr int calc_num_passes() * NB: Use pass=-1 for calc_mask(). */ template -__device__ constexpr int calc_start_bit(int pass) +_RAFT_DEVICE constexpr int calc_start_bit(int pass) { int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; if (start_bit < 0) { start_bit = 0; } @@ -63,7 +64,7 @@ __device__ constexpr int calc_start_bit(int pass) } template -__device__ constexpr unsigned calc_mask(int pass) +_RAFT_DEVICE constexpr unsigned calc_mask(int pass) { static_assert(BitsPerPass <= 31); int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); @@ -75,7 +76,7 @@ __device__ constexpr unsigned calc_mask(int pass) * as of integers. */ template -__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); @@ -84,7 +85,7 @@ __device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_m } template -__device__ T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) +_RAFT_DEVICE T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) { if (!select_min) { bits = ~bits; } bits = cub::Traits::TwiddleOut(bits); @@ -92,7 +93,7 @@ __device__ T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select } template -__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) { static_assert(BitsPerPass <= sizeof(int) * 8 - 1, "BitsPerPass is too large that the result type could not be int"); @@ -100,7 +101,7 @@ __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) } template -__device__ bool use_lazy_writing(IdxT original_len, IdxT len) +_RAFT_DEVICE bool use_lazy_writing(IdxT original_len, IdxT len) { // When using lazy writing, only read `in`(type T). // When not using it, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) and @@ -124,7 +125,7 @@ __device__ bool use_lazy_writing(IdxT original_len, IdxT len) * @param f the lambda taking two arguments (T x, IdxT idx) */ template -__device__ void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) { const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -179,18 +180,18 @@ struct alignas(128) Counter { * (see steps 4 & 1 in `radix_kernel` description). */ template -__device__ void filter_and_histogram(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT previous_len, - Counter* counter, - IdxT* histogram, - bool select_min, - int pass, - bool early_stop) +_RAFT_DEVICE void filter_and_histogram(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT previous_len, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass, + bool early_stop) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -277,7 +278,7 @@ __device__ void filter_and_histogram(const T* in_buf, * (step 2 in `radix_kernel` description) */ template -__device__ void scan(volatile IdxT* histogram) +_RAFT_DEVICE void scan(volatile IdxT* histogram) { constexpr int num_buckets = calc_num_buckets(); if constexpr (num_buckets >= BlockSize) { @@ -321,10 +322,10 @@ __device__ void scan(volatile IdxT* histogram) * (steps 3 in `radix_kernel` description) */ template -__device__ void choose_bucket(Counter* counter, - const IdxT* histogram, - const IdxT k, - const int pass) +_RAFT_DEVICE void choose_bucket(Counter* counter, + const IdxT* histogram, + const IdxT k, + const int pass) { constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { @@ -345,15 +346,15 @@ __device__ void choose_bucket(Counter* counter, // For one-block version, last_filter() could be called when pass < num_passes - 1. // So `pass` could not be constexpr template -__device__ void last_filter(const T* out_buf, - const IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT current_len, - IdxT k, - Counter* counter, - const bool select_min, - const int pass) +_RAFT_DEVICE void last_filter(const T* out_buf, + const IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT current_len, + IdxT k, + Counter* counter, + const bool select_min, + const int pass) { const auto kth_value_bits = counter->kth_value_bits; const int start_bit = calc_start_bit(pass); @@ -659,25 +660,33 @@ void radix_topk(const T* in, static_assert(calc_num_passes() > 1); constexpr int num_buckets = calc_num_buckets(); - auto pool_guard = - raft::get_pool_memory_resource(mr, - batch_size * (sizeof(Counter) // counters - + sizeof(IdxT) * num_buckets // histograms - + sizeof(T) * len * 2 // T bufs - + sizeof(IdxT) * len * 2 // IdxT bufs - ) + - 256 * 6); // might need extra memory for alignment + size_t req_aux = batch_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = batch_size * len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf; + size_t mem_free, mem_total; + RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); + std::optional managed_memory; + rmm::mr::device_memory_resource* mr_buf = nullptr; + if (mem_req > mem_free) { + // if there's not enough memory for buffers on the device, resort to the managed memory. + mem_req = req_aux; + managed_memory.emplace(); + mr_buf = &managed_memory.value(); + } + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); if (pool_guard) { - RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } + if (mr_buf == nullptr) { mr_buf = mr; } rmm::device_uvector> counters(batch_size, stream, mr); - rmm::device_uvector histograms(num_buckets * batch_size, stream, mr); - rmm::device_uvector buf1(len * batch_size, stream, mr); - rmm::device_uvector idx_buf1(len * batch_size, stream, mr); - rmm::device_uvector buf2(len * batch_size, stream, mr); - rmm::device_uvector idx_buf2(len * batch_size, stream, mr); + rmm::device_uvector histograms(batch_size * num_buckets, stream, mr); + rmm::device_uvector buf1(batch_size * len, stream, mr_buf); + rmm::device_uvector idx_buf1(batch_size * len, stream, mr_buf); + rmm::device_uvector buf2(batch_size * len, stream, mr_buf); + rmm::device_uvector idx_buf2(batch_size * len, stream, mr_buf); RAFT_CUDA_TRY( cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); @@ -768,16 +777,16 @@ void radix_topk(const T* in, // each row of a batch. It's used when len is relatively small, so intermediate data, like counters // and histograms, can be kept in shared memory and cheap sync operations can be used. template -__device__ void filter_and_histogram_for_one_block(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counter, - IdxT* histogram, - bool select_min, - int pass) +_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass) { constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { @@ -945,7 +954,7 @@ void radix_topk_one_block(const T* in, ) + 256 * 4); // might need extra memory for alignment if (pool_guard) { - RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } @@ -968,7 +977,7 @@ void radix_topk_one_block(const T* in, idx_buf2.data()); } -} // namespace radix_impl +} // namespace impl /** * Select k smallest or largest key/values from each row in the input data. @@ -1018,27 +1027,27 @@ void radix_topk_one_block(const T* in, * memory pool here to avoid memory allocations within the call). */ template -void radix_topk_updated(const T* in, - const IdxT* in_idx, - int batch_size, - IdxT len, - IdxT k, - T* out, - IdxT* out_idx, - bool select_min, - bool adaptive, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k_updated(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + bool adaptive, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { constexpr int items_per_thread = 32; if (len <= BlockSize * items_per_thread) { - radix_impl::radix_topk_one_block( + impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); } else { - radix_impl::radix_topk( + impl::radix_topk( in, in_idx, batch_size, len, k, out, out_idx, select_min, adaptive, stream, mr); } } -} // namespace raft::spatial::knn::detail::topk +} // namespace raft::matrix::detail::select::radix From 53ebcb87086c446eccfd54ef65913d1bf7a92357 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Tue, 24 Jan 2023 14:54:48 +0800 Subject: [PATCH 12/29] fix the case when k equals len --- .../matrix/detail/select_radix_updated.cuh | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 6fb2cf9975..49200aace7 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -513,6 +513,10 @@ __global__ void radix_kernel(const T* in, previous_len = counter->previous_len; } if (current_len == 0) { return; } + + // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle + // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is + // handled in other way in select_k() so such case is not possible here. bool early_stop = (current_len == current_k); constexpr int num_buckets = calc_num_buckets(); @@ -977,6 +981,16 @@ void radix_topk_one_block(const T* in, idx_buf2.data()); } +template +__global__ void fill_idx_kernel(int batch_size, IdxT len, IdxT* out_idx) +{ + const int batch_i = blockIdx.y; + const int stride = blockDim.x * gridDim.x; + for (IdxT i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += stride) { + out_idx[batch_i * len + i] = i; + } +} + } // namespace impl /** @@ -1039,6 +1053,20 @@ void select_k_updated(const T* in, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = nullptr) { + if (k == len) { + RAFT_CUDA_TRY( + cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + if (in_idx) { + RAFT_CUDA_TRY(cudaMemcpyAsync( + out_idx, in_idx, sizeof(IdxT) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + } else { + constexpr int block_dim = 256; + dim3 grid_dim((len - 1) / block_dim + 1, batch_size, 1); + impl::fill_idx_kernel<<>>(batch_size, len, out_idx); + } + return; + } + constexpr int items_per_thread = 32; if (len <= BlockSize * items_per_thread) { From fdd30e93525255657449877e2eb770aa6d801055 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Mon, 23 Jan 2023 15:57:27 +0800 Subject: [PATCH 13/29] radix top-k: update tests and benchmarks --- cpp/bench/matrix/select_k.cu | 134 ++++++++++++++++++++++++++--------- cpp/test/matrix/select_k.cu | 20 +++++- cpp/test/matrix/select_k.cuh | 46 +++++++++++- 3 files changed, 161 insertions(+), 39 deletions(-) diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu index 452a50ba50..252d2bea4f 100644 --- a/cpp/bench/matrix/select_k.cu +++ b/cpp/bench/matrix/select_k.cu @@ -36,6 +36,10 @@ #include #include +#include +#include +#include + namespace raft::matrix { using namespace raft::bench; // NOLINT @@ -51,7 +55,23 @@ struct selection : public fixture { { raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); raft::random::RngState state{42}; - raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); + + KeyT min_value = -1.0; + KeyT max_value = 1.0; + if (p.use_same_leading_bits) { + if constexpr (std::is_same_v) { + uint32_t min_bits = 0x3F800000; // 1.0 + uint32_t max_bits = 0x3F8000FF; // 1.00003 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } else if constexpr (std::is_same_v) { + uint64_t min_bits = 0x3FF0000000000000; // 1.0 + uint64_t max_bits = 0x3FF0000FFFFFFFFF; // 1.000015 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } + } + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value); } void run_benchmark(::benchmark::State& state) override // NOLINT @@ -61,6 +81,7 @@ struct selection : public fixture { try { std::ostringstream label_stream; label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; } state.SetLabel(label_stream.str()); loop_on_state(state, [this, &handle]() { select::select_k_impl(handle, @@ -86,21 +107,55 @@ struct selection : public fixture { }; const std::vector kInputs{ - {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, - {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, - {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, - - {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, - {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, - {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, - - {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, - {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, - {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, - - {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, - {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, - {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, + {20000, 500, 1, true}, + {20000, 500, 2, true}, + {20000, 500, 4, true}, + {20000, 500, 8, true}, + {20000, 500, 16, true}, + {20000, 500, 32, true}, + {20000, 500, 64, true}, + {20000, 500, 128, true}, + {20000, 500, 256, true}, + + {1000, 10000, 1, true}, + {1000, 10000, 2, true}, + {1000, 10000, 4, true}, + {1000, 10000, 8, true}, + {1000, 10000, 16, true}, + {1000, 10000, 32, true}, + {1000, 10000, 64, true}, + {1000, 10000, 128, true}, + {1000, 10000, 256, true}, + + {100, 100000, 1, true}, + {100, 100000, 2, true}, + {100, 100000, 4, true}, + {100, 100000, 8, true}, + {100, 100000, 16, true}, + {100, 100000, 32, true}, + {100, 100000, 64, true}, + {100, 100000, 128, true}, + {100, 100000, 256, true}, + + {10, 1000000, 1, true}, + {10, 1000000, 2, true}, + {10, 1000000, 4, true}, + {10, 1000000, 8, true}, + {10, 1000000, 16, true}, + {10, 1000000, 32, true}, + {10, 1000000, 64, true}, + {10, 1000000, 128, true}, + {10, 1000000, 256, true}, + + {10, 1000000, 1, true, false, true}, + {10, 1000000, 2, true, false, true}, + {10, 1000000, 4, true, false, true}, + {10, 1000000, 8, true, false, true}, + {10, 1000000, 16, true, false, true}, + {10, 1000000, 32, true, false, true}, + {10, 1000000, 64, true, false, true}, + {10, 1000000, 128, true, false, true}, + {10, 1000000, 256, true, false, true}, }; #define SELECTION_REGISTER(KeyT, IdxT, A) \ @@ -110,24 +165,33 @@ const std::vector kInputs{ RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ } -SELECTION_REGISTER(float, int, kPublicApi); // NOLINT -SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT -SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT -SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT -SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT -SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT -SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT -SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT - -SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT - -SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT +SELECTION_REGISTER(float, int, kPublicApi); // NOLINT +SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, int, kRadix8bitsUpdated); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bitsUpdated); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bitsAdaptive); // NOLINT +SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int, kRadix8bitsUpdated); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bitsUpdated); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bitsAdaptive); // NOLINT +SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix8bitsUpdated); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bitsUpdated); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bitsAdaptive); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT } // namespace raft::matrix diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index cb92c15790..f04d2c20d3 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -327,6 +327,9 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Values(select::Algo::kPublicApi, select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix8bitsUpdated, + select::Algo::kRadix11bitsUpdated, + select::Algo::kRadix11bitsAdaptive, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed))); @@ -421,6 +424,9 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix8bitsUpdated, + select::Algo::kRadix11bitsUpdated, + select::Algo::kRadix11bitsAdaptive, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -435,6 +441,9 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix8bitsUpdated, + select::Algo::kRadix11bitsUpdated, + select::Algo::kRadix11bitsAdaptive, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -446,7 +455,11 @@ TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomDoubleInt, - testing::Combine(inputs_random_largesize, testing::Values(select::Algo::kWarpAuto))); + testing::Combine(inputs_random_largesize, + testing::Values(select::Algo::kWarpAuto, + select::Algo::kRadix8bitsUpdated, + select::Algo::kRadix11bitsUpdated, + select::Algo::kRadix11bitsAdaptive))); using ReferencedRandomFloatSizeT = SelectK::params_random>; @@ -454,6 +467,9 @@ TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, testing::Combine(inputs_random_largek, - testing::Values(select::Algo::kRadix11bits))); + testing::Values(select::Algo::kRadix11bits, + select::Algo::kRadix8bitsUpdated, + select::Algo::kRadix11bitsUpdated, + select::Algo::kRadix11bitsAdaptive))); } // namespace raft::matrix diff --git a/cpp/test/matrix/select_k.cuh b/cpp/test/matrix/select_k.cuh index ee79b1ff80..08ed9a2ff8 100644 --- a/cpp/test/matrix/select_k.cuh +++ b/cpp/test/matrix/select_k.cuh @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -27,7 +28,8 @@ struct params { size_t len; int k; bool select_min; - bool use_index_input = true; + bool use_index_input = true; + bool use_same_leading_bits = false; }; inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& @@ -36,7 +38,8 @@ inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& os << ", len: " << ss.len; os << ", k: " << ss.k; os << (ss.select_min ? ", asc" : ", dsc"); - os << (ss.use_index_input ? "}" : ", no-input-index}"); + os << (ss.use_index_input ? "" : ", no-input-index"); + os << (ss.use_same_leading_bits ? ", same-leading-bits}" : "}"); return os; } @@ -44,6 +47,9 @@ enum class Algo { kPublicApi, kRadix8bits, kRadix11bits, + kRadix8bitsUpdated, + kRadix11bitsUpdated, + kRadix11bitsAdaptive, kWarpAuto, kWarpImmediate, kWarpFiltered, @@ -57,6 +63,9 @@ inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream& case Algo::kPublicApi: return os << "kPublicApi"; case Algo::kRadix8bits: return os << "kRadix8bits"; case Algo::kRadix11bits: return os << "kRadix11bits"; + case Algo::kRadix8bitsUpdated: return os << "kRadix8bitsUpdated"; + case Algo::kRadix11bitsUpdated: return os << "kRadix11bitsUpdated"; + case Algo::kRadix11bitsAdaptive: return os << "kRadix11bitsAdaptive"; case Algo::kWarpAuto: return os << "kWarpAuto"; case Algo::kWarpImmediate: return os << "kWarpImmediate"; case Algo::kWarpFiltered: return os << "kWarpFiltered"; @@ -102,6 +111,39 @@ void select_k_impl(const handle_t& handle, case Algo::kRadix11bits: return detail::select::radix::select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + case Algo::kRadix8bitsUpdated: + return detail::select::radix::select_k_updated(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + false, // adaptive + stream); + case Algo::kRadix11bitsUpdated: + return detail::select::radix::select_k_updated(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + false, // adaptive + stream); + case Algo::kRadix11bitsAdaptive: + return detail::select::radix::select_k_updated(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + true, // adaptive + stream); case Algo::kWarpAuto: return detail::select::warpsort::select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); From 87cd66ece86b203b1acc3d559f66155279923c53 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Sun, 12 Feb 2023 13:27:54 +0800 Subject: [PATCH 14/29] add comments and revise code --- .../matrix/detail/select_radix_updated.cuh | 175 +++++++++++------- 1 file changed, 105 insertions(+), 70 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 49200aace7..469ce114f0 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -16,9 +16,11 @@ #pragma once -#include #include #include +#include +#include +#include #include #include #include @@ -164,14 +166,39 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) template struct alignas(128) Counter { + // We are processing the values in multiple passes, from most significant to least significant. In + // each pass, we keep the length of input (`len`) and the `k` of current pass, and update them at + // the end of the pass. IdxT k; IdxT len; + + // `previous_len` is the length of input in previous pass. Note that `previous_len` rather + // than `len` is used for the filtering step because filtering is indeed for previous pass (see + // comments before `radix_kernel`). IdxT previous_len; + + // We determine the bits of the k_th value inside the mask processed by the pass. The + // already known bits are stored in `kth_value_bits`. It's used to discriminate a element is a + // result (written to `out`), a candidate for next pass (written to `out_buf`), or not useful + // (discarded). The bits that are not yet processed do not matter for this purpose. typename cub::Traits::UnsignedBits kth_value_bits; + // Record how many elements have passed filtering. It's used to determine the position in the + // `out_buf` where an element should be written to. alignas(128) IdxT filter_cnt; + + // For a row inside a batch, we may launch multiple thread blocks. This counter is used to + // determine if the current block is the last running block. If so, this block will execute scan() + // and choose_bucket(). alignas(128) unsigned int finished_block_cnt; + + // Record how many elements have been written to the front of `out`. Elements less (if + // select_min==true) than the k-th value are written from front to back. alignas(128) IdxT out_cnt; + + // Record how many elements have been written to the back of `out`. Elements equal to the k-th + // value are written from back to front. We need to keep count of them separately because the + // number of elements that <= the k-th value might exceed k. alignas(128) IdxT out_back_cnt; }; @@ -370,7 +397,7 @@ _RAFT_DEVICE void last_filter(const T* out_buf, IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; // For one-block version, `out_idx_buf` could be nullptr at pass 0. - // And for adaptive version, `out_idx_buf` could be nullptr if `out_buf` is `in` + // And for adaptive mode, `out_idx_buf` could be nullptr if `out_buf` is `in` out_idx[pos] = out_idx_buf ? out_idx_buf[i] : i; } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); @@ -383,7 +410,7 @@ _RAFT_DEVICE void last_filter(const T* out_buf, } } -// used only for adaptive version +// used only for adaptive mode template __global__ void last_filter_kernel(const T* in, const IdxT* in_idx, @@ -475,6 +502,12 @@ __global__ void last_filter_kernel(const T* in, * * In the implementation, the filtering step is delayed to the next pass so the filtering and * histogram computation are fused. In this way, inputs are read once rather than twice. + * + * For the adaptive mode, we don't write candidates (elements in bucket j) to `out_buf` in the + * filtering step if the number of candidates is relatively large (this could happen when the + * leading bits of input values are almost the same). And in the next pass, inputs are read from + * `in` rather than from `in_buf`. The benefit is that we can save the cost of writing candidates + * and their indices. */ template __global__ void radix_kernel(const T* in, @@ -568,7 +601,7 @@ __global__ void radix_kernel(const T* in, if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { - // last_filter_kernel from the adaptive version requires setting previous_len + // last_filter_kernel from the adaptive mode requires setting previous_len counter->previous_len = 0; counter->len = 0; } @@ -593,8 +626,8 @@ __global__ void radix_kernel(const T* in, counter->filter_cnt = 0; } - // For non-adaptive version, we do the last filtering using the last thread block. - // For adaptive version, we'll use a multi-block kernel (last_filter_kernel). + // For non-adaptive mode, we do the last filtering using the last thread block. + // For adaptive mode, we'll use a multi-block kernel (last_filter_kernel). if constexpr (!adaptive) { if (pass == num_passes - 1) { last_filter( @@ -620,14 +653,13 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool adaptive) IdxT best_num_blocks = 0; float best_tail_wave_penalty = 1.0f; - const IdxT max_num_blocks = (len - 1) / (VECTORIZED_READ_SIZE / sizeof(T) * BlockSize) + 1; + const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); for (int num_waves = 1;; ++num_waves) { IdxT num_blocks = std::min( max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); - IdxT items_per_thread = (len - 1) / (num_blocks * BlockSize) + 1; - items_per_thread = (items_per_thread - 1) / (VECTORIZED_READ_SIZE / sizeof(T)) + 1; - items_per_thread *= VECTORIZED_READ_SIZE / sizeof(T); - num_blocks = (len - 1) / (items_per_thread * BlockSize) + 1; + IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); + items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); + num_blocks = ceildiv(len, items_per_thread * BlockSize); float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; float tail_wave_penalty = (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); @@ -647,6 +679,42 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool adaptive) return best_num_blocks; } +template +_RAFT_HOST_DEVICE void set_buf_pointers(int pass, + const T* in, + const IdxT* in_idx, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } +} + template void radix_topk(const T* in, const IdxT* in_idx, @@ -713,27 +781,17 @@ void radix_topk(const T* in, constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } else if (pass % 2 == 0) { - in_buf = buf1.data(); - in_idx_buf = idx_buf1.data(); - out_buf = buf2.data(); - out_idx_buf = idx_buf2.data(); - } else { - in_buf = buf2.data(); - in_idx_buf = idx_buf2.data(); - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } + set_buf_pointers(pass, + in, + in_idx, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data(), + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); if (!adaptive) { radix_kernel @@ -768,18 +826,19 @@ void radix_topk(const T* in, select_min, pass); } + RAFT_CUDA_TRY(cudaPeekAtLastError()); } if (adaptive) { - dim3 blocks((len / (VECTORIZED_READ_SIZE / sizeof(T)) - 1) / BlockSize + 1, batch_size); + dim3 blocks(ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize), batch_size); last_filter_kernel<<>>( in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters.data(), select_min); + RAFT_CUDA_TRY(cudaPeekAtLastError()); } } // The following a few functions are for the one-block version, which uses single thread block for -// each row of a batch. It's used when len is relatively small, so intermediate data, like counters -// and histograms, can be kept in shared memory and cheap sync operations can be used. +// each row of a batch. template _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, const IdxT* in_idx_buf, @@ -879,27 +938,8 @@ __global__ void radix_topk_one_block_kernel(const T* in, constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = buf1; - out_idx_buf = idx_buf1; - } else if (pass % 2 == 0) { - in_buf = buf1; - in_idx_buf = idx_buf1; - out_buf = buf2; - out_idx_buf = idx_buf2; - } else { - in_buf = buf2; - in_idx_buf = idx_buf2; - out_buf = buf1; - out_idx_buf = idx_buf1; - } + set_buf_pointers( + pass, in, in_idx, buf1, idx_buf1, buf2, idx_buf2, in_buf, in_idx_buf, out_buf, out_idx_buf); IdxT current_len = counter.len; IdxT current_k = counter.k; @@ -937,6 +977,10 @@ __global__ void radix_topk_one_block_kernel(const T* in, } } +// radix_topk() might use multiple thread blocks for one row of a batch. In contrast, the following +// one-block version uses single thread block for one row of a batch, so intermediate data, like +// counters and global histograms, can be kept in shared memory and cheap sync operations can be +// used. It's used when len is relatively small. template void radix_topk_one_block(const T* in, const IdxT* in_idx, @@ -981,16 +1025,6 @@ void radix_topk_one_block(const T* in, idx_buf2.data()); } -template -__global__ void fill_idx_kernel(int batch_size, IdxT len, IdxT* out_idx) -{ - const int batch_i = blockIdx.y; - const int stride = blockDim.x * gridDim.x; - for (IdxT i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += stride) { - out_idx[batch_i * len + i] = i; - } -} - } // namespace impl /** @@ -1034,7 +1068,7 @@ __global__ void fill_idx_kernel(int batch_size, IdxT len, IdxT* out_idx) * @param select_min * whether to select k smallest (true) or largest (false) keys. * @param adaptive - * whether to use the adaptive implementation, which is preferable when the most significant bits + * whether to use the adaptive mode, which is preferable when the most significant bits * of input data are almost the same. That is, when the value range of input data is narrow. * @param stream * @param mr an optional memory resource to use across the calls (you can provide a large enough @@ -1060,9 +1094,10 @@ void select_k_updated(const T* in, RAFT_CUDA_TRY(cudaMemcpyAsync( out_idx, in_idx, sizeof(IdxT) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); } else { - constexpr int block_dim = 256; - dim3 grid_dim((len - 1) / block_dim + 1, batch_size, 1); - impl::fill_idx_kernel<<>>(batch_size, len, out_idx); + auto out_idx_view = + raft::make_device_vector_view(out_idx, static_cast(len) * batch_size); + raft::device_resources handle(stream); + raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op(len)); } return; } From 6add0b81992324e124b117b4a30d0c46a65cd7ee Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Sun, 12 Feb 2023 14:21:10 +0800 Subject: [PATCH 15/29] radix one-block: enable vectorized loading when pass==0 --- .../matrix/detail/select_radix_updated.cuh | 48 ++++++++++++------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 469ce114f0..23c264acaf 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -122,17 +122,18 @@ _RAFT_DEVICE bool use_lazy_writing(IdxT original_len, IdxT len) * @tparam IdxT indexing type * @tparam Func void (T x, IdxT idx) * + * @param thread_rank rank of the calling thread among all participated threads + * @param num_threads number of the threads that participate in processing * @param in the input data * @param len the number of elements to read * @param f the lambda taking two arguments (T x, IdxT idx) */ template -_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process( + IdxT thread_rank, IdxT num_threads, const T* in, IdxT len, Func f) { - const IdxT stride = blockDim.x * gridDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { - for (IdxT i = tid; i < len; i += stride) { + for (IdxT i = thread_rank; i < len; i += num_threads) { f(in[i], i); } } else { @@ -145,8 +146,8 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) const IdxT skip_cnt_left = std::min((IdxT)(align_bytes::roundUp(in) - in), len); // The main loop: process all aligned data - for (IdxT i = tid * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; - i += stride * wide_t::Ratio) { + for (IdxT i = thread_rank * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; + i += num_threads * wide_t::Ratio) { wide.load(in, i); #pragma unroll for (int j = 0; j < wide_t::Ratio; ++j) { @@ -156,10 +157,10 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) static_assert(WarpSize >= wide_t::Ratio); // Processes the skipped elements on the left - if (tid < skip_cnt_left) { f(in[tid], tid); } + if (thread_rank < skip_cnt_left) { f(in[thread_rank], thread_rank); } // Processes the skipped elements on the right const IdxT skip_cnt_right = align_elems::mod(len - skip_cnt_left); - const IdxT remain_i = len - skip_cnt_right + tid; + const IdxT remain_i = len - skip_cnt_right + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i); } } } @@ -238,7 +239,12 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); }; - vectorized_process(in_buf, previous_len, f); + vectorized_process(static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x), + static_cast(blockDim.x) * static_cast(gridDim.x), + in_buf, + previous_len, + f); } else { IdxT* p_filter_cnt = &counter->filter_cnt; IdxT* p_out_cnt = &counter->out_cnt; @@ -289,7 +295,12 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } }; - vectorized_process(in_buf, previous_len, f); + vectorized_process(static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x), + static_cast(blockDim.x) * static_cast(gridDim.x), + in_buf, + previous_len, + f); } if (early_stop) { return; } __syncthreads(); @@ -471,7 +482,12 @@ __global__ void last_filter_kernel(const T* in, } }; - vectorized_process(in_buf, previous_len, f); + vectorized_process( + static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x), + static_cast(blockDim.x) * static_cast(gridDim.x), + in_buf, + previous_len, + f); } /** @@ -864,14 +880,14 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, const IdxT previous_len = counter->previous_len; if (pass == 0) { - // Could not use vectorized_process() as in filter_and_histogram() because - // vectorized_process() assumes multi-block, e.g. uses gridDim.x - for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { - T value = in_buf[i]; + auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram + bucket, static_cast(1)); - } + }; + vectorized_process( + static_cast(threadIdx.x), static_cast(blockDim.x), in_buf, previous_len, f); } else { + // not use vectorized_process here because it increases #registers a lot IdxT* p_out_cnt = &counter->out_cnt; const auto kth_value_bits = counter->kth_value_bits; const int previous_start_bit = calc_start_bit(pass - 1); From 204b370f163a1a5e789a75d1a5f3b1c828b94b9d Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Sun, 12 Feb 2023 14:28:17 +0800 Subject: [PATCH 16/29] radix: use one-block version when calculated gridDim.x==1 --- .../matrix/detail/select_radix_updated.cuh | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 23c264acaf..cfbb74ba88 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -741,6 +741,7 @@ void radix_topk(const T* in, IdxT* out_idx, bool select_min, bool adaptive, + unsigned grid_dim, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -785,15 +786,7 @@ void radix_topk(const T* in, T* out_buf = nullptr; IdxT* out_idx_buf = nullptr; - int sm_cnt; - { - int dev; - RAFT_CUDA_TRY(cudaGetDevice(&dev)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); - } - dim3 blocks(calc_grid_dim(batch_size, len, sm_cnt, adaptive), - batch_size); - + dim3 blocks(grid_dim, batch_size); constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { @@ -1124,8 +1117,21 @@ void select_k_updated(const T* in, impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); } else { - impl::radix_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, adaptive, stream, mr); + int sm_cnt; + { + int dev; + RAFT_CUDA_TRY(cudaGetDevice(&dev)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); + } + unsigned grid_dim = + impl::calc_grid_dim(batch_size, len, sm_cnt, adaptive); + if (grid_dim == 1) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + } else { + impl::radix_topk( + in, in_idx, batch_size, len, k, out, out_idx, select_min, adaptive, grid_dim, stream, mr); + } } } From b482e293bdc417c7170eb8e6d2820bcffdc219e2 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Sun, 12 Feb 2023 21:48:42 +0800 Subject: [PATCH 17/29] radix: add chunking --- .../matrix/detail/select_radix_updated.cuh | 254 ++++++++++-------- 1 file changed, 146 insertions(+), 108 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index cfbb74ba88..950bb643a3 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -653,6 +653,20 @@ __global__ void radix_kernel(const T* in, } } +template +int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel) +{ + int active_blocks; + RAFT_CUDA_TRY( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks, kernel, BlockSize, 0)); + + constexpr int items_per_thread = 32; + constexpr int num_waves = 10; + int chunk_size = + std::max(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len); + return std::min(chunk_size, batch_size); +} + template unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool adaptive) { @@ -696,13 +710,13 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool adaptive) } template -_RAFT_HOST_DEVICE void set_buf_pointers(int pass, - const T* in, +_RAFT_HOST_DEVICE void set_buf_pointers(const T* in, const IdxT* in_idx, T* buf1, IdxT* idx_buf1, T* buf2, IdxT* idx_buf2, + int pass, const T*& in_buf, const IdxT*& in_idx_buf, T*& out_buf, @@ -742,6 +756,7 @@ void radix_topk(const T* in, bool select_min, bool adaptive, unsigned grid_dim, + int sm_cnt, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -749,8 +764,17 @@ void radix_topk(const T* in, static_assert(calc_num_passes() > 1); constexpr int num_buckets = calc_num_buckets(); - size_t req_aux = batch_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); - size_t req_buf = batch_size * len * 2 * (sizeof(T) + sizeof(IdxT)); + auto kernel = adaptive ? radix_kernel + : radix_kernel; + const int max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel); + if (max_chunk_size != batch_size) { + grid_dim = + calc_grid_dim(max_chunk_size, len, sm_cnt, adaptive); + } + + size_t req_aux = + static_cast(max_chunk_size) * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = static_cast(max_chunk_size) * len * 2 * (sizeof(T) + sizeof(IdxT)); size_t mem_req = req_aux + req_buf; size_t mem_free, mem_total; RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); @@ -770,79 +794,76 @@ void radix_topk(const T* in, } if (mr_buf == nullptr) { mr_buf = mr; } - rmm::device_uvector> counters(batch_size, stream, mr); - rmm::device_uvector histograms(batch_size * num_buckets, stream, mr); - rmm::device_uvector buf1(batch_size * len, stream, mr_buf); - rmm::device_uvector idx_buf1(batch_size * len, stream, mr_buf); - rmm::device_uvector buf2(batch_size * len, stream, mr_buf); - rmm::device_uvector idx_buf2(batch_size * len, stream, mr_buf); - - RAFT_CUDA_TRY( - cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; - - dim3 blocks(grid_dim, batch_size); - constexpr int num_passes = calc_num_passes(); + rmm::device_uvector> counters(max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); - for (int pass = 0; pass < num_passes; ++pass) { - set_buf_pointers(pass, - in, - in_idx, - buf1.data(), - idx_buf1.data(), - buf2.data(), - idx_buf2.data(), - in_buf, - in_idx_buf, - out_buf, - out_idx_buf); - - if (!adaptive) { - radix_kernel - <<>>(in, - in_idx, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - counters.data(), - histograms.data(), - len, - k, - select_min, - pass); - } else { - radix_kernel - <<>>(in, - in_idx, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - counters.data(), - histograms.data(), - len, - k, - select_min, - pass); + for (int offset = 0; offset < batch_size; offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + RAFT_CUDA_TRY( + cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); + + const T* chunk_in = in + offset * len; + const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + dim3 blocks(grid_dim, chunk_size); + constexpr int num_passes = calc_num_passes(); + + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers(chunk_in, + chunk_in_idx, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data(), + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + + kernel<<>>(chunk_in, + chunk_in_idx, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + counters.data(), + histograms.data(), + len, + k, + select_min, + pass); + RAFT_CUDA_TRY(cudaPeekAtLastError()); } - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - if (adaptive) { - dim3 blocks(ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize), batch_size); - last_filter_kernel<<>>( - in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters.data(), select_min); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + if (adaptive) { + dim3 blocks(ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize), chunk_size); + last_filter_kernel<<>>(chunk_in, + chunk_in_idx, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + len, + k, + counters.data(), + select_min); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } } } @@ -948,7 +969,8 @@ __global__ void radix_topk_one_block_kernel(const T* in, constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { set_buf_pointers( - pass, in, in_idx, buf1, idx_buf1, buf2, idx_buf2, in_buf, in_idx_buf, out_buf, out_idx_buf); + in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); + IdxT current_len = counter.len; IdxT current_k = counter.k; @@ -999,39 +1021,43 @@ void radix_topk_one_block(const T* in, T* out, IdxT* out_idx, bool select_min, + int sm_cnt, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { static_assert(calc_num_passes() > 1); - auto pool_guard = - raft::get_pool_memory_resource(mr, - batch_size * (sizeof(T) * len * 2 // T bufs - + sizeof(IdxT) * len * 2 // IdxT bufs - ) + - 256 * 4); // might need extra memory for alignment + auto kernel = radix_topk_one_block_kernel; + const int max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel); + + auto pool_guard = raft::get_pool_memory_resource( + mr, + static_cast(max_chunk_size) * len * 2 * (sizeof(T) + sizeof(IdxT)) + + 256 * 4); // might need extra memory for alignment if (pool_guard) { RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } - rmm::device_uvector buf1(len * batch_size, stream, mr); - rmm::device_uvector idx_buf1(len * batch_size, stream, mr); - rmm::device_uvector buf2(len * batch_size, stream, mr); - rmm::device_uvector idx_buf2(len * batch_size, stream, mr); - - radix_topk_one_block_kernel - <<>>(in, - in_idx, - len, - k, - out, - out_idx, - select_min, - buf1.data(), - idx_buf1.data(), - buf2.data(), - idx_buf2.data()); + rmm::device_uvector buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + + for (int offset = 0; offset < batch_size; offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + kernel<<>>(in + offset * len, + in_idx ? (in_idx + offset * len) : nullptr, + len, + k, + out + offset * k, + out_idx + offset * k, + select_min, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data()); + } } } // namespace impl @@ -1111,26 +1137,38 @@ void select_k_updated(const T* in, return; } + int sm_cnt; + { + int dev; + RAFT_CUDA_TRY(cudaGetDevice(&dev)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); + } + constexpr int items_per_thread = 32; if (len <= BlockSize * items_per_thread) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); } else { - int sm_cnt; - { - int dev; - RAFT_CUDA_TRY(cudaGetDevice(&dev)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); - } unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt, adaptive); if (grid_dim == 1) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); } else { - impl::radix_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, adaptive, grid_dim, stream, mr); + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + adaptive, + grid_dim, + sm_cnt, + stream, + mr); } } } From 834dc4bad832284da56ba4a4d514a2c0e8db50dc Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Sun, 12 Feb 2023 14:44:34 +0800 Subject: [PATCH 18/29] radix: reduce buf size for adaptive mode --- .../matrix/detail/select_radix_updated.cuh | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 950bb643a3..4c81014266 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -103,13 +103,15 @@ _RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) } template -_RAFT_DEVICE bool use_lazy_writing(IdxT original_len, IdxT len) +_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len, bool adaptive) { - // When using lazy writing, only read `in`(type T). - // When not using it, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) and - // `out_idx_buf`(IdxT). + if (!adaptive) { return len; } + // When writing is skipped in adaptive mode, only read `in`(type T). + // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) + // and `out_idx_buf`(IdxT). + // The ratio between these cases determines whether to skip writing and hence the buffer size. constexpr float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); - return len * ratio > original_len; + return len / ratio; } /** @@ -439,14 +441,15 @@ __global__ void last_filter_kernel(const T* in, Counter* counter = counters + batch_id; IdxT previous_len = counter->previous_len; if (previous_len == 0) { return; } - if (use_lazy_writing(len, previous_len)) { + const IdxT buf_len = calc_buf_len(len, true); + if (previous_len > buf_len) { in_buf = in; in_idx_buf = in_idx; previous_len = len; } - in_buf += batch_id * len; - if (in_idx_buf) { in_idx_buf += batch_id * len; } + in_buf += batch_id * (in_buf == in ? len : buf_len); + if (in_idx_buf) { in_idx_buf += batch_id * (in_idx_buf == in_idx ? len : buf_len); } out += batch_id * k; out_idx += batch_id * k; @@ -570,24 +573,25 @@ __global__ void radix_kernel(const T* in, constexpr int num_buckets = calc_num_buckets(); constexpr int num_passes = calc_num_passes(); + const IdxT buf_len = calc_buf_len(len, adaptive); if constexpr (adaptive) { // Figure out if the previous pass writes buffer - if (use_lazy_writing(len, previous_len)) { + if (previous_len > buf_len) { previous_len = len; in_buf = in; in_idx_buf = in_idx; } // Figure out if this pass need to write buffer - if (use_lazy_writing(len, current_len)) { + if (current_len > buf_len) { out_buf = nullptr; out_idx_buf = nullptr; } } - in_buf += batch_id * len; - if (in_idx_buf) { in_idx_buf += batch_id * len; } - if (out_buf) { out_buf += batch_id * len; } - if (out_idx_buf) { out_idx_buf += batch_id * len; } + in_buf += batch_id * (in_buf == in ? len : buf_len); + if (in_idx_buf) { in_idx_buf += batch_id * (in_idx_buf == in_idx ? len : buf_len); } + if (out_buf) { out_buf += batch_id * buf_len; } + if (out_idx_buf) { out_idx_buf += batch_id * buf_len; } if (out) { out += batch_id * k; out_idx += batch_id * k; @@ -771,10 +775,11 @@ void radix_topk(const T* in, grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt, adaptive); } + const IdxT buf_len = calc_buf_len(len, adaptive); size_t req_aux = static_cast(max_chunk_size) * (sizeof(Counter) + num_buckets * sizeof(IdxT)); - size_t req_buf = static_cast(max_chunk_size) * len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t req_buf = static_cast(max_chunk_size) * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); size_t mem_req = req_aux + req_buf; size_t mem_free, mem_total; RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); @@ -796,10 +801,10 @@ void radix_topk(const T* in, rmm::device_uvector> counters(max_chunk_size, stream, mr); rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr_buf); + rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr_buf); + rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr_buf); + rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr_buf); for (int offset = 0; offset < batch_size; offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); From 117f94d5f602a04e7694b34be2f510ca2739540e Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Sun, 19 Feb 2023 11:18:53 +0800 Subject: [PATCH 19/29] make implementation adaptive --- .../matrix/detail/select_radix_updated.cuh | 129 +++++++++--------- 1 file changed, 65 insertions(+), 64 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 4c81014266..9bd4ab88bc 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -103,10 +103,9 @@ _RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) } template -_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len, bool adaptive) +_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) { - if (!adaptive) { return len; } - // When writing is skipped in adaptive mode, only read `in`(type T). + // When writing is skipped, only read `in`(type T). // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) // and `out_idx_buf`(IdxT). // The ratio between these cases determines whether to skip writing and hence the buffer size. @@ -286,7 +285,7 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, } } // the condition `(out_buf || early_stop)` is a little tricky: - // If we skip writing to `out_buf` (when `out_buf` is false), we should skip writing to + // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should skip writing to // `out` too. So we won't write the same value to `out` multiple times in different passes. // And if we keep skipping the writing, values will be written in `last_filter_kernel` at // last. @@ -386,8 +385,8 @@ _RAFT_DEVICE void choose_bucket(Counter* counter, // For one-block version, last_filter() could be called when pass < num_passes - 1. // So `pass` could not be constexpr template -_RAFT_DEVICE void last_filter(const T* out_buf, - const IdxT* out_idx_buf, +_RAFT_DEVICE void last_filter(const T* in_buf, + const IdxT* in_idx_buf, T* out, IdxT* out_idx, IdxT current_len, @@ -404,26 +403,25 @@ _RAFT_DEVICE void last_filter(const T* out_buf, IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { - const T value = out_buf[i]; + const T value = in_buf[i]; const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; - // For one-block version, `out_idx_buf` could be nullptr at pass 0. - // And for adaptive mode, `out_idx_buf` could be nullptr if `out_buf` is `in` - out_idx[pos] = out_idx_buf ? out_idx_buf[i] : i; + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // And when writing is skipped, `in_idx_buf` could be nullptr if `in_buf` is `in` + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); if (back_pos < needed_num_of_kth) { IdxT pos = k - 1 - back_pos; out[pos] = value; - out_idx[pos] = out_idx_buf ? out_idx_buf[i] : i; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } } } } -// used only for adaptive mode template __global__ void last_filter_kernel(const T* in, const IdxT* in_idx, @@ -441,15 +439,15 @@ __global__ void last_filter_kernel(const T* in, Counter* counter = counters + batch_id; IdxT previous_len = counter->previous_len; if (previous_len == 0) { return; } - const IdxT buf_len = calc_buf_len(len, true); - if (previous_len > buf_len) { - in_buf = in; - in_idx_buf = in_idx; + const IdxT buf_len = calc_buf_len(len); + if (previous_len > buf_len || in_buf == in) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; } - - in_buf += batch_id * (in_buf == in ? len : buf_len); - if (in_idx_buf) { in_idx_buf += batch_id * (in_idx_buf == in_idx ? len : buf_len); } out += batch_id * k; out_idx += batch_id * k; @@ -528,7 +526,7 @@ __global__ void last_filter_kernel(const T* in, * `in` rather than from `in_buf`. The benefit is that we can save the cost of writing candidates * and their indices. */ -template +template __global__ void radix_kernel(const T* in, const IdxT* in_idx, const T* in_buf, @@ -569,34 +567,31 @@ __global__ void radix_kernel(const T* in, // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is // handled in other way in select_k() so such case is not possible here. - bool early_stop = (current_len == current_k); + const bool early_stop = (current_len == current_k); + const IdxT buf_len = calc_buf_len(len); - constexpr int num_buckets = calc_num_buckets(); - constexpr int num_passes = calc_num_passes(); - const IdxT buf_len = calc_buf_len(len, adaptive); - - if constexpr (adaptive) { - // Figure out if the previous pass writes buffer - if (previous_len > buf_len) { - previous_len = len; - in_buf = in; - in_idx_buf = in_idx; - } - // Figure out if this pass need to write buffer - if (current_len > buf_len) { - out_buf = nullptr; - out_idx_buf = nullptr; - } + // "previous_len > buf_len" means previous pass skips writing buffer + if (pass == 0 || pass == 1 || previous_len > buf_len) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; } - in_buf += batch_id * (in_buf == in ? len : buf_len); - if (in_idx_buf) { in_idx_buf += batch_id * (in_idx_buf == in_idx ? len : buf_len); } - if (out_buf) { out_buf += batch_id * buf_len; } - if (out_idx_buf) { out_idx_buf += batch_id * buf_len; } - if (out) { - out += batch_id * k; - out_idx += batch_id * k; + // "current_len > buf_len" means current pass will skip writing buffer + if (pass == 0 || current_len > buf_len) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else { + out_buf += batch_id * buf_len; + out_idx_buf += batch_id * buf_len; } - auto histogram = histograms + batch_id * num_buckets; + out += batch_id * k; + out_idx += batch_id * k; + + constexpr int num_buckets = calc_num_buckets(); + auto histogram = histograms + batch_id * num_buckets; filter_and_histogram(in_buf, in_idx_buf, @@ -633,6 +628,7 @@ __global__ void radix_kernel(const T* in, choose_bucket(counter, histogram, current_k, pass); __syncthreads(); + constexpr int num_passes = calc_num_passes(); // reset for next pass if (pass != num_passes - 1) { for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { @@ -648,10 +644,17 @@ __global__ void radix_kernel(const T* in, // For non-adaptive mode, we do the last filtering using the last thread block. // For adaptive mode, we'll use a multi-block kernel (last_filter_kernel). - if constexpr (!adaptive) { + if constexpr (fused_last_filter) { if (pass == num_passes - 1) { - last_filter( - out_buf, out_idx_buf, out, out_idx, current_len, k, counter, select_min, pass); + last_filter(out_buf ? out_buf : in_buf, + out_idx_buf ? out_idx_buf : in_idx_buf, + out, + out_idx, + out_buf ? current_len : len, + k, + counter, + select_min, + pass); } } } @@ -672,17 +675,13 @@ int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel) } template -unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt, bool adaptive) +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) { static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &active_blocks, - adaptive ? radix_kernel - : radix_kernel, - BlockSize, - 0)); + &active_blocks, radix_kernel, BlockSize, 0)); active_blocks *= sm_cnt; IdxT best_num_blocks = 0; @@ -758,7 +757,7 @@ void radix_topk(const T* in, T* out, IdxT* out_idx, bool select_min, - bool adaptive, + bool fused_last_filter, unsigned grid_dim, int sm_cnt, rmm::cuda_stream_view stream, @@ -768,14 +767,12 @@ void radix_topk(const T* in, static_assert(calc_num_passes() > 1); constexpr int num_buckets = calc_num_buckets(); - auto kernel = adaptive ? radix_kernel - : radix_kernel; + auto kernel = radix_kernel; const int max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel); if (max_chunk_size != batch_size) { - grid_dim = - calc_grid_dim(max_chunk_size, len, sm_cnt, adaptive); + grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt); } - const IdxT buf_len = calc_buf_len(len, adaptive); + const IdxT buf_len = calc_buf_len(len); size_t req_aux = static_cast(max_chunk_size) * (sizeof(Counter) + num_buckets * sizeof(IdxT)); @@ -838,6 +835,10 @@ void radix_topk(const T* in, out_buf, out_idx_buf); + if (fused_last_filter && pass == num_passes - 1) { + kernel = radix_kernel; + } + kernel<<>>(chunk_in, chunk_in_idx, in_buf, @@ -855,7 +856,7 @@ void radix_topk(const T* in, RAFT_CUDA_TRY(cudaPeekAtLastError()); } - if (adaptive) { + if (!fused_last_filter) { dim3 blocks(ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize), chunk_size); last_filter_kernel<<>>(chunk_in, chunk_in_idx, @@ -1123,7 +1124,7 @@ void select_k_updated(const T* in, T* out, IdxT* out_idx, bool select_min, - bool adaptive, + bool fused_last_filter, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = nullptr) { @@ -1156,7 +1157,7 @@ void select_k_updated(const T* in, in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); } else { unsigned grid_dim = - impl::calc_grid_dim(batch_size, len, sm_cnt, adaptive); + impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); @@ -1169,7 +1170,7 @@ void select_k_updated(const T* in, out, out_idx, select_min, - adaptive, + fused_last_filter, grid_dim, sm_cnt, stream, From 6555bbea7593ee106b421afa14a59bcd73a6ab51 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Sun, 19 Feb 2023 20:47:53 +0800 Subject: [PATCH 20/29] polish code comments --- .../matrix/detail/select_radix_updated.cuh | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 9bd4ab88bc..67dbb2a8c9 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -123,7 +123,7 @@ _RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) * @tparam IdxT indexing type * @tparam Func void (T x, IdxT idx) * - * @param thread_rank rank of the calling thread among all participated threads + * @param thread_rank rank of the calling thread among all participating threads * @param num_threads number of the threads that participate in processing * @param in the input data * @param len the number of elements to read @@ -186,7 +186,7 @@ struct alignas(128) Counter { typename cub::Traits::UnsignedBits kth_value_bits; // Record how many elements have passed filtering. It's used to determine the position in the - // `out_buf` where an element should be written to. + // `out_buf` where an element should be written. alignas(128) IdxT filter_cnt; // For a row inside a batch, we may launch multiple thread blocks. This counter is used to @@ -287,9 +287,8 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, // the condition `(out_buf || early_stop)` is a little tricky: // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should skip writing to // `out` too. So we won't write the same value to `out` multiple times in different passes. - // And if we keep skipping the writing, values will be written in `last_filter_kernel` at - // last. - // But when `early_stop` is true, we need to write to `out` since it's the last chance. + // And if we keep skipping the writing, values will be written in `last_filter_kernel()` at + // last. But when `early_stop` is true, we need to write to `out` since it's the last chance. else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; @@ -409,7 +408,8 @@ _RAFT_DEVICE void last_filter(const T* in_buf, IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; // For one-block version, `in_idx_buf` could be nullptr at pass 0. - // And when writing is skipped, `in_idx_buf` could be nullptr if `in_buf` is `in` + // For non one-block version, if writing has been skipped, `in_idx_buf` could be nullptr if + // `in_buf` is `in` out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); @@ -520,11 +520,11 @@ __global__ void last_filter_kernel(const T* in, * In the implementation, the filtering step is delayed to the next pass so the filtering and * histogram computation are fused. In this way, inputs are read once rather than twice. * - * For the adaptive mode, we don't write candidates (elements in bucket j) to `out_buf` in the - * filtering step if the number of candidates is relatively large (this could happen when the - * leading bits of input values are almost the same). And in the next pass, inputs are read from - * `in` rather than from `in_buf`. The benefit is that we can save the cost of writing candidates - * and their indices. + * During the filtering step, we won't write candidates (elements in bucket j) to `out_buf` if the + * number of candidates is larger than the length of `out_buf` (this could happen when the leading + * bits of input values are almost the same). And then in the next pass, inputs are read from `in` + * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and + * their indices. */ template __global__ void radix_kernel(const T* in, @@ -616,7 +616,7 @@ __global__ void radix_kernel(const T* in, if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { - // last_filter_kernel from the adaptive mode requires setting previous_len + // `last_filter_kernel()` requires setting previous_len counter->previous_len = 0; counter->len = 0; } @@ -636,14 +636,12 @@ __global__ void radix_kernel(const T* in, } } if (threadIdx.x == 0) { - // last_filter_kernel requires setting previous_len even in the last pass + // `last_filter_kernel()` requires setting previous_len even in the last pass counter->previous_len = current_len; // not necessary for the last pass, but put it here anyway counter->filter_cnt = 0; } - // For non-adaptive mode, we do the last filtering using the last thread block. - // For adaptive mode, we'll use a multi-block kernel (last_filter_kernel). if constexpr (fused_last_filter) { if (pass == num_passes - 1) { last_filter(out_buf ? out_buf : in_buf, @@ -1017,7 +1015,8 @@ __global__ void radix_topk_one_block_kernel(const T* in, // radix_topk() might use multiple thread blocks for one row of a batch. In contrast, the following // one-block version uses single thread block for one row of a batch, so intermediate data, like // counters and global histograms, can be kept in shared memory and cheap sync operations can be -// used. It's used when len is relatively small. +// used. It's used when len is relatively small or when the number of blocks per row calculated by +// `calc_grid_dim()` is 1. template void radix_topk_one_block(const T* in, const IdxT* in_idx, @@ -1108,9 +1107,13 @@ void radix_topk_one_block(const T* in, * the payload selected together with `out`. * @param select_min * whether to select k smallest (true) or largest (false) keys. - * @param adaptive - * whether to use the adaptive mode, which is preferable when the most significant bits - * of input data are almost the same. That is, when the value range of input data is narrow. + * @param fused_last_filter + * when it's true, the last filter is fused into the kernel in the last pass and only one thread + * block will do the filtering; when false, a standalone filter kernel with multiple thread + * blocks is called. The later case is preferable when the most significant bits of input data are + * almost the same. That is, when the value range of input data is narrow. In such case, there + * could be a large number of inputs for the last filter, hence using multiple thread blocks is + * beneficial. * @param stream * @param mr an optional memory resource to use across the calls (you can provide a large enough * memory pool here to avoid memory allocations within the call). From 48c4faf7a4c0193a57a9df1504044609ba92246a Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Sun, 19 Feb 2023 21:32:00 +0800 Subject: [PATCH 21/29] fix potential mul overflow --- .../matrix/detail/select_radix_updated.cuh | 79 +++++++++---------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 67dbb2a8c9..7a50aadf29 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -131,7 +131,7 @@ _RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) */ template _RAFT_DEVICE void vectorized_process( - IdxT thread_rank, IdxT num_threads, const T* in, IdxT len, Func f) + size_t thread_rank, size_t num_threads, const T* in, IdxT len, Func f) { if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { for (IdxT i = thread_rank; i < len; i += num_threads) { @@ -240,9 +240,8 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); }; - vectorized_process(static_cast(blockIdx.x) * static_cast(blockDim.x) + - static_cast(threadIdx.x), - static_cast(blockDim.x) * static_cast(gridDim.x), + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); @@ -295,9 +294,8 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } }; - vectorized_process(static_cast(blockIdx.x) * static_cast(blockDim.x) + - static_cast(threadIdx.x), - static_cast(blockDim.x) * static_cast(gridDim.x), + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); @@ -434,7 +432,7 @@ __global__ void last_filter_kernel(const T* in, Counter* counters, const bool select_min) { - const int batch_id = blockIdx.y; + const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow Counter* counter = counters + batch_id; IdxT previous_len = counter->previous_len; @@ -483,12 +481,11 @@ __global__ void last_filter_kernel(const T* in, } }; - vectorized_process( - static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x), - static_cast(blockDim.x) * static_cast(gridDim.x), - in_buf, - previous_len, - f); + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); } /** @@ -542,8 +539,8 @@ __global__ void radix_kernel(const T* in, const bool select_min, const int pass) { - const int batch_id = blockIdx.y; - auto counter = counters + batch_id; + const size_t batch_id = blockIdx.y; + auto counter = counters + batch_id; IdxT current_k; IdxT previous_len; IdxT current_len; @@ -765,16 +762,16 @@ void radix_topk(const T* in, static_assert(calc_num_passes() > 1); constexpr int num_buckets = calc_num_buckets(); - auto kernel = radix_kernel; - const int max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel); - if (max_chunk_size != batch_size) { + auto kernel = radix_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); + if (max_chunk_size != static_cast(batch_size)) { grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt); } const IdxT buf_len = calc_buf_len(len); - size_t req_aux = - static_cast(max_chunk_size) * (sizeof(Counter) + num_buckets * sizeof(IdxT)); - size_t req_buf = static_cast(max_chunk_size) * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); size_t mem_req = req_aux + req_buf; size_t mem_free, mem_total; RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); @@ -801,7 +798,7 @@ void radix_topk(const T* in, rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr_buf); rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr_buf); - for (int offset = 0; offset < batch_size; offset += max_chunk_size) { + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); RAFT_CUDA_TRY( cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); @@ -902,8 +899,7 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram + bucket, static_cast(1)); }; - vectorized_process( - static_cast(threadIdx.x), static_cast(blockDim.x), in_buf, previous_len, f); + vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); } else { // not use vectorized_process here because it increases #registers a lot IdxT* p_out_cnt = &counter->out_cnt; @@ -957,14 +953,15 @@ __global__ void radix_topk_one_block_kernel(const T* in, } __syncthreads(); - in += blockIdx.x * len; - if (in_idx) { in_idx += blockIdx.x * len; } - out += blockIdx.x * k; - out_idx += blockIdx.x * k; - buf1 += blockIdx.x * len; - idx_buf1 += blockIdx.x * len; - buf2 += blockIdx.x * len; - idx_buf2 += blockIdx.x * len; + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + in += batch_id * len; + if (in_idx) { in_idx += batch_id * len; } + out += batch_id * k; + out_idx += batch_id * k; + buf1 += batch_id * len; + idx_buf1 += batch_id * len; + buf2 += batch_id * len; + idx_buf2 += batch_id * len; const T* in_buf = nullptr; const IdxT* in_idx_buf = nullptr; T* out_buf = nullptr; @@ -1032,13 +1029,15 @@ void radix_topk_one_block(const T* in, { static_assert(calc_num_passes() > 1); - auto kernel = radix_topk_one_block_kernel; - const int max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel); + auto kernel = radix_topk_one_block_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); - auto pool_guard = raft::get_pool_memory_resource( - mr, - static_cast(max_chunk_size) * len * 2 * (sizeof(T) + sizeof(IdxT)) + - 256 * 4); // might need extra memory for alignment + auto pool_guard = + raft::get_pool_memory_resource(mr, + max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)) + + 256 * 4 // might need extra memory for alignment + ); if (pool_guard) { RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); @@ -1049,7 +1048,7 @@ void radix_topk_one_block(const T* in, rmm::device_uvector buf2(len * max_chunk_size, stream, mr); rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); - for (int offset = 0; offset < batch_size; offset += max_chunk_size) { + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); kernel<<>>(in + offset * len, in_idx ? (in_idx + offset * len) : nullptr, From 415798aea4f87f513fc837b46225b2b437c2241f Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Mon, 20 Feb 2023 15:25:11 +0800 Subject: [PATCH 22/29] fix launch conf of last_filter_kernel --- cpp/include/raft/matrix/detail/select_radix_updated.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 7a50aadf29..6713bc1916 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -852,7 +852,6 @@ void radix_topk(const T* in, } if (!fused_last_filter) { - dim3 blocks(ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize), chunk_size); last_filter_kernel<<>>(chunk_in, chunk_in_idx, out_buf, From 77a14d94f46e0d75921e1c94ea912c53e41ee2db Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Mon, 20 Feb 2023 21:23:47 +0800 Subject: [PATCH 23/29] update test and benchmark --- cpp/bench/matrix/select_k.cu | 56 +++++++++---------- .../raft_internal/matrix/select_k.cuh | 12 ++-- cpp/test/matrix/select_k.cu | 10 ++-- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu index a487b4201b..d152ff98c0 100644 --- a/cpp/bench/matrix/select_k.cu +++ b/cpp/bench/matrix/select_k.cu @@ -160,33 +160,33 @@ const std::vector kInputs{ RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ } -SELECTION_REGISTER(float, int, kPublicApi); // NOLINT -SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT -SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT -SELECTION_REGISTER(float, int, kRadix8bitsUpdated); // NOLINT -SELECTION_REGISTER(float, int, kRadix11bitsUpdated); // NOLINT -SELECTION_REGISTER(float, int, kRadix11bitsAdaptive); // NOLINT -SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT -SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT -SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT -SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT -SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT - -SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, int, kRadix8bitsUpdated); // NOLINT -SELECTION_REGISTER(double, int, kRadix11bitsUpdated); // NOLINT -SELECTION_REGISTER(double, int, kRadix11bitsAdaptive); // NOLINT -SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT - -SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, size_t, kRadix8bitsUpdated); // NOLINT -SELECTION_REGISTER(double, size_t, kRadix11bitsUpdated); // NOLINT -SELECTION_REGISTER(double, size_t, kRadix11bitsAdaptive); // NOLINT -SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT +SELECTION_REGISTER(float, int, kPublicApi); // NOLINT +SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, int, kRadix8bitsUpdated); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bitsUpdated); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int, kRadix8bitsUpdated); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bitsUpdated); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix8bitsUpdated); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bitsUpdated); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT } // namespace raft::matrix diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index 3412220ee1..d63c77050a 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -55,7 +55,7 @@ enum class Algo { kRadix11bits, kRadix8bitsUpdated, kRadix11bitsUpdated, - kRadix11bitsAdaptive, + kRadix11bitsExtraPass, kWarpAuto, kWarpImmediate, kWarpFiltered, @@ -71,7 +71,7 @@ inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream& case Algo::kRadix11bits: return os << "kRadix11bits"; case Algo::kRadix8bitsUpdated: return os << "kRadix8bitsUpdated"; case Algo::kRadix11bitsUpdated: return os << "kRadix11bitsUpdated"; - case Algo::kRadix11bitsAdaptive: return os << "kRadix11bitsAdaptive"; + case Algo::kRadix11bitsExtraPass: return os << "kRadix11bitsExtraPass"; case Algo::kWarpAuto: return os << "kWarpAuto"; case Algo::kWarpImmediate: return os << "kWarpImmediate"; case Algo::kWarpFiltered: return os << "kWarpFiltered"; @@ -126,7 +126,7 @@ void select_k_impl(const device_resources& handle, out, out_idx, select_min, - false, // adaptive + true, // fused_last_filter stream); case Algo::kRadix11bitsUpdated: return detail::select::radix::select_k_updated(in, @@ -137,9 +137,9 @@ void select_k_impl(const device_resources& handle, out, out_idx, select_min, - false, // adaptive + true, // fused_last_filter stream); - case Algo::kRadix11bitsAdaptive: + case Algo::kRadix11bitsExtraPass: return detail::select::radix::select_k_updated(in, in_idx, batch_size, @@ -148,7 +148,7 @@ void select_k_impl(const device_resources& handle, out, out_idx, select_min, - true, // adaptive + false, // fused_last_filter stream); case Algo::kWarpAuto: return detail::select::warpsort::select_k( diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index b62c2038fe..b1f406ed15 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -330,7 +330,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kRadix11bits, select::Algo::kRadix8bitsUpdated, select::Algo::kRadix11bitsUpdated, - select::Algo::kRadix11bitsAdaptive, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed))); @@ -427,7 +427,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kRadix11bits, select::Algo::kRadix8bitsUpdated, select::Algo::kRadix11bitsUpdated, - select::Algo::kRadix11bitsAdaptive, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -444,7 +444,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kRadix11bits, select::Algo::kRadix8bitsUpdated, select::Algo::kRadix11bitsUpdated, - select::Algo::kRadix11bitsAdaptive, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -460,7 +460,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Values(select::Algo::kWarpAuto, select::Algo::kRadix8bitsUpdated, select::Algo::kRadix11bitsUpdated, - select::Algo::kRadix11bitsAdaptive))); + select::Algo::kRadix11bitsExtraPass))); using ReferencedRandomFloatSizeT = SelectK::params_random>; @@ -471,6 +471,6 @@ INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT testing::Values(select::Algo::kRadix11bits, select::Algo::kRadix8bitsUpdated, select::Algo::kRadix11bitsUpdated, - select::Algo::kRadix11bitsAdaptive))); + select::Algo::kRadix11bitsExtraPass))); } // namespace raft::matrix From 532cb4ffee55ecac7bf7b9f895adf49d63a462d0 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Tue, 21 Feb 2023 22:26:51 +0800 Subject: [PATCH 24/29] remove managed_memory_resource and refine code comments --- .../matrix/detail/select_radix_updated.cuh | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh index 6713bc1916..7daa8cf6e6 100644 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ b/cpp/include/raft/matrix/detail/select_radix_updated.cuh @@ -772,31 +772,20 @@ void radix_topk(const T* in, size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); - size_t mem_req = req_aux + req_buf; - size_t mem_free, mem_total; - RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); - std::optional managed_memory; - rmm::mr::device_memory_resource* mr_buf = nullptr; - if (mem_req > mem_free) { - // if there's not enough memory for buffers on the device, resort to the managed memory. - mem_req = req_aux; - managed_memory.emplace(); - mr_buf = &managed_memory.value(); - } + size_t mem_req = req_aux + req_buf + 256 * 6; // might need extra memory for alignment auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); if (pool_guard) { RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } - if (mr_buf == nullptr) { mr_buf = mr; } rmm::device_uvector> counters(max_chunk_size, stream, mr); rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr_buf); - rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr_buf); - rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr_buf); - rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr_buf); + rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); @@ -1108,10 +1097,9 @@ void radix_topk_one_block(const T* in, * @param fused_last_filter * when it's true, the last filter is fused into the kernel in the last pass and only one thread * block will do the filtering; when false, a standalone filter kernel with multiple thread - * blocks is called. The later case is preferable when the most significant bits of input data are - * almost the same. That is, when the value range of input data is narrow. In such case, there - * could be a large number of inputs for the last filter, hence using multiple thread blocks is - * beneficial. + * blocks is called. The later case is preferable when leading bits of input data are almost the + * same. That is, when the value range of input data is narrow. In such case, there could be a + * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. * @param stream * @param mr an optional memory resource to use across the calls (you can provide a large enough * memory pool here to avoid memory allocations within the call). @@ -1144,6 +1132,8 @@ void select_k_updated(const T* in, return; } + // TODO: use device_resources::get_device_properties() instead; should change it when we refactor + // resource management int sm_cnt; { int dev; From 63c5fd064d7ae730b4c043b2dd8f07536016d873 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Wed, 22 Feb 2023 21:05:26 +0800 Subject: [PATCH 25/29] replace select_radix.cuh with select_radix_updated.cuh --- cpp/bench/matrix/select_k.cu | 6 - cpp/include/raft/matrix/detail/select_k.cuh | 2 +- .../raft/matrix/detail/select_radix.cuh | 1145 +++++++++++----- .../matrix/detail/select_radix_updated.cuh | 1173 ----------------- .../raft_internal/matrix/select_k.cuh | 71 +- cpp/test/matrix/select_k.cu | 12 +- 6 files changed, 862 insertions(+), 1547 deletions(-) delete mode 100644 cpp/include/raft/matrix/detail/select_radix_updated.cuh diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu index d152ff98c0..060a8cf0d3 100644 --- a/cpp/bench/matrix/select_k.cu +++ b/cpp/bench/matrix/select_k.cu @@ -163,8 +163,6 @@ const std::vector kInputs{ SELECTION_REGISTER(float, int, kPublicApi); // NOLINT SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT -SELECTION_REGISTER(float, int, kRadix8bitsUpdated); // NOLINT -SELECTION_REGISTER(float, int, kRadix11bitsUpdated); // NOLINT SELECTION_REGISTER(float, int, kRadix11bitsExtraPass); // NOLINT SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT @@ -174,15 +172,11 @@ SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, int, kRadix8bitsUpdated); // NOLINT -SELECTION_REGISTER(double, int, kRadix11bitsUpdated); // NOLINT SELECTION_REGISTER(double, int, kRadix11bitsExtraPass); // NOLINT SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, size_t, kRadix8bitsUpdated); // NOLINT -SELECTION_REGISTER(double, size_t, kRadix11bitsUpdated); // NOLINT SELECTION_REGISTER(double, size_t, kRadix11bitsExtraPass); // NOLINT SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k.cuh index ac1ba3dfa3..20c2fb119d 100644 --- a/cpp/include/raft/matrix/detail/select_k.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -84,7 +84,7 @@ void select_k(const T* in_val, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } else { select::radix::select_k= 4 ? 11 : 8), 512>( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); } } diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index de19e63a4c..dc1ce1920c 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,11 @@ #pragma once -#include #include #include +#include +#include +#include #include #include #include @@ -32,8 +34,8 @@ #include namespace raft::matrix::detail::select::radix { +namespace impl { -constexpr int ITEM_PER_THREAD = 32; constexpr int VECTORIZED_READ_SIZE = 16; template @@ -48,13 +50,6 @@ _RAFT_HOST_DEVICE constexpr int calc_num_passes() return ceildiv(sizeof(T) * 8, BitsPerPass); } -// Minimum reasonable block size for the given radix size. -template -_RAFT_HOST_DEVICE constexpr int calc_min_block_size() -{ - return 1 << std::max(BitsPerPass - 4, Pow2::Log2 + 1); -} - /** * Bit 0 is the least significant (rightmost); * this implementation processes input from the most to the least significant bit. @@ -79,23 +74,43 @@ _RAFT_DEVICE constexpr unsigned calc_mask(int pass) } /** - * Use cub to twiddle bits - so that we can correctly compare bits of floating-point values as well + * Use CUB to twiddle bits - so that we can correctly compare bits of floating-point values as well * as of integers. */ template -_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); - if (greater) { bits = ~bits; } + if (!select_min) { bits = ~bits; } return bits; } +template +_RAFT_DEVICE T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) +{ + if (!select_min) { bits = ~bits; } + bits = cub::Traits::TwiddleOut(bits); + return reinterpret_cast(bits); +} + template -_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) { - static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int - return (twiddle_in(x, greater) >> start_bit) & mask; + static_assert(BitsPerPass <= sizeof(int) * 8 - 1, + "BitsPerPass is too large that the result type could not be int"); + return (twiddle_in(x, select_min) >> start_bit) & mask; +} + +template +_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) +{ + // When writing is skipped, only read `in`(type T). + // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) + // and `out_idx_buf`(IdxT). + // The ratio between these cases determines whether to skip writing and hence the buffer size. + constexpr float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); + return len / ratio; } /** @@ -108,17 +123,18 @@ _RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) * @tparam IdxT indexing type * @tparam Func void (T x, IdxT idx) * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing * @param in the input data * @param len the number of elements to read * @param f the lambda taking two arguments (T x, IdxT idx) */ template -_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process( + size_t thread_rank, size_t num_threads, const T* in, IdxT len, Func f) { - const IdxT stride = blockDim.x * gridDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { - for (IdxT i = tid; i < len; i += stride) { + for (IdxT i = thread_rank; i < len; i += num_threads) { f(in[i], i); } } else { @@ -131,8 +147,8 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) const IdxT skip_cnt_left = std::min((IdxT)(align_bytes::roundUp(in) - in), len); // The main loop: process all aligned data - for (IdxT i = tid * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; - i += stride * wide_t::Ratio) { + for (IdxT i = thread_rank * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; + i += num_threads * wide_t::Ratio) { wide.load(in, i); #pragma unroll for (int j = 0; j < wide_t::Ratio; ++j) { @@ -142,30 +158,55 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) static_assert(WarpSize >= wide_t::Ratio); // Processes the skipped elements on the left - if (tid < skip_cnt_left) { f(in[tid], tid); } + if (thread_rank < skip_cnt_left) { f(in[thread_rank], thread_rank); } // Processes the skipped elements on the right const IdxT skip_cnt_right = align_elems::mod(len - skip_cnt_left); - const IdxT remain_i = len - skip_cnt_right + tid; + const IdxT remain_i = len - skip_cnt_right + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i); } } } template -struct Counter { +struct alignas(128) Counter { + // We are processing the values in multiple passes, from most significant to least significant. In + // each pass, we keep the length of input (`len`) and the `k` of current pass, and update them at + // the end of the pass. IdxT k; IdxT len; + + // `previous_len` is the length of input in previous pass. Note that `previous_len` rather + // than `len` is used for the filtering step because filtering is indeed for previous pass (see + // comments before `radix_kernel`). IdxT previous_len; - int bucket; - IdxT filter_cnt; - unsigned int finished_block_cnt; - IdxT out_cnt; - IdxT out_back_cnt; + // We determine the bits of the k_th value inside the mask processed by the pass. The + // already known bits are stored in `kth_value_bits`. It's used to discriminate a element is a + // result (written to `out`), a candidate for next pass (written to `out_buf`), or not useful + // (discarded). The bits that are not yet processed do not matter for this purpose. + typename cub::Traits::UnsignedBits kth_value_bits; + + // Record how many elements have passed filtering. It's used to determine the position in the + // `out_buf` where an element should be written. + alignas(128) IdxT filter_cnt; + + // For a row inside a batch, we may launch multiple thread blocks. This counter is used to + // determine if the current block is the last running block. If so, this block will execute scan() + // and choose_bucket(). + alignas(128) unsigned int finished_block_cnt; + + // Record how many elements have been written to the front of `out`. Elements less (if + // select_min==true) than the k-th value are written from front to back. + alignas(128) IdxT out_cnt; + + // Record how many elements have been written to the back of `out`. Elements equal to the k-th + // value are written from back to front. We need to keep count of them separately because the + // number of elements that <= the k-th value might exceed k. + alignas(128) IdxT out_back_cnt; }; /** - * Fused filtering of the current phase and building histogram for the next phase - * (see steps 4-1 in `radix_kernel` description). + * Fused filtering of the current pass and building histogram for the next pass + * (see steps 4 & 1 in `radix_kernel` description). */ template _RAFT_DEVICE void filter_and_histogram(const T* in_buf, @@ -174,12 +215,12 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, IdxT* out_idx_buf, T* out, IdxT* out_idx, - IdxT len, + IdxT previous_len, Counter* counter, IdxT* histogram, - bool greater, + bool select_min, int pass, - int k) + bool early_stop) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -195,19 +236,20 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, // Passed to vectorized_process, this function executes in all blocks in parallel, // i.e. the work is split along the input (both, in batches and chunks of a single row). // Later, the histograms are merged using atomicAdd. - auto f = [greater, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, IdxT(1)); + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); }; - vectorized_process(in_buf, len, f); + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); } else { - const IdxT previous_len = counter->previous_len; - const int want_bucket = counter->bucket; - IdxT& filter_cnt = counter->filter_cnt; - IdxT& out_cnt = counter->out_cnt; - const IdxT counter_len = counter->len; + IdxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; const int previous_start_bit = calc_start_bit(pass - 1); - const unsigned previous_mask = calc_mask(pass - 1); // See the remark above on the distributed execution of `f` using vectorized_process. auto f = [in_idx_buf, @@ -215,38 +257,50 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, out_idx_buf, out, out_idx, - greater, - k, + select_min, start_bit, mask, previous_start_bit, - previous_mask, - want_bucket, - &filter_cnt, - &out_cnt, - counter_len](T value, IdxT i) { - int prev_bucket = - calc_bucket(value, previous_start_bit, previous_mask, greater); - if (prev_bucket == want_bucket) { - IdxT pos = atomicAdd(&filter_cnt, IdxT(1)); - out_buf[pos] = value; - if (out_idx_buf) { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, IdxT(1)); - - if (counter_len == 1) { - out[k - 1] = value; - out_idx[k - 1] = in_idx_buf ? in_idx_buf[i] : i; + kth_value_bits, + p_filter_cnt, + p_out_cnt, + early_stop](T value, IdxT i) { + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + if (early_stop) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); } - } else if (prev_bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, IdxT(1)); + } + // the condition `(out_buf || early_stop)` is a little tricky: + // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should skip writing to + // `out` too. So we won't write the same value to `out` multiple times in different passes. + // And if we keep skipping the writing, values will be written in `last_filter_kernel()` at + // last. But when `early_stop` is true, we need to write to `out` since it's the last chance. + else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } }; - - vectorized_process(in_buf, previous_len, f); + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); } + if (early_stop) { return; } __syncthreads(); // merge histograms produced by individual blocks @@ -256,69 +310,184 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, } /** - * Replace a part of the histogram with its own prefix sum, starting from the `start` and adding - * `current` to each entry of the result. + * Replace histogram with its own prefix sum * (step 2 in `radix_kernel` description) */ template -_RAFT_DEVICE void scan(volatile IdxT* histogram, - const int start, - const int num_buckets, - const IdxT current) +_RAFT_DEVICE void scan(volatile IdxT* histogram) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; + constexpr int num_buckets = calc_num_buckets(); + if constexpr (num_buckets >= BlockSize) { + static_assert(num_buckets % BlockSize == 0); + constexpr int items_per_thread = num_buckets / BlockSize; + typedef cub::BlockLoad BlockLoad; + typedef cub::BlockStore + BlockStore; + typedef cub::BlockScan BlockScan; - IdxT thread_data = 0; - int index = start + threadIdx.x; - if (index < num_buckets) { thread_data = histogram[index]; } + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + IdxT thread_data[items_per_thread]; - BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); - __syncthreads(); - if (index < num_buckets) { histogram[index] = thread_data + current; } - __syncthreads(); // This sync is necessary, as the content of histogram needs - // to be read after + BlockLoad(temp_storage.load).Load(histogram, thread_data); + __syncthreads(); + + BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + BlockStore(temp_storage.store).Store(histogram, thread_data); + } else { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + IdxT thread_data = 0; + if (threadIdx.x < num_buckets) { thread_data = histogram[threadIdx.x]; } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + if (threadIdx.x < num_buckets) { histogram[threadIdx.x] = thread_data; } + } } /** * Calculate in which bucket the k-th value will fall - * (steps 2-3 in `radix_kernel` description) + * (steps 3 in `radix_kernel` description) */ -template -_RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +template +_RAFT_DEVICE void choose_bucket(Counter* counter, + const IdxT* histogram, + const IdxT k, + const int pass) { constexpr int num_buckets = calc_num_buckets(); - int index = threadIdx.x; - IdxT last_prefix_sum = 0; - int num_pass = 1; - if constexpr (num_buckets >= BlockSize) { - static_assert(num_buckets % BlockSize == 0); - num_pass = num_buckets / BlockSize; + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + IdxT prev = (i == 0) ? 0 : histogram[i - 1]; + IdxT cur = histogram[i]; + + // one and only one thread will satisfy this condition, so counter is written by only one thread + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename cub::Traits::UnsignedBits bucket = i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } } +} + +// For one-block version, last_filter() could be called when pass < num_passes - 1. +// So `pass` could not be constexpr +template +_RAFT_DEVICE void last_filter(const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT current_len, + IdxT k, + Counter* counter, + const bool select_min, + const int pass) +{ + const auto kth_value_bits = counter->kth_value_bits; + const int start_bit = calc_start_bit(pass); - for (int i = 0; i < num_pass && (last_prefix_sum < k); i++) { - // Turn the i-th chunk of the histogram into its prefix sum. - scan(histogram, i * BlockSize, num_buckets, last_prefix_sum); - if (index < num_buckets) { - // Number of values in the previous `index-1` buckets (see the `scan` op above) - IdxT prev = (index == 0) ? 0 : histogram[index - 1]; - // Number of values in `index` buckets - IdxT cur = histogram[index]; - - // one and only one thread will satisfy this condition, so only write once - if (prev < k && cur >= k) { - counter->k = k - prev; // how many values still are there to find - counter->previous_len = counter->len; - counter->len = cur - prev; // number of values in `index` bucket - counter->bucket = index; + // changed in choose_bucket(); need to reload + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` could be nullptr if + // `in_buf` is `in` + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } } - index += BlockSize; - // this will break the loop when the counter is set (cur >= k), because last_prefix_sum >= cur - last_prefix_sum = histogram[(i + 1) * BlockSize - 1]; } } +template +__global__ void last_filter_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + IdxT k, + Counter* counters, + const bool select_min) +{ + const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow + + Counter* counter = counters + batch_id; + IdxT previous_len = counter->previous_len; + if (previous_len == 0) { return; } + const IdxT buf_len = calc_buf_len(len); + if (previous_len > buf_len || in_buf == in) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + out += batch_id * k; + out_idx += batch_id * k; + + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); + + const auto kth_value_bits = counter->kth_value_bits; + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + + auto f = [k, + select_min, + kth_value_bits, + needed_num_of_kth, + p_out_cnt, + p_out_back_cnt, + in_idx_buf, + out, + out_idx](T value, IdxT i) { + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + }; + + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); +} + /** * * It is expected to call this kernel multiple times (passes), in each pass we process a radix, @@ -347,35 +516,79 @@ _RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, cons * * In the implementation, the filtering step is delayed to the next pass so the filtering and * histogram computation are fused. In this way, inputs are read once rather than twice. + * + * During the filtering step, we won't write candidates (elements in bucket j) to `out_buf` if the + * number of candidates is larger than the length of `out_buf` (this could happen when the leading + * bits of input values are almost the same). And then in the next pass, inputs are read from `in` + * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and + * their indices. */ -template -__global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counters, - IdxT* histograms, - const IdxT len, - const int k, - const bool greater, - const int pass) +template +__global__ void radix_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const IdxT k, + const bool select_min, + const int pass) { - __shared__ bool isLastBlockDone; + const size_t batch_id = blockIdx.y; + auto counter = counters + batch_id; + IdxT current_k; + IdxT previous_len; + IdxT current_len; + if (pass == 0) { + current_k = k; + previous_len = len; + // Need to do this so setting counter->previous_len for the next pass is correct. + // This value is meaningless for pass 0, but it's fine because pass 0 won't be the + // last pass in this implementation so pass 0 won't hit the "if (pass == + // num_passes - 1)" branch. + // Maybe it's better to reload counter->previous_len and use it rather than + // current_len in last_filter() + current_len = len; + } else { + current_k = counter->k; + current_len = counter->len; + previous_len = counter->previous_len; + } + if (current_len == 0) { return; } - constexpr int num_buckets = calc_num_buckets(); - constexpr int num_passes = calc_num_passes(); - const int batch_id = blockIdx.y; - in_buf += batch_id * len; - out_buf += batch_id * len; + // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle + // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is + // handled in other way in select_k() so such case is not possible here. + const bool early_stop = (current_len == current_k); + const IdxT buf_len = calc_buf_len(len); + + // "previous_len > buf_len" means previous pass skips writing buffer + if (pass == 0 || pass == 1 || previous_len > buf_len) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + // "current_len > buf_len" means current pass will skip writing buffer + if (pass == 0 || current_len > buf_len) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else { + out_buf += batch_id * buf_len; + out_idx_buf += batch_id * buf_len; + } out += batch_id * k; out_idx += batch_id * k; - if (in_idx_buf) { in_idx_buf += batch_id * len; } - if (out_idx_buf) { out_idx_buf += batch_id * len; } - auto counter = counters + batch_id; - auto histogram = histograms + batch_id * num_buckets; + constexpr int num_buckets = calc_num_buckets(); + auto histogram = histograms + batch_id * num_buckets; filter_and_histogram(in_buf, in_idx_buf, @@ -383,126 +596,464 @@ __global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf, out_idx_buf, out, out_idx, - len, + previous_len, counter, histogram, - greater, + select_min, pass, - k); + early_stop); __threadfence(); + bool isLastBlock = false; if (threadIdx.x == 0) { unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); - isLastBlockDone = (finished == (gridDim.x - 1)); + isLastBlock = (finished == (gridDim.x - 1)); } - // Synchronize to make sure that each thread reads the correct value of - // isLastBlockDone. - __syncthreads(); - if (isLastBlockDone) { - if (counter->len == 1 && threadIdx.x == 0) { - counter->previous_len = 0; - counter->len = 0; - } - // init counter, other members of counter is initialized with 0 by - // cudaMemset() - if (pass == 0 && threadIdx.x == 0) { - counter->k = k; - counter->len = len; - counter->out_back_cnt = 0; + if (__syncthreads_or(isLastBlock)) { + if (early_stop) { + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len + counter->previous_len = 0; + counter->len = 0; + } + return; } + + scan(histogram); + __syncthreads(); + choose_bucket(counter, histogram, current_k, pass); __syncthreads(); - IdxT ori_k = counter->k; + constexpr int num_passes = calc_num_passes(); + // reset for next pass + if (pass != num_passes - 1) { + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + } + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len even in the last pass + counter->previous_len = current_len; + // not necessary for the last pass, but put it here anyway + counter->filter_cnt = 0; + } + + if constexpr (fused_last_filter) { + if (pass == num_passes - 1) { + last_filter(out_buf ? out_buf : in_buf, + out_idx_buf ? out_idx_buf : in_idx_buf, + out, + out_idx, + out_buf ? current_len : len, + k, + counter, + select_min, + pass); + } + } + } +} + +template +int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel) +{ + int active_blocks; + RAFT_CUDA_TRY( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks, kernel, BlockSize, 0)); + + constexpr int items_per_thread = 32; + constexpr int num_waves = 10; + int chunk_size = + std::max(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len); + return std::min(chunk_size, batch_size); +} + +template +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) +{ + static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); + + int active_blocks; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &active_blocks, radix_kernel, BlockSize, 0)); + active_blocks *= sm_cnt; + + IdxT best_num_blocks = 0; + float best_tail_wave_penalty = 1.0f; + const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); + for (int num_waves = 1;; ++num_waves) { + IdxT num_blocks = std::min( + max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); + IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); + items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); + num_blocks = ceildiv(len, items_per_thread * BlockSize); + float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; + float tail_wave_penalty = + (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); - if (counter->len > 0) { - choose_bucket(counter, histogram, ori_k); + // 0.15 is determined experimentally. It also ensures breaking the loop early, + // e.g. when num_waves > 7, tail_wave_penalty will always <0.15 + if (tail_wave_penalty < 0.15) { + best_num_blocks = num_blocks; + break; + } else if (tail_wave_penalty < best_tail_wave_penalty) { + best_num_blocks = num_blocks; + best_tail_wave_penalty = tail_wave_penalty; } - __syncthreads(); - if (pass == num_passes - 1) { - const IdxT previous_len = counter->previous_len; - const int want_bucket = counter->bucket; - int start_bit = calc_start_bit(pass); - unsigned mask = calc_mask(pass); - - // radix topk - IdxT& out_cnt = counter->out_cnt; - for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { - const T value = out_buf[i]; - int bucket = calc_bucket(value, start_bit, mask, greater); - if (bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, IdxT(1)); - out[pos] = value; - out_idx[pos] = out_idx_buf[i]; - } else if (bucket == want_bucket) { - IdxT needed_num_of_kth = counter->k; - IdxT back_pos = atomicAdd(&(counter->out_back_cnt), IdxT(1)); - if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; - out[pos] = value; - out_idx[pos] = out_idx_buf[i]; - } - } + if (num_blocks == max_num_blocks) { break; } + } + return best_num_blocks; +} + +template +_RAFT_HOST_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } +} + +template +void radix_topk(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + bool fused_last_filter, + unsigned grid_dim, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // TODO: is it possible to relax this restriction? + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + auto kernel = radix_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); + if (max_chunk_size != static_cast(batch_size)) { + grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt); + } + const IdxT buf_len = calc_buf_len(len); + + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf + 256 * 6; // might need extra memory for alignment + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); + if (pool_guard) { + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::device_uvector> counters(max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + RAFT_CUDA_TRY( + cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); + + const T* chunk_in = in + offset * len; + const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + dim3 blocks(grid_dim, chunk_size); + constexpr int num_passes = calc_num_passes(); + + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers(chunk_in, + chunk_in_idx, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data(), + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + + if (fused_last_filter && pass == num_passes - 1) { + kernel = radix_kernel; } - __syncthreads(); - } else { - // reset for next pass - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - histogram[i] = 0; + + kernel<<>>(chunk_in, + chunk_in_idx, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + counters.data(), + histograms.data(), + len, + k, + select_min, + pass); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + if (!fused_last_filter) { + last_filter_kernel<<>>(chunk_in, + chunk_in_idx, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + len, + k, + counters.data(), + select_min); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } +} + +// The following a few functions are for the one-block version, which uses single thread block for +// each row of a batch. +template +_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + IdxT* p_filter_cnt = &counter->filter_cnt; + if (threadIdx.x == 0) { *p_filter_cnt = 0; } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + const IdxT previous_len = counter->previous_len; + + if (pass == 0) { + auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; + vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); + } else { + // not use vectorized_process here because it increases #registers a lot + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } else if (previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } - if (threadIdx.x == 0) { counter->filter_cnt = 0; } } } } -/** - * Calculate the minimal batch size, such that GPU is still fully occupied. - */ template -inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) +__global__ void radix_topk_one_block_kernel(const T* in, + const IdxT* in_idx, + const IdxT len, + const IdxT k, + T* out, + IdxT* out_idx, + const bool select_min, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2) { - int dev_id, sm_count, occupancy, max_grid_dim_y; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&max_grid_dim_y, cudaDevAttrMaxGridDimY, dev_id)); - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, radix_kernel, BlockSize, 0)); - - // number of block we'd use if the batch size is enough to occupy the gpu in any case - size_t blocks_per_row = ceildiv(len, BlockSize * ITEM_PER_THREAD); - - // fully occupy GPU - size_t opt_batch_size = ceildiv(sm_count * occupancy, blocks_per_row); - // round it up to the closest pow-of-two for better data alignment - opt_batch_size = isPo2(opt_batch_size) ? opt_batch_size : (1 << (log2(opt_batch_size) + 1)); - // Take a max possible pow-of-two grid_dim_y - max_grid_dim_y = isPo2(max_grid_dim_y) ? max_grid_dim_y : (1 << log2(max_grid_dim_y)); - // If the optimal batch size is very small compared to the requested batch size, we know - // the extra required memory is not significant and we can increase the batch size for - // better occupancy when the grid size is not multiple of the SM count. - // Also don't split the batch size when there is not much work overall. - const size_t safe_enlarge_factor = 9; - const size_t min_grid_size = 1024; - while ((opt_batch_size << safe_enlarge_factor) < req_batch_size || - blocks_per_row * opt_batch_size < min_grid_size) { - opt_batch_size <<= 1; + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ IdxT histogram[num_buckets]; + + if (threadIdx.x == 0) { + counter.k = k; + counter.len = len; + counter.previous_len = len; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; } + __syncthreads(); + + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + in += batch_id * len; + if (in_idx) { in_idx += batch_id * len; } + out += batch_id * k; + out_idx += batch_id * k; + buf1 += batch_id * len; + idx_buf1 += batch_id * len; + buf2 += batch_id * len; + idx_buf2 += batch_id * len; + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + constexpr int num_passes = calc_num_passes(); + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers( + in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); + + IdxT current_len = counter.len; + IdxT current_k = counter.k; + + filter_and_histogram_for_one_block(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + &counter, + histogram, + select_min, + pass); + __syncthreads(); - // Do not exceed the max grid size. - opt_batch_size = std::min(opt_batch_size, size_t(max_grid_dim_y)); - // Don't do more work than needed - opt_batch_size = std::min(opt_batch_size, req_batch_size); - // Let more blocks share one row if the required batch size is too small. - while (opt_batch_size * blocks_per_row < size_t(sm_count * occupancy) && - // Ensure we still can read data somewhat efficiently - len * sizeof(T) > 2 * VECTORIZED_READ_SIZE * BlockSize * blocks_per_row) { - blocks_per_row <<= 1; + scan(histogram); + __syncthreads(); + + choose_bucket(&counter, histogram, current_k, pass); + if (threadIdx.x == 0) { counter.previous_len = current_len; } + __syncthreads(); + + if (counter.len == counter.k || pass == num_passes - 1) { + last_filter(pass == 0 ? in : out_buf, + pass == 0 ? in_idx : out_idx_buf, + out, + out_idx, + current_len, + k, + &counter, + select_min, + pass); + break; + } } +} - return dim3(blocks_per_row, opt_batch_size); +// radix_topk() might use multiple thread blocks for one row of a batch. In contrast, the following +// one-block version uses single thread block for one row of a batch, so intermediate data, like +// counters and global histograms, can be kept in shared memory and cheap sync operations can be +// used. It's used when len is relatively small or when the number of blocks per row calculated by +// `calc_grid_dim()` is 1. +template +void radix_topk_one_block(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + static_assert(calc_num_passes() > 1); + + auto kernel = radix_topk_one_block_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); + + auto pool_guard = + raft::get_pool_memory_resource(mr, + max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)) + + 256 * 4 // might need extra memory for alignment + ); + if (pool_guard) { + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::device_uvector buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + kernel<<>>(in + offset * len, + in_idx ? (in_idx + offset * len) : nullptr, + len, + k, + out + offset * k, + out_idx + offset * k, + select_min, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data()); + } } +} // namespace impl + /** * Select k smallest or largest key/values from each row in the input data. * @@ -543,6 +1094,12 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) * the payload selected together with `out`. * @param select_min * whether to select k smallest (true) or largest (false) keys. + * @param fused_last_filter + * when it's true, the last filter is fused into the kernel in the last pass and only one thread + * block will do the filtering; when false, a standalone filter kernel with multiple thread + * blocks is called. The later case is preferable when leading bits of input data are almost the + * same. That is, when the value range of input data is narrow. In such case, there could be a + * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. * @param stream * @param mr an optional memory resource to use across the calls (you can provide a large enough * memory pool here to avoid memory allocations within the call). @@ -550,109 +1107,65 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) template void select_k(const T* in, const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, + int batch_size, + IdxT len, + IdxT k, T* out, IdxT* out_idx, bool select_min, + bool fused_last_filter, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = nullptr) { - // reduce the block size if the input length is too small. - if constexpr (BlockSize > calc_min_block_size()) { - if (BlockSize * ITEM_PER_THREAD > len) { - return select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + if (k == len) { + RAFT_CUDA_TRY( + cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + if (in_idx) { + RAFT_CUDA_TRY(cudaMemcpyAsync( + out_idx, in_idx, sizeof(IdxT) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + } else { + auto out_idx_view = + raft::make_device_vector_view(out_idx, static_cast(len) * batch_size); + raft::device_resources handle(stream); + raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op(len)); } + return; } - // TODO: is it possible to relax this restriction? - static_assert(calc_num_passes() > 1); - constexpr int num_buckets = calc_num_buckets(); - - dim3 blocks = get_optimal_grid_size(batch_size, len); - size_t max_chunk_size = blocks.y; - - size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); - size_t req_buf = max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)); - size_t mem_req = req_aux + req_buf; - size_t mem_free, mem_total; - RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); - std::optional managed_memory; - rmm::mr::device_memory_resource* mr_buf = nullptr; - if (mem_req > mem_free) { - // if there's not enough memory for buffers on the device, resort to the managed memory. - mem_req = req_aux; - managed_memory.emplace(); - mr_buf = &managed_memory.value(); + // TODO: use device_resources::get_device_properties() instead; should change it when we refactor + // resource management + int sm_cnt; + { + int dev; + RAFT_CUDA_TRY(cudaGetDevice(&dev)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); } - auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); - if (pool_guard) { - RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); - } - if (mr_buf == nullptr) { mr_buf = mr; } - - rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); - - for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { - blocks.y = std::min(max_chunk_size, batch_size - offset); - - RAFT_CUDA_TRY( - cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; + constexpr int items_per_thread = 32; - constexpr int num_passes = calc_num_passes(); - - for (int pass = 0; pass < num_passes; ++pass) { - if (pass == 0) { - in_buf = in + offset * len; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in + offset * len; - in_idx_buf = in_idx ? in_idx + offset * len : nullptr; - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } else if (pass % 2 == 0) { - in_buf = buf1.data(); - in_idx_buf = idx_buf1.data(); - out_buf = buf2.data(); - out_idx_buf = idx_buf2.data(); - } else { - in_buf = buf2.data(); - in_idx_buf = idx_buf2.data(); - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } - - radix_kernel - <<>>(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out + offset * k, - out_idx + offset * k, - counters.data(), - histograms.data(), - len, - k, - !select_min, - pass); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + if (len <= BlockSize * items_per_thread) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + } else { + unsigned grid_dim = + impl::calc_grid_dim(batch_size, len, sm_cnt); + if (grid_dim == 1) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + } else { + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + fused_last_filter, + grid_dim, + sm_cnt, + stream, + mr); } } } diff --git a/cpp/include/raft/matrix/detail/select_radix_updated.cuh b/cpp/include/raft/matrix/detail/select_radix_updated.cuh deleted file mode 100644 index 7daa8cf6e6..0000000000 --- a/cpp/include/raft/matrix/detail/select_radix_updated.cuh +++ /dev/null @@ -1,1173 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include - -namespace raft::matrix::detail::select::radix { -namespace impl { - -constexpr int VECTORIZED_READ_SIZE = 16; - -template -_RAFT_HOST_DEVICE constexpr int calc_num_buckets() -{ - return 1 << BitsPerPass; -} - -template -_RAFT_HOST_DEVICE constexpr int calc_num_passes() -{ - return ceildiv(sizeof(T) * 8, BitsPerPass); -} - -/** - * Bit 0 is the least significant (rightmost); - * this implementation processes input from the most to the least significant bit. - * This way, we can skip some passes in the end at the cost of having an unsorted output. - * - * NB: Use pass=-1 for calc_mask(). - */ -template -_RAFT_DEVICE constexpr int calc_start_bit(int pass) -{ - int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; - if (start_bit < 0) { start_bit = 0; } - return start_bit; -} - -template -_RAFT_DEVICE constexpr unsigned calc_mask(int pass) -{ - static_assert(BitsPerPass <= 31); - int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); - return (1 << num_bits) - 1; -} - -/** - * Use CUB to twiddle bits - so that we can correctly compare bits of floating-point values as well - * as of integers. - */ -template -_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) -{ - auto bits = reinterpret_cast::UnsignedBits&>(key); - bits = cub::Traits::TwiddleIn(bits); - if (!select_min) { bits = ~bits; } - return bits; -} - -template -_RAFT_DEVICE T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) -{ - if (!select_min) { bits = ~bits; } - bits = cub::Traits::TwiddleOut(bits); - return reinterpret_cast(bits); -} - -template -_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) -{ - static_assert(BitsPerPass <= sizeof(int) * 8 - 1, - "BitsPerPass is too large that the result type could not be int"); - return (twiddle_in(x, select_min) >> start_bit) & mask; -} - -template -_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) -{ - // When writing is skipped, only read `in`(type T). - // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) - // and `out_idx_buf`(IdxT). - // The ratio between these cases determines whether to skip writing and hence the buffer size. - constexpr float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); - return len / ratio; -} - -/** - * Map a Func over the input data, using vectorized load instructions if possible. - * - * NB: in future, we should move this to cpp/include/raft/linalg/detail/unary_op.cuh, which - * currently does not support the second lambda argument (index of an element) - * - * @tparam T element type - * @tparam IdxT indexing type - * @tparam Func void (T x, IdxT idx) - * - * @param thread_rank rank of the calling thread among all participating threads - * @param num_threads number of the threads that participate in processing - * @param in the input data - * @param len the number of elements to read - * @param f the lambda taking two arguments (T x, IdxT idx) - */ -template -_RAFT_DEVICE void vectorized_process( - size_t thread_rank, size_t num_threads, const T* in, IdxT len, Func f) -{ - if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { - for (IdxT i = thread_rank; i < len; i += num_threads) { - f(in[i], i); - } - } else { - using wide_t = TxN_t; - using align_bytes = Pow2<(size_t)VECTORIZED_READ_SIZE>; - using align_elems = Pow2; - wide_t wide; - - // how many elements to skip in order to do aligned vectorized load - const IdxT skip_cnt_left = std::min((IdxT)(align_bytes::roundUp(in) - in), len); - - // The main loop: process all aligned data - for (IdxT i = thread_rank * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; - i += num_threads * wide_t::Ratio) { - wide.load(in, i); -#pragma unroll - for (int j = 0; j < wide_t::Ratio; ++j) { - f(wide.val.data[j], i + j); - } - } - - static_assert(WarpSize >= wide_t::Ratio); - // Processes the skipped elements on the left - if (thread_rank < skip_cnt_left) { f(in[thread_rank], thread_rank); } - // Processes the skipped elements on the right - const IdxT skip_cnt_right = align_elems::mod(len - skip_cnt_left); - const IdxT remain_i = len - skip_cnt_right + thread_rank; - if (remain_i < len) { f(in[remain_i], remain_i); } - } -} - -template -struct alignas(128) Counter { - // We are processing the values in multiple passes, from most significant to least significant. In - // each pass, we keep the length of input (`len`) and the `k` of current pass, and update them at - // the end of the pass. - IdxT k; - IdxT len; - - // `previous_len` is the length of input in previous pass. Note that `previous_len` rather - // than `len` is used for the filtering step because filtering is indeed for previous pass (see - // comments before `radix_kernel`). - IdxT previous_len; - - // We determine the bits of the k_th value inside the mask processed by the pass. The - // already known bits are stored in `kth_value_bits`. It's used to discriminate a element is a - // result (written to `out`), a candidate for next pass (written to `out_buf`), or not useful - // (discarded). The bits that are not yet processed do not matter for this purpose. - typename cub::Traits::UnsignedBits kth_value_bits; - - // Record how many elements have passed filtering. It's used to determine the position in the - // `out_buf` where an element should be written. - alignas(128) IdxT filter_cnt; - - // For a row inside a batch, we may launch multiple thread blocks. This counter is used to - // determine if the current block is the last running block. If so, this block will execute scan() - // and choose_bucket(). - alignas(128) unsigned int finished_block_cnt; - - // Record how many elements have been written to the front of `out`. Elements less (if - // select_min==true) than the k-th value are written from front to back. - alignas(128) IdxT out_cnt; - - // Record how many elements have been written to the back of `out`. Elements equal to the k-th - // value are written from back to front. We need to keep count of them separately because the - // number of elements that <= the k-th value might exceed k. - alignas(128) IdxT out_back_cnt; -}; - -/** - * Fused filtering of the current pass and building histogram for the next pass - * (see steps 4 & 1 in `radix_kernel` description). - */ -template -_RAFT_DEVICE void filter_and_histogram(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT previous_len, - Counter* counter, - IdxT* histogram, - bool select_min, - int pass, - bool early_stop) -{ - constexpr int num_buckets = calc_num_buckets(); - __shared__ IdxT histogram_smem[num_buckets]; - for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { - histogram_smem[i] = 0; - } - __syncthreads(); - - const int start_bit = calc_start_bit(pass); - const unsigned mask = calc_mask(pass); - - if (pass == 0) { - // Passed to vectorized_process, this function executes in all blocks in parallel, - // i.e. the work is split along the input (both, in batches and chunks of a single row). - // Later, the histograms are merged using atomicAdd. - auto f = [select_min, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram_smem + bucket, static_cast(1)); - }; - vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, - static_cast(blockDim.x) * gridDim.x, - in_buf, - previous_len, - f); - } else { - IdxT* p_filter_cnt = &counter->filter_cnt; - IdxT* p_out_cnt = &counter->out_cnt; - const auto kth_value_bits = counter->kth_value_bits; - const int previous_start_bit = calc_start_bit(pass - 1); - - // See the remark above on the distributed execution of `f` using vectorized_process. - auto f = [in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - select_min, - start_bit, - mask, - previous_start_bit, - kth_value_bits, - p_filter_cnt, - p_out_cnt, - early_stop](T value, IdxT i) { - const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) - << previous_start_bit; - if (previous_bits == kth_value_bits) { - if (early_stop) { - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - out[pos] = value; - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; - } else { - if (out_buf) { - IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); - out_buf[pos] = value; - out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; - } - - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram_smem + bucket, static_cast(1)); - } - } - // the condition `(out_buf || early_stop)` is a little tricky: - // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should skip writing to - // `out` too. So we won't write the same value to `out` multiple times in different passes. - // And if we keep skipping the writing, values will be written in `last_filter_kernel()` at - // last. But when `early_stop` is true, we need to write to `out` since it's the last chance. - else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - out[pos] = value; - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; - } - }; - vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, - static_cast(blockDim.x) * gridDim.x, - in_buf, - previous_len, - f); - } - if (early_stop) { return; } - __syncthreads(); - - // merge histograms produced by individual blocks - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } - } -} - -/** - * Replace histogram with its own prefix sum - * (step 2 in `radix_kernel` description) - */ -template -_RAFT_DEVICE void scan(volatile IdxT* histogram) -{ - constexpr int num_buckets = calc_num_buckets(); - if constexpr (num_buckets >= BlockSize) { - static_assert(num_buckets % BlockSize == 0); - constexpr int items_per_thread = num_buckets / BlockSize; - typedef cub::BlockLoad BlockLoad; - typedef cub::BlockStore - BlockStore; - typedef cub::BlockScan BlockScan; - - __shared__ union { - typename BlockLoad::TempStorage load; - typename BlockScan::TempStorage scan; - typename BlockStore::TempStorage store; - } temp_storage; - IdxT thread_data[items_per_thread]; - - BlockLoad(temp_storage.load).Load(histogram, thread_data); - __syncthreads(); - - BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); - __syncthreads(); - - BlockStore(temp_storage.store).Store(histogram, thread_data); - } else { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - - IdxT thread_data = 0; - if (threadIdx.x < num_buckets) { thread_data = histogram[threadIdx.x]; } - - BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); - __syncthreads(); - - if (threadIdx.x < num_buckets) { histogram[threadIdx.x] = thread_data; } - } -} - -/** - * Calculate in which bucket the k-th value will fall - * (steps 3 in `radix_kernel` description) - */ -template -_RAFT_DEVICE void choose_bucket(Counter* counter, - const IdxT* histogram, - const IdxT k, - const int pass) -{ - constexpr int num_buckets = calc_num_buckets(); - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - IdxT prev = (i == 0) ? 0 : histogram[i - 1]; - IdxT cur = histogram[i]; - - // one and only one thread will satisfy this condition, so counter is written by only one thread - if (prev < k && cur >= k) { - counter->k = k - prev; // how many values still are there to find - counter->len = cur - prev; // number of values in next pass - typename cub::Traits::UnsignedBits bucket = i; - int start_bit = calc_start_bit(pass); - counter->kth_value_bits |= bucket << start_bit; - } - } -} - -// For one-block version, last_filter() could be called when pass < num_passes - 1. -// So `pass` could not be constexpr -template -_RAFT_DEVICE void last_filter(const T* in_buf, - const IdxT* in_idx_buf, - T* out, - IdxT* out_idx, - IdxT current_len, - IdxT k, - Counter* counter, - const bool select_min, - const int pass) -{ - const auto kth_value_bits = counter->kth_value_bits; - const int start_bit = calc_start_bit(pass); - - // changed in choose_bucket(); need to reload - const IdxT needed_num_of_kth = counter->k; - IdxT* p_out_cnt = &counter->out_cnt; - IdxT* p_out_back_cnt = &counter->out_back_cnt; - for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { - const T value = in_buf[i]; - const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; - if (bits < kth_value_bits) { - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - out[pos] = value; - // For one-block version, `in_idx_buf` could be nullptr at pass 0. - // For non one-block version, if writing has been skipped, `in_idx_buf` could be nullptr if - // `in_buf` is `in` - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; - } else if (bits == kth_value_bits) { - IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); - if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; - out[pos] = value; - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; - } - } - } -} - -template -__global__ void last_filter_kernel(const T* in, - const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, - T* out, - IdxT* out_idx, - IdxT len, - IdxT k, - Counter* counters, - const bool select_min) -{ - const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow - - Counter* counter = counters + batch_id; - IdxT previous_len = counter->previous_len; - if (previous_len == 0) { return; } - const IdxT buf_len = calc_buf_len(len); - if (previous_len > buf_len || in_buf == in) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; - } - out += batch_id * k; - out_idx += batch_id * k; - - constexpr int pass = calc_num_passes() - 1; - constexpr int start_bit = calc_start_bit(pass); - - const auto kth_value_bits = counter->kth_value_bits; - const IdxT needed_num_of_kth = counter->k; - IdxT* p_out_cnt = &counter->out_cnt; - IdxT* p_out_back_cnt = &counter->out_back_cnt; - - auto f = [k, - select_min, - kth_value_bits, - needed_num_of_kth, - p_out_cnt, - p_out_back_cnt, - in_idx_buf, - out, - out_idx](T value, IdxT i) { - const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; - if (bits < kth_value_bits) { - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - out[pos] = value; - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; - } else if (bits == kth_value_bits) { - IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); - if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; - out[pos] = value; - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; - } - } - }; - - vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, - static_cast(blockDim.x) * gridDim.x, - in_buf, - previous_len, - f); -} - -/** - * - * It is expected to call this kernel multiple times (passes), in each pass we process a radix, - * going from the most significant towards the least significant bits (MSD). - * - * Conceptually, each pass consists of 4 steps: - * - * 1. Calculate histogram - * First, transform bits into a digit, the value of which is in the range - * [0, 2^{BITS_PER_PASS}-1]. Then count the frequency of each digit value and the result is a - * histogram. That is, histogram[i] contains the count of inputs having value i. - * - * 2. Scan the histogram - * Inclusive prefix sum is computed for the histogram. After this step, histogram[i] contains - * the count of inputs having value <= i. - * - * 3. Find the bucket j of the histogram that the k-th value falls into - * - * 4. Filtering - * Input elements whose digit value -__global__ void radix_kernel(const T* in, - const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counters, - IdxT* histograms, - const IdxT len, - const IdxT k, - const bool select_min, - const int pass) -{ - const size_t batch_id = blockIdx.y; - auto counter = counters + batch_id; - IdxT current_k; - IdxT previous_len; - IdxT current_len; - if (pass == 0) { - current_k = k; - previous_len = len; - // Need to do this so setting counter->previous_len for the next pass is correct. - // This value is meaningless for pass 0, but it's fine because pass 0 won't be the - // last pass in this implementation so pass 0 won't hit the "if (pass == - // num_passes - 1)" branch. - // Maybe it's better to reload counter->previous_len and use it rather than - // current_len in last_filter() - current_len = len; - } else { - current_k = counter->k; - current_len = counter->len; - previous_len = counter->previous_len; - } - if (current_len == 0) { return; } - - // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle - // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is - // handled in other way in select_k() so such case is not possible here. - const bool early_stop = (current_len == current_k); - const IdxT buf_len = calc_buf_len(len); - - // "previous_len > buf_len" means previous pass skips writing buffer - if (pass == 0 || pass == 1 || previous_len > buf_len) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; - } - // "current_len > buf_len" means current pass will skip writing buffer - if (pass == 0 || current_len > buf_len) { - out_buf = nullptr; - out_idx_buf = nullptr; - } else { - out_buf += batch_id * buf_len; - out_idx_buf += batch_id * buf_len; - } - out += batch_id * k; - out_idx += batch_id * k; - - constexpr int num_buckets = calc_num_buckets(); - auto histogram = histograms + batch_id * num_buckets; - - filter_and_histogram(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - previous_len, - counter, - histogram, - select_min, - pass, - early_stop); - __threadfence(); - - bool isLastBlock = false; - if (threadIdx.x == 0) { - unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); - isLastBlock = (finished == (gridDim.x - 1)); - } - - if (__syncthreads_or(isLastBlock)) { - if (early_stop) { - if (threadIdx.x == 0) { - // `last_filter_kernel()` requires setting previous_len - counter->previous_len = 0; - counter->len = 0; - } - return; - } - - scan(histogram); - __syncthreads(); - choose_bucket(counter, histogram, current_k, pass); - __syncthreads(); - - constexpr int num_passes = calc_num_passes(); - // reset for next pass - if (pass != num_passes - 1) { - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - histogram[i] = 0; - } - } - if (threadIdx.x == 0) { - // `last_filter_kernel()` requires setting previous_len even in the last pass - counter->previous_len = current_len; - // not necessary for the last pass, but put it here anyway - counter->filter_cnt = 0; - } - - if constexpr (fused_last_filter) { - if (pass == num_passes - 1) { - last_filter(out_buf ? out_buf : in_buf, - out_idx_buf ? out_idx_buf : in_idx_buf, - out, - out_idx, - out_buf ? current_len : len, - k, - counter, - select_min, - pass); - } - } - } -} - -template -int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel) -{ - int active_blocks; - RAFT_CUDA_TRY( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks, kernel, BlockSize, 0)); - - constexpr int items_per_thread = 32; - constexpr int num_waves = 10; - int chunk_size = - std::max(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len); - return std::min(chunk_size, batch_size); -} - -template -unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) -{ - static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); - - int active_blocks; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &active_blocks, radix_kernel, BlockSize, 0)); - active_blocks *= sm_cnt; - - IdxT best_num_blocks = 0; - float best_tail_wave_penalty = 1.0f; - const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); - for (int num_waves = 1;; ++num_waves) { - IdxT num_blocks = std::min( - max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); - IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); - items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); - num_blocks = ceildiv(len, items_per_thread * BlockSize); - float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; - float tail_wave_penalty = - (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); - - // 0.15 is determined experimentally. It also ensures breaking the loop early, - // e.g. when num_waves > 7, tail_wave_penalty will always <0.15 - if (tail_wave_penalty < 0.15) { - best_num_blocks = num_blocks; - break; - } else if (tail_wave_penalty < best_tail_wave_penalty) { - best_num_blocks = num_blocks; - best_tail_wave_penalty = tail_wave_penalty; - } - - if (num_blocks == max_num_blocks) { break; } - } - return best_num_blocks; -} - -template -_RAFT_HOST_DEVICE void set_buf_pointers(const T* in, - const IdxT* in_idx, - T* buf1, - IdxT* idx_buf1, - T* buf2, - IdxT* idx_buf2, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) -{ - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = buf1; - out_idx_buf = idx_buf1; - } else if (pass % 2 == 0) { - in_buf = buf1; - in_idx_buf = idx_buf1; - out_buf = buf2; - out_idx_buf = idx_buf2; - } else { - in_buf = buf2; - in_idx_buf = idx_buf2; - out_buf = buf1; - out_idx_buf = idx_buf1; - } -} - -template -void radix_topk(const T* in, - const IdxT* in_idx, - int batch_size, - IdxT len, - IdxT k, - T* out, - IdxT* out_idx, - bool select_min, - bool fused_last_filter, - unsigned grid_dim, - int sm_cnt, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - // TODO: is it possible to relax this restriction? - static_assert(calc_num_passes() > 1); - constexpr int num_buckets = calc_num_buckets(); - - auto kernel = radix_kernel; - const size_t max_chunk_size = - calc_chunk_size(batch_size, len, sm_cnt, kernel); - if (max_chunk_size != static_cast(batch_size)) { - grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt); - } - const IdxT buf_len = calc_buf_len(len); - - size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); - size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); - size_t mem_req = req_aux + req_buf + 256 * 6; // might need extra memory for alignment - - auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); - if (pool_guard) { - RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); - } - - rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); - - for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { - int chunk_size = std::min(max_chunk_size, batch_size - offset); - RAFT_CUDA_TRY( - cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - - const T* chunk_in = in + offset * len; - const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; - T* chunk_out = out + offset * k; - IdxT* chunk_out_idx = out_idx + offset * k; - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; - - dim3 blocks(grid_dim, chunk_size); - constexpr int num_passes = calc_num_passes(); - - for (int pass = 0; pass < num_passes; ++pass) { - set_buf_pointers(chunk_in, - chunk_in_idx, - buf1.data(), - idx_buf1.data(), - buf2.data(), - idx_buf2.data(), - pass, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf); - - if (fused_last_filter && pass == num_passes - 1) { - kernel = radix_kernel; - } - - kernel<<>>(chunk_in, - chunk_in_idx, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - chunk_out, - chunk_out_idx, - counters.data(), - histograms.data(), - len, - k, - select_min, - pass); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - - if (!fused_last_filter) { - last_filter_kernel<<>>(chunk_in, - chunk_in_idx, - out_buf, - out_idx_buf, - chunk_out, - chunk_out_idx, - len, - k, - counters.data(), - select_min); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - } -} - -// The following a few functions are for the one-block version, which uses single thread block for -// each row of a batch. -template -_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counter, - IdxT* histogram, - bool select_min, - int pass) -{ - constexpr int num_buckets = calc_num_buckets(); - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - histogram[i] = 0; - } - IdxT* p_filter_cnt = &counter->filter_cnt; - if (threadIdx.x == 0) { *p_filter_cnt = 0; } - __syncthreads(); - - const int start_bit = calc_start_bit(pass); - const unsigned mask = calc_mask(pass); - const IdxT previous_len = counter->previous_len; - - if (pass == 0) { - auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram + bucket, static_cast(1)); - }; - vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); - } else { - // not use vectorized_process here because it increases #registers a lot - IdxT* p_out_cnt = &counter->out_cnt; - const auto kth_value_bits = counter->kth_value_bits; - const int previous_start_bit = calc_start_bit(pass - 1); - - for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { - const T value = in_buf[i]; - const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) - << previous_start_bit; - if (previous_bits == kth_value_bits) { - IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); - out_buf[pos] = value; - out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; - - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram + bucket, static_cast(1)); - } else if (previous_bits < kth_value_bits) { - IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - out[pos] = value; - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; - } - } - } -} - -template -__global__ void radix_topk_one_block_kernel(const T* in, - const IdxT* in_idx, - const IdxT len, - const IdxT k, - T* out, - IdxT* out_idx, - const bool select_min, - T* buf1, - IdxT* idx_buf1, - T* buf2, - IdxT* idx_buf2) -{ - constexpr int num_buckets = calc_num_buckets(); - __shared__ Counter counter; - __shared__ IdxT histogram[num_buckets]; - - if (threadIdx.x == 0) { - counter.k = k; - counter.len = len; - counter.previous_len = len; - counter.kth_value_bits = 0; - counter.out_cnt = 0; - counter.out_back_cnt = 0; - } - __syncthreads(); - - const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow - in += batch_id * len; - if (in_idx) { in_idx += batch_id * len; } - out += batch_id * k; - out_idx += batch_id * k; - buf1 += batch_id * len; - idx_buf1 += batch_id * len; - buf2 += batch_id * len; - idx_buf2 += batch_id * len; - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; - - constexpr int num_passes = calc_num_passes(); - for (int pass = 0; pass < num_passes; ++pass) { - set_buf_pointers( - in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); - - IdxT current_len = counter.len; - IdxT current_k = counter.k; - - filter_and_histogram_for_one_block(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - out_idx, - &counter, - histogram, - select_min, - pass); - __syncthreads(); - - scan(histogram); - __syncthreads(); - - choose_bucket(&counter, histogram, current_k, pass); - if (threadIdx.x == 0) { counter.previous_len = current_len; } - __syncthreads(); - - if (counter.len == counter.k || pass == num_passes - 1) { - last_filter(pass == 0 ? in : out_buf, - pass == 0 ? in_idx : out_idx_buf, - out, - out_idx, - current_len, - k, - &counter, - select_min, - pass); - break; - } - } -} - -// radix_topk() might use multiple thread blocks for one row of a batch. In contrast, the following -// one-block version uses single thread block for one row of a batch, so intermediate data, like -// counters and global histograms, can be kept in shared memory and cheap sync operations can be -// used. It's used when len is relatively small or when the number of blocks per row calculated by -// `calc_grid_dim()` is 1. -template -void radix_topk_one_block(const T* in, - const IdxT* in_idx, - int batch_size, - IdxT len, - IdxT k, - T* out, - IdxT* out_idx, - bool select_min, - int sm_cnt, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - static_assert(calc_num_passes() > 1); - - auto kernel = radix_topk_one_block_kernel; - const size_t max_chunk_size = - calc_chunk_size(batch_size, len, sm_cnt, kernel); - - auto pool_guard = - raft::get_pool_memory_resource(mr, - max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)) + - 256 * 4 // might need extra memory for alignment - ); - if (pool_guard) { - RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); - } - - rmm::device_uvector buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector buf2(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); - - for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { - int chunk_size = std::min(max_chunk_size, batch_size - offset); - kernel<<>>(in + offset * len, - in_idx ? (in_idx + offset * len) : nullptr, - len, - k, - out + offset * k, - out_idx + offset * k, - select_min, - buf1.data(), - idx_buf1.data(), - buf2.data(), - idx_buf2.data()); - } -} - -} // namespace impl - -/** - * Select k smallest or largest key/values from each row in the input data. - * - * If you think of the input data `in_keys` as a row-major matrix with len columns and - * batch_size rows, then this function selects k smallest/largest values in each row and fills - * in the row-major matrix `out` of size (batch_size, k). - * - * Note, the output is NOT sorted within the groups of `k` selected elements. - * - * @tparam T - * the type of the keys (what is being compared). - * @tparam IdxT - * the index type (what is being selected together with the keys). - * @tparam BitsPerPass - * The size of the radix; - * it affects the number of passes and number of buckets. - * @tparam BlockSize - * Number of threads in a kernel thread block. - * - * @param[in] in - * contiguous device array of inputs of size (len * batch_size); - * these are compared and selected. - * @param[in] in_idx - * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_keys. - * @param batch_size - * number of input rows, i.e. the batch size. - * @param len - * length of a single input array (row); also sometimes referred as n_cols. - * Invariant: len >= k. - * @param k - * the number of outputs to select in each input row. - * @param[out] out - * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_keys`. - * @param[out] out_idx - * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out`. - * @param select_min - * whether to select k smallest (true) or largest (false) keys. - * @param fused_last_filter - * when it's true, the last filter is fused into the kernel in the last pass and only one thread - * block will do the filtering; when false, a standalone filter kernel with multiple thread - * blocks is called. The later case is preferable when leading bits of input data are almost the - * same. That is, when the value range of input data is narrow. In such case, there could be a - * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. - * @param stream - * @param mr an optional memory resource to use across the calls (you can provide a large enough - * memory pool here to avoid memory allocations within the call). - */ -template -void select_k_updated(const T* in, - const IdxT* in_idx, - int batch_size, - IdxT len, - IdxT k, - T* out, - IdxT* out_idx, - bool select_min, - bool fused_last_filter, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) -{ - if (k == len) { - RAFT_CUDA_TRY( - cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); - if (in_idx) { - RAFT_CUDA_TRY(cudaMemcpyAsync( - out_idx, in_idx, sizeof(IdxT) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); - } else { - auto out_idx_view = - raft::make_device_vector_view(out_idx, static_cast(len) * batch_size); - raft::device_resources handle(stream); - raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op(len)); - } - return; - } - - // TODO: use device_resources::get_device_properties() instead; should change it when we refactor - // resource management - int sm_cnt; - { - int dev; - RAFT_CUDA_TRY(cudaGetDevice(&dev)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); - } - - constexpr int items_per_thread = 32; - - if (len <= BlockSize * items_per_thread) { - impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); - } else { - unsigned grid_dim = - impl::calc_grid_dim(batch_size, len, sm_cnt); - if (grid_dim == 1) { - impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); - } else { - impl::radix_topk(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - fused_last_filter, - grid_dim, - sm_cnt, - stream, - mr); - } - } -} - -} // namespace raft::matrix::detail::select::radix diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index d63c77050a..1e18d2e026 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include @@ -53,8 +52,6 @@ enum class Algo { kPublicApi, kRadix8bits, kRadix11bits, - kRadix8bitsUpdated, - kRadix11bitsUpdated, kRadix11bitsExtraPass, kWarpAuto, kWarpImmediate, @@ -69,8 +66,6 @@ inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream& case Algo::kPublicApi: return os << "kPublicApi"; case Algo::kRadix8bits: return os << "kRadix8bits"; case Algo::kRadix11bits: return os << "kRadix11bits"; - case Algo::kRadix8bitsUpdated: return os << "kRadix8bitsUpdated"; - case Algo::kRadix11bitsUpdated: return os << "kRadix11bitsUpdated"; case Algo::kRadix11bitsExtraPass: return os << "kRadix11bitsExtraPass"; case Algo::kWarpAuto: return os << "kWarpAuto"; case Algo::kWarpImmediate: return os << "kWarpImmediate"; @@ -112,44 +107,38 @@ void select_k_impl(const device_resources& handle, } } case Algo::kRadix8bits: - return detail::select::radix::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + true, // fused_last_filter + stream); case Algo::kRadix11bits: - return detail::select::radix::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - case Algo::kRadix8bitsUpdated: - return detail::select::radix::select_k_updated(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - true, // fused_last_filter - stream); - case Algo::kRadix11bitsUpdated: - return detail::select::radix::select_k_updated(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - true, // fused_last_filter - stream); + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + true, // fused_last_filter + stream); case Algo::kRadix11bitsExtraPass: - return detail::select::radix::select_k_updated(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - false, // fused_last_filter - stream); + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + false, // fused_last_filter + stream); case Algo::kWarpAuto: return detail::select::warpsort::select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index b1f406ed15..bd50e20822 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -328,8 +328,6 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Values(select::Algo::kPublicApi, select::Algo::kRadix8bits, select::Algo::kRadix11bits, - select::Algo::kRadix8bitsUpdated, - select::Algo::kRadix11bitsUpdated, select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, @@ -425,8 +423,6 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, - select::Algo::kRadix8bitsUpdated, - select::Algo::kRadix11bitsUpdated, select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, @@ -442,8 +438,6 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, - select::Algo::kRadix8bitsUpdated, - select::Algo::kRadix11bitsUpdated, select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, @@ -458,8 +452,8 @@ INSTANTIATE_TEST_CASE_P( // NOLINT ReferencedRandomDoubleInt, testing::Combine(inputs_random_largesize, testing::Values(select::Algo::kWarpAuto, - select::Algo::kRadix8bitsUpdated, - select::Algo::kRadix11bitsUpdated, + select::Algo::kRadix8bits, + select::Algo::kRadix11bits, select::Algo::kRadix11bitsExtraPass))); using ReferencedRandomFloatSizeT = @@ -469,8 +463,6 @@ INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, testing::Combine(inputs_random_largek, testing::Values(select::Algo::kRadix11bits, - select::Algo::kRadix8bitsUpdated, - select::Algo::kRadix11bitsUpdated, select::Algo::kRadix11bitsExtraPass))); } // namespace raft::matrix From f7061eb26c11e36a3466abccb37805a76b551e9b Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Sat, 11 Mar 2023 00:16:20 +0100 Subject: [PATCH 26/29] Add missing fused_last_filter arg while dispatching select_k --- cpp/include/raft/matrix/detail/select_radix.cuh | 2 +- cpp/include/raft/spatial/knn/knn.cuh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index dc1ce1920c..a844bad6db 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -1117,7 +1117,7 @@ void select_k(const T* in, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = nullptr) { - if (k == len) { + if (static_cast(k) == len) { RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); if (in_idx) { diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index ca2c248392..2a0186a649 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -154,12 +154,12 @@ template case SelectKAlgo::RADIX_8_BITS: matrix::detail::select::radix::select_k( - in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, stream); + in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, true, stream); break; case SelectKAlgo::RADIX_11_BITS: matrix::detail::select::radix::select_k( - in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, stream); + in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, true, stream); break; case SelectKAlgo::WARP_SORT: From d447fcbcd0f2f906b19d242d3ac708d9c96af324 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 23 Mar 2023 18:11:07 -0400 Subject: [PATCH 27/29] Checking in fix for select_k based on offline conversation w/ Yong Wang. --- cpp/include/raft/matrix/detail/select_radix.cuh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index dbb93b79e6..7ce087614a 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -275,7 +275,12 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else { if (out_buf) { - IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + +#if CUDART_VERSION < 12000 && CUDART_VERSION > 11000 + // Avoiding potential compiler bug in CUDA 11 + volatile +#endif + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); out_buf[pos] = value; out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } From a11fb8e4eb5ce79ce97d24f0f60e2cf8e17875ef Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Fri, 24 Mar 2023 16:45:31 +0800 Subject: [PATCH 28/29] minor polish --- cpp/include/raft/matrix/detail/select_radix.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 7ce087614a..cec6b09f20 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -1123,7 +1123,7 @@ void select_k(const T* in, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = nullptr) { - if (static_cast(k) == len) { + if (k == len) { RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); if (in_idx) { From dd6ae518a5dc7e2c7d1835503d4e46085c5fb13f Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Fri, 24 Mar 2023 17:08:59 +0800 Subject: [PATCH 29/29] adjust the place of volatile --- cpp/include/raft/matrix/detail/select_radix.cuh | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index cec6b09f20..7ac40ac0eb 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -275,12 +275,7 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else { if (out_buf) { - -#if CUDART_VERSION < 12000 && CUDART_VERSION > 11000 - // Avoiding potential compiler bug in CUDA 11 - volatile -#endif - IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); out_buf[pos] = value; out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -905,7 +900,11 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { - IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); +#if CUDART_VERSION < 12000 + // Avoiding potential compiler bug in CUDA 11 + volatile +#endif + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); out_buf[pos] = value; out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i;