Skip to content

Commit

Permalink
Replace CAGRA search kernel dispatch macros with functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Jul 17, 2023
1 parent 87c501a commit 700112f
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 227 deletions.
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/detail/cagra/hashmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "utils.hpp"
#include <cstdint>
#include <raft/core/detail/macros.hpp>
#include <raft/util/device_atomics.cuh>

This comment has been minimized.

Copy link
@cjnolet

cjnolet Jul 18, 2023

Member

Little nitpick: We should rename this file to cuh extension since it's bringing in cuh files.


// #pragma GCC diagnostic push
// #pragma GCC diagnostic ignored
Expand Down
62 changes: 25 additions & 37 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,43 +189,31 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
uint32_t topk)
{
cudaStream_t stream = resource::get_cuda_stream(res);
uint32_t block_size = thread_block_size;

SET_MC_KERNEL;
RAFT_CUDA_TRY(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Initialize hash table
const uint32_t hash_size = hashmap::get_size(hash_bitlen);
set_value_batch(
hashmap.data(), hash_size, utils::get_max_value<INDEX_T>(), hash_size, num_queries, stream);

dim3 block_dims(block_size, 1, 1);
dim3 grid_dims(num_cta_per_query, num_queries, 1);
RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %lu smem",
block_size,
num_cta_per_query,
num_queries,
smem_size);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(intermediate_indices.data(),
intermediate_distances.data(),
dataset.data_handle(),
dataset.extent(1),
dataset.extent(0),
dataset.stride(0),
queries_ptr,
graph.data_handle(),
graph.extent(1),
num_random_samplings,
rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap.data(),
hash_bitlen,
itopk_size,
num_parents,
min_iterations,
max_iterations,
num_executed_iterations);

select_and_run<TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T>(
dataset,
graph,
intermediate_indices.data(),
intermediate_distances.data(),
queries_ptr,
num_queries,
dev_seed_ptr,
num_executed_iterations,
topk,
thread_block_size,
result_buffer_size,
smem_size,
hash_bitlen,
hashmap.data(),
num_cta_per_query,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
num_parents,
min_iterations,
max_iterations,
stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

// Select the top-k results from the intermediate results
Expand Down
215 changes: 159 additions & 56 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -329,62 +329,6 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
#endif
}

#define SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS) \
kernel = search_kernel<TEAM_SIZE, \
BLOCK_SIZE, \
BLOCK_COUNT, \
MAX_ELEMENTS, \
MAX_DATASET_DIM, \
DATA_T, \
DISTANCE_T, \
INDEX_T, \
device::LOAD_128BIT_T>;

#define SET_MC_KERNEL_1(MAX_ELEMENTS) \
/* if ( block_size == 32 ) { \
SET_MC_KERNEL_3( 32, 32, MAX_ELEMENTS ) \
} else */ \
if (block_size == 64) { \
SET_MC_KERNEL_3(64, 16, MAX_ELEMENTS) \
} else if (block_size == 128) { \
SET_MC_KERNEL_3(128, 8, MAX_ELEMENTS) \
} else if (block_size == 256) { \
SET_MC_KERNEL_3(256, 4, MAX_ELEMENTS) \
} else if (block_size == 512) { \
SET_MC_KERNEL_3(512, 2, MAX_ELEMENTS) \
} else { \
SET_MC_KERNEL_3(1024, 1, MAX_ELEMENTS) \
}

#define SET_MC_KERNEL \
typedef void (*search_kernel_t)(INDEX_T* const result_indices_ptr, \
DISTANCE_T* const result_distances_ptr, \
const DATA_T* const dataset_ptr, \
const size_t dataset_dim, \
const size_t dataset_size, \
const size_t dataset_ld, \
const DATA_T* const queries_ptr, \
const INDEX_T* const knn_graph, \
const uint32_t graph_degree, \
const unsigned num_distilation, \
const uint64_t rand_xor_mask, \
const INDEX_T* seed_ptr, \
const uint32_t num_seeds, \
INDEX_T* const visited_hashmap_ptr, \
const uint32_t hash_bitlen, \
const uint32_t itopk_size, \
const uint32_t num_parents, \
const uint32_t min_iteration, \
const uint32_t max_iteration, \
uint32_t* const num_executed_iterations); \
search_kernel_t kernel = nullptr; \
if (result_buffer_size <= 64) { \
SET_MC_KERNEL_1(64) \
} else if (result_buffer_size <= 128) { \
SET_MC_KERNEL_1(128) \
} else if (result_buffer_size <= 256) { \
SET_MC_KERNEL_1(256) \
}
template <class T>
__global__ void set_value_batch_kernel(T* const dev_ptr,
const std::size_t ld,
Expand Down Expand Up @@ -413,5 +357,164 @@ void set_value_batch(T* const dev_ptr,
<<<grid_size, block_size, 0, cuda_stream>>>(dev_ptr, ld, val, count, batch_size);
}

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
typename DATA_T,
typename INDEX_T,
typename DISTANCE_T>
struct search_kernel_config {
// Search kernel function type. Note that the actual values for the template value
// parameters do not matter, because they are not part of the function signature. The
// second to fourth value parameters will be selected by the choose_* functions below.
using kernel_t = decltype(&search_kernel<TEAM_SIZE,
64,
16,
128,
MAX_DATASET_DIM,
DATA_T,
DISTANCE_T,
INDEX_T,
device::LOAD_128BIT_T>);

static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t
{
if (result_buffer_size <= 64) {
return choose_max_elements<64>(block_size);
} else if (result_buffer_size <= 128) {
return choose_max_elements<128>(block_size);
} else if (result_buffer_size <= 256) {
return choose_max_elements<256>(block_size);
}
THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256);
}

template <unsigned MAX_ELEMENTS>
// Todo: rename this to choose block_size
static auto choose_max_elements(unsigned block_size) -> kernel_t
{
if (block_size == 64) {
return search_kernel<TEAM_SIZE,
64,
16,
MAX_ELEMENTS,
MAX_DATASET_DIM,
DATA_T,
DISTANCE_T,
INDEX_T,
device::LOAD_128BIT_T>;
} else if (block_size == 128) {
return search_kernel<TEAM_SIZE,
128,
8,
MAX_ELEMENTS,
MAX_DATASET_DIM,
DATA_T,
DISTANCE_T,
INDEX_T,
device::LOAD_128BIT_T>;
} else if (block_size == 256) {
return search_kernel<TEAM_SIZE,
256,
4,
MAX_ELEMENTS,
MAX_DATASET_DIM,
DATA_T,
DISTANCE_T,
INDEX_T,
device::LOAD_128BIT_T>;
} else if (block_size == 512) {
return search_kernel<TEAM_SIZE,
512,
2,
MAX_ELEMENTS,
MAX_DATASET_DIM,
DATA_T,
DISTANCE_T,
INDEX_T,
device::LOAD_128BIT_T>;
} else {
return search_kernel<TEAM_SIZE,
1024,
1,
MAX_ELEMENTS,
MAX_DATASET_DIM,
DATA_T,
DISTANCE_T,
INDEX_T,
device::LOAD_128BIT_T>;
}
}
};

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
typename DATA_T,
typename INDEX_T,
typename DISTANCE_T>
void select_and_run( // raft::resources const& res,
raft::device_matrix_view<const DATA_T, INDEX_T, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, INDEX_T, row_major> graph,
INDEX_T* const topk_indices_ptr, // [num_queries, topk]
DISTANCE_T* const topk_distances_ptr, // [num_queries, topk]
const DATA_T* const queries_ptr, // [num_queries, dataset_dim]
const uint32_t num_queries,
const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
uint32_t topk,
// multi_cta_search (params struct)
uint32_t block_size, //
uint32_t result_buffer_size,
uint32_t smem_size,
int64_t hash_bitlen,
INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t num_parents,
size_t min_iterations,
size_t max_iterations,
cudaStream_t stream)
{
auto kernel = search_kernel_config<TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T>::
choose_buffer_size(result_buffer_size, block_size);

RAFT_CUDA_TRY(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Initialize hash table
const uint32_t hash_size = hashmap::get_size(hash_bitlen);
set_value_batch(
hashmap_ptr, hash_size, utils::get_max_value<INDEX_T>(), hash_size, num_queries, stream);

dim3 block_dims(block_size, 1, 1);
dim3 grid_dims(num_cta_per_query, num_queries, 1);
RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %lu smem",
block_size,
num_cta_per_query,
num_queries,
smem_size);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(topk_indices_ptr,
topk_distances_ptr,
dataset.data_handle(),
dataset.extent(1),
dataset.extent(0),
dataset.stride(0),
queries_ptr,
graph.data_handle(),
graph.extent(1),
num_random_samplings,
rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap_ptr,
hash_bitlen,
itopk_size,
num_parents,
min_iterations,
max_iterations,
num_executed_iterations);
}

} // namespace multi_cta_search
} // namespace raft::neighbors::experimental::cagra::detail
57 changes: 25 additions & 32 deletions cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -237,38 +237,31 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
uint32_t topk)
{
cudaStream_t stream = resource::get_cuda_stream(res);
uint32_t block_size = thread_block_size;
SET_KERNEL;
RAFT_CUDA_TRY(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
dim3 thread_dims(block_size, 1, 1);
dim3 block_dims(1, num_queries, 1);
RAFT_LOG_DEBUG(
"Launching kernel with %u threads, %u block %lu smem", block_size, num_queries, smem_size);
kernel<<<block_dims, thread_dims, smem_size, stream>>>(result_indices_ptr,
result_distances_ptr,
topk,
dataset.data_handle(),
dataset.extent(1),
dataset.extent(0),
dataset.stride(0),
queries_ptr,
graph.data_handle(),
graph.extent(1),
num_random_samplings,
rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap.data(),
itopk_size,
num_parents,
min_iterations,
max_iterations,
num_executed_iterations,
hash_bitlen,
small_hash_bitlen,
small_hash_reset_interval);
RAFT_CUDA_TRY(cudaPeekAtLastError());
select_and_run<TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T>(
dataset,
graph,
result_indices_ptr,
result_distances_ptr,
queries_ptr,
num_queries,
dev_seed_ptr,
num_executed_iterations,
topk,
num_itopk_candidates,
static_cast<uint32_t>(thread_block_size),
smem_size,
hash_bitlen,
hashmap.data(),
small_hash_bitlen,
small_hash_reset_interval,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
num_parents,
min_iterations,
max_iterations,
stream);
}
};

Expand Down
Loading

0 comments on commit 700112f

Please sign in to comment.