diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 9a722a6dfe..c6478bef84 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -109,9 +109,10 @@ struct search : public search_plan_implitopk_size = 32; - search_width = 1; - num_cta_per_query = max(params.search_width, params.itopk_size / 32); + constexpr unsigned muti_cta_itopk_size = 32; + this->itopk_size = muti_cta_itopk_size; + search_width = 1; + num_cta_per_query = max(params.search_width, params.itopk_size / muti_cta_itopk_size); result_buffer_size = itopk_size + search_width * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 9419385836..a0f346ab51 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -38,12 +38,13 @@ struct search_plan_impl_base : public search_params { { set_max_dim_team(dim); if (algo == search_algo::AUTO) { - if (itopk_size <= 512) { + const size_t num_sm = raft::getMultiProcessorCount(); + if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) { algo = search_algo::SINGLE_CTA; RAFT_LOG_DEBUG("Auto strategy: selecting single-cta"); } else { - algo = search_algo::MULTI_KERNEL; - RAFT_LOG_DEBUG("Auto strategy: selecting multi-kernel"); + algo = search_algo::MULTI_CTA; + RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta"); } } } diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index 74e9f53b91..ae813b5c7b 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -241,7 +241,7 @@ def test_cagra_index_params(params): "search_width": 4, "min_iterations": 0, "thread_block_size": 0, - "hashmap_mode": "small", + "hashmap_mode": "auto", "hashmap_min_bitlen": 0, "hashmap_max_fill_rate": 0.5, "num_random_samplings": 1,