Skip to content

Commit

Permalink
Add RAFT_EXPLICIT macros for cagra single/multi_cta kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed May 1, 2023
1 parent 67eb2ef commit 5682bee
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,46 @@
*/
#pragma once

#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

namespace raft::neighbors::experimental::cagra::detail {
namespace multi_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned BLOCK_SIZE,
unsigned BLOCK_COUNT,
unsigned MAX_ELEMENTS,
unsigned MAX_DATASET_DIM,
class DATA_T,
class DISTANCE_T,
class INDEX_T,
class LOAD_T>
__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size]
DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size]
const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim]
const size_t dataset_dim,
const size_t dataset_size,
const DATA_T* const queries_ptr, // [num_queries, dataset_dim]
const INDEX_T* const knn_graph, // [dataset_size, graph_degree]
const uint32_t graph_degree,
const unsigned num_distilation,
const uint64_t rand_xor_mask,
const INDEX_T* seed_ptr, // [num_queries, num_seeds]
const uint32_t num_seeds,
uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen]
const uint32_t hash_bitlen,
const uint32_t itopk_size,
const uint32_t num_parents,
const uint32_t min_iteration,
const uint32_t max_iteration,
uint32_t* const num_executed_iterations /* stats */
) RAFT_EXPLICIT;

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \
BLOCK_SIZE, \
BLOCK_COUNT, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "search_multi_cta_kernel-inl.cuh"
#endif

#ifdef RAFT_COMPILED
#include "search_multi_cta_kernel-ext.cuh"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,49 @@
*/
#pragma once

#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
namespace raft::neighbors::experimental::cagra::detail {
namespace single_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned BLOCK_SIZE,
unsigned BLOCK_COUNT,
unsigned MAX_ITOPK,
unsigned MAX_CANDIDATES,
unsigned TOPK_BY_BITONIC_SORT,
unsigned MAX_DATASET_DIM,
class DATA_T,
class DISTANCE_T,
class INDEX_T,
class LOAD_T>
__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__
void search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k]
DISTANCE_T* const result_distances_ptr, // [num_queries, top_k]
const std::uint32_t top_k,
const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim]
const std::size_t dataset_dim,
const std::size_t dataset_size,
const DATA_T* const queries_ptr, // [num_queries, dataset_dim]
const INDEX_T* const knn_graph, // [dataset_size, graph_degree]
const std::uint32_t graph_degree,
const unsigned num_distilation,
const uint64_t rand_xor_mask,
const INDEX_T* seed_ptr, // [num_queries, num_seeds]
const uint32_t num_seeds,
std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen]
const std::uint32_t internal_topk,
const std::uint32_t num_parents,
const std::uint32_t min_iteration,
const std::uint32_t max_iteration,
std::uint32_t* const num_executed_iterations, // [num_queries]
const std::uint32_t hash_bitlen,
const std::uint32_t small_hash_bitlen,
const std::uint32_t small_hash_reset_interval) RAFT_EXPLICIT;

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_single_cta_search_kernel(TEAM_SIZE, \
BLOCK_SIZE, \
BLOCK_COUNT, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "search_single_cta_kernel-inl.cuh"
#endif

#ifdef RAFT_COMPILED
#include "search_single_cta_kernel-ext.cuh"
Expand Down

0 comments on commit 5682bee

Please sign in to comment.