Skip to content

Commit

Permalink
updates from codereview
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Jan 22, 2024
1 parent 7826591 commit ca5478f
Showing 1 changed file with 103 additions and 90 deletions.
193 changes: 103 additions & 90 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -598,95 +598,6 @@ void set_value_batch(T* const dev_ptr,
<<<grid_size, block_size, 0, cuda_stream>>>(dev_ptr, ld, val, count, batch_size);
}

template <class ValT>
inline void _find_topk(raft::resources const& handle,
uint32_t topK,
uint32_t sizeBatch,
uint32_t numElements,
const float* inputKeys, // [sizeBatch, ldIK,]
uint32_t ldIK, // (*) ldIK >= numElements
const ValT* inputVals, // [sizeBatch, ldIV,]
uint32_t ldIV, // (*) ldIV >= numElements
float* outputKeys, // [sizeBatch, ldOK,]
uint32_t ldOK, // (*) ldOK >= topK
ValT* outputVals, // [sizeBatch, ldOV,]
uint32_t ldOV, // (*) ldOV >= topK
void* workspace,
bool sort,
uint32_t* hints)
{
auto stream = resource::get_cuda_stream(handle);

// _cuann_find_topk right now is limited to a max-k of 1024.
// RAFT has a matrix::select_k function - which handles arbitrary sized values of k,
// but doesn't accept strided inputs unlike _cuann_find_topk
// The multi-kernel search path requires strided access - since its cleverly allocating memory
// (layout described in the search_plan_impl function below), such that both the
// neighbors and the internal_topk are adjacent - in a double buffered format.
// Since this layout doesn't work with the matrix::select_k code - we have to copy
// over to a contiguous (non-strided) access to handle topk larger than 1024, and
// potentially also copy back to a strided layout afterwards
if (topK <= 1024) {
return _cuann_find_topk(topK,
sizeBatch,
numElements,
inputKeys,
ldIK,
inputVals,
ldIV,
outputKeys,
ldOK,
outputVals,
ldOV,
workspace,
sort,
hints,
stream);
}

rmm::device_uvector<float> input_keys_storage(0, stream);
rmm::device_uvector<float> output_keys_storage(0, stream);
rmm::device_uvector<ValT> input_values_storage(0, stream);
rmm::device_uvector<ValT> output_values_storage(0, stream);

if (ldIK > numElements) {
input_keys_storage.resize(sizeBatch * numElements, stream);
batched_memcpy(
input_keys_storage.data(), numElements, inputKeys, ldIK, numElements, sizeBatch, stream);
inputKeys = input_keys_storage.data();
}

if (ldIV > numElements) {
input_values_storage.resize(sizeBatch * numElements, stream);
batched_memcpy(
input_values_storage.data(), numElements, inputVals, ldIV, numElements, sizeBatch, stream);
inputVals = input_values_storage.data();
}

if (ldOK > topK) { output_keys_storage.resize(sizeBatch * topK, stream); }

if (ldOV > topK) { output_values_storage.resize(sizeBatch * topK, stream); }

raft::matrix::select_k<float, ValT>(
handle,
raft::make_device_matrix_view<const float, int64_t>(inputKeys, sizeBatch, numElements),
raft::make_device_matrix_view<const ValT, int64_t>(inputVals, sizeBatch, numElements),
raft::make_device_matrix_view<float, int64_t>(
ldOK > topK ? output_keys_storage.data() : outputKeys, sizeBatch, topK),
raft::make_device_matrix_view<ValT, int64_t>(
ldOV > topK ? output_values_storage.data() : outputVals, sizeBatch, topK),
true, // select_min
sort);

if (ldOK > topK) {
batched_memcpy(outputKeys, ldOK, output_keys_storage.data(), topK, topK, sizeBatch, stream);
}

if (ldOV > topK) {
batched_memcpy(outputVals, ldOV, output_values_storage.data(), topK, topK, sizeBatch, stream);
}
}

// result_buffer (work buffer) for "multi-kernel"
// +--------------------+------------------------------+-------------------+
// | internal_top_k (A) | neighbors of internal_top_k | internal_topk (B) |
Expand Down Expand Up @@ -743,6 +654,12 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
rmm::device_scalar<uint32_t> terminate_flag; // dev_terminate_flag, host_terminate_flag.;
rmm::device_uvector<uint32_t> topk_workspace;

// temporary storage for _find_topk
rmm::device_uvector<float> input_keys_storage;
rmm::device_uvector<float> output_keys_storage;
rmm::device_uvector<INDEX_T> input_values_storage;
rmm::device_uvector<INDEX_T> output_values_storage;

search(raft::resources const& res,
search_params params,
int64_t dim,
Expand All @@ -755,7 +672,11 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
parent_node_list(0, resource::get_cuda_stream(res)),
topk_hint(0, resource::get_cuda_stream(res)),
topk_workspace(0, resource::get_cuda_stream(res)),
terminate_flag(resource::get_cuda_stream(res))
terminate_flag(resource::get_cuda_stream(res)),
input_keys_storage(0, resource::get_cuda_stream(res)),
output_keys_storage(0, resource::get_cuda_stream(res)),
input_values_storage(0, resource::get_cuda_stream(res)),
output_values_storage(0, resource::get_cuda_stream(res))
{
set_params(res);
}
Expand Down Expand Up @@ -785,6 +706,98 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {

~search() {}

inline void _find_topk(raft::resources const& handle,
uint32_t topK,
uint32_t sizeBatch,
uint32_t numElements,
const float* inputKeys, // [sizeBatch, ldIK,]
uint32_t ldIK, // (*) ldIK >= numElements
const INDEX_T* inputVals, // [sizeBatch, ldIV,]
uint32_t ldIV, // (*) ldIV >= numElements
float* outputKeys, // [sizeBatch, ldOK,]
uint32_t ldOK, // (*) ldOK >= topK
INDEX_T* outputVals, // [sizeBatch, ldOV,]
uint32_t ldOV, // (*) ldOV >= topK
void* workspace,
bool sort,
uint32_t* hints)
{
auto stream = resource::get_cuda_stream(handle);

// _cuann_find_topk right now is limited to a max-k of 1024.
// RAFT has a matrix::select_k function - which handles arbitrary sized values of k,
// but doesn't accept strided inputs unlike _cuann_find_topk
// The multi-kernel search path requires strided access - since its cleverly allocating memory
// (layout described in the search_plan_impl function below), such that both the
// neighbors and the internal_topk are adjacent - in a double buffered format.
// Since this layout doesn't work with the matrix::select_k code - we have to copy
// over to a contiguous (non-strided) access to handle topk larger than 1024, and
// potentially also copy back to a strided layout afterwards
if (topK <= 1024) {
return _cuann_find_topk(topK,
sizeBatch,
numElements,
inputKeys,
ldIK,
inputVals,
ldIV,
outputKeys,
ldOK,
outputVals,
ldOV,
workspace,
sort,
hints,
stream);
}

if (ldIK > numElements) {
if (input_keys_storage.size() != sizeBatch * numElements) {
input_keys_storage.resize(sizeBatch * numElements, stream);
}
batched_memcpy(
input_keys_storage.data(), numElements, inputKeys, ldIK, numElements, sizeBatch, stream);
inputKeys = input_keys_storage.data();
}

if (ldIV > numElements) {
if (input_values_storage.size() != sizeBatch * numElements) {
input_values_storage.resize(sizeBatch * numElements, stream);
}

batched_memcpy(
input_values_storage.data(), numElements, inputVals, ldIV, numElements, sizeBatch, stream);
inputVals = input_values_storage.data();
}

if ((ldOK > topK) && (output_keys_storage.size() != sizeBatch * topK)) {
output_keys_storage.resize(sizeBatch * topK, stream);
}

if ((ldOV > topK) && (output_values_storage.size() != sizeBatch * topK)) {
output_values_storage.resize(sizeBatch * topK, stream);
}

raft::matrix::select_k<float, INDEX_T>(
handle,
raft::make_device_matrix_view<const float, int64_t>(inputKeys, sizeBatch, numElements),
raft::make_device_matrix_view<const INDEX_T, int64_t>(inputVals, sizeBatch, numElements),
raft::make_device_matrix_view<float, int64_t>(
ldOK > topK ? output_keys_storage.data() : outputKeys, sizeBatch, topK),
raft::make_device_matrix_view<INDEX_T, int64_t>(
ldOV > topK ? output_values_storage.data() : outputVals, sizeBatch, topK),
true, // select_min
sort);

if (ldOK > topK) {
batched_memcpy(outputKeys, ldOK, output_keys_storage.data(), topK, topK, sizeBatch, stream);
}

if (ldOV > topK) {
batched_memcpy(outputVals, ldOV, output_values_storage.data(), topK, topK, sizeBatch, stream);
}
}

void operator()(raft::resources const& res,
raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph,
Expand Down

0 comments on commit ca5478f

Please sign in to comment.