Skip to content

Commit

Permalink
[BUG] Fix search parameter check in CAGRA (#1784)
Browse files Browse the repository at this point in the history
An error occurs when using CAGRA multi-CTA implementation with topk>32. This PR fixes the bug.

Authors:
  - tsuki (https://github.com/enp1s0)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Divye Gala (https://github.com/divyegala)
  - Micka (https://github.com/lowener)

URL: #1784
  • Loading branch information
enp1s0 authored Aug 30, 2023
1 parent f6d35ae commit b1f7374
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
10 changes: 10 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res));
}

void check(const uint32_t topk) override
{
RAFT_EXPECTS(num_cta_per_query * 32 >= topk,
"`num_cta_per_query` (%u) * 32 must be equal to or greater than "
"`topk` (%u) when 'search_mode' is \"multi-cta\". "
"(`num_cta_per_query`=max(`search_width`, `itopk_size`/32))",
num_cta_per_query,
topk);
}

~search() {}

void operator()(raft::resources const& res,
Expand Down
11 changes: 2 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,10 @@ struct search_plan_impl : public search_plan_impl_base {
}
}

void check(uint32_t topk)
virtual void check(const uint32_t topk)
{
// For single-CTA and multi kernel
RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size);
if (algo == search_algo::MULTI_CTA) {
uint32_t mc_num_cta_per_query = max(search_width, itopk_size / 32);
RAFT_EXPECTS(mc_num_cta_per_query * 32 >= topk,
"`mc_num_cta_per_query` (%u) * 32 must be equal to or greater than "
"`topk` /%u) when 'search_mode' is \"multi-cta\"",
mc_num_cta_per_query,
topk);
}
}

inline void check_params()
Expand Down

0 comments on commit b1f7374

Please sign in to comment.