Skip to content

Commit

Permalink
Fix cagra multi CTA bug (#1628)
Browse files Browse the repository at this point in the history
rel: #1575

The main causes of the bug are the wrong memory allocation size and the wrong number of threads to be launched.

Authors:
  - tsuki (https://github.com/enp1s0)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1628
  • Loading branch information
enp1s0 authored Jul 4, 2023
1 parent 744881e commit cc1d1ed
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
20 changes: 12 additions & 8 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
assert(blockDim.x == BLOCK_SIZE);
assert(dataset_dim <= MAX_DATASET_DIM);

// const auto num_queries = gridDim.y;
const auto num_queries = gridDim.y;
const auto query_id = blockIdx.y;
const auto num_cta_per_query = gridDim.x;
const auto cta_id = blockIdx.x; // local CTA ID
Expand Down Expand Up @@ -225,6 +225,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
// compute distance to randomly selecting nodes
_CLK_START();
const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr;
uint32_t block_id = cta_id + (num_cta_per_query * query_id);
uint32_t num_blocks = num_cta_per_query * num_queries;
device::compute_distance_to_random_nodes<TEAM_SIZE, MAX_DATASET_DIM, LOAD_T>(
result_indices_buffer,
result_distances_buffer,
Expand All @@ -240,8 +242,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
num_seeds,
local_visited_hashmap_ptr,
hash_bitlen,
cta_id,
num_cta_per_query);
block_id,
num_blocks);
__syncthreads();
_CLK_REC(clk_compute_1st_distance);

Expand Down Expand Up @@ -472,14 +474,14 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
topk_workspace(0, resource::get_cuda_stream(res))

{
set_params(res);
set_params(res, params);
}

void set_params(raft::resources const& res)
void set_params(raft::resources const& res, const search_params& params)
{
this->itopk_size = 32;
num_parents = 1;
num_cta_per_query = max(num_parents, itopk_size / 32);
num_cta_per_query = max(params.num_parents, params.itopk_size / 32);
result_buffer_size = itopk_size + num_parents * graph_degree;
typedef raft::Pow2<32> AlignBytes;
unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size);
Expand Down Expand Up @@ -532,8 +534,10 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
// Allocate memory for intermediate buffer and workspace.
//
uint32_t num_intermediate_results = num_cta_per_query * itopk_size;
intermediate_indices.resize(num_intermediate_results, resource::get_cuda_stream(res));
intermediate_distances.resize(num_intermediate_results, resource::get_cuda_stream(res));
intermediate_indices.resize(num_intermediate_results * max_queries,
resource::get_cuda_stream(res));
intermediate_distances.resize(num_intermediate_results * max_queries,
resource::get_cuda_stream(res));

hashmap.resize(hashmap_size, resource::get_cuda_stream(res));

Expand Down
3 changes: 0 additions & 3 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
protected:
void testCagra()
{
if (ps.algo == search_algo::MULTI_CTA && ps.max_queries != 1) {
GTEST_SKIP() << "Skipping test due to issue #1575";
}
size_t queries_size = ps.n_queries * ps.k;
std::vector<IdxT> indices_Cagra(queries_size);
std::vector<IdxT> indices_naive(queries_size);
Expand Down

0 comments on commit cc1d1ed

Please sign in to comment.