Skip to content

Commit

Permalink
Use parameter prioritize_smaller_indice for device and host functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristinaZ committed Jan 17, 2024
1 parent f70a258 commit 8885046
Showing 1 changed file with 108 additions and 41 deletions.
149 changes: 108 additions & 41 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ _RAFT_DEVICE void choose_bucket(Counter<T, IdxT>* counter,

// For one-block version, last_filter() could be called when pass < num_passes - 1.
// So `pass` could not be constexpr
template <typename T, typename IdxT, int BitsPerPass, bool stable_last_filter = false>
template <typename T, typename IdxT, int BitsPerPass, bool prioritize_smaller_indice = false>
_RAFT_DEVICE void last_filter(const T* in_buf,
const IdxT* in_idx_buf,
T* out,
Expand Down Expand Up @@ -436,9 +436,9 @@ _RAFT_DEVICE void last_filter(const T* in_buf,
if (back_pos < num_of_kth_needed) {
IdxT pos = k - 1 - back_pos;
out[pos] = value;
if constexpr (!stable_last_filter) { out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; }
if constexpr (!prioritize_smaller_indice) { out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; }
}
if constexpr (stable_last_filter) {
if constexpr (prioritize_smaller_indice) {
// We don't need to overwrite the value here
// because we presume that two values that are equal have the same bits.
static_assert(std::is_scalar_v<T>);
Expand All @@ -454,7 +454,7 @@ _RAFT_DEVICE void last_filter(const T* in_buf,
}
}

template <typename T, typename IdxT, int BitsPerPass, bool stable_last_filter = false>
template <typename T, typename IdxT, int BitsPerPass, bool prioritize_smaller_indice = false>
RAFT_KERNEL last_filter_kernel(const T* in,
const IdxT* in_idx,
const T* in_buf,
Expand Down Expand Up @@ -513,9 +513,9 @@ RAFT_KERNEL last_filter_kernel(const T* in,
if (back_pos < num_of_kth_needed) {
IdxT pos = k - 1 - back_pos;
out[pos] = value;
if constexpr (!stable_last_filter) { out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; }
if constexpr (!prioritize_smaller_indice) { out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; }
}
if constexpr (stable_last_filter) {
if constexpr (prioritize_smaller_indice) {
static_assert(std::is_trivially_copyable_v<T>);
IdxT new_idx = in_idx_buf ? in_idx_buf[i] : i;
if (new_idx < ref_last.load(cuda::memory_order_relaxed)) {
Expand Down Expand Up @@ -575,7 +575,7 @@ template <typename T,
int BitsPerPass,
int BlockSize,
bool fused_last_filter,
bool stable_last_filter = false>
bool prioritize_smaller_indice = false>
RAFT_KERNEL radix_kernel(const T* in,
const IdxT* in_idx,
const T* in_buf,
Expand Down Expand Up @@ -692,7 +692,7 @@ RAFT_KERNEL radix_kernel(const T* in,
}

if (pass == num_passes - 1) {
if constexpr (stable_last_filter) {
if constexpr (prioritize_smaller_indice) {
const IdxT num_of_kth_needed = counter->k;
for (IdxT i = threadIdx.x; i < num_of_kth_needed; i += blockDim.x) {
out_idx[k - num_of_kth_needed + i] = cuda::std::numeric_limits<IdxT>::max();
Expand All @@ -701,7 +701,7 @@ RAFT_KERNEL radix_kernel(const T* in,
}

if constexpr (fused_last_filter) {
last_filter<T, IdxT, BitsPerPass, stable_last_filter>(
last_filter<T, IdxT, BitsPerPass, prioritize_smaller_indice>(
out_buf ? out_buf : in_buf,
out_idx_buf ? out_idx_buf : in_idx_buf,
out,
Expand Down Expand Up @@ -871,7 +871,8 @@ void radix_topk(const T* in,
unsigned grid_dim,
int sm_cnt,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
rmm::mr::device_memory_resource* mr,
bool prioritize_smaller_indice = false)
{
// TODO: is it possible to relax this restriction?
static_assert(calc_num_passes<T, BitsPerPass>() > 1);
Expand Down Expand Up @@ -901,8 +902,19 @@ void radix_topk(const T* in,
RAFT_CUDA_TRY(
cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter<T, IdxT>), stream));
RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream));
auto kernel = radix_kernel<T, IdxT, BitsPerPass, BlockSize, false>;

auto kernel = prioritize_smaller_indice ? radix_kernel<T,
IdxT,
BitsPerPass,
BlockSize,
/*fused_last_filter*/ false,
/*prioritize_smaller_indice*/ true>
: radix_kernel<T,
IdxT,
BitsPerPass,
BlockSize,
/*fused_last_filter*/ false,
/*prioritize_smaller_indice*/ false>;
const T* chunk_in = in + offset * len;
const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr;
T* chunk_out = out + offset * k;
Expand Down Expand Up @@ -930,7 +942,18 @@ void radix_topk(const T* in,
out_idx_buf);

if (fused_last_filter && pass == num_passes - 1) {
kernel = radix_kernel<T, IdxT, BitsPerPass, BlockSize, true>;
kernel = prioritize_smaller_indice ? radix_kernel<T,
IdxT,
BitsPerPass,
BlockSize,
/*fused_last_filter*/ true,
/*prioritize_smaller_indice*/ true>
: radix_kernel<T,
IdxT,
BitsPerPass,
BlockSize,
/*fused_last_filter*/ true,
/*prioritize_smaller_indice*/ false>;
}

kernel<<<blocks, BlockSize, 0, stream>>>(chunk_in,
Expand All @@ -951,16 +974,20 @@ void radix_topk(const T* in,
}

if (!fused_last_filter) {
last_filter_kernel<T, IdxT, BitsPerPass><<<blocks, BlockSize, 0, stream>>>(chunk_in,
chunk_in_idx,
out_buf,
out_idx_buf,
chunk_out,
chunk_out_idx,
len,
k,
counters.data(),
select_min);
auto kernel =
prioritize_smaller_indice
? last_filter_kernel<T, IdxT, BitsPerPass, /*prioritize_smaller_indice*/ true>
: last_filter_kernel<T, IdxT, BitsPerPass, /*prioritize_smaller_indice*/ false>;
kernel<<<blocks, BlockSize, 0, stream>>>(chunk_in,
chunk_in_idx,
out_buf,
out_idx_buf,
chunk_out,
chunk_out_idx,
len,
k,
counters.data(),
select_min);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}
}
Expand Down Expand Up @@ -1046,7 +1073,7 @@ template <typename T,
typename IdxT,
int BitsPerPass,
int BlockSize,
bool stable_last_filter = false>
bool prioritize_smaller_indice = false>
RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
const IdxT* in_idx,
const IdxT len,
Expand Down Expand Up @@ -1121,22 +1148,22 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
__syncthreads();

if ((pass == num_passes - 1)) {
if constexpr (stable_last_filter) {
if constexpr (prioritize_smaller_indice) {
const IdxT num_of_kth_needed = counter.k;
for (IdxT i = threadIdx.x; i < num_of_kth_needed; i += blockDim.x) {
out_idx[k - num_of_kth_needed + i] = cuda::std::numeric_limits<IdxT>::max();
}
__syncthreads();
}
last_filter<T, IdxT, BitsPerPass, stable_last_filter>(out_buf ? out_buf : in,
out_buf ? out_idx_buf : in_idx,
out,
out_idx,
out_buf ? current_len : len,
k,
&counter,
select_min,
pass);
last_filter<T, IdxT, BitsPerPass, prioritize_smaller_indice>(out_buf ? out_buf : in,
out_buf ? out_idx_buf : in_idx,
out,
out_idx,
out_buf ? current_len : len,
k,
&counter,
select_min,
pass);
break;
} else if (counter.len == counter.k) {
last_filter<T, IdxT, BitsPerPass, false>(out_buf ? out_buf : in,
Expand Down Expand Up @@ -1169,11 +1196,22 @@ void radix_topk_one_block(const T* in,
bool select_min,
int sm_cnt,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
rmm::mr::device_memory_resource* mr,
bool prioritize_smaller_indice = false)
{
static_assert(calc_num_passes<T, BitsPerPass>() > 1);

auto kernel = radix_topk_one_block_kernel<T, IdxT, BitsPerPass, BlockSize>;
auto kernel = prioritize_smaller_indice
? radix_topk_one_block_kernel<T,
IdxT,
BitsPerPass,
BlockSize,
/*prioritize_smaller_indice*/ true>
: radix_topk_one_block_kernel<T,
IdxT,
BitsPerPass,
BlockSize,
/*prioritize_smaller_indice*/ false>;
const IdxT buf_len = calc_buf_len<T, IdxT, unsigned>(len);
const size_t max_chunk_size =
calc_chunk_size<T, IdxT, BlockSize>(batch_size, len, sm_cnt, kernel, true);
Expand Down Expand Up @@ -1245,6 +1283,13 @@ void radix_topk_one_block(const T* in,
* @param stream
* @param mr an optional memory resource to use across the calls (you can provide a large enough
* memory pool here to avoid memory allocations within the call).
* @param prioritize_smaller_indice
* For the Kth smallest/largest value, there might be more than one elements with this value.
* So it's possible that more than K elements are eligible to be results.
* By default, we only ensure to output K elements. So only some of the elements with the Kth
* smallest/largest value will be the results. The indices are chosen randomly.
* With this flag set (prioritize_smaller_indice=true), we will always choose smallest several
* indice for the elements with the Kth smallest/largest value.
*/
template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
void select_k(const T* in,
Expand All @@ -1257,7 +1302,8 @@ void select_k(const T* in,
bool select_min,
bool fused_last_filter,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr)
rmm::mr::device_memory_resource* mr = nullptr,
bool prioritize_smaller_indice = false)
{
if (k == len) {
RAFT_CUDA_TRY(
Expand Down Expand Up @@ -1287,14 +1333,34 @@ void select_k(const T* in,
constexpr int items_per_thread = 32;

if (len <= BlockSize * items_per_thread) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(in,
in_idx,
batch_size,
len,
k,
out,
out_idx,
select_min,
sm_cnt,
stream,
mr,
prioritize_smaller_indice);
} else {
unsigned grid_dim =
impl::calc_grid_dim<T, IdxT, BitsPerPass, BlockSize>(batch_size, len, sm_cnt);
if (grid_dim == 1) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(in,
in_idx,
batch_size,
len,
k,
out,
out_idx,
select_min,
sm_cnt,
stream,
mr,
prioritize_smaller_indice);
} else {
impl::radix_topk<T, IdxT, BitsPerPass, BlockSize>(in,
in_idx,
Expand All @@ -1308,7 +1374,8 @@ void select_k(const T* in,
grid_dim,
sm_cnt,
stream,
mr);
mr,
prioritize_smaller_indice);
}
}
}
Expand Down

0 comments on commit 8885046

Please sign in to comment.