Skip to content

Commit

Permalink
CAGRA max_queries auto configuration (#1613)
Browse files Browse the repository at this point in the history
This PR changes the behavior of max_queries in CAGRA.

In the search operation of CAGRA, a batch of queries is divided into several `max_queries` queries sub-batches and operated; by increasing `max_queries,` higher throughput can be obtained. However, if one forgets to set `max_queries,` the default value, 1, will be used, resulting in lower throughput. This is not user-friendly.


This PR modifies CAGRA to automatically set `max_queries` as the batch size by default.

An alternative way is to remove `max_queries`.

Authors:
  - tsuki (https://github.com/enp1s0)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1613
  • Loading branch information
enp1s0 authored Jul 3, 2023
1 parent e9d86f1 commit 104f3ba
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ enum class search_algo {
enum class hash_mode { HASH, SMALL, AUTO };

struct search_params : ann::search_params {
/** Maximum number of queries to search at the same time (batch size). */
size_t max_queries = 1;
/** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/
size_t max_queries = 0;

/** Number of intermediate search results retained during the search.
*
Expand Down
8 changes: 5 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ void search_main(raft::resources const& res,
static_cast<size_t>(queries.extent(0)),
static_cast<size_t>(queries.extent(1)));
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match");
uint32_t topk = neighbors.extent(1);
const uint32_t topk = neighbors.extent(1);

if (params.max_queries == 0) { params.max_queries = queries.extent(0); }

std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT>> plan =
factory<T, internal_IdxT, DistanceT>::create(
Expand All @@ -74,8 +76,8 @@ void search_main(raft::resources const& res,
plan->check(neighbors.extent(1));

RAFT_LOG_DEBUG("Cagra search");
uint32_t max_queries = plan->max_queries;
uint32_t query_dim = queries.extent(1);
const uint32_t max_queries = plan->max_queries;
const uint32_t query_dim = queries.extent(1);

for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) {
const uint32_t n_queries = std::min<std::size_t>(max_queries, queries.extent(0) - qid);
Expand Down
6 changes: 3 additions & 3 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
protected:
void testCagra()
{
if (ps.algo == search_algo::MULTI_CTA && ps.max_queries > 1) {
if (ps.algo == search_algo::MULTI_CTA && ps.max_queries != 1) {
GTEST_SKIP() << "Skipping test due to issue #1575";
}
size_t queries_size = ps.n_queries * ps.k;
Expand Down Expand Up @@ -377,9 +377,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{100},
{1000},
{1, 8, 17},
{1, 16}, // k
{1, 16}, // k
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL},
{1, 10, 100}, // query size
{0, 1, 10, 100}, // query size
{0},
{256},
{1},
Expand Down

0 comments on commit 104f3ba

Please sign in to comment.