From 104f3ba77a18b04f6d6003906151a53f560ce441 Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Tue, 4 Jul 2023 02:20:36 +0900 Subject: [PATCH] CAGRA max_queries auto configuration (#1613) This PR changes the behavior of max_queries in CAGRA. In the search operation of CAGRA, a batch of queries is divided into several `max_queries` queries sub-batches and operated; by increasing `max_queries,` higher throughput can be obtained. However, if one forgets to set `max_queries,` the default value, 1, will be used, resulting in lower throughput. This is not user-friendly. This PR modifies CAGRA to automatically set `max_queries` as the batch size by default. An alternative way is to remove `max_queries`. Authors: - tsuki (https://github.com/enp1s0) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1613 --- cpp/include/raft/neighbors/cagra_types.hpp | 4 ++-- cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh | 8 +++++--- cpp/test/neighbors/ann_cagra.cuh | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index a88a449a68..4a384b90e1 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -55,8 +55,8 @@ enum class search_algo { enum class hash_mode { HASH, SMALL, AUTO }; struct search_params : ann::search_params { - /** Maximum number of queries to search at the same time (batch size). */ - size_t max_queries = 1; + /** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/ + size_t max_queries = 0; /** Number of intermediate search results retained during the search. * diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 7b35af4417..1561a3bb8d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -65,7 +65,9 @@ void search_main(raft::resources const& res, static_cast(queries.extent(0)), static_cast(queries.extent(1))); RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match"); - uint32_t topk = neighbors.extent(1); + const uint32_t topk = neighbors.extent(1); + + if (params.max_queries == 0) { params.max_queries = queries.extent(0); } std::unique_ptr> plan = factory::create( @@ -74,8 +76,8 @@ void search_main(raft::resources const& res, plan->check(neighbors.extent(1)); RAFT_LOG_DEBUG("Cagra search"); - uint32_t max_queries = plan->max_queries; - uint32_t query_dim = queries.extent(1); + const uint32_t max_queries = plan->max_queries; + const uint32_t query_dim = queries.extent(1); for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) { const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index d3bd5ba31d..5d78074470 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -166,7 +166,7 @@ class AnnCagraTest : public ::testing::TestWithParam { protected: void testCagra() { - if (ps.algo == search_algo::MULTI_CTA && ps.max_queries > 1) { + if (ps.algo == search_algo::MULTI_CTA && ps.max_queries != 1) { GTEST_SKIP() << "Skipping test due to issue #1575"; } size_t queries_size = ps.n_queries * ps.k; @@ -377,9 +377,9 @@ inline std::vector generate_inputs() {100}, {1000}, {1, 8, 17}, - {1, 16}, // k + {1, 16}, // k {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, - {1, 10, 100}, // query size + {0, 1, 10, 100}, // query size {0}, {256}, {1},