Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance optimization of IVF-flat / select_k #2221

Merged
merged 11 commits into from
Mar 20, 2024
8 changes: 5 additions & 3 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY
Expand All @@ -58,7 +59,8 @@ void select_k(raft::resources const& handle,
IdxT* out_idx, \
bool select_min, \
bool sorted, \
raft::matrix::SelectAlgo algo)
raft::matrix::SelectAlgo algo, \
const IdxT* len_i)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
Expand Down
16 changes: 10 additions & 6 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ void segmented_sort_by_key(raft::resources const& handle,
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
* @param[in] len_i
* array of size (batch_size) providing lengths for each individual row
* only radix select-k supported
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
Expand All @@ -240,8 +243,9 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);
Expand All @@ -262,9 +266,8 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
true // fused_last_filter
);

true, // fused_last_filter
len_i);
} else {
bool fused_last_filter = algo == SelectAlgo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(handle,
Expand All @@ -276,7 +279,8 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
fused_last_filter);
fused_last_filter,
len_i);
}
if (sorted) {
auto offsets = make_device_mdarray<IdxT, IdxT>(
Expand Down
35 changes: 31 additions & 4 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ RAFT_KERNEL radix_kernel(const T* in,
Counter<T, IdxT>* counters,
IdxT* histograms,
const IdxT len,
const IdxT* len_i,
const IdxT k,
const bool select_min,
const int pass)
Expand Down Expand Up @@ -598,6 +599,14 @@ RAFT_KERNEL radix_kernel(const T* in,
in_buf += batch_id * buf_len;
in_idx_buf += batch_id * buf_len;
}

// in case we have individual len for each query defined we want to make sure
// that we only iterate valid elements.
if (len_i != nullptr) {
const IdxT max_len = max(len_i[batch_id], k);
if (max_len < previous_len) previous_len = len_i[batch_id];
}

// "current_len > buf_len" means current pass will skip writing buffer
if (pass == 0 || current_len > buf_len) {
out_buf = nullptr;
Expand Down Expand Up @@ -829,6 +838,7 @@ void radix_topk(const T* in,
IdxT* out_idx,
bool select_min,
bool fused_last_filter,
const IdxT* len_i,
unsigned grid_dim,
int sm_cnt,
rmm::cuda_stream_view stream,
Expand Down Expand Up @@ -868,6 +878,7 @@ void radix_topk(const T* in,
const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr;
T* chunk_out = out + offset * k;
IdxT* chunk_out_idx = out_idx + offset * k;
const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr;

const T* in_buf = nullptr;
const IdxT* in_idx_buf = nullptr;
Expand Down Expand Up @@ -905,6 +916,7 @@ void radix_topk(const T* in,
counters.data(),
histograms.data(),
len,
chunk_len_i,
k,
select_min,
pass);
Expand Down Expand Up @@ -1007,6 +1019,7 @@ template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
const IdxT* in_idx,
const IdxT len,
const IdxT* len_i,
const IdxT k,
T* out,
IdxT* out_idx,
Expand Down Expand Up @@ -1057,6 +1070,13 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
out_idx_buf = nullptr;
}

// in case we have individual len for each query defined we want to make sure
// that we only iterate valid elements.
if (len_i != nullptr) {
const IdxT max_len = max(len_i[batch_id], k);
if (max_len < previous_len) previous_len = len_i[batch_id];
}

filter_and_histogram_for_one_block<T, IdxT, BitsPerPass>(in_buf,
in_idx_buf,
out_buf,
Expand Down Expand Up @@ -1106,6 +1126,7 @@ void radix_topk_one_block(const T* in,
T* out,
IdxT* out_idx,
bool select_min,
const IdxT* len_i,
int sm_cnt,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
Expand All @@ -1121,10 +1142,12 @@ void radix_topk_one_block(const T* in,
max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr);

for (size_t offset = 0; offset < static_cast<size_t>(batch_size); offset += max_chunk_size) {
int chunk_size = std::min(max_chunk_size, batch_size - offset);
int chunk_size = std::min(max_chunk_size, batch_size - offset);
const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr;
kernel<<<chunk_size, BlockSize, 0, stream>>>(in + offset * len,
in_idx ? (in_idx + offset * len) : nullptr,
len,
chunk_len_i,
k,
out + offset * k,
out_idx + offset * k,
Expand Down Expand Up @@ -1188,6 +1211,8 @@ void radix_topk_one_block(const T* in,
* blocks is called. The later case is preferable when leading bits of input data are almost the
* same. That is, when the value range of input data is narrow. In such case, there could be a
* large number of inputs for the last filter, hence using multiple thread blocks is beneficial.
* @param len_i
* optional array of size (batch_size) providing lengths for each individual row
*/
template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
void select_k(raft::resources const& res,
Expand All @@ -1199,7 +1224,8 @@ void select_k(raft::resources const& res,
T* out,
IdxT* out_idx,
bool select_min,
bool fused_last_filter)
bool fused_last_filter,
const IdxT* len_i)
{
auto stream = resource::get_cuda_stream(res);
auto mr = resource::get_workspace_resource(res);
Expand All @@ -1223,13 +1249,13 @@ void select_k(raft::resources const& res,

if (len <= BlockSize * items_per_thread) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr);
} else {
unsigned grid_dim =
impl::calc_grid_dim<T, IdxT, BitsPerPass, BlockSize>(batch_size, len, sm_cnt);
if (grid_dim == 1) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr);
} else {
impl::radix_topk<T, IdxT, BitsPerPass, BlockSize>(in,
in_idx,
Expand All @@ -1240,6 +1266,7 @@ void select_k(raft::resources const& res,
out_idx,
select_min,
fused_last_filter,
len_i,
grid_dim,
sm_cnt,
stream,
Expand Down
20 changes: 10 additions & 10 deletions cpp/include/raft/neighbors/detail/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT
return ix_min;
}

template <int BlockDim, typename IdxT1, typename IdxT2 = uint32_t>
template <int BlockDim, typename IdxT>
__launch_bounds__(BlockDim) RAFT_KERNEL
postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk]
const IdxT2* neighbors_in, // [n_queries, topk]
const IdxT1* const* db_indices, // [n_clusters][..]
postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
Expand All @@ -170,7 +170,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices);
const bool valid = chunk_ix < n_probes;
neighbors_out[k] =
valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord<IdxT1>;
valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord<IdxT>;
}

/**
Expand All @@ -180,10 +180,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
* probed clusters / defined by the `chunk_indices`.
* We assume the searched sample sizes (for a single query) fit into `uint32_t`.
*/
template <typename IdxT1, typename IdxT2 = uint32_t>
void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk]
const IdxT2* neighbors_in, // [n_queries, topk]
const IdxT1* const* db_indices, // [n_clusters][..]
template <typename IdxT>
void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
Expand All @@ -193,7 +193,7 @@ void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, to
{
constexpr int kPNThreads = 256;
const int pn_blocks = raft::div_rounding_up_unsafe<size_t>(n_queries * topk, kPNThreads);
postprocess_neighbors_kernel<kPNThreads, IdxT1, IdxT2>
postprocess_neighbors_kernel<kPNThreads, IdxT>
<<<pn_blocks, kPNThreads, 0, stream>>>(neighbors_out,
neighbors_in,
db_indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream) RAFT_EXPLICIT;
Expand All @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
const uint32_t query_smem_elems,
const T* query,
const uint32_t* coarse_index,
const IdxT* const* list_indices_ptrs,
const T* const* list_data_ptrs,
const uint32_t* list_sizes,
const uint32_t queries_offset,
Expand All @@ -700,7 +699,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
const uint32_t* chunk_indices,
const uint32_t dim,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances)
{
extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[];
Expand All @@ -719,16 +718,16 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
distances += query_id * k * gridDim.x + blockIdx.x * k;
} else {
distances += query_id * uint64_t(max_samples);
chunk_indices += (n_probes * query_id);
}
chunk_indices += (n_probes * query_id);
coarse_index += query_id * n_probes;
}

// Copy a part of the query into shared memory for faster processing
copy_vectorized(query_shared, query, std::min(dim, query_smem_elems));
__syncthreads();

using local_topk_t = block_sort_t<Capacity, Ascending, float, IdxT>;
using local_topk_t = block_sort_t<Capacity, Ascending, float, uint32_t>;
local_topk_t queue(k);
{
using align_warp = Pow2<WarpSize>;
Expand All @@ -752,11 +751,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2

uint32_t sample_offset = 0;
if constexpr (!kManageLocalTopK) {
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);
}
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);

constexpr int kUnroll = WarpSize / Veclen;
constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize;
Expand Down Expand Up @@ -806,8 +803,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
// Enqueue one element per thread
const float val = valid ? static_cast<float>(dist) : local_topk_t::queue_t::kDummy;
if constexpr (kManageLocalTopK) {
const size_t idx = valid ? static_cast<size_t>(list_indices_ptrs[list_id][vec_id]) : 0;
mfoerste4 marked this conversation as resolved.
Show resolved Hide resolved
queue.add(val, idx);
queue.add(val, sample_offset + vec_id);
} else {
if (vec_id < list_length) distances[sample_offset + vec_id] = val;
}
Expand Down Expand Up @@ -873,7 +869,7 @@ void launch_kernel(Lambda lambda,
const uint32_t max_samples,
const uint32_t* chunk_indices,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down Expand Up @@ -927,7 +923,6 @@ void launch_kernel(Lambda lambda,
query_smem_elems,
queries,
coarse_index,
index.inds_ptrs().data_handle(),
index.data_ptrs().data_handle(),
index.list_sizes().data_handle(),
queries_offset + query_offset,
Expand Down Expand Up @@ -1161,7 +1156,7 @@ void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down
Loading
Loading