diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 8ab6b19b98..2f34febdd2 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -158,7 +158,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( assert(blockDim.x == BLOCK_SIZE); assert(dataset_dim <= MAX_DATASET_DIM); - // const auto num_queries = gridDim.y; + const auto num_queries = gridDim.y; const auto query_id = blockIdx.y; const auto num_cta_per_query = gridDim.x; const auto cta_id = blockIdx.x; // local CTA ID @@ -225,6 +225,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( // compute distance to randomly selecting nodes _CLK_START(); const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + uint32_t block_id = cta_id + (num_cta_per_query * query_id); + uint32_t num_blocks = num_cta_per_query * num_queries; device::compute_distance_to_random_nodes( result_indices_buffer, result_distances_buffer, @@ -240,8 +242,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( num_seeds, local_visited_hashmap_ptr, hash_bitlen, - cta_id, - num_cta_per_query); + block_id, + num_blocks); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -472,14 +474,14 @@ struct search : public search_plan_impl { topk_workspace(0, resource::get_cuda_stream(res)) { - set_params(res); + set_params(res, params); } - void set_params(raft::resources const& res) + void set_params(raft::resources const& res, const search_params& params) { this->itopk_size = 32; num_parents = 1; - num_cta_per_query = max(num_parents, itopk_size / 32); + num_cta_per_query = max(params.num_parents, params.itopk_size / 32); result_buffer_size = itopk_size + num_parents * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); @@ -532,8 +534,10 @@ struct search : public search_plan_impl { // Allocate memory for intermediate buffer and workspace. // uint32_t num_intermediate_results = num_cta_per_query * itopk_size; - intermediate_indices.resize(num_intermediate_results, resource::get_cuda_stream(res)); - intermediate_distances.resize(num_intermediate_results, resource::get_cuda_stream(res)); + intermediate_indices.resize(num_intermediate_results * max_queries, + resource::get_cuda_stream(res)); + intermediate_distances.resize(num_intermediate_results * max_queries, + resource::get_cuda_stream(res)); hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 83263c5a64..3e929f9f3b 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -166,9 +166,6 @@ class AnnCagraTest : public ::testing::TestWithParam { protected: void testCagra() { - if (ps.algo == search_algo::MULTI_CTA && ps.max_queries != 1) { - GTEST_SKIP() << "Skipping test due to issue #1575"; - } size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); std::vector indices_naive(queries_size);