diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index a88a449a68..4a384b90e1 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -55,8 +55,8 @@ enum class search_algo { enum class hash_mode { HASH, SMALL, AUTO }; struct search_params : ann::search_params { - /** Maximum number of queries to search at the same time (batch size). */ - size_t max_queries = 1; + /** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/ + size_t max_queries = 0; /** Number of intermediate search results retained during the search. * diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 7b35af4417..1561a3bb8d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -65,7 +65,9 @@ void search_main(raft::resources const& res, static_cast(queries.extent(0)), static_cast(queries.extent(1))); RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match"); - uint32_t topk = neighbors.extent(1); + const uint32_t topk = neighbors.extent(1); + + if (params.max_queries == 0) { params.max_queries = queries.extent(0); } std::unique_ptr> plan = factory::create( @@ -74,8 +76,8 @@ void search_main(raft::resources const& res, plan->check(neighbors.extent(1)); RAFT_LOG_DEBUG("Cagra search"); - uint32_t max_queries = plan->max_queries; - uint32_t query_dim = queries.extent(1); + const uint32_t max_queries = plan->max_queries; + const uint32_t query_dim = queries.extent(1); for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) { const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index d3bd5ba31d..5d78074470 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -166,7 +166,7 @@ class AnnCagraTest : public ::testing::TestWithParam { protected: void testCagra() { - if (ps.algo == search_algo::MULTI_CTA && ps.max_queries > 1) { + if (ps.algo == search_algo::MULTI_CTA && ps.max_queries != 1) { GTEST_SKIP() << "Skipping test due to issue #1575"; } size_t queries_size = ps.n_queries * ps.k; @@ -377,9 +377,9 @@ inline std::vector generate_inputs() {100}, {1000}, {1, 8, 17}, - {1, 16}, // k + {1, 16}, // k {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, - {1, 10, 100}, // query size + {0, 1, 10, 100}, // query size {0}, {256}, {1},