From ed42bb5c706c9903b51388f5b214ceea360d9b0a Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Wed, 27 Sep 2023 00:48:53 +0800 Subject: [PATCH] Change CAGRA auto mode selection (#1830) This PR changes the rule to select a CAGRA search mode as written in the [CAGRA paper](https://arxiv.org/abs/2308.15136). Authors: - tsuki (https://github.com/enp1s0) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1830 --- .../raft/neighbors/detail/cagra/search_multi_cta.cuh | 7 ++++--- cpp/include/raft/neighbors/detail/cagra/search_plan.cuh | 7 ++++--- python/pylibraft/pylibraft/test/test_cagra.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 9a722a6dfe..c6478bef84 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -109,9 +109,10 @@ struct search : public search_plan_implitopk_size = 32; - search_width = 1; - num_cta_per_query = max(params.search_width, params.itopk_size / 32); + constexpr unsigned muti_cta_itopk_size = 32; + this->itopk_size = muti_cta_itopk_size; + search_width = 1; + num_cta_per_query = max(params.search_width, params.itopk_size / muti_cta_itopk_size); result_buffer_size = itopk_size + search_width * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 9419385836..a0f346ab51 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -38,12 +38,13 @@ struct search_plan_impl_base : public search_params { { set_max_dim_team(dim); if (algo == search_algo::AUTO) { - if (itopk_size <= 512) { + const size_t num_sm = raft::getMultiProcessorCount(); + if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) { algo = search_algo::SINGLE_CTA; RAFT_LOG_DEBUG("Auto strategy: selecting single-cta"); } else { - algo = search_algo::MULTI_KERNEL; - RAFT_LOG_DEBUG("Auto strategy: selecting multi-kernel"); + algo = search_algo::MULTI_CTA; + RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta"); } } } diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index f74fc5ae62..24126c0c5a 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -251,7 +251,7 @@ def test_cagra_index_params(params): "search_width": 4, "min_iterations": 0, "thread_block_size": 0, - "hashmap_mode": "small", + "hashmap_mode": "auto", "hashmap_min_bitlen": 0, "hashmap_max_fill_rate": 0.5, "num_random_samplings": 1,