Skip to content

Commit

Permalink
Add the deterministic version of AIR Top-K ( the radix-based topk in …
Browse files Browse the repository at this point in the history
…RAFT)
  • Loading branch information
ChristinaZ committed Jan 17, 2024
1 parent addb059 commit 629123f
Showing 1 changed file with 98 additions and 34 deletions.
132 changes: 98 additions & 34 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/radix_rank_sort_operations.cuh>
#include <cuda/atomic>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
Expand Down Expand Up @@ -400,7 +401,7 @@ _RAFT_DEVICE void choose_bucket(Counter<T, IdxT>* counter,

// For one-block version, last_filter() could be called when pass < num_passes - 1.
// So `pass` could not be constexpr
template <typename T, typename IdxT, int BitsPerPass>
template <typename T, typename IdxT, int BitsPerPass, bool stable_last_filter = false>
_RAFT_DEVICE void last_filter(const T* in_buf,
const IdxT* in_idx_buf,
T* out,
Expand All @@ -418,6 +419,8 @@ _RAFT_DEVICE void last_filter(const T* in_buf,
const IdxT num_of_kth_needed = counter->k;
IdxT* p_out_cnt = &counter->out_cnt;
IdxT* p_out_back_cnt = &counter->out_back_cnt;
IdxT* p_equal = out_idx + k - num_of_kth_needed;
cuda::atomic_ref<IdxT, cuda::thread_scope_block> ref_last(p_equal[num_of_kth_needed - 1]);
for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) {
const T value = in_buf[i];
const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit;
Expand All @@ -431,15 +434,27 @@ _RAFT_DEVICE void last_filter(const T* in_buf,
} else if (bits == kth_value_bits) {
IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast<IdxT>(1));
if (back_pos < num_of_kth_needed) {
IdxT pos = k - 1 - back_pos;
out[pos] = value;
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
IdxT pos = k - 1 - back_pos;
out[pos] = value;
if constexpr (!stable_last_filter) { out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; }
}
if constexpr (stable_last_filter) {
// We don't need to overwrite the value here
// because we presume that two values that are equal have the same bits.
static_assert(std::is_scalar_v<T>);
IdxT new_idx = in_idx_buf ? in_idx_buf[i] : i;
if (new_idx < ref_last.load(cuda::memory_order_relaxed)) {
for (int j = 0; j < num_of_kth_needed; j++) {
IdxT pre_idx = atomicMin_block(&p_equal[j], new_idx);
if (pre_idx > new_idx) { new_idx = pre_idx; }
}
}
}
}
}
}

template <typename T, typename IdxT, int BitsPerPass>
template <typename T, typename IdxT, int BitsPerPass, bool stable_last_filter = false>
RAFT_KERNEL last_filter_kernel(const T* in,
const IdxT* in_idx,
const T* in_buf,
Expand Down Expand Up @@ -475,7 +490,8 @@ RAFT_KERNEL last_filter_kernel(const T* in,
const IdxT num_of_kth_needed = counter->k;
IdxT* p_out_cnt = &counter->out_cnt;
IdxT* p_out_back_cnt = &counter->out_back_cnt;

IdxT* p_equal = out_idx + k - num_of_kth_needed;
cuda::atomic_ref<IdxT> ref_last(p_equal[num_of_kth_needed - 1]);
auto f = [k,
select_min,
kth_value_bits,
Expand All @@ -484,7 +500,9 @@ RAFT_KERNEL last_filter_kernel(const T* in,
p_out_back_cnt,
in_idx_buf,
out,
out_idx](T value, IdxT i) {
out_idx,
p_equal,
ref_last](T value, IdxT i) {
const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit;
if (bits < kth_value_bits) {
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
Expand All @@ -493,9 +511,19 @@ RAFT_KERNEL last_filter_kernel(const T* in,
} else if (bits == kth_value_bits) {
IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast<IdxT>(1));
if (back_pos < num_of_kth_needed) {
IdxT pos = k - 1 - back_pos;
out[pos] = value;
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
IdxT pos = k - 1 - back_pos;
out[pos] = value;
if constexpr (!stable_last_filter) { out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; }
}
if constexpr (stable_last_filter) {
static_assert(std::is_trivially_copyable_v<T>);
IdxT new_idx = in_idx_buf ? in_idx_buf[i] : i;
if (new_idx < ref_last.load(cuda::memory_order_relaxed)) {
for (int j = 0; j < num_of_kth_needed; j++) {
IdxT pre_idx = atomicMin(&p_equal[j], new_idx);
if (pre_idx > new_idx) { new_idx = pre_idx; }
}
}
}
}
};
Expand Down Expand Up @@ -542,7 +570,12 @@ RAFT_KERNEL last_filter_kernel(const T* in,
* rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and
* their indices.
*/
template <typename T, typename IdxT, int BitsPerPass, int BlockSize, bool fused_last_filter>
template <typename T,
typename IdxT,
int BitsPerPass,
int BlockSize,
bool fused_last_filter,
bool stable_last_filter = false>
RAFT_KERNEL radix_kernel(const T* in,
const IdxT* in_idx,
const T* in_buf,
Expand Down Expand Up @@ -658,17 +691,26 @@ RAFT_KERNEL radix_kernel(const T* in,
counter->filter_cnt = 0;
}

if constexpr (fused_last_filter) {
if (pass == num_passes - 1) {
last_filter<T, IdxT, BitsPerPass>(out_buf ? out_buf : in_buf,
out_idx_buf ? out_idx_buf : in_idx_buf,
out,
out_idx,
out_buf ? current_len : len,
k,
counter,
select_min,
pass);
if (pass == num_passes - 1) {
if constexpr (stable_last_filter) {
const IdxT num_of_kth_needed = counter->k;
for (IdxT i = threadIdx.x; i < num_of_kth_needed; i += blockDim.x) {
out_idx[k - num_of_kth_needed + i] = cuda::std::numeric_limits<IdxT>::max();
}
__syncthreads();
}

if constexpr (fused_last_filter) {
last_filter<T, IdxT, BitsPerPass, stable_last_filter>(
out_buf ? out_buf : in_buf,
out_idx_buf ? out_idx_buf : in_idx_buf,
out,
out_idx,
out_buf ? current_len : len,
k,
counter,
select_min,
pass);
}
}
}
Expand Down Expand Up @@ -1000,7 +1042,11 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf,
}
}

template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
template <typename T,
typename IdxT,
int BitsPerPass,
int BlockSize,
bool stable_last_filter = false>
RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
const IdxT* in_idx,
const IdxT len,
Expand Down Expand Up @@ -1074,16 +1120,34 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
if (threadIdx.x == 0) { counter.previous_len = current_len; }
__syncthreads();

if (counter.len == counter.k || pass == num_passes - 1) {
last_filter<T, IdxT, BitsPerPass>(out_buf ? out_buf : in,
out_buf ? out_idx_buf : in_idx,
out,
out_idx,
out_buf ? current_len : len,
k,
&counter,
select_min,
pass);
if ((pass == num_passes - 1)) {
if constexpr (stable_last_filter) {
const IdxT num_of_kth_needed = counter.k;
for (IdxT i = threadIdx.x; i < num_of_kth_needed; i += blockDim.x) {
out_idx[k - num_of_kth_needed + i] = cuda::std::numeric_limits<IdxT>::max();
}
__syncthreads();
}
last_filter<T, IdxT, BitsPerPass, stable_last_filter>(out_buf ? out_buf : in,
out_buf ? out_idx_buf : in_idx,
out,
out_idx,
out_buf ? current_len : len,
k,
&counter,
select_min,
pass);
break;
} else if (counter.len == counter.k) {
last_filter<T, IdxT, BitsPerPass, false>(out_buf ? out_buf : in,
out_buf ? out_idx_buf : in_idx,
out,
out_idx,
out_buf ? current_len : len,
k,
&counter,
select_min,
pass);
break;
}
}
Expand Down Expand Up @@ -1249,4 +1313,4 @@ void select_k(const T* in,
}
}

} // namespace raft::matrix::detail::select::radix
} // namespace raft::matrix::detail::select::radix

0 comments on commit 629123f

Please sign in to comment.