diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 4990d896ce..010b0a6f80 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -111,7 +111,8 @@ struct search : public search_plan_implitopk_size = muti_cta_itopk_size; search_width = 1; - num_cta_per_query = max(params.search_width, params.itopk_size / muti_cta_itopk_size); + num_cta_per_query = + max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)muti_cta_itopk_size)); result_buffer_size = itopk_size + search_width * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); @@ -184,7 +185,7 @@ struct search : public search_plan_impl= topk, "`num_cta_per_query` (%u) * 32 must be equal to or greater than " "`topk` (%u) when 'search_mode' is \"multi-cta\". " - "(`num_cta_per_query`=max(`search_width`, `itopk_size`/32))", + "(`num_cta_per_query`=max(`search_width`, ceildiv(`itopk_size`, 32)))", num_cta_per_query, topk); } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index f2f51617f4..20df2adf61 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -154,7 +154,7 @@ struct search_plan_impl : public search_plan_impl_base { if (algo == search_algo::MULTI_CTA) { mc_itopk_size = 32; mc_search_width = 1; - mc_num_cta_per_query = max(search_width, itopk_size / 32); + mc_num_cta_per_query = max(search_width, raft::ceildiv(itopk_size, (size_t)32)); RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); RAFT_LOG_DEBUG("# mc_search_width: %u", mc_search_width); RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query);