diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu index d4873e2640..870119db52 100644 --- a/cpp/bench/matrix/select_k.cu +++ b/cpp/bench/matrix/select_k.cu @@ -35,6 +35,10 @@ #include #include +#include +#include +#include + namespace raft::matrix { using namespace raft::bench; // NOLINT @@ -50,7 +54,23 @@ struct selection : public fixture { { raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); raft::random::RngState state{42}; - raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); + + KeyT min_value = -1.0; + KeyT max_value = 1.0; + if (p.use_same_leading_bits) { + if constexpr (std::is_same_v) { + uint32_t min_bits = 0x3F800000; // 1.0 + uint32_t max_bits = 0x3F8000FF; // 1.00003 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } else if constexpr (std::is_same_v) { + uint64_t min_bits = 0x3FF0000000000000; // 1.0 + uint64_t max_bits = 0x3FF0000FFFFFFFFF; // 1.000015 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } + } + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value); } void run_benchmark(::benchmark::State& state) override // NOLINT @@ -60,6 +80,7 @@ struct selection : public fixture { try { std::ostringstream label_stream; label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; } state.SetLabel(label_stream.str()); loop_on_state(state, [this, &handle]() { select::select_k_impl(handle, @@ -85,21 +106,55 @@ struct selection : public fixture { }; const std::vector kInputs{ - {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, - {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, - {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, - - {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, - {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, - {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, - - {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, - {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, - {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, - - {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, - {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, - {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, + {20000, 500, 1, true}, + {20000, 500, 2, true}, + {20000, 500, 4, true}, + {20000, 500, 8, true}, + {20000, 500, 16, true}, + {20000, 500, 32, true}, + {20000, 500, 64, true}, + {20000, 500, 128, true}, + {20000, 500, 256, true}, + + {1000, 10000, 1, true}, + {1000, 10000, 2, true}, + {1000, 10000, 4, true}, + {1000, 10000, 8, true}, + {1000, 10000, 16, true}, + {1000, 10000, 32, true}, + {1000, 10000, 64, true}, + {1000, 10000, 128, true}, + {1000, 10000, 256, true}, + + {100, 100000, 1, true}, + {100, 100000, 2, true}, + {100, 100000, 4, true}, + {100, 100000, 8, true}, + {100, 100000, 16, true}, + {100, 100000, 32, true}, + {100, 100000, 64, true}, + {100, 100000, 128, true}, + {100, 100000, 256, true}, + + {10, 1000000, 1, true}, + {10, 1000000, 2, true}, + {10, 1000000, 4, true}, + {10, 1000000, 8, true}, + {10, 1000000, 16, true}, + {10, 1000000, 32, true}, + {10, 1000000, 64, true}, + {10, 1000000, 128, true}, + {10, 1000000, 256, true}, + + {10, 1000000, 1, true, false, true}, + {10, 1000000, 2, true, false, true}, + {10, 1000000, 4, true, false, true}, + {10, 1000000, 8, true, false, true}, + {10, 1000000, 16, true, false, true}, + {10, 1000000, 32, true, false, true}, + {10, 1000000, 64, true, false, true}, + {10, 1000000, 128, true, false, true}, + {10, 1000000, 256, true, false, true}, }; #define SELECTION_REGISTER(KeyT, IdxT, A) \ @@ -109,24 +164,27 @@ const std::vector kInputs{ RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ } -SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT -SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT - -SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT - -SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT +SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, uint32_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k.cuh index ac1ba3dfa3..20c2fb119d 100644 --- a/cpp/include/raft/matrix/detail/select_k.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -84,7 +84,7 @@ void select_k(const T* in_val, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } else { select::radix::select_k= 4 ? 11 : 8), 512>( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); } } diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 643a63d9db..7ac40ac0eb 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -16,11 +16,11 @@ #pragma once -#include - -#include #include #include +#include +#include +#include #include #include #include @@ -35,8 +35,8 @@ #include namespace raft::matrix::detail::select::radix { +namespace impl { -constexpr int ITEM_PER_THREAD = 32; constexpr int VECTORIZED_READ_SIZE = 16; template @@ -51,13 +51,6 @@ _RAFT_HOST_DEVICE constexpr int calc_num_passes() return ceildiv(sizeof(T) * 8, BitsPerPass); } -// Minimum reasonable block size for the given radix size. -template -_RAFT_HOST_DEVICE constexpr int calc_min_block_size() -{ - return 1 << std::max(BitsPerPass - 4, Pow2::Log2 + 1); -} - /** * Bit 0 is the least significant (rightmost); * this implementation processes input from the most to the least significant bit. @@ -82,23 +75,43 @@ _RAFT_DEVICE constexpr unsigned calc_mask(int pass) } /** - * Use cub to twiddle bits - so that we can correctly compare bits of floating-point values as well + * Use CUB to twiddle bits - so that we can correctly compare bits of floating-point values as well * as of integers. */ template -_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); - if (greater) { bits = ~bits; } + if (!select_min) { bits = ~bits; } return bits; } +template +_RAFT_DEVICE T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) +{ + if (!select_min) { bits = ~bits; } + bits = cub::Traits::TwiddleOut(bits); + return reinterpret_cast(bits); +} + template -_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) +{ + static_assert(BitsPerPass <= sizeof(int) * 8 - 1, + "BitsPerPass is too large that the result type could not be int"); + return (twiddle_in(x, select_min) >> start_bit) & mask; +} + +template +_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) { - static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int - return (twiddle_in(x, greater) >> start_bit) & mask; + // When writing is skipped, only read `in`(type T). + // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) + // and `out_idx_buf`(IdxT). + // The ratio between these cases determines whether to skip writing and hence the buffer size. + constexpr float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); + return len / ratio; } /** @@ -111,17 +124,18 @@ _RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) * @tparam IdxT indexing type * @tparam Func void (T x, IdxT idx) * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing * @param in the input data * @param len the number of elements to read * @param f the lambda taking two arguments (T x, IdxT idx) */ template -_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process( + size_t thread_rank, size_t num_threads, const T* in, IdxT len, Func f) { - const IdxT stride = blockDim.x * gridDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { - for (IdxT i = tid; i < len; i += stride) { + for (IdxT i = thread_rank; i < len; i += num_threads) { f(in[i], i); } } else { @@ -134,8 +148,8 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) const IdxT skip_cnt_left = std::min((IdxT)(align_bytes::roundUp(in) - in), len); // The main loop: process all aligned data - for (IdxT i = tid * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; - i += stride * wide_t::Ratio) { + for (IdxT i = thread_rank * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; + i += num_threads * wide_t::Ratio) { wide.load(in, i); #pragma unroll for (int j = 0; j < wide_t::Ratio; ++j) { @@ -145,30 +159,55 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) static_assert(WarpSize >= wide_t::Ratio); // Processes the skipped elements on the left - if (tid < skip_cnt_left) { f(in[tid], tid); } + if (thread_rank < skip_cnt_left) { f(in[thread_rank], thread_rank); } // Processes the skipped elements on the right const IdxT skip_cnt_right = align_elems::mod(len - skip_cnt_left); - const IdxT remain_i = len - skip_cnt_right + tid; + const IdxT remain_i = len - skip_cnt_right + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i); } } } template -struct Counter { +struct alignas(128) Counter { + // We are processing the values in multiple passes, from most significant to least significant. In + // each pass, we keep the length of input (`len`) and the `k` of current pass, and update them at + // the end of the pass. IdxT k; IdxT len; + + // `previous_len` is the length of input in previous pass. Note that `previous_len` rather + // than `len` is used for the filtering step because filtering is indeed for previous pass (see + // comments before `radix_kernel`). IdxT previous_len; - int bucket; - IdxT filter_cnt; - unsigned int finished_block_cnt; - IdxT out_cnt; - IdxT out_back_cnt; + // We determine the bits of the k_th value inside the mask processed by the pass. The + // already known bits are stored in `kth_value_bits`. It's used to discriminate a element is a + // result (written to `out`), a candidate for next pass (written to `out_buf`), or not useful + // (discarded). The bits that are not yet processed do not matter for this purpose. + typename cub::Traits::UnsignedBits kth_value_bits; + + // Record how many elements have passed filtering. It's used to determine the position in the + // `out_buf` where an element should be written. + alignas(128) IdxT filter_cnt; + + // For a row inside a batch, we may launch multiple thread blocks. This counter is used to + // determine if the current block is the last running block. If so, this block will execute scan() + // and choose_bucket(). + alignas(128) unsigned int finished_block_cnt; + + // Record how many elements have been written to the front of `out`. Elements less (if + // select_min==true) than the k-th value are written from front to back. + alignas(128) IdxT out_cnt; + + // Record how many elements have been written to the back of `out`. Elements equal to the k-th + // value are written from back to front. We need to keep count of them separately because the + // number of elements that <= the k-th value might exceed k. + alignas(128) IdxT out_back_cnt; }; /** - * Fused filtering of the current phase and building histogram for the next phase - * (see steps 4-1 in `radix_kernel` description). + * Fused filtering of the current pass and building histogram for the next pass + * (see steps 4 & 1 in `radix_kernel` description). */ template _RAFT_DEVICE void filter_and_histogram(const T* in_buf, @@ -177,12 +216,12 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, IdxT* out_idx_buf, T* out, IdxT* out_idx, - IdxT len, + IdxT previous_len, Counter* counter, IdxT* histogram, - bool greater, + bool select_min, int pass, - int k) + bool early_stop) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -198,19 +237,20 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, // Passed to vectorized_process, this function executes in all blocks in parallel, // i.e. the work is split along the input (both, in batches and chunks of a single row). // Later, the histograms are merged using atomicAdd. - auto f = [greater, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, IdxT(1)); + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); }; - vectorized_process(in_buf, len, f); + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); } else { - const IdxT previous_len = counter->previous_len; - const int want_bucket = counter->bucket; - IdxT& filter_cnt = counter->filter_cnt; - IdxT& out_cnt = counter->out_cnt; - const IdxT counter_len = counter->len; + IdxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; const int previous_start_bit = calc_start_bit(pass - 1); - const unsigned previous_mask = calc_mask(pass - 1); // See the remark above on the distributed execution of `f` using vectorized_process. auto f = [in_idx_buf, @@ -218,38 +258,50 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, out_idx_buf, out, out_idx, - greater, - k, + select_min, start_bit, mask, previous_start_bit, - previous_mask, - want_bucket, - &filter_cnt, - &out_cnt, - counter_len](T value, IdxT i) { - int prev_bucket = - calc_bucket(value, previous_start_bit, previous_mask, greater); - if (prev_bucket == want_bucket) { - IdxT pos = atomicAdd(&filter_cnt, IdxT(1)); - out_buf[pos] = value; - if (out_idx_buf) { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, IdxT(1)); - - if (counter_len == 1) { - out[k - 1] = value; - out_idx[k - 1] = in_idx_buf ? in_idx_buf[i] : i; + kth_value_bits, + p_filter_cnt, + p_out_cnt, + early_stop](T value, IdxT i) { + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + if (early_stop) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); } - } else if (prev_bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, IdxT(1)); + } + // the condition `(out_buf || early_stop)` is a little tricky: + // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should skip writing to + // `out` too. So we won't write the same value to `out` multiple times in different passes. + // And if we keep skipping the writing, values will be written in `last_filter_kernel()` at + // last. But when `early_stop` is true, we need to write to `out` since it's the last chance. + else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } }; - - vectorized_process(in_buf, previous_len, f); + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); } + if (early_stop) { return; } __syncthreads(); // merge histograms produced by individual blocks @@ -259,69 +311,184 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, } /** - * Replace a part of the histogram with its own prefix sum, starting from the `start` and adding - * `current` to each entry of the result. + * Replace histogram with its own prefix sum * (step 2 in `radix_kernel` description) */ template -_RAFT_DEVICE void scan(volatile IdxT* histogram, - const int start, - const int num_buckets, - const IdxT current) +_RAFT_DEVICE void scan(volatile IdxT* histogram) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; + constexpr int num_buckets = calc_num_buckets(); + if constexpr (num_buckets >= BlockSize) { + static_assert(num_buckets % BlockSize == 0); + constexpr int items_per_thread = num_buckets / BlockSize; + typedef cub::BlockLoad BlockLoad; + typedef cub::BlockStore + BlockStore; + typedef cub::BlockScan BlockScan; - IdxT thread_data = 0; - int index = start + threadIdx.x; - if (index < num_buckets) { thread_data = histogram[index]; } + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + IdxT thread_data[items_per_thread]; - BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); - __syncthreads(); - if (index < num_buckets) { histogram[index] = thread_data + current; } - __syncthreads(); // This sync is necessary, as the content of histogram needs - // to be read after + BlockLoad(temp_storage.load).Load(histogram, thread_data); + __syncthreads(); + + BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + BlockStore(temp_storage.store).Store(histogram, thread_data); + } else { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + IdxT thread_data = 0; + if (threadIdx.x < num_buckets) { thread_data = histogram[threadIdx.x]; } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + if (threadIdx.x < num_buckets) { histogram[threadIdx.x] = thread_data; } + } } /** * Calculate in which bucket the k-th value will fall - * (steps 2-3 in `radix_kernel` description) + * (steps 3 in `radix_kernel` description) */ -template -_RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +template +_RAFT_DEVICE void choose_bucket(Counter* counter, + const IdxT* histogram, + const IdxT k, + const int pass) { constexpr int num_buckets = calc_num_buckets(); - int index = threadIdx.x; - IdxT last_prefix_sum = 0; - int num_pass = 1; - if constexpr (num_buckets >= BlockSize) { - static_assert(num_buckets % BlockSize == 0); - num_pass = num_buckets / BlockSize; + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + IdxT prev = (i == 0) ? 0 : histogram[i - 1]; + IdxT cur = histogram[i]; + + // one and only one thread will satisfy this condition, so counter is written by only one thread + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename cub::Traits::UnsignedBits bucket = i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } } +} - for (int i = 0; i < num_pass && (last_prefix_sum < k); i++) { - // Turn the i-th chunk of the histogram into its prefix sum. - scan(histogram, i * BlockSize, num_buckets, last_prefix_sum); - if (index < num_buckets) { - // Number of values in the previous `index-1` buckets (see the `scan` op above) - IdxT prev = (index == 0) ? 0 : histogram[index - 1]; - // Number of values in `index` buckets - IdxT cur = histogram[index]; - - // one and only one thread will satisfy this condition, so only write once - if (prev < k && cur >= k) { - counter->k = k - prev; // how many values still are there to find - counter->previous_len = counter->len; - counter->len = cur - prev; // number of values in `index` bucket - counter->bucket = index; +// For one-block version, last_filter() could be called when pass < num_passes - 1. +// So `pass` could not be constexpr +template +_RAFT_DEVICE void last_filter(const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT current_len, + IdxT k, + Counter* counter, + const bool select_min, + const int pass) +{ + const auto kth_value_bits = counter->kth_value_bits; + const int start_bit = calc_start_bit(pass); + + // changed in choose_bucket(); need to reload + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` could be nullptr if + // `in_buf` is `in` + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } } - index += BlockSize; - // this will break the loop when the counter is set (cur >= k), because last_prefix_sum >= cur - last_prefix_sum = histogram[(i + 1) * BlockSize - 1]; } } +template +__global__ void last_filter_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + IdxT k, + Counter* counters, + const bool select_min) +{ + const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow + + Counter* counter = counters + batch_id; + IdxT previous_len = counter->previous_len; + if (previous_len == 0) { return; } + const IdxT buf_len = calc_buf_len(len); + if (previous_len > buf_len || in_buf == in) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + out += batch_id * k; + out_idx += batch_id * k; + + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); + + const auto kth_value_bits = counter->kth_value_bits; + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + + auto f = [k, + select_min, + kth_value_bits, + needed_num_of_kth, + p_out_cnt, + p_out_back_cnt, + in_idx_buf, + out, + out_idx](T value, IdxT i) { + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + }; + + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); +} + /** * * It is expected to call this kernel multiple times (passes), in each pass we process a radix, @@ -350,35 +517,79 @@ _RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, cons * * In the implementation, the filtering step is delayed to the next pass so the filtering and * histogram computation are fused. In this way, inputs are read once rather than twice. + * + * During the filtering step, we won't write candidates (elements in bucket j) to `out_buf` if the + * number of candidates is larger than the length of `out_buf` (this could happen when the leading + * bits of input values are almost the same). And then in the next pass, inputs are read from `in` + * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and + * their indices. */ -template -__global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counters, - IdxT* histograms, - const IdxT len, - const int k, - const bool greater, - const int pass) +template +__global__ void radix_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const IdxT k, + const bool select_min, + const int pass) { - __shared__ bool isLastBlockDone; + const size_t batch_id = blockIdx.y; + auto counter = counters + batch_id; + IdxT current_k; + IdxT previous_len; + IdxT current_len; + if (pass == 0) { + current_k = k; + previous_len = len; + // Need to do this so setting counter->previous_len for the next pass is correct. + // This value is meaningless for pass 0, but it's fine because pass 0 won't be the + // last pass in this implementation so pass 0 won't hit the "if (pass == + // num_passes - 1)" branch. + // Maybe it's better to reload counter->previous_len and use it rather than + // current_len in last_filter() + current_len = len; + } else { + current_k = counter->k; + current_len = counter->len; + previous_len = counter->previous_len; + } + if (current_len == 0) { return; } - constexpr int num_buckets = calc_num_buckets(); - constexpr int num_passes = calc_num_passes(); - const int batch_id = blockIdx.y; - in_buf += batch_id * len; - out_buf += batch_id * len; + // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle + // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is + // handled in other way in select_k() so such case is not possible here. + const bool early_stop = (current_len == current_k); + const IdxT buf_len = calc_buf_len(len); + + // "previous_len > buf_len" means previous pass skips writing buffer + if (pass == 0 || pass == 1 || previous_len > buf_len) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + // "current_len > buf_len" means current pass will skip writing buffer + if (pass == 0 || current_len > buf_len) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else { + out_buf += batch_id * buf_len; + out_idx_buf += batch_id * buf_len; + } out += batch_id * k; out_idx += batch_id * k; - if (in_idx_buf) { in_idx_buf += batch_id * len; } - if (out_idx_buf) { out_idx_buf += batch_id * len; } - auto counter = counters + batch_id; - auto histogram = histograms + batch_id * num_buckets; + constexpr int num_buckets = calc_num_buckets(); + auto histogram = histograms + batch_id * num_buckets; filter_and_histogram(in_buf, in_idx_buf, @@ -386,126 +597,468 @@ __global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf, out_idx_buf, out, out_idx, - len, + previous_len, counter, histogram, - greater, + select_min, pass, - k); + early_stop); __threadfence(); + bool isLastBlock = false; if (threadIdx.x == 0) { unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); - isLastBlockDone = (finished == (gridDim.x - 1)); + isLastBlock = (finished == (gridDim.x - 1)); } - // Synchronize to make sure that each thread reads the correct value of - // isLastBlockDone. - __syncthreads(); - if (isLastBlockDone) { - if (counter->len == 1 && threadIdx.x == 0) { - counter->previous_len = 0; - counter->len = 0; - } - // init counter, other members of counter is initialized with 0 by - // cudaMemset() - if (pass == 0 && threadIdx.x == 0) { - counter->k = k; - counter->len = len; - counter->out_back_cnt = 0; + if (__syncthreads_or(isLastBlock)) { + if (early_stop) { + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len + counter->previous_len = 0; + counter->len = 0; + } + return; } + + scan(histogram); __syncthreads(); + choose_bucket(counter, histogram, current_k, pass); + __syncthreads(); + + constexpr int num_passes = calc_num_passes(); + // reset for next pass + if (pass != num_passes - 1) { + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + } + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len even in the last pass + counter->previous_len = current_len; + // not necessary for the last pass, but put it here anyway + counter->filter_cnt = 0; + } + + if constexpr (fused_last_filter) { + if (pass == num_passes - 1) { + last_filter(out_buf ? out_buf : in_buf, + out_idx_buf ? out_idx_buf : in_idx_buf, + out, + out_idx, + out_buf ? current_len : len, + k, + counter, + select_min, + pass); + } + } + } +} + +template +int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel) +{ + int active_blocks; + RAFT_CUDA_TRY( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks, kernel, BlockSize, 0)); + + constexpr int items_per_thread = 32; + constexpr int num_waves = 10; + int chunk_size = + std::max(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len); + return std::min(chunk_size, batch_size); +} - IdxT ori_k = counter->k; +template +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) +{ + static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); + + int active_blocks; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &active_blocks, radix_kernel, BlockSize, 0)); + active_blocks *= sm_cnt; - if (counter->len > 0) { - choose_bucket(counter, histogram, ori_k); + IdxT best_num_blocks = 0; + float best_tail_wave_penalty = 1.0f; + const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); + for (int num_waves = 1;; ++num_waves) { + IdxT num_blocks = std::min( + max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); + IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); + items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); + num_blocks = ceildiv(len, items_per_thread * BlockSize); + float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; + float tail_wave_penalty = + (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); + + // 0.15 is determined experimentally. It also ensures breaking the loop early, + // e.g. when num_waves > 7, tail_wave_penalty will always <0.15 + if (tail_wave_penalty < 0.15) { + best_num_blocks = num_blocks; + break; + } else if (tail_wave_penalty < best_tail_wave_penalty) { + best_num_blocks = num_blocks; + best_tail_wave_penalty = tail_wave_penalty; } - __syncthreads(); - if (pass == num_passes - 1) { - const IdxT previous_len = counter->previous_len; - const int want_bucket = counter->bucket; - int start_bit = calc_start_bit(pass); - unsigned mask = calc_mask(pass); - - // radix topk - IdxT& out_cnt = counter->out_cnt; - for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { - const T value = out_buf[i]; - int bucket = calc_bucket(value, start_bit, mask, greater); - if (bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, IdxT(1)); - out[pos] = value; - out_idx[pos] = out_idx_buf[i]; - } else if (bucket == want_bucket) { - IdxT needed_num_of_kth = counter->k; - IdxT back_pos = atomicAdd(&(counter->out_back_cnt), IdxT(1)); - if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; - out[pos] = value; - out_idx[pos] = out_idx_buf[i]; - } - } + if (num_blocks == max_num_blocks) { break; } + } + return best_num_blocks; +} + +template +_RAFT_HOST_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } +} + +template +void radix_topk(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + bool fused_last_filter, + unsigned grid_dim, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // TODO: is it possible to relax this restriction? + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + auto kernel = radix_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); + if (max_chunk_size != static_cast(batch_size)) { + grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt); + } + const IdxT buf_len = calc_buf_len(len); + + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf + 256 * 6; // might need extra memory for alignment + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); + if (pool_guard) { + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::device_uvector> counters(max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + RAFT_CUDA_TRY( + cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); + + const T* chunk_in = in + offset * len; + const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + dim3 blocks(grid_dim, chunk_size); + constexpr int num_passes = calc_num_passes(); + + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers(chunk_in, + chunk_in_idx, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data(), + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + + if (fused_last_filter && pass == num_passes - 1) { + kernel = radix_kernel; } - __syncthreads(); - } else { - // reset for next pass - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - histogram[i] = 0; + + kernel<<>>(chunk_in, + chunk_in_idx, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + counters.data(), + histograms.data(), + len, + k, + select_min, + pass); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + if (!fused_last_filter) { + last_filter_kernel<<>>(chunk_in, + chunk_in_idx, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + len, + k, + counters.data(), + select_min); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } +} + +// The following a few functions are for the one-block version, which uses single thread block for +// each row of a batch. +template +_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + IdxT* p_filter_cnt = &counter->filter_cnt; + if (threadIdx.x == 0) { *p_filter_cnt = 0; } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + const IdxT previous_len = counter->previous_len; + + if (pass == 0) { + auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; + vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); + } else { + // not use vectorized_process here because it increases #registers a lot + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { +#if CUDART_VERSION < 12000 + // Avoiding potential compiler bug in CUDA 11 + volatile +#endif + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } else if (previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } - if (threadIdx.x == 0) { counter->filter_cnt = 0; } } } } -/** - * Calculate the minimal batch size, such that GPU is still fully occupied. - */ template -inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) +__global__ void radix_topk_one_block_kernel(const T* in, + const IdxT* in_idx, + const IdxT len, + const IdxT k, + T* out, + IdxT* out_idx, + const bool select_min, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2) { - int dev_id, sm_count, occupancy, max_grid_dim_y; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&max_grid_dim_y, cudaDevAttrMaxGridDimY, dev_id)); - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, radix_kernel, BlockSize, 0)); - - // number of block we'd use if the batch size is enough to occupy the gpu in any case - size_t blocks_per_row = ceildiv(len, BlockSize * ITEM_PER_THREAD); - - // fully occupy GPU - size_t opt_batch_size = ceildiv(sm_count * occupancy, blocks_per_row); - // round it up to the closest pow-of-two for better data alignment - opt_batch_size = isPo2(opt_batch_size) ? opt_batch_size : (1 << (log2(opt_batch_size) + 1)); - // Take a max possible pow-of-two grid_dim_y - max_grid_dim_y = isPo2(max_grid_dim_y) ? max_grid_dim_y : (1 << log2(max_grid_dim_y)); - // If the optimal batch size is very small compared to the requested batch size, we know - // the extra required memory is not significant and we can increase the batch size for - // better occupancy when the grid size is not multiple of the SM count. - // Also don't split the batch size when there is not much work overall. - const size_t safe_enlarge_factor = 9; - const size_t min_grid_size = 1024; - while ((opt_batch_size << safe_enlarge_factor) < req_batch_size || - blocks_per_row * opt_batch_size < min_grid_size) { - opt_batch_size <<= 1; + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ IdxT histogram[num_buckets]; + + if (threadIdx.x == 0) { + counter.k = k; + counter.len = len; + counter.previous_len = len; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; } + __syncthreads(); + + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + in += batch_id * len; + if (in_idx) { in_idx += batch_id * len; } + out += batch_id * k; + out_idx += batch_id * k; + buf1 += batch_id * len; + idx_buf1 += batch_id * len; + buf2 += batch_id * len; + idx_buf2 += batch_id * len; + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + constexpr int num_passes = calc_num_passes(); + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers( + in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); + + IdxT current_len = counter.len; + IdxT current_k = counter.k; + + filter_and_histogram_for_one_block(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + &counter, + histogram, + select_min, + pass); + __syncthreads(); + + scan(histogram); + __syncthreads(); + + choose_bucket(&counter, histogram, current_k, pass); + if (threadIdx.x == 0) { counter.previous_len = current_len; } + __syncthreads(); - // Do not exceed the max grid size. - opt_batch_size = std::min(opt_batch_size, size_t(max_grid_dim_y)); - // Don't do more work than needed - opt_batch_size = std::min(opt_batch_size, req_batch_size); - // Let more blocks share one row if the required batch size is too small. - while (opt_batch_size * blocks_per_row < size_t(sm_count * occupancy) && - // Ensure we still can read data somewhat efficiently - len * sizeof(T) > 2 * VECTORIZED_READ_SIZE * BlockSize * blocks_per_row) { - blocks_per_row <<= 1; + if (counter.len == counter.k || pass == num_passes - 1) { + last_filter(pass == 0 ? in : out_buf, + pass == 0 ? in_idx : out_idx_buf, + out, + out_idx, + current_len, + k, + &counter, + select_min, + pass); + break; + } } +} - return dim3(blocks_per_row, opt_batch_size); +// radix_topk() might use multiple thread blocks for one row of a batch. In contrast, the following +// one-block version uses single thread block for one row of a batch, so intermediate data, like +// counters and global histograms, can be kept in shared memory and cheap sync operations can be +// used. It's used when len is relatively small or when the number of blocks per row calculated by +// `calc_grid_dim()` is 1. +template +void radix_topk_one_block(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + static_assert(calc_num_passes() > 1); + + auto kernel = radix_topk_one_block_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); + + auto pool_guard = + raft::get_pool_memory_resource(mr, + max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)) + + 256 * 4 // might need extra memory for alignment + ); + if (pool_guard) { + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::device_uvector buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + kernel<<>>(in + offset * len, + in_idx ? (in_idx + offset * len) : nullptr, + len, + k, + out + offset * k, + out_idx + offset * k, + select_min, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data()); + } } +} // namespace impl + /** * Select k smallest or largest key/values from each row in the input data. * @@ -546,6 +1099,12 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) * the payload selected together with `out`. * @param select_min * whether to select k smallest (true) or largest (false) keys. + * @param fused_last_filter + * when it's true, the last filter is fused into the kernel in the last pass and only one thread + * block will do the filtering; when false, a standalone filter kernel with multiple thread + * blocks is called. The later case is preferable when leading bits of input data are almost the + * same. That is, when the value range of input data is narrow. In such case, there could be a + * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. * @param stream * @param mr an optional memory resource to use across the calls (you can provide a large enough * memory pool here to avoid memory allocations within the call). @@ -553,109 +1112,65 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) template void select_k(const T* in, const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, + int batch_size, + IdxT len, + IdxT k, T* out, IdxT* out_idx, bool select_min, + bool fused_last_filter, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = nullptr) { - // reduce the block size if the input length is too small. - if constexpr (BlockSize > calc_min_block_size()) { - if (BlockSize * ITEM_PER_THREAD > len) { - return select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + if (k == len) { + RAFT_CUDA_TRY( + cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + if (in_idx) { + RAFT_CUDA_TRY(cudaMemcpyAsync( + out_idx, in_idx, sizeof(IdxT) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + } else { + auto out_idx_view = + raft::make_device_vector_view(out_idx, static_cast(len) * batch_size); + raft::device_resources handle(stream); + raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op(len)); } + return; } - // TODO: is it possible to relax this restriction? - static_assert(calc_num_passes() > 1); - constexpr int num_buckets = calc_num_buckets(); - - dim3 blocks = get_optimal_grid_size(batch_size, len); - size_t max_chunk_size = blocks.y; - - size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); - size_t req_buf = max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)); - size_t mem_req = req_aux + req_buf; - size_t mem_free, mem_total; - RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); - std::optional managed_memory; - rmm::mr::device_memory_resource* mr_buf = nullptr; - if (mem_req > mem_free) { - // if there's not enough memory for buffers on the device, resort to the managed memory. - mem_req = req_aux; - managed_memory.emplace(); - mr_buf = &managed_memory.value(); - } - - auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); - if (pool_guard) { - RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); + // TODO: use device_resources::get_device_properties() instead; should change it when we refactor + // resource management + int sm_cnt; + { + int dev; + RAFT_CUDA_TRY(cudaGetDevice(&dev)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); } - if (mr_buf == nullptr) { mr_buf = mr; } - - rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); - for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { - blocks.y = std::min(max_chunk_size, batch_size - offset); + constexpr int items_per_thread = 32; - RAFT_CUDA_TRY( - cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; - - constexpr int num_passes = calc_num_passes(); - - for (int pass = 0; pass < num_passes; ++pass) { - if (pass == 0) { - in_buf = in + offset * len; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in + offset * len; - in_idx_buf = in_idx ? in_idx + offset * len : nullptr; - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } else if (pass % 2 == 0) { - in_buf = buf1.data(); - in_idx_buf = idx_buf1.data(); - out_buf = buf2.data(); - out_idx_buf = idx_buf2.data(); - } else { - in_buf = buf2.data(); - in_idx_buf = idx_buf2.data(); - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } - - radix_kernel - <<>>(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out + offset * k, - out_idx + offset * k, - counters.data(), - histograms.data(), - len, - k, - !select_min, - pass); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + if (len <= BlockSize * items_per_thread) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + } else { + unsigned grid_dim = + impl::calc_grid_dim(batch_size, len, sm_cnt); + if (grid_dim == 1) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + } else { + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + fused_last_filter, + grid_dim, + sm_cnt, + stream, + mr); } } } diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index 692d262043..a7bbfd9500 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -153,12 +153,12 @@ template case SelectKAlgo::RADIX_8_BITS: matrix::detail::select::radix::select_k( - in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, stream); + in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, true, stream); break; case SelectKAlgo::RADIX_11_BITS: matrix::detail::select::radix::select_k( - in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, stream); + in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, true, stream); break; case SelectKAlgo::WARP_SORT: diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index ede6382c33..188122c9b4 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -33,7 +33,8 @@ struct params { size_t len; int k; bool select_min; - bool use_index_input = true; + bool use_index_input = true; + bool use_same_leading_bits = false; }; inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& @@ -42,7 +43,8 @@ inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& os << ", len: " << ss.len; os << ", k: " << ss.k; os << (ss.select_min ? ", asc" : ", dsc"); - os << (ss.use_index_input ? "}" : ", no-input-index}"); + os << (ss.use_index_input ? "" : ", no-input-index"); + os << (ss.use_same_leading_bits ? ", same-leading-bits}" : "}"); return os; } @@ -50,6 +52,7 @@ enum class Algo { kPublicApi, kRadix8bits, kRadix11bits, + kRadix11bitsExtraPass, kWarpAuto, kWarpImmediate, kWarpFiltered, @@ -63,6 +66,7 @@ inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream& case Algo::kPublicApi: return os << "kPublicApi"; case Algo::kRadix8bits: return os << "kRadix8bits"; case Algo::kRadix11bits: return os << "kRadix11bits"; + case Algo::kRadix11bitsExtraPass: return os << "kRadix11bitsExtraPass"; case Algo::kWarpAuto: return os << "kWarpAuto"; case Algo::kWarpImmediate: return os << "kWarpImmediate"; case Algo::kWarpFiltered: return os << "kWarpFiltered"; @@ -103,11 +107,38 @@ void select_k_impl(const device_resources& handle, } } case Algo::kRadix8bits: - return detail::select::radix::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + true, // fused_last_filter + stream); case Algo::kRadix11bits: - return detail::select::radix::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + true, // fused_last_filter + stream); + case Algo::kRadix11bitsExtraPass: + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + false, // fused_last_filter + stream); case Algo::kWarpAuto: return detail::select::warpsort::select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 392464eb27..2a40d70abc 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -332,6 +332,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Values(select::Algo::kPublicApi, select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed))); @@ -426,6 +427,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -440,6 +442,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -451,7 +454,11 @@ TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomDoubleInt, - testing::Combine(inputs_random_largesize, testing::Values(select::Algo::kWarpAuto))); + testing::Combine(inputs_random_largesize, + testing::Values(select::Algo::kWarpAuto, + select::Algo::kRadix8bits, + select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass))); using ReferencedRandomFloatSizeT = SelectK::params_random>; @@ -459,6 +466,7 @@ TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, testing::Combine(inputs_random_largek, - testing::Values(select::Algo::kRadix11bits))); + testing::Values(select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass))); } // namespace raft::matrix