Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IVF-PQ: Fix illegal memory access with large max_samples #1685

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ template <typename OutT,
int Capacity,
bool PrecompBaseDiff,
bool EnableSMemLut>
__global__ void compute_similarity_kernel(uint32_t n_rows,
uint32_t dim,
__global__ void compute_similarity_kernel(uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
uint32_t n_queries,
Expand Down Expand Up @@ -82,7 +81,6 @@ struct selected {
template <typename OutT, typename LutT, typename IvfSampleFilterT>
void compute_similarity_run(selected<OutT, LutT, IvfSampleFilterT> s,
rmm::cuda_stream_view stream,
uint32_t n_rows,
uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
Expand Down Expand Up @@ -156,7 +154,6 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim,
* Setting this to `false` allows to reduce the shared memory usage (and maximum data dim)
* at the cost of reducing global memory reading throughput.
*
* @param n_rows the number of records in the dataset
* @param dim the dimensionality of the data (NB: after rotation transform, i.e. `index.rot_dim()`).
* @param n_probes the number of clusters to search for each query
* @param pq_dim
Expand Down Expand Up @@ -251,8 +250,7 @@ template <typename OutT,
int Capacity,
bool PrecompBaseDiff,
bool EnableSMemLut>
__global__ void compute_similarity_kernel(uint32_t n_rows,
uint32_t dim,
__global__ void compute_similarity_kernel(uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
uint32_t n_queries,
Expand Down Expand Up @@ -327,14 +325,15 @@ __global__ void compute_similarity_kernel(uint32_t n_rows,
uint32_t* out_indices = nullptr;
if constexpr (kManageLocalTopK) {
// Store topk calculated distances to out_scores (and its indices to out_indices)
out_scores = _out_scores + topk * (probe_ix + (n_probes * query_ix));
out_indices = _out_indices + topk * (probe_ix + (n_probes * query_ix));
const uint64_t out_offset = probe_ix + n_probes * query_ix;
out_scores = _out_scores + out_offset * topk;
out_indices = _out_indices + out_offset * topk;
} else {
// Store all calculated distances to out_scores
out_scores = _out_scores + max_samples * query_ix;
out_scores = _out_scores + uint64_t(max_samples) * query_ix;
}
uint32_t label = cluster_labels[n_probes * query_ix + probe_ix];
const float* cluster_center = cluster_centers + (dim * label);
const float* cluster_center = cluster_centers + dim * label;
const float* pq_center;
if (codebook_kind == codebook_gen::PER_SUBSPACE) {
pq_center = pq_centers;
Expand Down Expand Up @@ -602,7 +601,6 @@ template <typename OutT,
typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter>
void compute_similarity_run(selected<OutT, LutT, IvfSampleFilterT> s,
rmm::cuda_stream_view stream,
uint32_t n_rows,
uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
Expand All @@ -625,8 +623,7 @@ void compute_similarity_run(selected<OutT, LutT, IvfSampleFilterT> s,
OutT* _out_scores,
uint32_t* _out_indices)
{
s.kernel<<<s.grid_dim, s.block_dim, s.smem_size, stream>>>(n_rows,
dim,
s.kernel<<<s.grid_dim, s.block_dim, s.smem_size, stream>>>(dim,
n_probes,
pq_dim,
n_queries,
Expand Down
29 changes: 17 additions & 12 deletions cpp/include/raft/neighbors/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,10 @@ void ivfpq_search_worker(raft::resources const& handle,
auto stream = resource::get_cuda_stream(handle);
auto mr = resource::get_workspace_resource(handle);

bool manage_local_topk = is_local_topk_feasible(topK, n_probes, n_queries);
auto topk_len = manage_local_topk ? n_probes * topK : max_samples;
bool manage_local_topk = is_local_topk_feasible(topK, n_probes, n_queries);
auto topk_len = manage_local_topk ? n_probes * topK : max_samples;
std::size_t n_queries_probes = std::size_t(n_queries) * std::size_t(n_probes);
std::size_t n_queries_topk_len = std::size_t(n_queries) * std::size_t(topk_len);
if (manage_local_topk) {
RAFT_LOG_DEBUG("Fused version of the search kernel is selected (manage_local_topk == true)");
} else {
Expand All @@ -448,13 +450,13 @@ void ivfpq_search_worker(raft::resources const& handle,
rmm::device_uvector<uint32_t> index_list_sorted_buf(0, stream, mr);
uint32_t* index_list_sorted = nullptr;
rmm::device_uvector<uint32_t> num_samples(n_queries, stream, mr);
rmm::device_uvector<uint32_t> chunk_index(n_queries * n_probes, stream, mr);
rmm::device_uvector<uint32_t> chunk_index(n_queries_probes, stream, mr);
// [maxBatchSize, max_samples] or [maxBatchSize, n_probes, topk]
rmm::device_uvector<ScoreT> distances_buf(n_queries * topk_len, stream, mr);
rmm::device_uvector<ScoreT> distances_buf(n_queries_topk_len, stream, mr);
rmm::device_uvector<uint32_t> neighbors_buf(0, stream, mr);
uint32_t* neighbors_ptr = nullptr;
if (manage_local_topk) {
neighbors_buf.resize(n_queries * topk_len, stream);
neighbors_buf.resize(n_queries_topk_len, stream);
neighbors_ptr = neighbors_buf.data();
}
rmm::device_uvector<uint32_t> neighbors_uint32_buf(0, stream, mr);
Expand All @@ -479,10 +481,10 @@ void ivfpq_search_worker(raft::resources const& handle,
// The goal is to incrase the L2 cache hit rate to read the vectors
// of a cluster by processing the cluster at the same time as much as
// possible.
index_list_sorted_buf.resize(n_queries * n_probes, stream);
index_list_sorted_buf.resize(n_queries_probes, stream);
auto index_list_buf =
make_device_mdarray<uint32_t>(handle, mr, make_extents<uint32_t>(n_queries * n_probes));
rmm::device_uvector<uint32_t> cluster_labels_out(n_queries * n_probes, stream, mr);
make_device_mdarray<uint32_t>(handle, mr, make_extents<uint32_t>(n_queries_probes));
rmm::device_uvector<uint32_t> cluster_labels_out(n_queries_probes, stream, mr);
auto index_list = index_list_buf.data_handle();
index_list_sorted = index_list_sorted_buf.data();

Expand All @@ -497,7 +499,7 @@ void ivfpq_search_worker(raft::resources const& handle,
cluster_labels_out.data(),
index_list,
index_list_sorted,
n_queries * n_probes,
n_queries_probes,
begin_bit,
end_bit,
stream);
Expand All @@ -508,7 +510,7 @@ void ivfpq_search_worker(raft::resources const& handle,
cluster_labels_out.data(),
index_list,
index_list_sorted,
n_queries * n_probes,
n_queries_probes,
begin_bit,
end_bit,
stream);
Expand Down Expand Up @@ -558,7 +560,6 @@ void ivfpq_search_worker(raft::resources const& handle,
}
compute_similarity_run(search_instance,
stream,
index.size(),
index.rot_dim(),
n_probes,
index.pq_dim(),
Expand Down Expand Up @@ -706,7 +707,11 @@ inline auto get_max_batch_size(raft::resources const& res,
}
// Check in the tmp distance buffer is not too big
auto ws_size = [k, n_probes, max_samples](uint32_t bs) -> uint64_t {
return uint64_t(is_local_topk_feasible(k, n_probes, bs) ? k * n_probes : max_samples) * bs;
const uint64_t buffers_fused = 12ull * k * n_probes;
const uint64_t buffers_non_fused = 4ull * max_samples;
const uint64_t other = 32ull * n_probes;
return static_cast<uint64_t>(bs) *
(other + (is_local_topk_feasible(k, n_probes, bs) ? buffers_fused : buffers_non_fused));
};
auto max_ws_size = resource::get_workspace_free_bytes(res);
if (ws_size(max_batch_size) > max_ws_size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
template void raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \\
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \\
rmm::cuda_stream_view stream, \\
uint32_t n_rows, \\
uint32_t dim, \\
uint32_t n_probes, \\
uint32_t pq_dim, \\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down