Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAGRA max_queries auto configuration #1613

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ enum class search_algo {
enum class hash_mode { HASH, SMALL, AUTO };

struct search_params : ann::search_params {
/** Maximum number of queries to search at the same time (batch size). */
size_t max_queries = 1;
/** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/
size_t max_queries = 0;

/** Number of intermediate search results retained during the search.
*
Expand Down
8 changes: 5 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ void search_main(raft::resources const& res,
static_cast<size_t>(queries.extent(0)),
static_cast<size_t>(queries.extent(1)));
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match");
uint32_t topk = neighbors.extent(1);
const uint32_t topk = neighbors.extent(1);

if (params.max_queries == 0) { params.max_queries = queries.extent(0); }

std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT>> plan =
factory<T, internal_IdxT, DistanceT>::create(
Expand All @@ -74,8 +76,8 @@ void search_main(raft::resources const& res,
plan->check(neighbors.extent(1));

RAFT_LOG_DEBUG("Cagra search");
uint32_t max_queries = plan->max_queries;
uint32_t query_dim = queries.extent(1);
const uint32_t max_queries = plan->max_queries;
const uint32_t query_dim = queries.extent(1);

for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) {
const uint32_t n_queries = std::min<std::size_t>(max_queries, queries.extent(0) - qid);
Expand Down
6 changes: 3 additions & 3 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
protected:
void testCagra()
{
if (ps.algo == search_algo::MULTI_CTA && ps.max_queries > 1) {
if (ps.algo == search_algo::MULTI_CTA && ps.max_queries != 1) {
GTEST_SKIP() << "Skipping test due to issue #1575";
}
size_t queries_size = ps.n_queries * ps.k;
Expand Down Expand Up @@ -377,9 +377,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{100},
{1000},
{1, 8, 17},
{1, 16}, // k
{1, 16}, // k
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL},
{1, 10, 100}, // query size
{0, 1, 10, 100}, // query size
{0},
{256},
{1},
Expand Down