Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance optimization of IVF-flat / select_k #2221

Merged
merged 11 commits into from
Mar 20, 2024
8 changes: 5 additions & 3 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY
Expand All @@ -58,7 +59,8 @@ void select_k(raft::resources const& handle,
IdxT* out_idx, \
bool select_min, \
bool sorted, \
raft::matrix::SelectAlgo algo)
raft::matrix::SelectAlgo algo, \
const IdxT* len_i)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
Expand Down
16 changes: 10 additions & 6 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ void segmented_sort_by_key(raft::resources const& handle,
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
* @param[in] len_i
* array of size (batch_size) providing lengths for each individual row
* only radix select-k supported
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
Expand All @@ -240,8 +243,9 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);
Expand All @@ -262,9 +266,8 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
true // fused_last_filter
);

true, // fused_last_filter
len_i);
} else {
bool fused_last_filter = algo == SelectAlgo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(handle,
Expand All @@ -276,7 +279,8 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
fused_last_filter);
fused_last_filter,
len_i);
}
if (sorted) {
auto offsets = make_device_mdarray<IdxT, IdxT>(
Expand Down
37 changes: 33 additions & 4 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ RAFT_KERNEL last_filter_kernel(const T* in,
T* out,
IdxT* out_idx,
const IdxT len,
const IdxT* len_i,
mfoerste4 marked this conversation as resolved.
Show resolved Hide resolved
const IdxT k,
Counter<T, IdxT>* counters,
const bool select_min)
Expand Down Expand Up @@ -557,6 +558,7 @@ RAFT_KERNEL radix_kernel(const T* in,
Counter<T, IdxT>* counters,
IdxT* histograms,
const IdxT len,
const IdxT* len_i,
const IdxT k,
const bool select_min,
const int pass)
Expand Down Expand Up @@ -598,6 +600,14 @@ RAFT_KERNEL radix_kernel(const T* in,
in_buf += batch_id * buf_len;
in_idx_buf += batch_id * buf_len;
}

// in case we have individual len for each query defined we want to make sure
// that we only iterate valid elements.
if (len_i != nullptr) {
const IdxT max_len = max(len_i[batch_id], k);
if (max_len < previous_len) previous_len = len_i[batch_id];
}

// "current_len > buf_len" means current pass will skip writing buffer
if (pass == 0 || current_len > buf_len) {
out_buf = nullptr;
Expand Down Expand Up @@ -829,6 +839,7 @@ void radix_topk(const T* in,
IdxT* out_idx,
bool select_min,
bool fused_last_filter,
const IdxT* len_i,
unsigned grid_dim,
int sm_cnt,
rmm::cuda_stream_view stream,
Expand Down Expand Up @@ -868,6 +879,7 @@ void radix_topk(const T* in,
const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr;
T* chunk_out = out + offset * k;
IdxT* chunk_out_idx = out_idx + offset * k;
const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr;

const T* in_buf = nullptr;
const IdxT* in_idx_buf = nullptr;
Expand Down Expand Up @@ -905,6 +917,7 @@ void radix_topk(const T* in,
counters.data(),
histograms.data(),
len,
chunk_len_i,
k,
select_min,
pass);
Expand All @@ -919,6 +932,7 @@ void radix_topk(const T* in,
chunk_out,
chunk_out_idx,
len,
chunk_len_i,
k,
counters.data(),
select_min);
Expand Down Expand Up @@ -1007,6 +1021,7 @@ template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
const IdxT* in_idx,
const IdxT len,
const IdxT* len_i,
const IdxT k,
T* out,
IdxT* out_idx,
Expand Down Expand Up @@ -1057,6 +1072,13 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
out_idx_buf = nullptr;
}

// in case we have individual len for each query defined we want to make sure
// that we only iterate valid elements.
if (len_i != nullptr) {
const IdxT max_len = max(len_i[batch_id], k);
if (max_len < previous_len) previous_len = len_i[batch_id];
}

filter_and_histogram_for_one_block<T, IdxT, BitsPerPass>(in_buf,
in_idx_buf,
out_buf,
Expand Down Expand Up @@ -1106,6 +1128,7 @@ void radix_topk_one_block(const T* in,
T* out,
IdxT* out_idx,
bool select_min,
const IdxT* len_i,
int sm_cnt,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
Expand All @@ -1121,10 +1144,12 @@ void radix_topk_one_block(const T* in,
max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr);

for (size_t offset = 0; offset < static_cast<size_t>(batch_size); offset += max_chunk_size) {
int chunk_size = std::min(max_chunk_size, batch_size - offset);
int chunk_size = std::min(max_chunk_size, batch_size - offset);
const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr;
kernel<<<chunk_size, BlockSize, 0, stream>>>(in + offset * len,
in_idx ? (in_idx + offset * len) : nullptr,
len,
chunk_len_i,
k,
out + offset * k,
out_idx + offset * k,
Expand Down Expand Up @@ -1188,6 +1213,8 @@ void radix_topk_one_block(const T* in,
* blocks is called. The later case is preferable when leading bits of input data are almost the
* same. That is, when the value range of input data is narrow. In such case, there could be a
* large number of inputs for the last filter, hence using multiple thread blocks is beneficial.
* @param len_i
* optional array of size (batch_size) providing lengths for each individual row
*/
template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
void select_k(raft::resources const& res,
Expand All @@ -1199,7 +1226,8 @@ void select_k(raft::resources const& res,
T* out,
IdxT* out_idx,
bool select_min,
bool fused_last_filter)
bool fused_last_filter,
const IdxT* len_i)
{
auto stream = resource::get_cuda_stream(res);
auto mr = resource::get_workspace_resource(res);
Expand All @@ -1223,13 +1251,13 @@ void select_k(raft::resources const& res,

if (len <= BlockSize * items_per_thread) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr);
} else {
unsigned grid_dim =
impl::calc_grid_dim<T, IdxT, BitsPerPass, BlockSize>(batch_size, len, sm_cnt);
if (grid_dim == 1) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr);
} else {
impl::radix_topk<T, IdxT, BitsPerPass, BlockSize>(in,
in_idx,
Expand All @@ -1240,6 +1268,7 @@ void select_k(raft::resources const& res,
out_idx,
select_min,
fused_last_filter,
len_i,
grid_dim,
sm_cnt,
stream,
Expand Down
9 changes: 5 additions & 4 deletions cpp/include/raft/neighbors/detail/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ struct dummy_block_sort_t {
* in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total
* number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples.
*/
template <int BlockDim>
template <int BlockDim, typename IdxT>
__launch_bounds__(BlockDim) RAFT_KERNEL
calc_chunk_indices_kernel(uint32_t n_probes,
const uint32_t* cluster_sizes, // [n_clusters]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t* n_samples // [n_queries]
IdxT* n_samples // [n_queries]
)
{
using block_scan = cub::BlockScan<uint32_t, BlockDim>;
Expand All @@ -75,6 +75,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; }
}

template <typename IdxT>
struct calc_chunk_indices {
public:
struct configured {
Expand All @@ -86,7 +87,7 @@ struct calc_chunk_indices {
inline void operator()(const uint32_t* cluster_sizes,
const uint32_t* clusters_to_probe,
uint32_t* chunk_indices,
uint32_t* n_samples,
IdxT* n_samples,
rmm::cuda_stream_view stream)
{
void* args[] = // NOLINT
Expand All @@ -107,7 +108,7 @@ struct calc_chunk_indices {
if constexpr (BlockDim >= WarpSize * 2) {
if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); }
}
return {reinterpret_cast<void*>(calc_chunk_indices_kernel<BlockDim>),
return {reinterpret_cast<void*>(calc_chunk_indices_kernel<BlockDim, IdxT>),
dim3(BlockDim, 1, 1),
dim3(n_queries, 1, 1),
n_probes};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream) RAFT_EXPLICIT;
Expand All @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
const uint32_t* chunk_indices,
const uint32_t dim,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances)
{
extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[];
Expand Down Expand Up @@ -752,11 +752,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2

uint32_t sample_offset = 0;
if constexpr (!kManageLocalTopK) {
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);
}
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);

constexpr int kUnroll = WarpSize / Veclen;
constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize;
Expand Down Expand Up @@ -806,8 +804,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
// Enqueue one element per thread
const float val = valid ? static_cast<float>(dist) : local_topk_t::queue_t::kDummy;
if constexpr (kManageLocalTopK) {
const size_t idx = valid ? static_cast<size_t>(list_indices_ptrs[list_id][vec_id]) : 0;
mfoerste4 marked this conversation as resolved.
Show resolved Hide resolved
queue.add(val, idx);
queue.add(val, sample_offset + vec_id);
} else {
if (vec_id < list_length) distances[sample_offset + vec_id] = val;
}
Expand Down Expand Up @@ -873,7 +870,7 @@ void launch_kernel(Lambda lambda,
const uint32_t max_samples,
const uint32_t* chunk_indices,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down Expand Up @@ -1161,7 +1158,7 @@ void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down
Loading
Loading