diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index edde924892..1f6357a563 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -203,6 +203,9 @@ struct alignas(128) Counter { // value are written from back to front. We need to keep count of them separately because the // number of elements that <= the k-th value might exceed k. alignas(128) IdxT out_back_cnt; + + // Number of infinities found in the zero pass. + alignas(128) IdxT bound_cnt; }; /** @@ -221,7 +224,8 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, IdxT* histogram, bool select_min, int pass, - bool early_stop) + bool early_stop, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -232,12 +236,24 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, const int start_bit = calc_start_bit(pass); const unsigned mask = calc_mask(pass); + // The last possible value (k-th cannot be further). + const auto bound = select_min ? upper_bound() : lower_bound(); if (pass == 0) { + IdxT* p_bound_cnt = &counter->bound_cnt; // Passed to vectorized_process, this function executes in all blocks in parallel, // i.e. the work is split along the input (both, in batches and chunks of a single row). // Later, the histograms are merged using atomicAdd. - auto f = [select_min, start_bit, mask](T value, IdxT) { + auto f = [in_idx_buf, out, out_idx, select_min, start_bit, mask, bound, p_bound_cnt, k]( + T value, IdxT i) { + if (value == bound) { + if (i < k) { + IdxT pos = k - 1 - atomicAdd(p_bound_cnt, IdxT{1}); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + return; + } int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); }; @@ -265,7 +281,9 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, kth_value_bits, p_filter_cnt, p_out_cnt, + bound, early_stop](T value, IdxT i) { + if (value == bound) { return; } const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { @@ -370,7 +388,7 @@ _RAFT_DEVICE void choose_bucket(Counter* counter, IdxT cur = histogram[i]; // one and only one thread will satisfy this condition, so counter is written by only one thread - if (prev < k && cur >= k) { + if (prev < k && (cur >= k || i + 1 == num_buckets)) { counter->k = k - prev; // how many values still are there to find counter->len = cur - prev; // number of values in next pass typename cub::Traits::UnsignedBits bucket = i; @@ -395,13 +413,15 @@ _RAFT_DEVICE void last_filter(const T* in_buf, { const auto kth_value_bits = counter->kth_value_bits; const int start_bit = calc_start_bit(pass); + const auto bound = select_min ? upper_bound() : lower_bound(); // changed in choose_bucket(); need to reload const IdxT needed_num_of_kth = counter->k; IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { - const T value = in_buf[i]; + const T value = in_buf[i]; + if (value == bound) { continue; } const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); @@ -413,7 +433,7 @@ _RAFT_DEVICE void last_filter(const T* in_buf, } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; + IdxT pos = k - needed_num_of_kth + back_pos; out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -457,6 +477,7 @@ __global__ void last_filter_kernel(const T* in, const IdxT needed_num_of_kth = counter->k; IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; + const auto bound = select_min ? upper_bound() : lower_bound(); auto f = [k, select_min, @@ -465,8 +486,10 @@ __global__ void last_filter_kernel(const T* in, p_out_cnt, p_out_back_cnt, in_idx_buf, + bound, out, out_idx](T value, IdxT i) { + if (value == bound) { return; } const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); @@ -475,7 +498,7 @@ __global__ void last_filter_kernel(const T* in, } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; + IdxT pos = k - needed_num_of_kth + back_pos; out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -560,12 +583,11 @@ __global__ void radix_kernel(const T* in, current_len = counter->len; previous_len = counter->previous_len; } - if (current_len == 0) { return; } // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is // handled in other way in select_k() so such case is not possible here. - const bool early_stop = (current_len == current_k); + const bool early_stop = (current_len <= current_k); const IdxT buf_len = calc_buf_len(len); // "previous_len > buf_len" means previous pass skips writing buffer @@ -602,7 +624,8 @@ __global__ void radix_kernel(const T* in, histogram, select_min, pass, - early_stop); + early_stop, + k); __threadfence(); bool isLastBlock = false; @@ -723,7 +746,7 @@ _RAFT_HOST_DEVICE void set_buf_pointers(const T* in, { if (pass == 0) { in_buf = in; - in_idx_buf = nullptr; + in_idx_buf = in_idx; out_buf = nullptr; out_idx_buf = nullptr; } else if (pass == 1) { @@ -860,16 +883,17 @@ void radix_topk(const T* in, // The following a few functions are for the one-block version, which uses single thread block for // each row of a batch. template -_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counter, - IdxT* histogram, - bool select_min, - int pass) +_RAFT_DEVICE __noinline__ void filter_and_histogram_for_one_block(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { @@ -882,12 +906,25 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, const int start_bit = calc_start_bit(pass); const unsigned mask = calc_mask(pass); const IdxT previous_len = counter->previous_len; + // The last possible value (k-th cannot be further). + const auto bound = select_min ? upper_bound() : lower_bound(); if (pass == 0) { - auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram + bucket, static_cast(1)); - }; + IdxT* p_bound_cnt = &counter->bound_cnt; + auto f = + [histogram, in_idx_buf, out, out_idx, select_min, start_bit, mask, bound, p_bound_cnt, k]( + T value, IdxT i) { + if (value == bound) { + if (i < k) { + IdxT pos = k - 1 - atomicAdd(p_bound_cnt, IdxT{1}); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + return; + } + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); } else { // not use vectorized_process here because it increases #registers a lot @@ -896,7 +933,8 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, const int previous_start_bit = calc_start_bit(pass - 1); for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { - const T value = in_buf[i]; + const T value = in_buf[i]; + if (value == bound) { continue; } const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { @@ -943,6 +981,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, counter.kth_value_bits = 0; counter.out_cnt = 0; counter.out_back_cnt = 0; + counter.bound_cnt = 0; } __syncthreads(); @@ -977,7 +1016,8 @@ __global__ void radix_topk_one_block_kernel(const T* in, &counter, histogram, select_min, - pass); + pass, + k); __syncthreads(); scan(histogram); @@ -987,7 +1027,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, if (threadIdx.x == 0) { counter.previous_len = current_len; } __syncthreads(); - if (counter.len == counter.k || pass == num_passes - 1) { + if (counter.len <= counter.k || pass == num_passes - 1) { last_filter(pass == 0 ? in : out_buf, pass == 0 ? in_idx : out_idx_buf, out,