From d89ab1be41fc4b1dc67263c74a424dda60ecfdeb Mon Sep 17 00:00:00 2001 From: Micka Date: Tue, 23 Jan 2024 20:55:34 +0100 Subject: [PATCH] [BUG] Fix `num_cta_per_query` div (#2107) This bug happened when trying to run a CAGRA index with `itopk=100` and `topk=100`. The `num_cta_per_query` variable was equal to 3 because 100 / 32 = 3.125 instead of ceildiv(100, 32) = 4. This resulted in the following error: ``` RuntimeError: RAFT failure at file=/opt/conda/conda-bld/work/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh line=183: `num_cta_per_query` (3) * 32 must be equal to or greater than `topk` (100) when 'search_mode' is "multi-cta". (`num_cta_per_query`=max(`search_width`, `itopk_size`/32)) ``` Authors: - Micka (https://github.com/lowener) Approvers: - tsuki (https://github.com/enp1s0) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2107 --- .../raft/neighbors/detail/cagra/search_multi_cta.cuh | 7 ++++--- cpp/include/raft/neighbors/detail/cagra/search_plan.cuh | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 4990d896ce..010b0a6f80 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -111,7 +111,8 @@ struct search : public search_plan_implitopk_size = muti_cta_itopk_size; search_width = 1; - num_cta_per_query = max(params.search_width, params.itopk_size / muti_cta_itopk_size); + num_cta_per_query = + max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)muti_cta_itopk_size)); result_buffer_size = itopk_size + search_width * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); @@ -184,7 +185,7 @@ struct search : public search_plan_impl= topk, "`num_cta_per_query` (%u) * 32 must be equal to or greater than " "`topk` (%u) when 'search_mode' is \"multi-cta\". " - "(`num_cta_per_query`=max(`search_width`, `itopk_size`/32))", + "(`num_cta_per_query`=max(`search_width`, ceildiv(`itopk_size`, 32)))", num_cta_per_query, topk); } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index f2f51617f4..20df2adf61 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -154,7 +154,7 @@ struct search_plan_impl : public search_plan_impl_base { if (algo == search_algo::MULTI_CTA) { mc_itopk_size = 32; mc_search_width = 1; - mc_num_cta_per_query = max(search_width, itopk_size / 32); + mc_num_cta_per_query = max(search_width, raft::ceildiv(itopk_size, (size_t)32)); RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); RAFT_LOG_DEBUG("# mc_search_width: %u", mc_search_width); RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query);