Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow topk larger than 1024 in CAGRA #2097

Merged
merged 8 commits into from
Jan 23, 2024
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ void search_main(raft::resources const& res,
factory<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>::create(
res, params, index.dim(), index.graph_degree(), topk);

plan->check(neighbors.extent(1));
plan->check(topk);

RAFT_LOG_DEBUG("Cagra search");
const uint32_t max_queries = plan->max_queries;
Expand Down
167 changes: 135 additions & 32 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,6 +37,7 @@
#include "topk_for_cagra/topk_core.cuh" //todo replace with raft kernel
#include "utils.hpp"
#include <raft/core/logger.hpp>
#include <raft/matrix/select_k.cuh>
#include <raft/util/cuda_rt_essentials.hpp>
#include <raft/util/cudart_utils.hpp> // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp

Expand Down Expand Up @@ -653,6 +654,12 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
rmm::device_scalar<uint32_t> terminate_flag; // dev_terminate_flag, host_terminate_flag.;
rmm::device_uvector<uint32_t> topk_workspace;

// temporary storage for _find_topk
rmm::device_uvector<float> input_keys_storage;
rmm::device_uvector<float> output_keys_storage;
rmm::device_uvector<INDEX_T> input_values_storage;
rmm::device_uvector<INDEX_T> output_values_storage;

search(raft::resources const& res,
search_params params,
int64_t dim,
Expand All @@ -665,7 +672,11 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
parent_node_list(0, resource::get_cuda_stream(res)),
topk_hint(0, resource::get_cuda_stream(res)),
topk_workspace(0, resource::get_cuda_stream(res)),
terminate_flag(resource::get_cuda_stream(res))
terminate_flag(resource::get_cuda_stream(res)),
input_keys_storage(0, resource::get_cuda_stream(res)),
output_keys_storage(0, resource::get_cuda_stream(res)),
input_values_storage(0, resource::get_cuda_stream(res)),
output_values_storage(0, resource::get_cuda_stream(res))
{
set_params(res);
}
Expand Down Expand Up @@ -695,6 +706,98 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {

~search() {}

inline void _find_topk(raft::resources const& handle,
uint32_t topK,
uint32_t sizeBatch,
uint32_t numElements,
const float* inputKeys, // [sizeBatch, ldIK,]
uint32_t ldIK, // (*) ldIK >= numElements
const INDEX_T* inputVals, // [sizeBatch, ldIV,]
uint32_t ldIV, // (*) ldIV >= numElements
float* outputKeys, // [sizeBatch, ldOK,]
uint32_t ldOK, // (*) ldOK >= topK
INDEX_T* outputVals, // [sizeBatch, ldOV,]
uint32_t ldOV, // (*) ldOV >= topK
void* workspace,
bool sort,
uint32_t* hints)
{
auto stream = resource::get_cuda_stream(handle);

// _cuann_find_topk right now is limited to a max-k of 1024.
// RAFT has a matrix::select_k function - which handles arbitrary sized values of k,
// but doesn't accept strided inputs unlike _cuann_find_topk
// The multi-kernel search path requires strided access - since its cleverly allocating memory
// (layout described in the search_plan_impl function below), such that both the
// neighbors and the internal_topk are adjacent - in a double buffered format.
// Since this layout doesn't work with the matrix::select_k code - we have to copy
// over to a contiguous (non-strided) access to handle topk larger than 1024, and
// potentially also copy back to a strided layout afterwards
if (topK <= 1024) {
return _cuann_find_topk(topK,
sizeBatch,
numElements,
inputKeys,
ldIK,
inputVals,
ldIV,
outputKeys,
ldOK,
outputVals,
ldOV,
workspace,
sort,
hints,
stream);
}

if (ldIK > numElements) {
if (input_keys_storage.size() != sizeBatch * numElements) {
input_keys_storage.resize(sizeBatch * numElements, stream);
}
batched_memcpy(
input_keys_storage.data(), numElements, inputKeys, ldIK, numElements, sizeBatch, stream);
inputKeys = input_keys_storage.data();
}

if (ldIV > numElements) {
if (input_values_storage.size() != sizeBatch * numElements) {
input_values_storage.resize(sizeBatch * numElements, stream);
}

batched_memcpy(
input_values_storage.data(), numElements, inputVals, ldIV, numElements, sizeBatch, stream);
inputVals = input_values_storage.data();
}

if ((ldOK > topK) && (output_keys_storage.size() != sizeBatch * topK)) {
output_keys_storage.resize(sizeBatch * topK, stream);
}

if ((ldOV > topK) && (output_values_storage.size() != sizeBatch * topK)) {
output_values_storage.resize(sizeBatch * topK, stream);
}

raft::matrix::select_k<float, INDEX_T>(
handle,
raft::make_device_matrix_view<const float, int64_t>(inputKeys, sizeBatch, numElements),
raft::make_device_matrix_view<const INDEX_T, int64_t>(inputVals, sizeBatch, numElements),
raft::make_device_matrix_view<float, int64_t>(
ldOK > topK ? output_keys_storage.data() : outputKeys, sizeBatch, topK),
raft::make_device_matrix_view<INDEX_T, int64_t>(
ldOV > topK ? output_values_storage.data() : outputVals, sizeBatch, topK),
true, // select_min
sort);

if (ldOK > topK) {
batched_memcpy(outputKeys, ldOK, output_keys_storage.data(), topK, topK, sizeBatch, stream);
}

if (ldOV > topK) {
batched_memcpy(outputVals, ldOV, output_values_storage.data(), topK, topK, sizeBatch, stream);
}
}

void operator()(raft::resources const& res,
raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph,
Expand Down Expand Up @@ -746,21 +849,21 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
unsigned iter = 0;
while (1) {
// Make an index list of internal top-k nodes
_cuann_find_topk(itopk_size,
num_queries,
result_buffer_size,
result_distances.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size,
result_buffer_allocation_size,
result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size,
result_buffer_allocation_size,
topk_workspace.data(),
true,
top_hint_ptr,
stream);
_find_topk(res,
itopk_size,
num_queries,
result_buffer_size,
result_distances.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size,
result_buffer_allocation_size,
result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size,
result_buffer_allocation_size,
topk_workspace.data(),
true,
top_hint_ptr);

// termination (1)
if ((iter + 1 == max_iterations)) {
Expand Down Expand Up @@ -841,21 +944,21 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {

result_indices_ptr = result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size;
result_distances_ptr = result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size;
_cuann_find_topk(itopk_size,
num_queries,
result_buffer_size,
result_distances.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_distances_ptr,
result_buffer_allocation_size,
result_indices_ptr,
result_buffer_allocation_size,
topk_workspace.data(),
true,
top_hint_ptr,
stream);
_find_topk(res,
itopk_size,
num_queries,
result_buffer_size,
result_distances.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_distances_ptr,
result_buffer_allocation_size,
result_indices_ptr,
result_buffer_allocation_size,
topk_workspace.data(),
true,
top_hint_ptr);
} else {
// Remove parent bit in search results
remove_parent_bit(
Expand Down
10 changes: 7 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ struct search_plan_impl_base : public search_params {
if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) {
algo = search_algo::SINGLE_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting single-cta");
} else {
} else if (topk <= 1024) {
algo = search_algo::MULTI_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta");
} else {
algo = search_algo::MULTI_KERNEL;
RAFT_LOG_DEBUG("Auto strategy: selecting multi kernel");
}
}
}
Expand Down Expand Up @@ -255,15 +258,16 @@ struct search_plan_impl : public search_plan_impl_base {
virtual void check(const uint32_t topk)
{
// For single-CTA and multi kernel
RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size);
RAFT_EXPECTS(
topk <= itopk_size, "topk = %u must be smaller than itopk_size = %lu", topk, itopk_size);
}

inline void check_params()
{
std::string error_message = "";

if (itopk_size > 1024) {
if (algo == search_algo::MULTI_CTA) {
if ((algo == search_algo::MULTI_CTA) || (algo == search_algo::MULTI_KERNEL)) {
} else {
error_message += std::string("- `internal_topk` (" + std::to_string(itopk_size) +
") must be smaller or equal to 1024");
Expand Down
20 changes: 20 additions & 0 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;
search_params.itopk_size = ps.itopk_size;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);
Expand Down Expand Up @@ -496,6 +497,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;
search_params.itopk_size = ps.itopk_size;
search_params.hashmap_mode = cagra::hash_mode::HASH;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
Expand Down Expand Up @@ -611,6 +613,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;
search_params.itopk_size = ps.itopk_size;
search_params.hashmap_mode = cagra::hash_mode::HASH;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
Expand Down Expand Up @@ -818,6 +821,23 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

inputs2 =
raft::util::itertools::product<AnnCagraInputs>({100},
{20000},
{32},
{2048}, // k
{graph_build_algo::NN_DESCENT},
{search_algo::AUTO},
{10},
{0},
{4096}, // itopk_size
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{false},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

return inputs;
}

Expand Down