From b1f73741e1696cfe400c62a0610102b982ddc993 Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Wed, 30 Aug 2023 23:25:59 +0900 Subject: [PATCH] [BUG] Fix search parameter check in CAGRA (#1784) An error occurs when using CAGRA multi-CTA implementation with topk>32. This PR fixes the bug. Authors: - tsuki (https://github.com/enp1s0) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Divye Gala (https://github.com/divyegala) - Micka (https://github.com/lowener) URL: https://github.com/rapidsai/raft/pull/1784 --- .../raft/neighbors/detail/cagra/search_multi_cta.cuh | 10 ++++++++++ .../raft/neighbors/detail/cagra/search_plan.cuh | 11 ++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 3fd4fca0f3..314ab6e6a6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -175,6 +175,16 @@ struct search : public search_plan_impl { topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res)); } + void check(const uint32_t topk) override + { + RAFT_EXPECTS(num_cta_per_query * 32 >= topk, + "`num_cta_per_query` (%u) * 32 must be equal to or greater than " + "`topk` (%u) when 'search_mode' is \"multi-cta\". " + "(`num_cta_per_query`=max(`search_width`, `itopk_size`/32))", + num_cta_per_query, + topk); + } + ~search() {} void operator()(raft::resources const& res, diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index bc2102b9b0..e6966987c8 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -250,17 +250,10 @@ struct search_plan_impl : public search_plan_impl_base { } } - void check(uint32_t topk) + virtual void check(const uint32_t topk) { + // For single-CTA and multi kernel RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size); - if (algo == search_algo::MULTI_CTA) { - uint32_t mc_num_cta_per_query = max(search_width, itopk_size / 32); - RAFT_EXPECTS(mc_num_cta_per_query * 32 >= topk, - "`mc_num_cta_per_query` (%u) * 32 must be equal to or greater than " - "`topk` /%u) when 'search_mode' is \"multi-cta\"", - mc_num_cta_per_query, - topk); - } } inline void check_params()