diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 40cc7c76fb..41a43c9bce 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -131,7 +131,7 @@ void search_main(raft::resources const& res, factory::create( res, params, index.dim(), index.graph_degree(), topk); - plan->check(topk); + plan->check(neighbors.extent(1)); RAFT_LOG_DEBUG("Cagra search"); const uint32_t max_queries = plan->max_queries; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index f9bf525503..7be3fedfa2 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,6 @@ #include "topk_for_cagra/topk_core.cuh" //todo replace with raft kernel #include "utils.hpp" #include -#include #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp @@ -654,12 +653,6 @@ struct search : search_plan_impl { rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; rmm::device_uvector topk_workspace; - // temporary storage for _find_topk - rmm::device_uvector input_keys_storage; - rmm::device_uvector output_keys_storage; - rmm::device_uvector input_values_storage; - rmm::device_uvector output_values_storage; - search(raft::resources const& res, search_params params, int64_t dim, @@ -672,11 +665,7 @@ struct search : search_plan_impl { parent_node_list(0, resource::get_cuda_stream(res)), topk_hint(0, resource::get_cuda_stream(res)), topk_workspace(0, resource::get_cuda_stream(res)), - terminate_flag(resource::get_cuda_stream(res)), - input_keys_storage(0, resource::get_cuda_stream(res)), - output_keys_storage(0, resource::get_cuda_stream(res)), - input_values_storage(0, resource::get_cuda_stream(res)), - output_values_storage(0, resource::get_cuda_stream(res)) + terminate_flag(resource::get_cuda_stream(res)) { set_params(res); } @@ -706,98 +695,6 @@ struct search : search_plan_impl { ~search() {} - inline void _find_topk(raft::resources const& handle, - uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const INDEX_T* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - INDEX_T* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK - void* workspace, - bool sort, - uint32_t* hints) - { - auto stream = resource::get_cuda_stream(handle); - - // _cuann_find_topk right now is limited to a max-k of 1024. - // RAFT has a matrix::select_k function - which handles arbitrary sized values of k, - // but doesn't accept strided inputs unlike _cuann_find_topk - // The multi-kernel search path requires strided access - since its cleverly allocating memory - // (layout described in the search_plan_impl function below), such that both the - // neighbors and the internal_topk are adjacent - in a double buffered format. - // Since this layout doesn't work with the matrix::select_k code - we have to copy - // over to a contiguous (non-strided) access to handle topk larger than 1024, and - // potentially also copy back to a strided layout afterwards - if (topK <= 1024) { - return _cuann_find_topk(topK, - sizeBatch, - numElements, - inputKeys, - ldIK, - inputVals, - ldIV, - outputKeys, - ldOK, - outputVals, - ldOV, - workspace, - sort, - hints, - stream); - } - - if (ldIK > numElements) { - if (input_keys_storage.size() != sizeBatch * numElements) { - input_keys_storage.resize(sizeBatch * numElements, stream); - } - batched_memcpy( - input_keys_storage.data(), numElements, inputKeys, ldIK, numElements, sizeBatch, stream); - inputKeys = input_keys_storage.data(); - } - - if (ldIV > numElements) { - if (input_values_storage.size() != sizeBatch * numElements) { - input_values_storage.resize(sizeBatch * numElements, stream); - } - - batched_memcpy( - input_values_storage.data(), numElements, inputVals, ldIV, numElements, sizeBatch, stream); - inputVals = input_values_storage.data(); - } - - if ((ldOK > topK) && (output_keys_storage.size() != sizeBatch * topK)) { - output_keys_storage.resize(sizeBatch * topK, stream); - } - - if ((ldOV > topK) && (output_values_storage.size() != sizeBatch * topK)) { - output_values_storage.resize(sizeBatch * topK, stream); - } - - raft::matrix::select_k( - handle, - raft::make_device_matrix_view(inputKeys, sizeBatch, numElements), - raft::make_device_matrix_view(inputVals, sizeBatch, numElements), - raft::make_device_matrix_view( - ldOK > topK ? output_keys_storage.data() : outputKeys, sizeBatch, topK), - raft::make_device_matrix_view( - ldOV > topK ? output_values_storage.data() : outputVals, sizeBatch, topK), - true, // select_min - sort); - - if (ldOK > topK) { - batched_memcpy(outputKeys, ldOK, output_keys_storage.data(), topK, topK, sizeBatch, stream); - } - - if (ldOV > topK) { - batched_memcpy(outputVals, ldOV, output_values_storage.data(), topK, topK, sizeBatch, stream); - } - } - void operator()(raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, @@ -849,21 +746,21 @@ struct search : search_plan_impl { unsigned iter = 0; while (1) { // Make an index list of internal top-k nodes - _find_topk(res, - itopk_size, - num_queries, - result_buffer_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - topk_workspace.data(), - true, - top_hint_ptr); + _cuann_find_topk(itopk_size, + num_queries, + result_buffer_size, + result_distances.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_indices.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_buffer_allocation_size, + result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_buffer_allocation_size, + topk_workspace.data(), + true, + top_hint_ptr, + stream); // termination (1) if ((iter + 1 == max_iterations)) { @@ -944,21 +841,21 @@ struct search : search_plan_impl { result_indices_ptr = result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size; result_distances_ptr = result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size; - _find_topk(res, - itopk_size, - num_queries, - result_buffer_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_distances_ptr, - result_buffer_allocation_size, - result_indices_ptr, - result_buffer_allocation_size, - topk_workspace.data(), - true, - top_hint_ptr); + _cuann_find_topk(itopk_size, + num_queries, + result_buffer_size, + result_distances.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_indices.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_distances_ptr, + result_buffer_allocation_size, + result_indices_ptr, + result_buffer_allocation_size, + topk_workspace.data(), + true, + top_hint_ptr, + stream); } else { // Remove parent bit in search results remove_parent_bit( diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 271a1f4955..20df2adf61 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -42,12 +42,9 @@ struct search_plan_impl_base : public search_params { if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) { algo = search_algo::SINGLE_CTA; RAFT_LOG_DEBUG("Auto strategy: selecting single-cta"); - } else if (topk <= 1024) { + } else { algo = search_algo::MULTI_CTA; RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta"); - } else { - algo = search_algo::MULTI_KERNEL; - RAFT_LOG_DEBUG("Auto strategy: selecting multi kernel"); } } } @@ -258,8 +255,7 @@ struct search_plan_impl : public search_plan_impl_base { virtual void check(const uint32_t topk) { // For single-CTA and multi kernel - RAFT_EXPECTS( - topk <= itopk_size, "topk = %u must be smaller than itopk_size = %lu", topk, itopk_size); + RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size); } inline void check_params() @@ -267,7 +263,7 @@ struct search_plan_impl : public search_plan_impl_base { std::string error_message = ""; if (itopk_size > 1024) { - if ((algo == search_algo::MULTI_CTA) || (algo == search_algo::MULTI_KERNEL)) { + if (algo == search_algo::MULTI_CTA) { } else { error_message += std::string("- `internal_topk` (" + std::to_string(itopk_size) + ") must be smaller or equal to 1024"); diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index ef4f27ae64..915ef8a394 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -259,7 +259,6 @@ class AnnCagraTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; - search_params.itopk_size = ps.itopk_size; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); @@ -497,7 +496,6 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; - search_params.itopk_size = ps.itopk_size; search_params.hashmap_mode = cagra::hash_mode::HASH; auto database_view = raft::make_device_matrix_view( @@ -613,7 +611,6 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; - search_params.itopk_size = ps.itopk_size; search_params.hashmap_mode = cagra::hash_mode::HASH; auto database_view = raft::make_device_matrix_view( @@ -821,23 +818,6 @@ inline std::vector generate_inputs() {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {20000}, - {32}, - {2048}, // k - {graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, - {4096}, // itopk_size - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false}, - {false}, - {0.995}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - return inputs; }