From b87d4e0ca6964bc94f5bf893de2c92a419d4fffd Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Tue, 29 Aug 2023 20:27:42 +0900 Subject: [PATCH 1/3] Fix search papameter check in CAGRA --- .../raft/neighbors/detail/cagra/search_multi_cta.cuh | 10 ++++++++++ .../raft/neighbors/detail/cagra/search_plan.cuh | 11 ++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 3fd4fca0f3..2abc9e7de1 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -175,6 +175,16 @@ struct search : public search_plan_impl { topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res)); } + virtual void check(const uint32_t topk) + { + RAFT_EXPECTS(num_cta_per_query * 32 >= topk, + "`num_cta_per_query` (%u) * 32 must be equal to or greater than " + "`topk` (%u) when 'search_mode' is \"multi-cta\". " + "(`num_cta_per_query`=max(`search_width`, `itopk_size`/32))", + num_cta_per_query, + topk); + } + ~search() {} void operator()(raft::resources const& res, diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index bc2102b9b0..f11409e733 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -250,17 +250,10 @@ struct search_plan_impl : public search_plan_impl_base { } } - void check(uint32_t topk) + virtual void check(const uint32_t topk) { + // For multi-CTA and multi kernel RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size); - if (algo == search_algo::MULTI_CTA) { - uint32_t mc_num_cta_per_query = max(search_width, itopk_size / 32); - RAFT_EXPECTS(mc_num_cta_per_query * 32 >= topk, - "`mc_num_cta_per_query` (%u) * 32 must be equal to or greater than " - "`topk` /%u) when 'search_mode' is \"multi-cta\"", - mc_num_cta_per_query, - topk); - } } inline void check_params() From 80400deb75066eed885de48931911ad9bb52fe31 Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Wed, 30 Aug 2023 12:38:26 +0900 Subject: [PATCH 2/3] Update cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh Co-authored-by: Micka --- cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 2abc9e7de1..314ab6e6a6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -175,7 +175,7 @@ struct search : public search_plan_impl { topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res)); } - virtual void check(const uint32_t topk) + void check(const uint32_t topk) override { RAFT_EXPECTS(num_cta_per_query * 32 >= topk, "`num_cta_per_query` (%u) * 32 must be equal to or greater than " From 0368e2c7093c2b6edc8a2911066d4f611be6bd49 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Wed, 30 Aug 2023 12:39:10 +0900 Subject: [PATCH 3/3] Fix comment --- cpp/include/raft/neighbors/detail/cagra/search_plan.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index f11409e733..e6966987c8 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -252,7 +252,7 @@ struct search_plan_impl : public search_plan_impl_base { virtual void check(const uint32_t topk) { - // For multi-CTA and multi kernel + // For single-CTA and multi kernel RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size); }