Skip to content

Commit

Permalink
Performance optimization of IVF-flat / select_k (rapidsai#2221)
Browse files Browse the repository at this point in the history
This PR is a followup to rapidsai#2169. To enable IVF-flat with k>256 we need an additional select_k invocation which was unexpectedly slow. There are two reasons for that:

First problem is the data handed to select_k: The valid data length per row is much smaller than the conservative maximum that could be achieved by probing the N largest probes. Therefore each query row contains roughly ~50% dummy values. This is also the case for IVF-PQ, but did not show up as prominent due to the second reason.

The second problem, and also a difference to the IVF-PQ algorithm - is that a 64bit payload data type is used for selectK. The performance of selectK with 64bit index type is significantly slower than with 32bit, especially when many elements are in the same range:
```
Benchmark                                                           Time             CPU   Iterations
-----------------------------------------------------------------------------------------------------
SelectK/float/uint32_t/kRadix11bitsExtraPass/1/manual_time       1.68 ms         1.74 ms          413 1357#200000#512
SelectK/float/uint32_t/kRadix11bitsExtraPass/3/manual_time       2.31 ms         2.37 ms          302 1357#200000#512#same-leading-bits
SelectK/float/int64_t/kRadix11bitsExtraPass/1/manual_time        5.92 ms         5.98 ms          116 1357#200000#512
SelectK/float/int64_t/kRadix11bitsExtraPass/3/manual_time        83.7 ms         83.8 ms            8 1357#200000#512#same-leading-bits
-----------------------------------------------------------------------------------------------------
```
The data distribution within a IVF-flat benchmark resulted in a select_k time of ~24ms. 

### scope:
* additional parameter added to select_k to optionally pass individual row lengths for every batch entry. This parameter is utilized by both IVF-Flat and IVF-PQ and results in a ~2x speedup (50 nodes out of 5000) of the final `select_k`. 
* refactor ivf-flat search to work with 32bit indices by storing positions instead of actual indices. This allows to utilize 32bit index type select_k for ~10x speedup in the final `select_k`.

FYI @tfeher @achirkin 

### not in scope:
* General optimization of select_k: In the current implementation there is no difference in the type of the payload and the actual index type. Especially the type of the histogram has a large effect on performance (due to the atomics).

Authors:
  - Malte Förster (https://github.com/mfoerste4)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: rapidsai#2221
  • Loading branch information
mfoerste4 authored Mar 20, 2024
1 parent 0b9692b commit 335236c
Show file tree
Hide file tree
Showing 22 changed files with 221 additions and 99 deletions.
8 changes: 5 additions & 3 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY
Expand All @@ -58,7 +59,8 @@ void select_k(raft::resources const& handle,
IdxT* out_idx, \
bool select_min, \
bool sorted, \
raft::matrix::SelectAlgo algo)
raft::matrix::SelectAlgo algo, \
const IdxT* len_i)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
Expand Down
16 changes: 10 additions & 6 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ void segmented_sort_by_key(raft::resources const& handle,
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
* @param[in] len_i
* array of size (batch_size) providing lengths for each individual row
* only radix select-k supported
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
Expand All @@ -240,8 +243,9 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);
Expand All @@ -262,9 +266,8 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
true // fused_last_filter
);

true, // fused_last_filter
len_i);
} else {
bool fused_last_filter = algo == SelectAlgo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(handle,
Expand All @@ -276,7 +279,8 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
fused_last_filter);
fused_last_filter,
len_i);
}
if (sorted) {
auto offsets = make_device_mdarray<IdxT, IdxT>(
Expand Down
35 changes: 31 additions & 4 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ RAFT_KERNEL radix_kernel(const T* in,
Counter<T, IdxT>* counters,
IdxT* histograms,
const IdxT len,
const IdxT* len_i,
const IdxT k,
const bool select_min,
const int pass)
Expand Down Expand Up @@ -598,6 +599,14 @@ RAFT_KERNEL radix_kernel(const T* in,
in_buf += batch_id * buf_len;
in_idx_buf += batch_id * buf_len;
}

// in case we have individual len for each query defined we want to make sure
// that we only iterate valid elements.
if (len_i != nullptr) {
const IdxT max_len = max(len_i[batch_id], k);
if (max_len < previous_len) previous_len = max_len;
}

// "current_len > buf_len" means current pass will skip writing buffer
if (pass == 0 || current_len > buf_len) {
out_buf = nullptr;
Expand Down Expand Up @@ -829,6 +838,7 @@ void radix_topk(const T* in,
IdxT* out_idx,
bool select_min,
bool fused_last_filter,
const IdxT* len_i,
unsigned grid_dim,
int sm_cnt,
rmm::cuda_stream_view stream,
Expand Down Expand Up @@ -868,6 +878,7 @@ void radix_topk(const T* in,
const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr;
T* chunk_out = out + offset * k;
IdxT* chunk_out_idx = out_idx + offset * k;
const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr;

const T* in_buf = nullptr;
const IdxT* in_idx_buf = nullptr;
Expand Down Expand Up @@ -905,6 +916,7 @@ void radix_topk(const T* in,
counters.data(),
histograms.data(),
len,
chunk_len_i,
k,
select_min,
pass);
Expand Down Expand Up @@ -1007,6 +1019,7 @@ template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
const IdxT* in_idx,
const IdxT len,
const IdxT* len_i,
const IdxT k,
T* out,
IdxT* out_idx,
Expand Down Expand Up @@ -1057,6 +1070,13 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
out_idx_buf = nullptr;
}

// in case we have individual len for each query defined we want to make sure
// that we only iterate valid elements.
if (len_i != nullptr) {
const IdxT max_len = max(len_i[batch_id], k);
if (max_len < previous_len) previous_len = max_len;
}

filter_and_histogram_for_one_block<T, IdxT, BitsPerPass>(in_buf,
in_idx_buf,
out_buf,
Expand Down Expand Up @@ -1106,6 +1126,7 @@ void radix_topk_one_block(const T* in,
T* out,
IdxT* out_idx,
bool select_min,
const IdxT* len_i,
int sm_cnt,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
Expand All @@ -1121,10 +1142,12 @@ void radix_topk_one_block(const T* in,
max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr);

for (size_t offset = 0; offset < static_cast<size_t>(batch_size); offset += max_chunk_size) {
int chunk_size = std::min(max_chunk_size, batch_size - offset);
int chunk_size = std::min(max_chunk_size, batch_size - offset);
const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr;
kernel<<<chunk_size, BlockSize, 0, stream>>>(in + offset * len,
in_idx ? (in_idx + offset * len) : nullptr,
len,
chunk_len_i,
k,
out + offset * k,
out_idx + offset * k,
Expand Down Expand Up @@ -1188,6 +1211,8 @@ void radix_topk_one_block(const T* in,
* blocks is called. The later case is preferable when leading bits of input data are almost the
* same. That is, when the value range of input data is narrow. In such case, there could be a
* large number of inputs for the last filter, hence using multiple thread blocks is beneficial.
* @param len_i
* optional array of size (batch_size) providing lengths for each individual row
*/
template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
void select_k(raft::resources const& res,
Expand All @@ -1199,7 +1224,8 @@ void select_k(raft::resources const& res,
T* out,
IdxT* out_idx,
bool select_min,
bool fused_last_filter)
bool fused_last_filter,
const IdxT* len_i)
{
auto stream = resource::get_cuda_stream(res);
auto mr = resource::get_workspace_resource(res);
Expand All @@ -1223,13 +1249,13 @@ void select_k(raft::resources const& res,

if (len <= BlockSize * items_per_thread) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr);
} else {
unsigned grid_dim =
impl::calc_grid_dim<T, IdxT, BitsPerPass, BlockSize>(batch_size, len, sm_cnt);
if (grid_dim == 1) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr);
} else {
impl::radix_topk<T, IdxT, BitsPerPass, BlockSize>(in,
in_idx,
Expand All @@ -1240,6 +1266,7 @@ void select_k(raft::resources const& res,
out_idx,
select_min,
fused_last_filter,
len_i,
grid_dim,
sm_cnt,
stream,
Expand Down
20 changes: 10 additions & 10 deletions cpp/include/raft/neighbors/detail/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT
return ix_min;
}

template <int BlockDim, typename IdxT1, typename IdxT2 = uint32_t>
template <int BlockDim, typename IdxT>
__launch_bounds__(BlockDim) RAFT_KERNEL
postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk]
const IdxT2* neighbors_in, // [n_queries, topk]
const IdxT1* const* db_indices, // [n_clusters][..]
postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
Expand All @@ -170,7 +170,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices);
const bool valid = chunk_ix < n_probes;
neighbors_out[k] =
valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord<IdxT1>;
valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord<IdxT>;
}

/**
Expand All @@ -180,10 +180,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
* probed clusters / defined by the `chunk_indices`.
* We assume the searched sample sizes (for a single query) fit into `uint32_t`.
*/
template <typename IdxT1, typename IdxT2 = uint32_t>
void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk]
const IdxT2* neighbors_in, // [n_queries, topk]
const IdxT1* const* db_indices, // [n_clusters][..]
template <typename IdxT>
void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
Expand All @@ -193,7 +193,7 @@ void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, to
{
constexpr int kPNThreads = 256;
const int pn_blocks = raft::div_rounding_up_unsafe<size_t>(n_queries * topk, kPNThreads);
postprocess_neighbors_kernel<kPNThreads, IdxT1, IdxT2>
postprocess_neighbors_kernel<kPNThreads, IdxT>
<<<pn_blocks, kPNThreads, 0, stream>>>(neighbors_out,
neighbors_in,
db_indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream) RAFT_EXPLICIT;
Expand All @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)
Expand Down
25 changes: 10 additions & 15 deletions cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
const uint32_t query_smem_elems,
const T* query,
const uint32_t* coarse_index,
const IdxT* const* list_indices_ptrs,
const T* const* list_data_ptrs,
const uint32_t* list_sizes,
const uint32_t queries_offset,
Expand All @@ -700,7 +699,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
const uint32_t* chunk_indices,
const uint32_t dim,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances)
{
extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[];
Expand All @@ -719,16 +718,16 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
distances += query_id * k * gridDim.x + blockIdx.x * k;
} else {
distances += query_id * uint64_t(max_samples);
chunk_indices += (n_probes * query_id);
}
chunk_indices += (n_probes * query_id);
coarse_index += query_id * n_probes;
}

// Copy a part of the query into shared memory for faster processing
copy_vectorized(query_shared, query, std::min(dim, query_smem_elems));
__syncthreads();

using local_topk_t = block_sort_t<Capacity, Ascending, float, IdxT>;
using local_topk_t = block_sort_t<Capacity, Ascending, float, uint32_t>;
local_topk_t queue(k);
{
using align_warp = Pow2<WarpSize>;
Expand All @@ -752,11 +751,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2

uint32_t sample_offset = 0;
if constexpr (!kManageLocalTopK) {
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);
}
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);

constexpr int kUnroll = WarpSize / Veclen;
constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize;
Expand Down Expand Up @@ -806,8 +803,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
// Enqueue one element per thread
const float val = valid ? static_cast<float>(dist) : local_topk_t::queue_t::kDummy;
if constexpr (kManageLocalTopK) {
const size_t idx = valid ? static_cast<size_t>(list_indices_ptrs[list_id][vec_id]) : 0;
queue.add(val, idx);
queue.add(val, sample_offset + vec_id);
} else {
if (vec_id < list_length) distances[sample_offset + vec_id] = val;
}
Expand Down Expand Up @@ -873,7 +869,7 @@ void launch_kernel(Lambda lambda,
const uint32_t max_samples,
const uint32_t* chunk_indices,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down Expand Up @@ -927,7 +923,6 @@ void launch_kernel(Lambda lambda,
query_smem_elems,
queries,
coarse_index,
index.inds_ptrs().data_handle(),
index.data_ptrs().data_handle(),
index.list_sizes().data_handle(),
queries_offset + query_offset,
Expand All @@ -945,8 +940,8 @@ void launch_kernel(Lambda lambda,
distances += grid_dim_y * grid_dim_x * k;
} else {
distances += grid_dim_y * max_samples;
chunk_indices += grid_dim_y * n_probes;
}
chunk_indices += grid_dim_y * n_probes;
coarse_index += grid_dim_y * n_probes;
}
}
Expand Down Expand Up @@ -1161,7 +1156,7 @@ void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down
Loading

0 comments on commit 335236c

Please sign in to comment.