Skip to content

Commit

Permalink
ivf-pq post integration hotfixes (#878)
Browse files Browse the repository at this point in the history
Fixes to ivf_pq::search:

  1. Fixed a typo in argument name of `select_cluster` (n_queries -> queries_batch) which led to illegal memory access for large batches.
  2. Removed unnecessary argument `max_batch_size` from the worker function.
  3. Replaced `<<<bracket-calling>>>` of dynamically selected kernels with `cudaLaunchKernel` invocations.
     The former led to the kernels silently being not called at all in some cases for no apparent reason (not reproducible in tests).

Fixes to ivf_pq::build:
  1. A missing include `raft/util/device_atomics.cuh` passed unnoticed due to transient dependencies.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #878
  • Loading branch information
achirkin authored Oct 4, 2022
1 parent 97303db commit e7bf57c
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 47 deletions.
1 change: 1 addition & 0 deletions cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <raft/random/rng.cuh>
#include <raft/stats/histogram.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/device_atomics.cuh>
#include <raft/util/pow2_utils.cuh>

#include <rmm/cuda_stream_view.hpp>
Expand Down
66 changes: 33 additions & 33 deletions cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -262,22 +262,21 @@ __launch_bounds__(BlockDim) __global__
template <typename IdxT>
struct calc_chunk_indices {
public:
using kernel_t = void (*)(uint32_t, const IdxT*, const uint32_t*, uint32_t*, uint32_t*);

struct configured {
kernel_t kernel;
uint32_t block_dim;
void* kernel;
dim3 block_dim;
dim3 grid_dim;
uint32_t n_probes;
uint32_t n_queries;

void operator()(const IdxT* cluster_offsets,
const uint32_t* clusters_to_probe,
uint32_t* chunk_indices,
uint32_t* n_samples,
rmm::cuda_stream_view stream)
{
kernel<<<n_queries, block_dim, 0, stream>>>(
n_probes, cluster_offsets, clusters_to_probe, chunk_indices, n_samples);
void* args[] = // NOLINT
{&n_probes, &cluster_offsets, &clusters_to_probe, &chunk_indices, &n_samples};
RAFT_CUDA_TRY(cudaLaunchKernel(kernel, grid_dim, block_dim, args, 0, stream));
}
};

Expand All @@ -293,7 +292,10 @@ struct calc_chunk_indices {
if constexpr (BlockDim >= WarpSize * 2) {
if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); }
}
return {calc_chunk_indices_kernel<BlockDim, IdxT>, BlockDim, n_probes, n_queries};
return {reinterpret_cast<void*>(calc_chunk_indices_kernel<BlockDim, IdxT>),
dim3(BlockDim, 1, 1),
dim3(n_queries, 1, 1),
n_probes};
}
};

Expand Down Expand Up @@ -830,16 +832,17 @@ struct ivfpq_compute_similarity {
};

struct selected {
kernel_t kernel;
uint32_t n_blocks;
uint32_t n_threads;
void* kernel;
dim3 grid_dim;
dim3 block_dim;
size_t smem_size;
size_t device_lut_size;

template <typename... Args>
void operator()(rmm::cuda_stream_view stream, Args&&... args)
void operator()(rmm::cuda_stream_view stream, Args... args)
{
kernel<<<n_blocks, n_threads, smem_size, stream>>>(std::forward<Args>(args)...);
void* xs[] = {&args...}; // NOLINT
RAFT_CUDA_TRY(cudaLaunchKernel(kernel, grid_dim, block_dim, xs, smem_size, stream));
}
};

Expand Down Expand Up @@ -967,7 +970,11 @@ struct ivfpq_compute_similarity {
}

uint32_t device_lut_size = use_smem_lut ? 0u : n_blocks * (pq_dim << pq_bits);
return {kernel, n_blocks, n_threads, smem_size, device_lut_size};
return {reinterpret_cast<void*>(kernel),
dim3(n_blocks, 1, 1),
dim3(n_threads, 1, 1),
smem_size,
device_lut_size};
}
};

Expand All @@ -984,7 +991,6 @@ void ivfpq_search_worker(const handle_t& handle,
const index<IdxT>& index,
uint32_t max_samples,
uint32_t n_probes,
uint32_t max_batch_size,
uint32_t topK,
uint32_t preferred_thread_block_size,
uint32_t n_queries,
Expand All @@ -994,10 +1000,6 @@ void ivfpq_search_worker(const handle_t& handle,
float* distances, // [n_queries, topK]
rmm::mr::device_memory_resource* mr)
{
RAFT_EXPECTS(n_queries <= max_batch_size,
"number of queries (%u) must be smaller the max batch size (%u)",
n_queries,
max_batch_size);
auto stream = handle.get_stream();

auto pq_centers = index.pq_centers().data_handle();
Expand All @@ -1006,10 +1008,10 @@ void ivfpq_search_worker(const handle_t& handle,
auto cluster_centers = index.centers_rot().data_handle();
auto cluster_offsets = index.list_offsets().data_handle();

bool manage_local_topk =
topK <= kMaxCapacity // depth is not too large
&& n_probes >= 16 // not too few clusters looked up
&& max_batch_size * n_probes >= 256 // overall amount of work is not too small
bool manage_local_topk = topK <= kMaxCapacity // depth is not too large
&& n_probes >= 16 // not too few clusters looked up
&&
n_queries * n_probes >= 256 // overall amount of work is not too small
;
auto topk_len = manage_local_topk ? n_probes * topK : max_samples;
if (manage_local_topk) {
Expand All @@ -1021,14 +1023,14 @@ void ivfpq_search_worker(const handle_t& handle,

rmm::device_uvector<uint32_t> index_list_sorted_buf(0, stream, mr);
uint32_t* index_list_sorted = nullptr;
rmm::device_uvector<uint32_t> num_samples(max_batch_size, stream, mr);
rmm::device_uvector<uint32_t> chunk_index(max_batch_size * n_probes, stream, mr);
rmm::device_uvector<uint32_t> num_samples(n_queries, stream, mr);
rmm::device_uvector<uint32_t> chunk_index(n_queries * n_probes, stream, mr);
// [maxBatchSize, max_samples] or [maxBatchSize, n_probes, topk]
rmm::device_uvector<ScoreT> distances_buf(max_batch_size * topk_len, stream, mr);
rmm::device_uvector<ScoreT> distances_buf(n_queries * topk_len, stream, mr);
rmm::device_uvector<IdxT> neighbors_buf(0, stream, mr);
IdxT* neighbors_ptr = nullptr;
if (manage_local_topk) {
neighbors_buf.resize(max_batch_size * topk_len, stream);
neighbors_buf.resize(n_queries * topk_len, stream);
neighbors_ptr = neighbors_buf.data();
}

Expand All @@ -1040,9 +1042,9 @@ void ivfpq_search_worker(const handle_t& handle,
// The goal is to incrase the L2 cache hit rate to read the vectors
// of a cluster by processing the cluster at the same time as much as
// possible.
index_list_sorted_buf.resize(max_batch_size * n_probes, stream);
rmm::device_uvector<uint32_t> index_list_buf(max_batch_size * n_probes, stream, mr);
rmm::device_uvector<uint32_t> cluster_labels_out(max_batch_size * n_probes, stream, mr);
index_list_sorted_buf.resize(n_queries * n_probes, stream);
rmm::device_uvector<uint32_t> index_list_buf(n_queries * n_probes, stream, mr);
rmm::device_uvector<uint32_t> cluster_labels_out(n_queries * n_probes, stream, mr);
auto index_list = index_list_buf.data();
index_list_sorted = index_list_sorted_buf.data();
thrust::sequence(handle.get_thrust_policy(),
Expand Down Expand Up @@ -1150,7 +1152,6 @@ struct ivfpq_search {
uint32_t,
uint32_t,
uint32_t,
uint32_t,
const uint32_t*,
const float*,
IdxT*,
Expand Down Expand Up @@ -1319,7 +1320,7 @@ inline void search(const handle_t& handle,
select_clusters(handle,
clusters_to_probe.data(),
float_queries.data(),
n_queries,
queries_batch,
params.n_probes,
index.n_lists(),
dim,
Expand Down Expand Up @@ -1358,7 +1359,6 @@ inline void search(const handle_t& handle,
index,
max_samples,
params.n_probes,
max_batch_size,
k,
params.preferred_thread_block_size,
batch_size,
Expand Down
28 changes: 28 additions & 0 deletions cpp/test/spatial/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,34 @@ inline auto var_k() -> test_cases_t
});
}

/**
* Cases brought up from downstream projects.
*/
inline auto special_cases() -> test_cases_t
{
test_cases_t xs;

#define ADD_CASE(f) \
do { \
xs.push_back({}); \
([](ivf_pq_inputs & x) f)(xs[xs.size() - 1]); \
} while (0);

ADD_CASE({
x.num_db_vecs = 1183514;
x.dim = 100;
x.num_queries = 10000;
x.k = 10;
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE;
x.index_params.pq_dim = 10;
x.index_params.pq_bits = 8;
x.index_params.n_lists = 1024;
x.search_params.n_probes = 50;
});

return xs;
}

/* Test instantiations */

#define TEST_BUILD_SEARCH(type) \
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/spatial/ann_ivf_pq/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ namespace raft::spatial::knn {
using f32_f32_u32 = ivf_pq_test<float, float, uint32_t>;

TEST_BUILD_SEARCH(f32_f32_u32)
INSTANTIATE(f32_f32_u32, defaults() + var_n_probes() + var_k());
INSTANTIATE(f32_f32_u32, defaults() + var_n_probes() + var_k() + special_cases());

} // namespace raft::spatial::knn
32 changes: 19 additions & 13 deletions cpp/test/spatial/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
#include <raft/spatial/knn/detail/topk.cuh>
#include <raft/util/cuda_utils.cuh>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

namespace raft::spatial::knn {

Expand Down Expand Up @@ -99,13 +101,13 @@ inline auto operator<<(std::ostream& os, const print_metric& p) -> std::ostream&
}

template <typename EvalT, typename DataT, typename IdxT>
__global__ void naiveDistanceKernel(EvalT* dist,
const DataT* x,
const DataT* y,
IdxT m,
IdxT n,
IdxT k,
raft::distance::DistanceType type)
__global__ void naive_distance_kernel(EvalT* dist,
const DataT* x,
const DataT* y,
IdxT m,
IdxT n,
IdxT k,
raft::distance::DistanceType type)
{
detail::utils::mapping<EvalT> f{};
IdxT midx = threadIdx.x + blockIdx.x * blockDim.x;
Expand Down Expand Up @@ -146,23 +148,26 @@ void naiveBfKnn(EvalT* dist_topk,
size_t dim,
uint32_t k,
raft::distance::DistanceType type,
cudaStream_t stream = 0)
rmm::cuda_stream_view stream)
{
rmm::mr::device_memory_resource* mr = nullptr;
auto pool_guard = raft::get_pool_memory_resource(mr, 1024 * 1024);

dim3 block_dim(16, 32, 1);
// maximum reasonable grid size in `y` direction
uint16_t grid_y =
auto grid_y =
static_cast<uint16_t>(std::min<size_t>(raft::ceildiv<size_t>(input_len, block_dim.y), 32768));

// bound the memory used by this function
size_t max_batch_size =
std::min<size_t>(n_inputs, raft::ceildiv<size_t>(size_t(1) << size_t(27), input_len));
rmm::device_uvector<EvalT> dist(max_batch_size * input_len, stream);
rmm::device_uvector<EvalT> dist(max_batch_size * input_len, stream, mr);

for (size_t offset = 0; offset < n_inputs; offset += max_batch_size) {
size_t batch_size = std::min(max_batch_size, n_inputs - offset);
dim3 grid_dim(raft::ceildiv<size_t>(batch_size, block_dim.x), grid_y, 1);

naiveDistanceKernel<EvalT, DataT, IdxT><<<grid_dim, block_dim, 0, stream>>>(
naive_distance_kernel<EvalT, DataT, IdxT><<<grid_dim, block_dim, 0, stream>>>(
dist.data(), x + offset * dim, y, batch_size, input_len, dim, type);

detail::select_topk<EvalT, IdxT>(dist.data(),
Expand All @@ -173,7 +178,8 @@ void naiveBfKnn(EvalT* dist_topk,
dist_topk + offset * k,
indices_topk + offset * k,
type != raft::distance::DistanceType::InnerProduct,
stream);
stream,
mr);
}
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
}
Expand All @@ -183,7 +189,7 @@ struct idx_dist_pair {
IdxT idx;
DistT dist;
CompareDist eq_compare;
bool operator==(const idx_dist_pair<IdxT, DistT, CompareDist>& a) const
auto operator==(const idx_dist_pair<IdxT, DistT, CompareDist>& a) const -> bool
{
if (idx == a.idx) return true;
if (eq_compare(dist, a.dist)) return true;
Expand Down

0 comments on commit e7bf57c

Please sign in to comment.