From bf20fd31d725844319a169f14474690ff20721f8 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 16 Jan 2024 10:40:08 -0800 Subject: [PATCH 1/3] Allow topk larger than 1024 in CAGRA This change allows CAGRA search to have an arbitrarily large top-k, instead of being limited to 1024 like in the previous code. This works by using the multi-kernel search path, and replacing the _cuann_find_topk code with the matrix::select_k code - which can handle large K values. --- .../neighbors/detail/cagra/cagra_search.cuh | 2 +- .../detail/cagra/search_multi_kernel.cuh | 152 ++++++++++++++---- .../neighbors/detail/cagra/search_plan.cuh | 10 +- cpp/test/neighbors/ann_cagra.cuh | 20 ++- 4 files changed, 148 insertions(+), 36 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 41a43c9bce..40cc7c76fb 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -131,7 +131,7 @@ void search_main(raft::resources const& res, factory::create( res, params, index.dim(), index.graph_degree(), topk); - plan->check(neighbors.extent(1)); + plan->check(topk); RAFT_LOG_DEBUG("Cagra search"); const uint32_t max_queries = plan->max_queries; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index 7be3fedfa2..a8405126d9 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ #include "topk_for_cagra/topk_core.cuh" //todo replace with raft kernel #include "utils.hpp" #include +#include #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp @@ -597,6 +598,95 @@ void set_value_batch(T* const dev_ptr, <<>>(dev_ptr, ld, val, count, batch_size); } +template +inline void _find_topk(raft::resources const& handle, + uint32_t topK, + uint32_t sizeBatch, + uint32_t numElements, + const float* inputKeys, // [sizeBatch, ldIK,] + uint32_t ldIK, // (*) ldIK >= numElements + const ValT* inputVals, // [sizeBatch, ldIV,] + uint32_t ldIV, // (*) ldIV >= numElements + float* outputKeys, // [sizeBatch, ldOK,] + uint32_t ldOK, // (*) ldOK >= topK + ValT* outputVals, // [sizeBatch, ldOV,] + uint32_t ldOV, // (*) ldOV >= topK + void* workspace, + bool sort, + uint32_t* hints) +{ + auto stream = resource::get_cuda_stream(handle); + + // _cuann_find_topk right now is limited to a max-k of 1024. + // RAFT has a matrix::select_k function - which handles arbitrary sized values of k, + // but doesn't accept strided inputs unlike _cuann_find_topk + // The multi-kernel search path requires strided access - since its cleverly allocating memory + // (layout described in the search_plan_impl function below), such that both the + // neighbors and the internal_topk are adjacent - in a double buffered format. + // Since this layout doesn't work with the matrix::select_k code - we have to copy + // over to a contiguous (non-strided) access to handle topk larger than 1024, and + // potentially also copy back to a strided layout afterwards + if (topK <= 1024) { + return _cuann_find_topk(topK, + sizeBatch, + numElements, + inputKeys, + ldIK, + inputVals, + ldIV, + outputKeys, + ldOK, + outputVals, + ldOV, + workspace, + sort, + hints, + stream); + } + + rmm::device_uvector input_keys_storage(0, stream); + rmm::device_uvector output_keys_storage(0, stream); + rmm::device_uvector input_values_storage(0, stream); + rmm::device_uvector output_values_storage(0, stream); + + if (ldIK > numElements) { + input_keys_storage.resize(sizeBatch * numElements, stream); + batched_memcpy( + input_keys_storage.data(), numElements, inputKeys, ldIK, numElements, sizeBatch, stream); + inputKeys = input_keys_storage.data(); + } + + if (ldIV > numElements) { + input_values_storage.resize(sizeBatch * numElements, stream); + batched_memcpy( + input_values_storage.data(), numElements, inputVals, ldIV, numElements, sizeBatch, stream); + inputVals = input_values_storage.data(); + } + + if (ldOK > topK) { output_keys_storage.resize(sizeBatch * topK, stream); } + + if (ldOV > topK) { output_values_storage.resize(sizeBatch * topK, stream); } + + raft::matrix::select_k( + handle, + raft::make_device_matrix_view(inputKeys, sizeBatch, numElements), + raft::make_device_matrix_view(inputVals, sizeBatch, numElements), + raft::make_device_matrix_view( + ldOK > topK ? output_keys_storage.data() : outputKeys, sizeBatch, topK), + raft::make_device_matrix_view( + ldOV > topK ? output_values_storage.data() : outputVals, sizeBatch, topK), + true, // select_min + sort); + + if (ldOK > topK) { + batched_memcpy(outputKeys, ldOK, output_keys_storage.data(), topK, topK, sizeBatch, stream); + } + + if (ldOV > topK) { + batched_memcpy(outputVals, ldOV, output_values_storage.data(), topK, topK, sizeBatch, stream); + } +} + // result_buffer (work buffer) for "multi-kernel" // +--------------------+------------------------------+-------------------+ // | internal_top_k (A) | neighbors of internal_top_k | internal_topk (B) | @@ -746,21 +836,21 @@ struct search : search_plan_impl { unsigned iter = 0; while (1) { // Make an index list of internal top-k nodes - _cuann_find_topk(itopk_size, - num_queries, - result_buffer_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - topk_workspace.data(), - true, - top_hint_ptr, - stream); + _find_topk(res, + itopk_size, + num_queries, + result_buffer_size, + result_distances.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_indices.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_buffer_allocation_size, + result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_buffer_allocation_size, + topk_workspace.data(), + true, + top_hint_ptr); // termination (1) if ((iter + 1 == max_iterations)) { @@ -841,21 +931,21 @@ struct search : search_plan_impl { result_indices_ptr = result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size; result_distances_ptr = result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size; - _cuann_find_topk(itopk_size, - num_queries, - result_buffer_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_distances_ptr, - result_buffer_allocation_size, - result_indices_ptr, - result_buffer_allocation_size, - topk_workspace.data(), - true, - top_hint_ptr, - stream); + _find_topk(res, + itopk_size, + num_queries, + result_buffer_size, + result_distances.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_indices.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_distances_ptr, + result_buffer_allocation_size, + result_indices_ptr, + result_buffer_allocation_size, + topk_workspace.data(), + true, + top_hint_ptr); } else { // Remove parent bit in search results remove_parent_bit( diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index f2f51617f4..ab420261e4 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -42,9 +42,12 @@ struct search_plan_impl_base : public search_params { if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) { algo = search_algo::SINGLE_CTA; RAFT_LOG_DEBUG("Auto strategy: selecting single-cta"); - } else { + } else if (topk <= 1024) { algo = search_algo::MULTI_CTA; RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta"); + } else { + algo = search_algo::MULTI_KERNEL; + RAFT_LOG_DEBUG("Auto strategy: selecting multi kernel"); } } } @@ -255,7 +258,8 @@ struct search_plan_impl : public search_plan_impl_base { virtual void check(const uint32_t topk) { // For single-CTA and multi kernel - RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size); + RAFT_EXPECTS( + topk <= itopk_size, "topk = %u must be smaller than itopk_size = %lu", topk, itopk_size); } inline void check_params() @@ -263,7 +267,7 @@ struct search_plan_impl : public search_plan_impl_base { std::string error_message = ""; if (itopk_size > 1024) { - if (algo == search_algo::MULTI_CTA) { + if ((algo == search_algo::MULTI_CTA) || (algo == search_algo::MULTI_KERNEL)) { } else { error_message += std::string("- `internal_topk` (" + std::to_string(itopk_size) + ") must be smaller or equal to 1024"); diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index a9790b07b5..25793972cb 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -230,6 +230,7 @@ class AnnCagraTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; + search_params.itopk_size = ps.itopk_size; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); @@ -789,6 +790,23 @@ inline std::vector generate_inputs() {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + inputs2 = + raft::util::itertools::product({100}, + {20000}, + {32}, + {2048}, // k + {graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0}, + {4096}, // itopk_size + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {false}, + {0.995}); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + return inputs; } From fea490a4cbe674443e81a3d6df3a9a8b3e9ec31a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 16 Jan 2024 16:51:35 -0800 Subject: [PATCH 2/3] Fix cagra filter tests --- cpp/test/neighbors/ann_cagra.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 25793972cb..ff22736670 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -468,6 +468,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; + search_params.itopk_size = ps.itopk_size; search_params.hashmap_mode = cagra::hash_mode::HASH; auto database_view = raft::make_device_matrix_view( @@ -583,6 +584,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; + search_params.itopk_size = ps.itopk_size; search_params.hashmap_mode = cagra::hash_mode::HASH; auto database_view = raft::make_device_matrix_view( From ca5478fbb0d74ac9162f49c382e69989102b7ccf Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 22 Jan 2024 15:20:40 -0800 Subject: [PATCH 3/3] updates from codereview --- .../detail/cagra/search_multi_kernel.cuh | 193 ++++++++++-------- 1 file changed, 103 insertions(+), 90 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index a8405126d9..f9bf525503 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -598,95 +598,6 @@ void set_value_batch(T* const dev_ptr, <<>>(dev_ptr, ld, val, count, batch_size); } -template -inline void _find_topk(raft::resources const& handle, - uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const ValT* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - ValT* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK - void* workspace, - bool sort, - uint32_t* hints) -{ - auto stream = resource::get_cuda_stream(handle); - - // _cuann_find_topk right now is limited to a max-k of 1024. - // RAFT has a matrix::select_k function - which handles arbitrary sized values of k, - // but doesn't accept strided inputs unlike _cuann_find_topk - // The multi-kernel search path requires strided access - since its cleverly allocating memory - // (layout described in the search_plan_impl function below), such that both the - // neighbors and the internal_topk are adjacent - in a double buffered format. - // Since this layout doesn't work with the matrix::select_k code - we have to copy - // over to a contiguous (non-strided) access to handle topk larger than 1024, and - // potentially also copy back to a strided layout afterwards - if (topK <= 1024) { - return _cuann_find_topk(topK, - sizeBatch, - numElements, - inputKeys, - ldIK, - inputVals, - ldIV, - outputKeys, - ldOK, - outputVals, - ldOV, - workspace, - sort, - hints, - stream); - } - - rmm::device_uvector input_keys_storage(0, stream); - rmm::device_uvector output_keys_storage(0, stream); - rmm::device_uvector input_values_storage(0, stream); - rmm::device_uvector output_values_storage(0, stream); - - if (ldIK > numElements) { - input_keys_storage.resize(sizeBatch * numElements, stream); - batched_memcpy( - input_keys_storage.data(), numElements, inputKeys, ldIK, numElements, sizeBatch, stream); - inputKeys = input_keys_storage.data(); - } - - if (ldIV > numElements) { - input_values_storage.resize(sizeBatch * numElements, stream); - batched_memcpy( - input_values_storage.data(), numElements, inputVals, ldIV, numElements, sizeBatch, stream); - inputVals = input_values_storage.data(); - } - - if (ldOK > topK) { output_keys_storage.resize(sizeBatch * topK, stream); } - - if (ldOV > topK) { output_values_storage.resize(sizeBatch * topK, stream); } - - raft::matrix::select_k( - handle, - raft::make_device_matrix_view(inputKeys, sizeBatch, numElements), - raft::make_device_matrix_view(inputVals, sizeBatch, numElements), - raft::make_device_matrix_view( - ldOK > topK ? output_keys_storage.data() : outputKeys, sizeBatch, topK), - raft::make_device_matrix_view( - ldOV > topK ? output_values_storage.data() : outputVals, sizeBatch, topK), - true, // select_min - sort); - - if (ldOK > topK) { - batched_memcpy(outputKeys, ldOK, output_keys_storage.data(), topK, topK, sizeBatch, stream); - } - - if (ldOV > topK) { - batched_memcpy(outputVals, ldOV, output_values_storage.data(), topK, topK, sizeBatch, stream); - } -} - // result_buffer (work buffer) for "multi-kernel" // +--------------------+------------------------------+-------------------+ // | internal_top_k (A) | neighbors of internal_top_k | internal_topk (B) | @@ -743,6 +654,12 @@ struct search : search_plan_impl { rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; rmm::device_uvector topk_workspace; + // temporary storage for _find_topk + rmm::device_uvector input_keys_storage; + rmm::device_uvector output_keys_storage; + rmm::device_uvector input_values_storage; + rmm::device_uvector output_values_storage; + search(raft::resources const& res, search_params params, int64_t dim, @@ -755,7 +672,11 @@ struct search : search_plan_impl { parent_node_list(0, resource::get_cuda_stream(res)), topk_hint(0, resource::get_cuda_stream(res)), topk_workspace(0, resource::get_cuda_stream(res)), - terminate_flag(resource::get_cuda_stream(res)) + terminate_flag(resource::get_cuda_stream(res)), + input_keys_storage(0, resource::get_cuda_stream(res)), + output_keys_storage(0, resource::get_cuda_stream(res)), + input_values_storage(0, resource::get_cuda_stream(res)), + output_values_storage(0, resource::get_cuda_stream(res)) { set_params(res); } @@ -785,6 +706,98 @@ struct search : search_plan_impl { ~search() {} + inline void _find_topk(raft::resources const& handle, + uint32_t topK, + uint32_t sizeBatch, + uint32_t numElements, + const float* inputKeys, // [sizeBatch, ldIK,] + uint32_t ldIK, // (*) ldIK >= numElements + const INDEX_T* inputVals, // [sizeBatch, ldIV,] + uint32_t ldIV, // (*) ldIV >= numElements + float* outputKeys, // [sizeBatch, ldOK,] + uint32_t ldOK, // (*) ldOK >= topK + INDEX_T* outputVals, // [sizeBatch, ldOV,] + uint32_t ldOV, // (*) ldOV >= topK + void* workspace, + bool sort, + uint32_t* hints) + { + auto stream = resource::get_cuda_stream(handle); + + // _cuann_find_topk right now is limited to a max-k of 1024. + // RAFT has a matrix::select_k function - which handles arbitrary sized values of k, + // but doesn't accept strided inputs unlike _cuann_find_topk + // The multi-kernel search path requires strided access - since its cleverly allocating memory + // (layout described in the search_plan_impl function below), such that both the + // neighbors and the internal_topk are adjacent - in a double buffered format. + // Since this layout doesn't work with the matrix::select_k code - we have to copy + // over to a contiguous (non-strided) access to handle topk larger than 1024, and + // potentially also copy back to a strided layout afterwards + if (topK <= 1024) { + return _cuann_find_topk(topK, + sizeBatch, + numElements, + inputKeys, + ldIK, + inputVals, + ldIV, + outputKeys, + ldOK, + outputVals, + ldOV, + workspace, + sort, + hints, + stream); + } + + if (ldIK > numElements) { + if (input_keys_storage.size() != sizeBatch * numElements) { + input_keys_storage.resize(sizeBatch * numElements, stream); + } + batched_memcpy( + input_keys_storage.data(), numElements, inputKeys, ldIK, numElements, sizeBatch, stream); + inputKeys = input_keys_storage.data(); + } + + if (ldIV > numElements) { + if (input_values_storage.size() != sizeBatch * numElements) { + input_values_storage.resize(sizeBatch * numElements, stream); + } + + batched_memcpy( + input_values_storage.data(), numElements, inputVals, ldIV, numElements, sizeBatch, stream); + inputVals = input_values_storage.data(); + } + + if ((ldOK > topK) && (output_keys_storage.size() != sizeBatch * topK)) { + output_keys_storage.resize(sizeBatch * topK, stream); + } + + if ((ldOV > topK) && (output_values_storage.size() != sizeBatch * topK)) { + output_values_storage.resize(sizeBatch * topK, stream); + } + + raft::matrix::select_k( + handle, + raft::make_device_matrix_view(inputKeys, sizeBatch, numElements), + raft::make_device_matrix_view(inputVals, sizeBatch, numElements), + raft::make_device_matrix_view( + ldOK > topK ? output_keys_storage.data() : outputKeys, sizeBatch, topK), + raft::make_device_matrix_view( + ldOV > topK ? output_values_storage.data() : outputVals, sizeBatch, topK), + true, // select_min + sort); + + if (ldOK > topK) { + batched_memcpy(outputKeys, ldOK, output_keys_storage.data(), topK, topK, sizeBatch, stream); + } + + if (ldOV > topK) { + batched_memcpy(outputVals, ldOV, output_values_storage.data(), topK, topK, sizeBatch, stream); + } + } + void operator()(raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph,