Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change CAGRA auto mode selection #1830

Merged
merged 8 commits into from
Sep 26, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILT

void set_params(raft::resources const& res, const search_params& params)
{
this->itopk_size = 32;
search_width = 1;
num_cta_per_query = max(params.search_width, params.itopk_size / 32);
constexpr unsigned muti_cta_itopk_size = 32;
this->itopk_size = muti_cta_itopk_size;
search_width = 1;
num_cta_per_query = max(params.search_width, params.itopk_size / muti_cta_itopk_size);
result_buffer_size = itopk_size + search_width * graph_degree;
typedef raft::Pow2<32> AlignBytes;
unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size);
Expand Down
7 changes: 4 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ struct search_plan_impl_base : public search_params {
{
set_max_dim_team(dim);
if (algo == search_algo::AUTO) {
if (itopk_size <= 512) {
const size_t num_sm = raft::getMultiProcessorCount();
if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) {
algo = search_algo::SINGLE_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting single-cta");
} else {
algo = search_algo::MULTI_KERNEL;
RAFT_LOG_DEBUG("Auto strategy: selecting multi-kernel");
algo = search_algo::MULTI_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta");
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/test/test_cagra.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_cagra_index_params(params):
"search_width": 4,
"min_iterations": 0,
"thread_block_size": 0,
"hashmap_mode": "small",
"hashmap_mode": "auto",
"hashmap_min_bitlen": 0,
"hashmap_max_fill_rate": 0.5,
"num_random_samplings": 1,
Expand Down