From 2739c83a1a86d0ff2a29c884f017110da38a93b5 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Mon, 3 Jul 2023 02:04:28 +0200 Subject: [PATCH 1/6] Add template instantiations for cagra search_singe_cta --- .../neighbors/detail/cagra/graph_core.cuh | 16 +- .../detail/cagra/search_single_cta.cuh | 770 +------------- .../cagra/search_single_cta_kernel-ext.cuh | 982 ++++++++++++++++++ .../cagra/search_single_cta_kernel-inl.cuh | 740 +++++++++++++ .../detail/cagra/search_single_cta_kernel.cuh | 24 + .../neighbors/detail/cagra/topk_by_radix.cuh | 97 ++ .../cagra/search_single_cta_00_generate.py | 137 +++ ...rch_single_cta_float_uint32_dim1024_t32.cu | 145 +++ ...earch_single_cta_float_uint32_dim128_t8.cu | 145 +++ ...arch_single_cta_float_uint32_dim256_t16.cu | 145 +++ ...arch_single_cta_float_uint32_dim512_t32.cu | 145 +++ ...arch_single_cta_int8_uint32_dim1024_t32.cu | 151 +++ ...search_single_cta_int8_uint32_dim128_t8.cu | 145 +++ ...earch_single_cta_int8_uint32_dim256_t16.cu | 145 +++ ...earch_single_cta_int8_uint32_dim512_t32.cu | 145 +++ ...rch_single_cta_uint8_uint32_dim1024_t32.cu | 182 ++++ ...earch_single_cta_uint8_uint32_dim128_t8.cu | 145 +++ ...arch_single_cta_uint8_uint32_dim256_t16.cu | 151 +++ ...arch_single_cta_uint8_uint32_dim512_t32.cu | 151 +++ cpp/test/CMakeLists.txt | 62 +- cpp/test/neighbors/ann_cagra.cuh | 2 + 21 files changed, 3834 insertions(+), 791 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index d915634df9..949dcfda8b 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -72,6 +72,7 @@ template __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dataset_dim] const IdxT dataset_size, const uint32_t dataset_dim, + const uint32_t dataset_ld, IdxT* const knn_graph, // [graph_chunk_size, graph_degree] const uint32_t graph_size, const uint32_t graph_degree) @@ -90,9 +91,9 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, float dist = 0.0; for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) { float diff = spatial::knn::detail::utils::mapping{}( - dataset[d + static_cast(dataset_dim) * srcNode]) - + dataset[d + static_cast(dataset_ld) * srcNode]) - spatial::knn::detail::utils::mapping{}( - dataset[d + static_cast(dataset_dim) * dstNode]); + dataset[d + static_cast(dataset_ld) * dstNode]); dist += diff * diff; } dist += __shfl_xor_sync(0xffffffff, dist, 1); @@ -238,6 +239,7 @@ void sort_knn_graph(raft::resources const& res, "dataset size is expected to have the same number of graph index size"); const uint32_t dataset_size = dataset.extent(0); const uint32_t dataset_dim = dataset.extent(1); + const uint32_t dataset_ld = dataset.stride(0); const DataT* dataset_ptr = dataset.data_handle(); const IdxT graph_size = dataset_size; @@ -263,8 +265,13 @@ void sort_knn_graph(raft::resources const& res, graph_size * input_graph_degree, resource::get_cuda_stream(res)); - void (*kernel_sort)( - const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t); + void (*kernel_sort)(const DataT* const, + const IdxT, + const uint32_t, + const uint32_t, + IdxT* const, + const uint32_t, + const uint32_t); if (input_graph_degree <= 32) { constexpr int numElementsPerThread = 1; kernel_sort = kern_sort; @@ -299,6 +306,7 @@ void sort_knn_graph(raft::resources const& res, d_dataset.data_handle(), dataset_size, dataset_dim, + dataset_ld, d_input_graph.data_handle(), graph_size, input_graph_degree); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 219a1dd717..a1e02db055 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -34,6 +34,8 @@ #include "device_common.hpp" #include "hashmap.hpp" #include "search_plan.cuh" +#include "search_single_cta_kernel.cuh" +#include "topk_by_radix.cuh" #include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk #include "utils.hpp" #include @@ -43,774 +45,6 @@ namespace raft::neighbors::experimental::cagra::detail { namespace single_cta_search { -// #define _CLK_BREAKDOWN - -template -__device__ void pickup_next_parents(std::uint32_t* const terminate_flag, - INDEX_T* const next_parent_indices, - INDEX_T* const internal_topk_indices, - const std::size_t internal_topk_size, - const std::size_t dataset_size, - const std::uint32_t num_parents) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - // if (threadIdx.x >= 32) return; - - for (std::uint32_t i = threadIdx.x; i < num_parents; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - std::uint32_t itopk_max = internal_topk_size; - if (itopk_max % 32) { itopk_max += 32 - (itopk_max % 32); } - std::uint32_t num_new_parents = 0; - for (std::uint32_t j = threadIdx.x; j < itopk_max; j += 32) { - std::uint32_t jj = j; - if (TOPK_BY_BITONIC_SORT) { jj = device::swizzling(j); } - INDEX_T index; - int new_parent = 0; - if (j < internal_topk_size) { - index = internal_topk_indices[jj]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } - } - const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; - if (i < num_parents) { - next_parent_indices[i] = index; - // set most significant bit as used node - internal_topk_indices[jj] |= index_msb_1_mask; - } - } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= num_parents) { break; } - } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } -} - -template -struct topk_by_radix_sort_base { - static constexpr std::uint32_t smem_size = MAX_INTERNAL_TOPK * 2 + 2048 + 8; - static constexpr std::uint32_t state_bit_lenght = 0; - static constexpr std::uint32_t vecLen = 2; // TODO -}; -template -struct topk_by_radix_sort : topk_by_radix_sort_base {}; - -template -struct topk_by_radix_sort> - : topk_by_radix_sort_base { - __device__ void operator()(uint32_t topk, - uint32_t batch_size, - uint32_t len_x, - const uint32_t* _x, - const IdxT* _in_vals, - uint32_t* _y, - IdxT* _out_vals, - uint32_t* work, - uint32_t* _hints, - bool sort, - uint32_t* _smem) - { - std::uint8_t* const state = reinterpret_cast(work); - topk_cta_11_core::state_bit_lenght, - topk_by_radix_sort_base::vecLen, - 64, - 32, - IdxT>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); - } -}; - -#define TOP_FUNC_PARTIAL_SPECIALIZATION(V) \ - template \ - struct topk_by_radix_sort< \ - MAX_INTERNAL_TOPK, \ - BLOCK_SIZE, \ - IdxT, \ - std::enable_if_t<((MAX_INTERNAL_TOPK <= V) && (2 * MAX_INTERNAL_TOPK > V))>> \ - : topk_by_radix_sort_base { \ - __device__ void operator()(uint32_t topk, \ - uint32_t batch_size, \ - uint32_t len_x, \ - const uint32_t* _x, \ - const IdxT* _in_vals, \ - uint32_t* _y, \ - IdxT* _out_vals, \ - uint32_t* work, \ - uint32_t* _hints, \ - bool sort, \ - uint32_t* _smem) \ - { \ - assert(BLOCK_SIZE >= V / 4); \ - std::uint8_t* state = (std::uint8_t*)work; \ - topk_cta_11_core::state_bit_lenght, \ - topk_by_radix_sort_base::vecLen, \ - V, \ - V / 4, \ - IdxT>( \ - topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); \ - } \ - }; -TOP_FUNC_PARTIAL_SPECIALIZATION(128); -TOP_FUNC_PARTIAL_SPECIALIZATION(256); -TOP_FUNC_PARTIAL_SPECIALIZATION(512); -TOP_FUNC_PARTIAL_SPECIALIZATION(1024); - -template -__device__ inline void topk_by_bitonic_sort_1st(float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - const std::uint32_t num_itopk) -{ - const unsigned lane_id = threadIdx.x % 32; - const unsigned warp_id = threadIdx.x / 32; - if (MULTI_WARPS == 0) { - if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_CANDIDATES + 31) / 32; - float key[N]; - IdxT val[N]; - /* Candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); - if (j < num_candidates) { - key[i] = candidate_distances[j]; - val[i] = candidate_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Sort */ - bitonic::warp_sort(key, val); - /* Reg -> Temp_itopk */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_candidates && j < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } else { - // Use two warps (64 threads) - constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; - constexpr unsigned N = (max_candidates_per_warp + 31) / 32; - float key[N]; - IdxT val[N]; - if (warp_id < 2) { - /* Candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = lane_id + (32 * i); - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates) { - key[i] = candidate_distances[j]; - val[i] = candidate_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Sort */ - bitonic::warp_sort(key, val); - /* Reg -> Temp_candidates */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates && jl < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } - __syncthreads(); - - unsigned num_warps_used = (num_itopk + max_candidates_per_warp - 1) / max_candidates_per_warp; - if (warp_id < num_warps_used) { - /* Temp_candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned kl = max_candidates_per_warp - 1 - jl; - unsigned j = jl + (max_candidates_per_warp * warp_id); - unsigned k = MAX_CANDIDATES - 1 - j; - if (j >= num_candidates || k >= num_candidates || kl >= num_itopk) continue; - float temp_key = candidate_distances[device::swizzling(k)]; - if (key[i] == temp_key) continue; - if ((warp_id == 0) == (key[i] > temp_key)) { - key[i] = temp_key; - val[i] = candidate_indices[device::swizzling(k)]; - } - } - } - if (num_warps_used > 1) { __syncthreads(); } - if (warp_id < num_warps_used) { - /* Merge */ - bitonic::warp_merge(key, val, 32); - /* Reg -> Temp_itopk */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates && j < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } - if (num_warps_used > 1) { __syncthreads(); } - } -} - -template -__device__ inline void topk_by_bitonic_sort_2nd(float* itopk_distances, // [num_itopk] - IdxT* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first) -{ - const unsigned lane_id = threadIdx.x % 32; - const unsigned warp_id = threadIdx.x / 32; - if (MULTI_WARPS == 0) { - if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_ITOPK + 31) / 32; - float key[N]; - IdxT val[N]; - if (first) { - /* Load itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); - if (j < num_itopk) { - key[i] = itopk_distances[j]; - val[i] = itopk_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - } else { - /* Load itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - key[i] = itopk_distances[device::swizzling(j)]; - val[i] = itopk_indices[device::swizzling(j)]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - } - /* Merge candidates */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; // [0:MAX_ITOPK-1] - unsigned k = MAX_ITOPK - 1 - j; - if (k >= num_itopk || k >= num_candidates) continue; - float candidate_key = candidate_distances[device::swizzling(k)]; - if (key[i] > candidate_key) { - key[i] = candidate_key; - val[i] = candidate_indices[device::swizzling(k)]; - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, 32); - /* Store new itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - } else { - // Use two warps (64 threads) or more - constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; - constexpr unsigned N = (max_itopk_per_warp + 31) / 32; - float key[N]; - IdxT val[N]; - if (first) { - /* Load itop results (not sorted) */ - if (warp_id < 2) { - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i) + (max_itopk_per_warp * warp_id); - if (j < num_itopk) { - key[i] = itopk_distances[j]; - val[i] = itopk_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - /* Store intermedidate results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - if (j >= num_itopk) continue; - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - __syncthreads(); - if (warp_id < 2) { - /* Load intermedidate results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - unsigned k = MAX_ITOPK - 1 - j; - if (k >= num_itopk) continue; - float temp_key = itopk_distances[device::swizzling(k)]; - if (key[i] == temp_key) continue; - if ((warp_id == 0) == (key[i] > temp_key)) { - key[i] = temp_key; - val[i] = itopk_indices[device::swizzling(k)]; - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, 32); - } - __syncthreads(); - /* Store itopk results (sorted) */ - if (warp_id < 2) { - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - if (j >= num_itopk) continue; - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - } - const uint32_t num_itopk_div2 = num_itopk / 2; - if (threadIdx.x < 3) { - // work_buf is used to obtain turning points in 1st and 2nd half of itopk afer merge. - work_buf[threadIdx.x] = num_itopk_div2; - } - __syncthreads(); - - // Merge candidates (using whole threads) - for (unsigned k = threadIdx.x; k < min(num_candidates, num_itopk); k += blockDim.x) { - const unsigned j = num_itopk - 1 - k; - const float itopk_key = itopk_distances[device::swizzling(j)]; - const float candidate_key = candidate_distances[device::swizzling(k)]; - if (itopk_key > candidate_key) { - itopk_distances[device::swizzling(j)] = candidate_key; - itopk_indices[device::swizzling(j)] = candidate_indices[device::swizzling(k)]; - if (j < num_itopk_div2) { - atomicMin(work_buf + 2, j); - } else { - atomicMin(work_buf + 1, j - num_itopk_div2); - } - } - } - __syncthreads(); - - // Merge 1st and 2nd half of itopk (using whole threads) - for (unsigned j = threadIdx.x; j < num_itopk_div2; j += blockDim.x) { - const unsigned k = j + num_itopk_div2; - float key_0 = itopk_distances[device::swizzling(j)]; - float key_1 = itopk_distances[device::swizzling(k)]; - if (key_0 > key_1) { - itopk_distances[device::swizzling(j)] = key_1; - itopk_distances[device::swizzling(k)] = key_0; - IdxT val_0 = itopk_indices[device::swizzling(j)]; - IdxT val_1 = itopk_indices[device::swizzling(k)]; - itopk_indices[device::swizzling(j)] = val_1; - itopk_indices[device::swizzling(k)] = val_0; - atomicMin(work_buf + 0, j); - } - } - if (threadIdx.x == blockDim.x - 1) { - if (work_buf[2] < num_itopk_div2) { work_buf[1] = work_buf[2]; } - } - __syncthreads(); - // if ((blockIdx.x == 0) && (threadIdx.x == 0)) { - // RAFT_LOG_DEBUG( "work_buf: %u, %u, %u\n", work_buf[0], work_buf[1], work_buf[2] ); - // } - - // Warp-0 merges 1st half of itopk, warp-1 does 2nd half. - if (warp_id < 2) { - // Load intermedidate itopk results - const uint32_t turning_point = work_buf[warp_id]; // turning_point <= num_itopk_div2 - for (unsigned i = 0; i < N; i++) { - unsigned k = num_itopk; - unsigned j = (N * lane_id) + i; - if (j < turning_point) { - k = j + (num_itopk_div2 * warp_id); - } else if (j >= (MAX_ITOPK / 2 - num_itopk_div2)) { - j -= (MAX_ITOPK / 2 - num_itopk_div2); - if ((turning_point <= j) && (j < num_itopk_div2)) { k = j + (num_itopk_div2 * warp_id); } - } - if (k < num_itopk) { - key[i] = itopk_distances[device::swizzling(k)]; - val[i] = itopk_indices[device::swizzling(k)]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, 32); - /* Store new itopk results */ - for (unsigned i = 0; i < N; i++) { - const unsigned j = (N * lane_id) + i; - if (j < num_itopk_div2) { - unsigned k = j + (num_itopk_div2 * warp_id); - itopk_distances[device::swizzling(k)] = key[i]; - itopk_indices[device::swizzling(k)] = val[i]; - } - } - } - } -} - -template -__device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] - IdxT* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first) -{ - // The results in candidate_distances/indices are sorted by bitonic sort. - topk_by_bitonic_sort_1st( - candidate_distances, candidate_indices, num_candidates, num_itopk); - - // The results sorted above are merged with the internal intermediate top-k - // results so far using bitonic merge. - topk_by_bitonic_sort_2nd(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); -} - -template -__device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr, - const size_t hashmap_bitlen, - const INDEX_T* itopk_indices, - uint32_t itopk_size) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - if (threadIdx.x < FIRST_TID || threadIdx.x >= LAST_TID) return; - for (unsigned i = threadIdx.x - FIRST_TID; i < itopk_size; i += LAST_TID - FIRST_TID) { - auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit - hashmap::insert(hashmap_ptr, hashmap_bitlen, key); - } -} - -template -__device__ inline void set_value_device(T* const ptr, const T fill, const std::uint32_t count) -{ - for (std::uint32_t i = threadIdx.x; i < count; i += BLOCK_SIZE) { - ptr[i] = fill; - } -} - -// One query one thread block -template -__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ - void search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k] - DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] - const std::uint32_t top_k, - const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const std::size_t dataset_dim, - const std::size_t dataset_size, - const std::size_t dataset_ld, // stride of dataset - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t internal_topk, - const std::uint32_t num_parents, - const std::uint32_t min_iteration, - const std::uint32_t max_iteration, - std::uint32_t* const num_executed_iterations, // [num_queries] - const std::uint32_t hash_bitlen, - const std::uint32_t small_hash_bitlen, - const std::uint32_t small_hash_reset_interval) -{ - const auto query_id = blockIdx.y; - -#ifdef _CLK_BREAKDOWN - std::uint64_t clk_init = 0; - std::uint64_t clk_compute_1st_distance = 0; - std::uint64_t clk_topk = 0; - std::uint64_t clk_reset_hash = 0; - std::uint64_t clk_pickup_parents = 0; - std::uint64_t clk_restore_hash = 0; - std::uint64_t clk_compute_distance = 0; - std::uint64_t clk_start; -#define _CLK_START() clk_start = clock64() -#define _CLK_REC(V) V += clock64() - clk_start; -#else -#define _CLK_START() -#define _CLK_REC(V) -#endif - _CLK_START(); - - extern __shared__ std::uint32_t smem[]; - - // Layout of result_buffer - // +----------------------+------------------------------+---------+ - // | internal_top_k | neighbors of internal_top_k | padding | - // | | | upto 32 | - // +----------------------+------------------------------+---------+ - // |<--- result_buffer_size --->| - std::uint32_t result_buffer_size = internal_topk + (num_parents * graph_degree); - std::uint32_t result_buffer_size_32 = result_buffer_size; - if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } - const auto small_hash_size = hashmap::get_size(small_hash_bitlen); - auto query_buffer = reinterpret_cast(smem); - auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); - auto result_distances_buffer = - reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto visited_hash_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); - auto topk_ws = reinterpret_cast(parent_list_buffer + num_parents); - auto terminate_flag = reinterpret_cast(topk_ws + 3); - auto smem_working_ptr = reinterpret_cast(terminate_flag + 1); - - const DATA_T* const query_ptr = queries_ptr + query_id * dataset_dim; - for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { - unsigned j = device::swizzling(i); - if (i < dataset_dim) { - query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); - } else { - query_buffer[j] = 0.0; - } - } - if (threadIdx.x == 0) { - terminate_flag[0] = 0; - topk_ws[0] = ~0u; - } - - // Init hashmap - INDEX_T* local_visited_hashmap_ptr; - if (small_hash_bitlen) { - local_visited_hashmap_ptr = visited_hash_buffer; - } else { - local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); - } - hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); - __syncthreads(); - _CLK_REC(clk_init); - - // compute distance to randomly selecting nodes - _CLK_START(); - const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - device::compute_distance_to_random_nodes( - result_indices_buffer, - result_distances_buffer, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, - result_buffer_size, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - hash_bitlen); - __syncthreads(); - _CLK_REC(clk_compute_1st_distance); - - std::uint32_t iter = 0; - while (1) { - // sort - if (TOPK_BY_BITONIC_SORT) { - // [Notice] - // It is good to use multiple warps in topk_by_bitonic_sort() when - // batch size is small (short-latency), but it might not be always good - // when batch size is large (high-throughput). - // topk_by_bitonic_sort() consists of two operations: - // if MAX_CANDIDATES is greater than 128, the first operation uses two warps; - // if MAX_ITOPK is greater than 256, the second operation used two warps. - constexpr unsigned multi_warps_1 = ((BLOCK_SIZE >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; - constexpr unsigned multi_warps_2 = ((BLOCK_SIZE >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; - - // reset small-hash table. - if ((iter + 1) % small_hash_reset_interval == 0) { - // Depending on the block size and the number of warps used in - // topk_by_bitonic_sort(), determine which warps are used to reset - // the small hash and whether they are performed in overlap with - // topk_by_bitonic_sort(). - _CLK_START(); - if (BLOCK_SIZE == 32) { - hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); - } else if (BLOCK_SIZE == 64) { - if (multi_warps_1 || multi_warps_2) { - hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); - } else { - hashmap::init<32, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); - } - } else { - if (multi_warps_1 || multi_warps_2) { - hashmap::init<64, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); - } else { - hashmap::init<32, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); - } - } - _CLK_REC(clk_reset_hash); - } - - // topk with bitonic sort - _CLK_START(); - topk_by_bitonic_sort( - result_distances_buffer, - result_indices_buffer, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - num_parents * graph_degree, - topk_ws, - (iter == 0)); - _CLK_REC(clk_topk); - - } else { - _CLK_START(); - // topk with radix block sort - topk_by_radix_sort{}( - internal_topk, - gridDim.x, - result_buffer_size, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - nullptr, - topk_ws, - true, - reinterpret_cast(smem_working_ptr)); - _CLK_REC(clk_topk); - - // reset small-hash table - if ((iter + 1) % small_hash_reset_interval == 0) { - _CLK_START(); - hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); - _CLK_REC(clk_reset_hash); - } - } - __syncthreads(); - - if (iter + 1 == max_iteration) { break; } - - // pick up next parents - if (threadIdx.x < 32) { - _CLK_START(); - pickup_next_parents(terminate_flag, - parent_list_buffer, - result_indices_buffer, - internal_topk, - dataset_size, - num_parents); - _CLK_REC(clk_pickup_parents); - } - - // restore small-hash table by putting internal-topk indices in it - if ((iter + 1) % small_hash_reset_interval == 0) { - constexpr unsigned first_tid = ((BLOCK_SIZE <= 32) ? 0 : 32); - _CLK_START(); - hashmap_restore( - local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk); - _CLK_REC(clk_restore_hash); - } - __syncthreads(); - - if (*terminate_flag && iter >= min_iteration) { break; } - - // compute the norms between child nodes and query node - _CLK_START(); - constexpr unsigned max_n_frags = 16; - device:: - compute_distance_to_child_nodes( - result_indices_buffer + internal_topk, - result_distances_buffer + internal_topk, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_list_buffer, - num_parents); - __syncthreads(); - _CLK_REC(clk_compute_distance); - - iter++; - } - for (std::uint32_t i = threadIdx.x; i < top_k; i += BLOCK_SIZE) { - unsigned j = i + (top_k * query_id); - unsigned ii = i; - if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } - - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - result_indices_ptr[j] = - result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit - } - if (threadIdx.x == 0 && num_executed_iterations != nullptr) { - num_executed_iterations[query_id] = iter + 1; - } -#ifdef _CLK_BREAKDOWN - if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && ((query_id * 3) % gridDim.y < 3)) { - RAFT_LOG_DEBUG( - "query, %d, thread, %d" - ", init, %d" - ", 1st_distance, %lu" - ", topk, %lu" - ", reset_hash, %lu" - ", pickup_parents, %lu" - ", restore_hash, %lu" - ", distance, %lu" - "\n", - query_id, - threadIdx.x, - clk_init, - clk_compute_1st_distance, - clk_topk, - clk_reset_hash, - clk_pickup_parents, - clk_restore_hash, - clk_compute_distance); - } -#endif -} - #define SET_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ITOPK, MAX_CANDIDATES, TOPK_BY_BITONIC_SORT) \ kernel = search_kernel // RAFT_EXPLICIT +namespace raft::neighbors::experimental::cagra::detail { +namespace single_cta_search { + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA + +template +__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ + void search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k] + DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] + const std::uint32_t top_k, + const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] + const std::size_t dataset_dim, + const std::size_t dataset_size, + const std::size_t dataset_ld, + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t internal_topk, + const std::uint32_t num_parents, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval) RAFT_EXPLICIT; + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + extern template __global__ void search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const std::size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +// search_single_cta_float_uint32_dim1024_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 1024, float, float, uint32_t, uint4); + +// search_single_cta_float_uint32_dim128_t8.cu +instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, float, float, uint32_t, uint4); + +// search_single_cta_float_uint32_dim256_t16.cu +instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, float, float, uint32_t, uint4); + +// search_single_cta_float_uint32_dim512_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, float, float, uint32_t, uint4); + +// search_single_cta_int8_uint32_dim1024_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 1024, int8_t, float, uint32_t, uint4); + +// search_single_cta_int8_uint32_dim128_t8.cu +instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, int8_t, float, uint32_t, uint4); + +// search_single_cta_int8_uint32_dim256_t16.cu +instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, int8_t, float, uint32_t, uint4); + +// search_single_cta_int8_uint32_dim512_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, int8_t, float, uint32_t, uint4); + +// search_single_cta_uint8_uint32_dim1024_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 32, 0, 1024, uint8_t, float, uint32_t, uint4); + +// search_single_cta_uint8_uint32_dim128_t8.cu +instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, uint8_t, float, uint32_t, uint4); + +// search_single_cta_uint8_uint32_dim256_t16.cu +instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, uint8_t, float, uint32_t, uint4); + +// search_single_cta_uint8_uint32_dim512_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, uint8_t, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace single_cta_search +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh new file mode 100644 index 0000000000..8100db9dcb --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -0,0 +1,740 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bitonic.hpp" +#include "compute_distance.hpp" +#include "device_common.hpp" +#include "hashmap.hpp" +#include "search_plan.cuh" +#include "topk_by_radix.cuh" +#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk +#include "utils.hpp" +#include +#include +#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp + +namespace raft::neighbors::experimental::cagra::detail { +namespace single_cta_search { + +// #define _CLK_BREAKDOWN + +template +__device__ void pickup_next_parents(std::uint32_t* const terminate_flag, + INDEX_T* const next_parent_indices, + INDEX_T* const internal_topk_indices, + const std::size_t internal_topk_size, + const std::size_t dataset_size, + const std::uint32_t num_parents) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + // if (threadIdx.x >= 32) return; + + for (std::uint32_t i = threadIdx.x; i < num_parents; i += 32) { + next_parent_indices[i] = utils::get_max_value(); + } + std::uint32_t itopk_max = internal_topk_size; + if (itopk_max % 32) { itopk_max += 32 - (itopk_max % 32); } + std::uint32_t num_new_parents = 0; + for (std::uint32_t j = threadIdx.x; j < itopk_max; j += 32) { + std::uint32_t jj = j; + if (TOPK_BY_BITONIC_SORT) { jj = device::swizzling(j); } + INDEX_T index; + int new_parent = 0; + if (j < internal_topk_size) { + index = internal_topk_indices[jj]; + if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set + new_parent = 1; + } + } + const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); + if (new_parent) { + const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; + if (i < num_parents) { + next_parent_indices[i] = index; + // set most significant bit as used node + internal_topk_indices[jj] |= index_msb_1_mask; + } + } + num_new_parents += __popc(ballot_mask); + if (num_new_parents >= num_parents) { break; } + } + if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } +} + +template +__device__ inline void topk_by_bitonic_sort_1st(float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + const unsigned lane_id = threadIdx.x % 32; + const unsigned warp_id = threadIdx.x / 32; + if (MULTI_WARPS == 0) { + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_CANDIDATES + 31) / 32; + float key[N]; + IdxT val[N]; + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (32 * i); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } else { + // Use two warps (64 threads) + constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; + constexpr unsigned N = (max_candidates_per_warp + 31) / 32; + float key[N]; + IdxT val[N]; + if (warp_id < 2) { + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = lane_id + (32 * i); + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && jl < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + __syncthreads(); + + unsigned num_warps_used = (num_itopk + max_candidates_per_warp - 1) / max_candidates_per_warp; + if (warp_id < num_warps_used) { + /* Temp_candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned kl = max_candidates_per_warp - 1 - jl; + unsigned j = jl + (max_candidates_per_warp * warp_id); + unsigned k = MAX_CANDIDATES - 1 - j; + if (j >= num_candidates || k >= num_candidates || kl >= num_itopk) continue; + float temp_key = candidate_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + if (warp_id < num_warps_used) { + /* Merge */ + bitonic::warp_merge(key, val, 32); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + } +} + +template +__device__ inline void topk_by_bitonic_sort_2nd(float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + const unsigned lane_id = threadIdx.x % 32; + const unsigned warp_id = threadIdx.x / 32; + if (MULTI_WARPS == 0) { + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_ITOPK + 31) / 32; + float key[N]; + IdxT val[N]; + if (first) { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (32 * i); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + } else { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + key[i] = itopk_distances[device::swizzling(j)]; + val[i] = itopk_indices[device::swizzling(j)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + } + /* Merge candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; // [0:MAX_ITOPK-1] + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk || k >= num_candidates) continue; + float candidate_key = candidate_distances[device::swizzling(k)]; + if (key[i] > candidate_key) { + key[i] = candidate_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, 32); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } else { + // Use two warps (64 threads) or more + constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; + constexpr unsigned N = (max_itopk_per_warp + 31) / 32; + float key[N]; + IdxT val[N]; + if (first) { + /* Load itop results (not sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (32 * i) + (max_itopk_per_warp * warp_id); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + __syncthreads(); + if (warp_id < 2) { + /* Load intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk) continue; + float temp_key = itopk_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = itopk_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, 32); + } + __syncthreads(); + /* Store itopk results (sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } + const uint32_t num_itopk_div2 = num_itopk / 2; + if (threadIdx.x < 3) { + // work_buf is used to obtain turning points in 1st and 2nd half of itopk afer merge. + work_buf[threadIdx.x] = num_itopk_div2; + } + __syncthreads(); + + // Merge candidates (using whole threads) + for (unsigned k = threadIdx.x; k < min(num_candidates, num_itopk); k += blockDim.x) { + const unsigned j = num_itopk - 1 - k; + const float itopk_key = itopk_distances[device::swizzling(j)]; + const float candidate_key = candidate_distances[device::swizzling(k)]; + if (itopk_key > candidate_key) { + itopk_distances[device::swizzling(j)] = candidate_key; + itopk_indices[device::swizzling(j)] = candidate_indices[device::swizzling(k)]; + if (j < num_itopk_div2) { + atomicMin(work_buf + 2, j); + } else { + atomicMin(work_buf + 1, j - num_itopk_div2); + } + } + } + __syncthreads(); + + // Merge 1st and 2nd half of itopk (using whole threads) + for (unsigned j = threadIdx.x; j < num_itopk_div2; j += blockDim.x) { + const unsigned k = j + num_itopk_div2; + float key_0 = itopk_distances[device::swizzling(j)]; + float key_1 = itopk_distances[device::swizzling(k)]; + if (key_0 > key_1) { + itopk_distances[device::swizzling(j)] = key_1; + itopk_distances[device::swizzling(k)] = key_0; + IdxT val_0 = itopk_indices[device::swizzling(j)]; + IdxT val_1 = itopk_indices[device::swizzling(k)]; + itopk_indices[device::swizzling(j)] = val_1; + itopk_indices[device::swizzling(k)] = val_0; + atomicMin(work_buf + 0, j); + } + } + if (threadIdx.x == blockDim.x - 1) { + if (work_buf[2] < num_itopk_div2) { work_buf[1] = work_buf[2]; } + } + __syncthreads(); + // if ((blockIdx.x == 0) && (threadIdx.x == 0)) { + // RAFT_LOG_DEBUG( "work_buf: %u, %u, %u\n", work_buf[0], work_buf[1], work_buf[2] ); + // } + + // Warp-0 merges 1st half of itopk, warp-1 does 2nd half. + if (warp_id < 2) { + // Load intermedidate itopk results + const uint32_t turning_point = work_buf[warp_id]; // turning_point <= num_itopk_div2 + for (unsigned i = 0; i < N; i++) { + unsigned k = num_itopk; + unsigned j = (N * lane_id) + i; + if (j < turning_point) { + k = j + (num_itopk_div2 * warp_id); + } else if (j >= (MAX_ITOPK / 2 - num_itopk_div2)) { + j -= (MAX_ITOPK / 2 - num_itopk_div2); + if ((turning_point <= j) && (j < num_itopk_div2)) { k = j + (num_itopk_div2 * warp_id); } + } + if (k < num_itopk) { + key[i] = itopk_distances[device::swizzling(k)]; + val[i] = itopk_indices[device::swizzling(k)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, 32); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + const unsigned j = (N * lane_id) + i; + if (j < num_itopk_div2) { + unsigned k = j + (num_itopk_div2 * warp_id); + itopk_distances[device::swizzling(k)] = key[i]; + itopk_indices[device::swizzling(k)] = val[i]; + } + } + } + } +} + +template +__device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + // The results in candidate_distances/indices are sorted by bitonic sort. + topk_by_bitonic_sort_1st( + candidate_distances, candidate_indices, num_candidates, num_itopk); + + // The results sorted above are merged with the internal intermediate top-k + // results so far using bitonic merge. + topk_by_bitonic_sort_2nd(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +template +__device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr, + const size_t hashmap_bitlen, + const INDEX_T* itopk_indices, + uint32_t itopk_size) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + if (threadIdx.x < FIRST_TID || threadIdx.x >= LAST_TID) return; + for (unsigned i = threadIdx.x - FIRST_TID; i < itopk_size; i += LAST_TID - FIRST_TID) { + auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit + hashmap::insert(hashmap_ptr, hashmap_bitlen, key); + } +} + +template +__device__ inline void set_value_device(T* const ptr, const T fill, const std::uint32_t count) +{ + for (std::uint32_t i = threadIdx.x; i < count; i += BLOCK_SIZE) { + ptr[i] = fill; + } +} + +// One query one thread block +template +__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ + void search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k] + DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] + const std::uint32_t top_k, + const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] + const std::size_t dataset_dim, + const std::size_t dataset_size, + const std::size_t dataset_ld, // stride of dataset + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t internal_topk, + const std::uint32_t num_parents, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval) +{ + const auto query_id = blockIdx.y; + +#ifdef _CLK_BREAKDOWN + std::uint64_t clk_init = 0; + std::uint64_t clk_compute_1st_distance = 0; + std::uint64_t clk_topk = 0; + std::uint64_t clk_reset_hash = 0; + std::uint64_t clk_pickup_parents = 0; + std::uint64_t clk_restore_hash = 0; + std::uint64_t clk_compute_distance = 0; + std::uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ std::uint32_t smem[]; + + // Layout of result_buffer + // +----------------------+------------------------------+---------+ + // | internal_top_k | neighbors of internal_top_k | padding | + // | | | upto 32 | + // +----------------------+------------------------------+---------+ + // |<--- result_buffer_size --->| + std::uint32_t result_buffer_size = internal_topk + (num_parents * graph_degree); + std::uint32_t result_buffer_size_32 = result_buffer_size; + if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } + const auto small_hash_size = hashmap::get_size(small_hash_bitlen); + auto query_buffer = reinterpret_cast(smem); + auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); + auto result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto visited_hash_buffer = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); + auto topk_ws = reinterpret_cast(parent_list_buffer + num_parents); + auto terminate_flag = reinterpret_cast(topk_ws + 3); + auto smem_working_ptr = reinterpret_cast(terminate_flag + 1); + + const DATA_T* const query_ptr = queries_ptr + query_id * dataset_dim; // dataset_dim + for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { + unsigned j = device::swizzling(i); + if (i < dataset_dim) { + query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); + } else { + query_buffer[j] = 0.0; + } + } + if (threadIdx.x == 0) { + terminate_flag[0] = 0; + topk_ws[0] = ~0u; + } + + // Init hashmap + INDEX_T* local_visited_hashmap_ptr; + if (small_hash_bitlen) { + local_visited_hashmap_ptr = visited_hash_buffer; + } else { + local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); + } + hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes + _CLK_START(); + const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + device::compute_distance_to_random_nodes( + result_indices_buffer, + result_distances_buffer, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_size, + dataset_ld, + result_buffer_size, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + hash_bitlen); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + std::uint32_t iter = 0; + while (1) { + // sort + if (TOPK_BY_BITONIC_SORT) { + // [Notice] + // It is good to use multiple warps in topk_by_bitonic_sort() when + // batch size is small (short-latency), but it might not be always good + // when batch size is large (high-throughput). + // topk_by_bitonic_sort() consists of two operations: + // if MAX_CANDIDATES is greater than 128, the first operation uses two warps; + // if MAX_ITOPK is greater than 256, the second operation used two warps. + constexpr unsigned multi_warps_1 = ((BLOCK_SIZE >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; + constexpr unsigned multi_warps_2 = ((BLOCK_SIZE >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; + + // reset small-hash table. + if ((iter + 1) % small_hash_reset_interval == 0) { + // Depending on the block size and the number of warps used in + // topk_by_bitonic_sort(), determine which warps are used to reset + // the small hash and whether they are performed in overlap with + // topk_by_bitonic_sort(). + _CLK_START(); + if (BLOCK_SIZE == 32) { + hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } else if (BLOCK_SIZE == 64) { + if (multi_warps_1 || multi_warps_2) { + hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } else { + hashmap::init<32, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } + } else { + if (multi_warps_1 || multi_warps_2) { + hashmap::init<64, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } else { + hashmap::init<32, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + } + } + _CLK_REC(clk_reset_hash); + } + + // topk with bitonic sort + _CLK_START(); + topk_by_bitonic_sort( + result_distances_buffer, + result_indices_buffer, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + num_parents * graph_degree, + topk_ws, + (iter == 0)); + _CLK_REC(clk_topk); + + } else { + _CLK_START(); + // topk with radix block sort + topk_by_radix_sort{}( + internal_topk, + gridDim.x, + result_buffer_size, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + nullptr, + topk_ws, + true, + reinterpret_cast(smem_working_ptr)); + _CLK_REC(clk_topk); + + // reset small-hash table + if ((iter + 1) % small_hash_reset_interval == 0) { + _CLK_START(); + hashmap::init<0, BLOCK_SIZE>(local_visited_hashmap_ptr, hash_bitlen); + _CLK_REC(clk_reset_hash); + } + } + __syncthreads(); + + if (iter + 1 == max_iteration) { break; } + + // pick up next parents + if (threadIdx.x < 32) { + _CLK_START(); + pickup_next_parents(terminate_flag, + parent_list_buffer, + result_indices_buffer, + internal_topk, + dataset_size, + num_parents); + _CLK_REC(clk_pickup_parents); + } + + // restore small-hash table by putting internal-topk indices in it + if ((iter + 1) % small_hash_reset_interval == 0) { + constexpr unsigned first_tid = ((BLOCK_SIZE <= 32) ? 0 : 32); + _CLK_START(); + hashmap_restore( + local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk); + _CLK_REC(clk_restore_hash); + } + __syncthreads(); + + if (*terminate_flag && iter >= min_iteration) { break; } + + // compute the norms between child nodes and query node + _CLK_START(); + constexpr unsigned max_n_frags = 16; + device:: + compute_distance_to_child_nodes( + result_indices_buffer + internal_topk, + result_distances_buffer + internal_topk, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_ld, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + parent_list_buffer, + num_parents); + __syncthreads(); + _CLK_REC(clk_compute_distance); + + iter++; + } + for (std::uint32_t i = threadIdx.x; i < top_k; i += BLOCK_SIZE) { + unsigned j = i + (top_k * query_id); + unsigned ii = i; + if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } + if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + result_indices_ptr[j] = + result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit + } + if (threadIdx.x == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && ((query_id * 3) % gridDim.y < 3)) { + RAFT_LOG_DEBUG( + "query, %d, thread, %d" + ", init, %d" + ", 1st_distance, %lu" + ", topk, %lu" + ", reset_hash, %lu" + ", pickup_parents, %lu" + ", restore_hash, %lu" + ", distance, %lu" + "\n", + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_reset_hash, + clk_pickup_parents, + clk_restore_hash, + clk_compute_distance); + } +#endif +} + +} // namespace single_cta_search +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh new file mode 100644 index 0000000000..e57fa31763 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA +#include "search_single_cta_kernel-inl.cuh" +#endif + +#ifdef RAFT_COMPILED_CAGRA +#include "search_single_cta_kernel-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh new file mode 100644 index 0000000000..d151cc8ee7 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "topk_for_cagra/topk_core.cuh" + +namespace raft::neighbors::experimental::cagra::detail { +namespace single_cta_search { + +template +struct topk_by_radix_sort_base { + static constexpr std::uint32_t smem_size = MAX_INTERNAL_TOPK * 2 + 2048 + 8; + static constexpr std::uint32_t state_bit_lenght = 0; + static constexpr std::uint32_t vecLen = 2; // TODO +}; +template +struct topk_by_radix_sort : topk_by_radix_sort_base {}; + +template +struct topk_by_radix_sort> + : topk_by_radix_sort_base { + __device__ void operator()(uint32_t topk, + uint32_t batch_size, + uint32_t len_x, + const uint32_t* _x, + const IdxT* _in_vals, + uint32_t* _y, + IdxT* _out_vals, + uint32_t* work, + uint32_t* _hints, + bool sort, + uint32_t* _smem) + { + std::uint8_t* const state = reinterpret_cast(work); + topk_cta_11_core::state_bit_lenght, + topk_by_radix_sort_base::vecLen, + 64, + 32, + IdxT>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); + } +}; + +#define TOP_FUNC_PARTIAL_SPECIALIZATION(V) \ + template \ + struct topk_by_radix_sort< \ + MAX_INTERNAL_TOPK, \ + BLOCK_SIZE, \ + IdxT, \ + std::enable_if_t<((MAX_INTERNAL_TOPK <= V) && (2 * MAX_INTERNAL_TOPK > V))>> \ + : topk_by_radix_sort_base { \ + __device__ void operator()(uint32_t topk, \ + uint32_t batch_size, \ + uint32_t len_x, \ + const uint32_t* _x, \ + const IdxT* _in_vals, \ + uint32_t* _y, \ + IdxT* _out_vals, \ + uint32_t* work, \ + uint32_t* _hints, \ + bool sort, \ + uint32_t* _smem) \ + { \ + assert(BLOCK_SIZE >= V / 4); \ + std::uint8_t* state = (std::uint8_t*)work; \ + topk_cta_11_core::state_bit_lenght, \ + topk_by_radix_sort_base::vecLen, \ + V, \ + V / 4, \ + IdxT>( \ + topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); \ + } \ + }; +TOP_FUNC_PARTIAL_SPECIALIZATION(128); +TOP_FUNC_PARTIAL_SPECIALIZATION(256); +TOP_FUNC_PARTIAL_SPECIALIZATION(512); +TOP_FUNC_PARTIAL_SPECIALIZATION(1024); + +} // namespace single_cta_search +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py new file mode 100644 index 0000000000..44da773157 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \\ + BLOCK_SIZE, \\ + BLOCK_COUNT, \\ + MAX_ITOPK, \\ + MAX_CANDIDATES, \\ + TOPK_BY_BITONIC_SORT, \\ + MAX_DATASET_DIM, \\ + DATA_T, \\ + DISTANCE_T, \\ + INDEX_T, \\ + LOAD_T) \\ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \\ + DISTANCE_T* const result_distances_ptr, \\ + const std::uint32_t top_k, \\ + const DATA_T* const dataset_ptr, \\ + const std::size_t dataset_dim, \\ + const std::size_t dataset_size, \\ + const size_t dataset_ld, \\ + const DATA_T* const queries_ptr, \\ + const INDEX_T* const knn_graph, \\ + const std::uint32_t graph_degree, \\ + const unsigned num_distilation, \\ + const uint64_t rand_xor_mask, \\ + const INDEX_T* seed_ptr, \\ + const uint32_t num_seeds, \\ + std::uint32_t* const visited_hashmap_ptr, \\ + const std::uint32_t internal_topk, \\ + const std::uint32_t num_parents, \\ + const std::uint32_t min_iteration, \\ + const std::uint32_t max_iteration, \\ + std::uint32_t* const num_executed_iterations, \\ + const std::uint32_t hash_bitlen, \\ + const std::uint32_t small_hash_bitlen, \\ + const std::uint32_t small_hash_reset_interval); + +""" + +trailer = """ +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +""" + +mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] +block = [(64, 16), (128, 8), (256, 4), (512, 2), (1024, 1)] +itopk_candidates = [64, 128, 256] +itopk_size = [64, 128, 256, 512] +mxelem = [64, 128, 256] +load_types = ["uint4"] + +rblock = [(256, 4), (512, 2), (1024, 1)] +rcandidates = [32] +rsize = [256, 512] + +search_types = dict( + float_uint32=("float", "uint32_t", "float"), # data_t, idx_t, distance_t + int8_uint32=("int8_t", "uint32_t", "float"), + uint8_uint32=("uint8_t", "uint32_t", "float"), +) + +# knn +for type_path, (data_t, idx_t, distance_t) in search_types.items(): + for (mxdim, team) in mxdim_team: + path = f"search_single_cta_{type_path}_dim{mxdim}_t{team}.cu" + with open(path, "w") as f: + f.write(header) + for load_t in load_types: + for b in block: + for candidates in itopk_candidates: + for isize in itopk_size: + f.write( + f"instantiate_single_cta_search_kernel({team}, {b[0]}, {b[1]}, {isize}, {candidates}, 1, {mxdim},{data_t}, {distance_t}, {idx_t}, {load_t});\n" + ) + for b in rblock: + for candidates in rcandidates: + for isize in rsize: + f.write( + f"instantiate_single_cta_search_kernel({team}, {b[0]}, {b[1]}, {isize}, {candidates}, 0, {mxdim},{data_t}, {distance_t}, {idx_t}, {load_t});\n" + ) + f.write(trailer) + # For pasting into CMakeLists.txt + print(f"src/neighbors/detail/cagra/{path}") diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu new file mode 100644 index 0000000000..f948e398b3 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 1024, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 1024, float, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu new file mode 100644 index 0000000000..efa9e6accf --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, float, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu new file mode 100644 index 0000000000..f5ea9af226 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 128, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 256, 1, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, float, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu new file mode 100644 index 0000000000..15a4d37ede --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, float, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, float, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu new file mode 100644 index 0000000000..d0bc01448d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu @@ -0,0 +1,151 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 128, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 256, 1, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 1024, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 1024, int8_t, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu new file mode 100644 index 0000000000..7176dc84e6 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, int8_t, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu new file mode 100644 index 0000000000..60b8ea3999 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 128, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 256, 1, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, int8_t, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu new file mode 100644 index 0000000000..16ae7f5a2b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, int8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, int8_t, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu new file mode 100644 index 0000000000..a2573f7d6b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu @@ -0,0 +1,182 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 64, 16, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 128, 8, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 256, 4, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 512, 2, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 64, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 64, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 128, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 64, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 256, 1, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 32, 0, 1024, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 32, 0, 1024, uint8_t, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu new file mode 100644 index 0000000000..e3728f3765 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, uint8_t, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu new file mode 100644 index 0000000000..d04185aa50 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu @@ -0,0 +1,151 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 128, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 256, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 512, 128, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 128, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 256, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 16, 1024, 1, 512, 256, 1, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, uint8_t, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu new file mode 100644 index 0000000000..831027b44f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu @@ -0,0 +1,151 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + std::uint32_t* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 128, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 128, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 256, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel( + 32, 1024, 1, 512, 256, 1, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, uint8_t, float, uint32_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, uint8_t, float, uint32_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 33d4dd9423..413cd89ce7 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -316,25 +316,49 @@ if(BUILD_TESTS) NEIGHBORS_TEST PATH test/neighbors/ann_cagra/test_float_uint32_t.cu - test/neighbors/ann_cagra/test_int8_t_uint32_t.cu - test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu - test/neighbors/ann_cagra/test_float_int64_t.cu - test/neighbors/ann_ivf_flat/test_float_int64_t.cu - test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu - test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu - test/neighbors/ann_ivf_pq/test_float_int64_t.cu - test/neighbors/ann_ivf_pq/test_float_uint32_t.cu - test/neighbors/ann_ivf_pq/test_float_int64_t.cu - test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu - test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu - test/neighbors/knn.cu - test/neighbors/fused_l2_knn.cu - test/neighbors/tiled_knn.cu - test/neighbors/haversine.cu - test/neighbors/ball_cover.cu - test/neighbors/epsilon_neighborhood.cu - test/neighbors/refine.cu - test/neighbors/selection.cu + # test/neighbors/ann_cagra/test_int8_t_uint32_t.cu + # test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu + # test/neighbors/ann_cagra/test_float_int64_t.cu + # test/neighbors/ann_ivf_flat/test_float_int64_t.cu + # test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu + # test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu + # test/neighbors/ann_ivf_pq/test_float_int64_t.cu + # test/neighbors/ann_ivf_pq/test_float_uint32_t.cu + # test/neighbors/ann_ivf_pq/test_float_int64_t.cu + # test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu + # test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu + # test/neighbors/knn.cu + # test/neighbors/fused_l2_knn.cu + # test/neighbors/tiled_knn.cu + # test/neighbors/haversine.cu + # test/neighbors/ball_cover.cu + # test/neighbors/epsilon_neighborhood.cu + # test/neighbors/refine.cu + # test/neighbors/selection.cu + # src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu + # src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu + # src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu + # src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu + # src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu + # src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu + # src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu + # src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu + # src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu + # src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu + # src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu + # src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu + src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu + src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu + src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu + src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu + src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu + src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu + src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu + src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu + src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu + src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu + src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu + src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index d3bd5ba31d..10d6d4b679 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -15,6 +15,8 @@ */ #pragma once +#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA +#define RAFT_COMPILED_CAGRA #include "../test_utils.cuh" #include "ann_utils.cuh" #include From ec5689b1ac8170474e2df243a093e383b7a30a68 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Mon, 3 Jul 2023 23:29:46 +0200 Subject: [PATCH 2/6] Fix instantiations --- .../neighbors/detail/cagra/graph_core.cuh | 16 ++---- .../cagra/search_single_cta_kernel-inl.cuh | 1 + .../cagra/search_single_cta_00_generate.py | 1 + cpp/test/CMakeLists.txt | 50 +++++++------------ cpp/test/neighbors/ann_cagra.cuh | 4 +- .../neighbors/ann_cagra/test_float_int64_t.cu | 4 +- .../ann_cagra/test_float_uint32_t.cu | 2 + .../ann_cagra/test_int8_t_uint32_t.cu | 2 + .../ann_cagra/test_uint8_t_uint32_t.cu | 2 + 9 files changed, 36 insertions(+), 46 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 949dcfda8b..d915634df9 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -72,7 +72,6 @@ template __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dataset_dim] const IdxT dataset_size, const uint32_t dataset_dim, - const uint32_t dataset_ld, IdxT* const knn_graph, // [graph_chunk_size, graph_degree] const uint32_t graph_size, const uint32_t graph_degree) @@ -91,9 +90,9 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, float dist = 0.0; for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) { float diff = spatial::knn::detail::utils::mapping{}( - dataset[d + static_cast(dataset_ld) * srcNode]) - + dataset[d + static_cast(dataset_dim) * srcNode]) - spatial::knn::detail::utils::mapping{}( - dataset[d + static_cast(dataset_ld) * dstNode]); + dataset[d + static_cast(dataset_dim) * dstNode]); dist += diff * diff; } dist += __shfl_xor_sync(0xffffffff, dist, 1); @@ -239,7 +238,6 @@ void sort_knn_graph(raft::resources const& res, "dataset size is expected to have the same number of graph index size"); const uint32_t dataset_size = dataset.extent(0); const uint32_t dataset_dim = dataset.extent(1); - const uint32_t dataset_ld = dataset.stride(0); const DataT* dataset_ptr = dataset.data_handle(); const IdxT graph_size = dataset_size; @@ -265,13 +263,8 @@ void sort_knn_graph(raft::resources const& res, graph_size * input_graph_degree, resource::get_cuda_stream(res)); - void (*kernel_sort)(const DataT* const, - const IdxT, - const uint32_t, - const uint32_t, - IdxT* const, - const uint32_t, - const uint32_t); + void (*kernel_sort)( + const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t); if (input_graph_degree <= 32) { constexpr int numElementsPerThread = 1; kernel_sort = kern_sort; @@ -306,7 +299,6 @@ void sort_knn_graph(raft::resources const& res, d_dataset.data_handle(), dataset_size, dataset_dim, - dataset_ld, d_input_graph.data_handle(), graph_size, input_graph_degree); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 8100db9dcb..db3bcc3e9c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -17,6 +17,7 @@ #include #include +#include #include #include #include diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py index 44da773157..bb5e7d6838 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py @@ -111,6 +111,7 @@ float_uint32=("float", "uint32_t", "float"), # data_t, idx_t, distance_t int8_uint32=("int8_t", "uint32_t", "float"), uint8_uint32=("uint8_t", "uint32_t", "float"), + float_uint64=("float", "uint64_t", "float"), ) # knn diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 413cd89ce7..01f2380bfd 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -316,37 +316,25 @@ if(BUILD_TESTS) NEIGHBORS_TEST PATH test/neighbors/ann_cagra/test_float_uint32_t.cu - # test/neighbors/ann_cagra/test_int8_t_uint32_t.cu - # test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu - # test/neighbors/ann_cagra/test_float_int64_t.cu - # test/neighbors/ann_ivf_flat/test_float_int64_t.cu - # test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu - # test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu - # test/neighbors/ann_ivf_pq/test_float_int64_t.cu - # test/neighbors/ann_ivf_pq/test_float_uint32_t.cu - # test/neighbors/ann_ivf_pq/test_float_int64_t.cu - # test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu - # test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu - # test/neighbors/knn.cu - # test/neighbors/fused_l2_knn.cu - # test/neighbors/tiled_knn.cu - # test/neighbors/haversine.cu - # test/neighbors/ball_cover.cu - # test/neighbors/epsilon_neighborhood.cu - # test/neighbors/refine.cu - # test/neighbors/selection.cu - # src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu - # src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu - # src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu - # src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu - # src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu - # src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu - # src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu - # src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu - # src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu - # src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu - # src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu - # src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu + test/neighbors/ann_cagra/test_int8_t_uint32_t.cu + test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu + test/neighbors/ann_cagra/test_float_int64_t.cu + test/neighbors/ann_ivf_flat/test_float_int64_t.cu + test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu + test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu + test/neighbors/ann_ivf_pq/test_float_int64_t.cu + test/neighbors/ann_ivf_pq/test_float_uint32_t.cu + test/neighbors/ann_ivf_pq/test_float_int64_t.cu + test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu + test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu + test/neighbors/knn.cu + test/neighbors/fused_l2_knn.cu + test/neighbors/tiled_knn.cu + test/neighbors/haversine.cu + test/neighbors/ball_cover.cu + test/neighbors/epsilon_neighborhood.cu + test/neighbors/refine.cu + test/neighbors/selection.cu src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 10d6d4b679..ed8d85dd7b 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -15,8 +15,8 @@ */ #pragma once -#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA -#define RAFT_COMPILED_CAGRA +// #define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA +// #define RAFT_COMPILED_CAGRA #include "../test_utils.cuh" #include "ann_utils.cuh" #include diff --git a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu index e473a72b2b..b76ff478db 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu @@ -16,7 +16,9 @@ #include -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA +#undef RAFT_COMPILED_CAGRA + #include "../ann_cagra.cuh" namespace raft::neighbors::experimental::cagra { diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index dbaf4dedd9..28d28d473d 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA +#define RAFT_COMPILED_CAGRA #include #include "../ann_cagra.cuh" diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index ba60131677..30e76643e2 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA +#define RAFT_COMPILED_CAGRA #include #include "../ann_cagra.cuh" diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index cc172e4833..21af8ae14c 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA +#define RAFT_COMPILED_CAGRA #include From fb7510acee8e931b4a09a43ced0e923e8eedc6c6 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 4 Jul 2023 10:55:54 +0200 Subject: [PATCH 3/6] Move CAGRA obj files to libraft.so --- cpp/CMakeLists.txt | 12 ++++++++++++ cpp/test/CMakeLists.txt | 12 ------------ cpp/test/neighbors/ann_cagra.cuh | 2 -- cpp/test/neighbors/ann_cagra/test_float_int64_t.cu | 4 ++-- cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu | 2 -- cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu | 2 -- .../neighbors/ann_cagra/test_uint8_t_uint32_t.cu | 2 -- 7 files changed, 14 insertions(+), 22 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6fa1b5830e..98f6286716 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -307,6 +307,18 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu src/neighbors/brute_force_knn_int_float_int.cu src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu + src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu + src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu + src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu + src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu + src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu + src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu + src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu + src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu + src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu + src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu + src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu + src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 01f2380bfd..33d4dd9423 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -335,18 +335,6 @@ if(BUILD_TESTS) test/neighbors/epsilon_neighborhood.cu test/neighbors/refine.cu test/neighbors/selection.cu - src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index ed8d85dd7b..d3bd5ba31d 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -15,8 +15,6 @@ */ #pragma once -// #define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA -// #define RAFT_COMPILED_CAGRA #include "../test_utils.cuh" #include "ann_utils.cuh" #include diff --git a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu index b76ff478db..d7405c166c 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu @@ -16,8 +16,8 @@ #include -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA -#undef RAFT_COMPILED_CAGRA +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#undef RAFT_COMPILED #include "../ann_cagra.cuh" diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index 28d28d473d..dbaf4dedd9 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -14,8 +14,6 @@ * limitations under the License. */ -#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA -#define RAFT_COMPILED_CAGRA #include #include "../ann_cagra.cuh" diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index 30e76643e2..ba60131677 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -14,8 +14,6 @@ * limitations under the License. */ -#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA -#define RAFT_COMPILED_CAGRA #include #include "../ann_cagra.cuh" diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index 21af8ae14c..cc172e4833 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -13,8 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA -#define RAFT_COMPILED_CAGRA #include From 95760b362cafd3f5734bf79dcd292e5b2ffa5a76 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 4 Jul 2023 15:13:09 +0200 Subject: [PATCH 4/6] Fix template instantiations for uint64_t --- .../raft/neighbors/detail/cagra/hashmap.hpp | 1 + .../cagra/search_single_cta_kernel-ext.cuh | 2 +- .../detail/cagra/search_single_cta_kernel.cuh | 4 +- .../cagra/search_single_cta_00_generate.py | 2 +- ...rch_single_cta_float_uint32_dim1024_t32.cu | 2 +- ...earch_single_cta_float_uint32_dim128_t8.cu | 2 +- ...arch_single_cta_float_uint32_dim256_t16.cu | 2 +- ...arch_single_cta_float_uint32_dim512_t32.cu | 2 +- ...rch_single_cta_float_uint64_dim1024_t32.cu | 145 ++++++++++++++++++ ...earch_single_cta_float_uint64_dim128_t8.cu | 145 ++++++++++++++++++ ...arch_single_cta_float_uint64_dim256_t16.cu | 145 ++++++++++++++++++ ...arch_single_cta_float_uint64_dim512_t32.cu | 145 ++++++++++++++++++ ...arch_single_cta_int8_uint32_dim1024_t32.cu | 2 +- ...search_single_cta_int8_uint32_dim128_t8.cu | 2 +- ...earch_single_cta_int8_uint32_dim256_t16.cu | 2 +- ...earch_single_cta_int8_uint32_dim512_t32.cu | 2 +- ...rch_single_cta_uint8_uint32_dim1024_t32.cu | 2 +- ...earch_single_cta_uint8_uint32_dim128_t8.cu | 2 +- ...arch_single_cta_uint8_uint32_dim256_t16.cu | 2 +- ...arch_single_cta_uint8_uint32_dim512_t32.cu | 2 +- cpp/test/CMakeLists.txt | 4 + 21 files changed, 601 insertions(+), 16 deletions(-) create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu diff --git a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp index cd2c8ec491..5992aaaf1d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp @@ -18,6 +18,7 @@ #include "utils.hpp" #include #include +#include // #pragma GCC diagnostic push // #pragma GCC diagnostic ignored diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh index 56c6f41c9e..82dff6d78d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh @@ -19,7 +19,7 @@ namespace raft::neighbors::experimental::cagra::detail { namespace single_cta_search { -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY template python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 1024, float, float, uint64_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu new file mode 100644 index 0000000000..9102e83923 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, float, float, uint64_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu new file mode 100644 index 0000000000..eb3a6440be --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, float, float, uint64_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu new file mode 100644 index 0000000000..a801ef936d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu @@ -0,0 +1,145 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_single_cta_00_generate.py + * + */ + +#include + +namespace raft::neighbors::experimental::cagra::detail::single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, float, float, uint64_t, uint4); + +#undef instantiate_single_cta_search_kernel + +} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu index d0bc01448d..8cfa9217ab 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu @@ -63,7 +63,7 @@ namespace raft::neighbors::experimental::cagra::detail::single_cta_search { const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t internal_topk, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu index 7176dc84e6..2626dec869 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu @@ -63,7 +63,7 @@ namespace raft::neighbors::experimental::cagra::detail::single_cta_search { const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t internal_topk, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu index 60b8ea3999..e7daee7986 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu @@ -63,7 +63,7 @@ namespace raft::neighbors::experimental::cagra::detail::single_cta_search { const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t internal_topk, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu index 16ae7f5a2b..9e7bc0ee3d 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu @@ -63,7 +63,7 @@ namespace raft::neighbors::experimental::cagra::detail::single_cta_search { const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t internal_topk, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu index a2573f7d6b..baeb459d9b 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu @@ -63,7 +63,7 @@ namespace raft::neighbors::experimental::cagra::detail::single_cta_search { const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t internal_topk, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu index e3728f3765..031d9d0a68 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu @@ -63,7 +63,7 @@ namespace raft::neighbors::experimental::cagra::detail::single_cta_search { const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t internal_topk, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu index d04185aa50..049e0d16cf 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu @@ -63,7 +63,7 @@ namespace raft::neighbors::experimental::cagra::detail::single_cta_search { const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t internal_topk, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu index 831027b44f..78863b82bd 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu @@ -63,7 +63,7 @@ namespace raft::neighbors::experimental::cagra::detail::single_cta_search { const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t internal_topk, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 33d4dd9423..04104c09db 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -319,6 +319,10 @@ if(BUILD_TESTS) test/neighbors/ann_cagra/test_int8_t_uint32_t.cu test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu test/neighbors/ann_cagra/test_float_int64_t.cu + src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu + # src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu + # src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu + # src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu From 868b5f1219c1abe132ea8e1bb3f9e0dcec45d708 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 4 Jul 2023 22:52:49 +0200 Subject: [PATCH 5/6] Add multi_cta_search_kernel instantiations --- cpp/CMakeLists.txt | 13 + .../detail/cagra/search_multi_cta.cuh | 289 +------------- .../cagra/search_multi_cta_kernel-ext.cuh | 371 ++++++++++++++++++ .../cagra/search_multi_cta_kernel-inl.cuh | 332 ++++++++++++++++ .../detail/cagra/search_multi_cta_kernel.cuh | 24 ++ .../cagra/search_single_cta_kernel-ext.cuh | 272 +++++++++++++ .../cagra/search_multi_cta_00_generate.py | 113 ++++++ ...arch_multi_cta_float_uint32_dim1024_t32.cu | 85 ++++ ...search_multi_cta_float_uint32_dim128_t8.cu | 85 ++++ ...earch_multi_cta_float_uint32_dim256_t16.cu | 85 ++++ ...earch_multi_cta_float_uint32_dim512_t32.cu | 85 ++++ ...arch_multi_cta_float_uint64_dim1024_t32.cu | 85 ++++ ...search_multi_cta_float_uint64_dim128_t8.cu | 85 ++++ ...earch_multi_cta_float_uint64_dim256_t16.cu | 85 ++++ ...earch_multi_cta_float_uint64_dim512_t32.cu | 85 ++++ ...earch_multi_cta_int8_uint32_dim1024_t32.cu | 85 ++++ .../search_multi_cta_int8_uint32_dim128_t8.cu | 85 ++++ ...search_multi_cta_int8_uint32_dim256_t16.cu | 85 ++++ ...search_multi_cta_int8_uint32_dim512_t32.cu | 85 ++++ ...arch_multi_cta_uint8_uint32_dim1024_t32.cu | 85 ++++ ...search_multi_cta_uint8_uint32_dim128_t8.cu | 85 ++++ ...earch_multi_cta_uint8_uint32_dim256_t16.cu | 85 ++++ ...earch_multi_cta_uint8_uint32_dim512_t32.cu | 85 ++++ cpp/test/CMakeLists.txt | 10 +- .../neighbors/ann_cagra/test_float_int64_t.cu | 3 - 25 files changed, 2493 insertions(+), 294 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh create mode 100644 cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 98f6286716..ae1f5b9c7a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -307,6 +307,18 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu src/neighbors/brute_force_knn_int_float_int.cu src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu + src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu + src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu + src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu + src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu + src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu + src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu + src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu + src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu + src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu + src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu + src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu + src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu @@ -319,6 +331,7 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu + src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 2f34febdd2..4598c0ffa9 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -33,6 +33,7 @@ #include "compute_distance.hpp" #include "device_common.hpp" #include "hashmap.hpp" +#include "search_multi_cta_kernel.cuh" #include "search_plan.cuh" #include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible #include "utils.hpp" @@ -43,294 +44,6 @@ namespace raft::neighbors::experimental::cagra::detail { namespace multi_cta_search { -// #define _CLK_BREAKDOWN - -template -__device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num_parents] - const uint32_t num_parents, - INDEX_T* const itopk_indices, // [num_itopk] - const size_t num_itopk, - uint32_t* const terminate_flag) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const unsigned warp_id = threadIdx.x / 32; - if (warp_id > 0) { return; } - const unsigned lane_id = threadIdx.x % 32; - for (uint32_t i = lane_id; i < num_parents; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - uint32_t max_itopk = num_itopk; - if (max_itopk % 32) { max_itopk += 32 - (max_itopk % 32); } - uint32_t num_new_parents = 0; - for (uint32_t j = lane_id; j < max_itopk; j += 32) { - INDEX_T index; - int new_parent = 0; - if (j < num_itopk) { - index = itopk_indices[j]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } - } - const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; - if (i < num_parents) { - next_parent_indices[i] = index; - itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node - } - } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= num_parents) { break; } - } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } -} - -template -__device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] - INDEX_T* indices, // [num_elements] - const uint32_t num_elements, - const uint32_t num_itopk // num_itopk <= num_elements -) -{ - const unsigned warp_id = threadIdx.x / 32; - if (warp_id > 0) { return; } - const unsigned lane_id = threadIdx.x % 32; - constexpr unsigned N = (MAX_ELEMENTS + 31) / 32; - float key[N]; - INDEX_T val[N]; - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); - if (j < num_elements) { - key[i] = distances[j]; - val[i] = indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - /* Store itopk sorted results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - distances[j] = key[i]; - indices[j] = val[i]; - } - } -} - -// -// multiple CTAs per single query -// -template -__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( - INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] - DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] - const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const size_t dataset_dim, - const size_t dataset_size, - const size_t dataset_ld, - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const uint32_t graph_degree, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const uint32_t hash_bitlen, - const uint32_t itopk_size, - const uint32_t num_parents, - const uint32_t min_iteration, - const uint32_t max_iteration, - uint32_t* const num_executed_iterations /* stats */ -) -{ - assert(blockDim.x == BLOCK_SIZE); - assert(dataset_dim <= MAX_DATASET_DIM); - - const auto num_queries = gridDim.y; - const auto query_id = blockIdx.y; - const auto num_cta_per_query = gridDim.x; - const auto cta_id = blockIdx.x; // local CTA ID - -#ifdef _CLK_BREAKDOWN - uint64_t clk_init = 0; - uint64_t clk_compute_1st_distance = 0; - uint64_t clk_topk = 0; - uint64_t clk_pickup_parents = 0; - uint64_t clk_compute_distance = 0; - uint64_t clk_start; -#define _CLK_START() clk_start = clock64() -#define _CLK_REC(V) V += clock64() - clk_start; -#else -#define _CLK_START() -#define _CLK_REC(V) -#endif - _CLK_START(); - - extern __shared__ uint32_t smem[]; - - // Layout of result_buffer - // +----------------+------------------------------+---------+ - // | internal_top_k | neighbors of parent nodes | padding | - // | | | upto 32 | - // +----------------+------------------------------+---------+ - // |<--- result_buffer_size --->| - uint32_t result_buffer_size = itopk_size + (num_parents * graph_degree); - uint32_t result_buffer_size_32 = result_buffer_size; - if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } - assert(result_buffer_size_32 <= MAX_ELEMENTS); - - auto query_buffer = reinterpret_cast(smem); - auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); - auto result_distances_buffer = - reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto parent_indices_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto terminate_flag = reinterpret_cast(parent_indices_buffer + num_parents); - -#if 0 - /* debug */ - for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += BLOCK_SIZE) { - result_indices_buffer[i] = utils::get_max_value(); - result_distances_buffer[i] = utils::get_max_value(); - } -#endif - - const DATA_T* const query_ptr = queries_ptr + (dataset_dim * query_id); - for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { - unsigned j = device::swizzling(i); - if (i < dataset_dim) { - query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); - } else { - query_buffer[j] = 0.0; - } - } - if (threadIdx.x == 0) { terminate_flag[0] = 0; } - INDEX_T* const local_visited_hashmap_ptr = - visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); - __syncthreads(); - _CLK_REC(clk_init); - - // compute distance to randomly selecting nodes - _CLK_START(); - const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - uint32_t block_id = cta_id + (num_cta_per_query * query_id); - uint32_t num_blocks = num_cta_per_query * num_queries; - device::compute_distance_to_random_nodes( - result_indices_buffer, - result_distances_buffer, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, - result_buffer_size, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - hash_bitlen, - block_id, - num_blocks); - __syncthreads(); - _CLK_REC(clk_compute_1st_distance); - - uint32_t iter = 0; - while (1) { - // topk with bitonic sort - _CLK_START(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (num_parents * graph_degree), - itopk_size); - _CLK_REC(clk_topk); - - if (iter + 1 == max_iteration) { - __syncthreads(); - break; - } - - // pick up next parents - _CLK_START(); - pickup_next_parents( - parent_indices_buffer, num_parents, result_indices_buffer, itopk_size, terminate_flag); - _CLK_REC(clk_pickup_parents); - - __syncthreads(); - if (*terminate_flag && iter >= min_iteration) { break; } - - // compute the norms between child nodes and query node - _CLK_START(); - // constexpr unsigned max_n_frags = 16; - constexpr unsigned max_n_frags = 0; - device:: - compute_distance_to_child_nodes( - result_indices_buffer + itopk_size, - result_distances_buffer + itopk_size, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_indices_buffer, - num_parents); - _CLK_REC(clk_compute_distance); - __syncthreads(); - - iter++; - } - - for (uint32_t i = threadIdx.x; i < itopk_size; i += BLOCK_SIZE) { - uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } - - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - result_indices_ptr[j] = - result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit - } - - if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { - num_executed_iterations[query_id] = iter + 1; - } - -#ifdef _CLK_BREAKDOWN - if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && (blockIdx.x == 0) && - ((query_id * 3) % gridDim.y < 3)) { - RAFT_LOG_DEBUG( - "query, %d, thread, %d" - ", init, %d" - ", 1st_distance, %lu" - ", topk, %lu" - ", pickup_parents, %lu" - ", distance, %lu" - "\n", - query_id, - threadIdx.x, - clk_init, - clk_compute_1st_distance, - clk_topk, - clk_pickup_parents, - clk_compute_distance); - } -#endif -} - #define SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS) \ kernel = search_kernel // RAFT_EXPLICIT + +namespace raft::neighbors::experimental::cagra::detail { +namespace multi_cta_search { + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +template +__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( + INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] + const size_t dataset_dim, + const size_t dataset_size, + const size_t dataset_ld, + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const uint32_t hash_bitlen, + const uint32_t itopk_size, + const uint32_t num_parents, + const uint32_t min_iteration, + const uint32_t max_iteration, + uint32_t* const num_executed_iterations /* stats */ + ) RAFT_EXPLICIT; + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + extern template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +// search_multi_cta_float_uint32_dim1024_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint32_t, uint4); + +// search_multi_cta_float_uint32_dim128_t8.cu +instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, float, float, uint32_t, uint4); + +// search_multi_cta_float_uint32_dim256_t16.cu +instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, float, float, uint32_t, uint4); + +// search_multi_cta_float_uint32_dim512_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, float, float, uint32_t, uint4); + +// search_multi_cta_int8_uint32_dim1024_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, int8_t, float, uint32_t, uint4); + +// search_multi_cta_int8_uint32_dim128_t8.cu +instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, int8_t, float, uint32_t, uint4); + +// search_multi_cta_int8_uint32_dim256_t16.cu +instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, int8_t, float, uint32_t, uint4); + +// search_multi_cta_int8_uint32_dim512_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, int8_t, float, uint32_t, uint4); + +// search_multi_cta_uint8_uint32_dim1024_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, uint8_t, float, uint32_t, uint4); + +// search_multi_cta_uint8_uint32_dim128_t8.cu +instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, uint8_t, float, uint32_t, uint4); + +// search_multi_cta_uint8_uint32_dim256_t16.cu +instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, uint8_t, float, uint32_t, uint4); + +// search_multi_cta_uint8_uint32_dim512_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, uint8_t, float, uint32_t, uint4); + +// search_multi_cta_float_uint64_dim1024_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint64_t, uint4); + +// search_multi_cta_float_uint64_dim128_t8.cu +instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, float, float, uint64_t, uint4); + +// search_multi_cta_float_uint64_dim256_t16.cu +instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, float, float, uint64_t, uint4); + +// search_multi_cta_float_uint64_dim512_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, float, float, uint64_t, uint4); + +#undef instantiate_multi_cta_search_kernel +} // namespace multi_cta_search +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh new file mode 100644 index 0000000000..47ff9a3339 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "bitonic.hpp" +#include "compute_distance.hpp" +#include "device_common.hpp" +#include "hashmap.hpp" +#include "search_plan.cuh" +#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible +#include "utils.hpp" +#include +#include +#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp + +namespace raft::neighbors::experimental::cagra::detail { +namespace multi_cta_search { + +// #define _CLK_BREAKDOWN + +template +__device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num_parents] + const uint32_t num_parents, + INDEX_T* const itopk_indices, // [num_itopk] + const size_t num_itopk, + uint32_t* const terminate_flag) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const unsigned warp_id = threadIdx.x / 32; + if (warp_id > 0) { return; } + const unsigned lane_id = threadIdx.x % 32; + for (uint32_t i = lane_id; i < num_parents; i += 32) { + next_parent_indices[i] = utils::get_max_value(); + } + uint32_t max_itopk = num_itopk; + if (max_itopk % 32) { max_itopk += 32 - (max_itopk % 32); } + uint32_t num_new_parents = 0; + for (uint32_t j = lane_id; j < max_itopk; j += 32) { + INDEX_T index; + int new_parent = 0; + if (j < num_itopk) { + index = itopk_indices[j]; + if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set + new_parent = 1; + } + } + const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); + if (new_parent) { + const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; + if (i < num_parents) { + next_parent_indices[i] = index; + itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node + } + } + num_new_parents += __popc(ballot_mask); + if (num_new_parents >= num_parents) { break; } + } + if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } +} + +template +__device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] + INDEX_T* indices, // [num_elements] + const uint32_t num_elements, + const uint32_t num_itopk // num_itopk <= num_elements +) +{ + const unsigned warp_id = threadIdx.x / 32; + if (warp_id > 0) { return; } + const unsigned lane_id = threadIdx.x % 32; + constexpr unsigned N = (MAX_ELEMENTS + 31) / 32; + float key[N]; + INDEX_T val[N]; + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (32 * i); + if (j < num_elements) { + key[i] = distances[j]; + val[i] = indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store itopk sorted results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + distances[j] = key[i]; + indices[j] = val[i]; + } + } +} + +// +// multiple CTAs per single query +// +template +__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( + INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] + const size_t dataset_dim, + const size_t dataset_size, + const size_t dataset_ld, + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const uint32_t hash_bitlen, + const uint32_t itopk_size, + const uint32_t num_parents, + const uint32_t min_iteration, + const uint32_t max_iteration, + uint32_t* const num_executed_iterations /* stats */ +) +{ + assert(blockDim.x == BLOCK_SIZE); + assert(dataset_dim <= MAX_DATASET_DIM); + + const auto num_queries = gridDim.y; + const auto query_id = blockIdx.y; + const auto num_cta_per_query = gridDim.x; + const auto cta_id = blockIdx.x; // local CTA ID + +#ifdef _CLK_BREAKDOWN + uint64_t clk_init = 0; + uint64_t clk_compute_1st_distance = 0; + uint64_t clk_topk = 0; + uint64_t clk_pickup_parents = 0; + uint64_t clk_compute_distance = 0; + uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ uint32_t smem[]; + + // Layout of result_buffer + // +----------------+------------------------------+---------+ + // | internal_top_k | neighbors of parent nodes | padding | + // | | | upto 32 | + // +----------------+------------------------------+---------+ + // |<--- result_buffer_size --->| + uint32_t result_buffer_size = itopk_size + (num_parents * graph_degree); + uint32_t result_buffer_size_32 = result_buffer_size; + if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } + assert(result_buffer_size_32 <= MAX_ELEMENTS); + + auto query_buffer = reinterpret_cast(smem); + auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); + auto result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto parent_indices_buffer = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto terminate_flag = reinterpret_cast(parent_indices_buffer + num_parents); + +#if 0 + /* debug */ + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += BLOCK_SIZE) { + result_indices_buffer[i] = utils::get_max_value(); + result_distances_buffer[i] = utils::get_max_value(); + } +#endif + const DATA_T* const query_ptr = queries_ptr + (dataset_dim * query_id); + for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { + unsigned j = device::swizzling(i); + if (i < dataset_dim) { + query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); + } else { + query_buffer[j] = 0.0; + } + } + if (threadIdx.x == 0) { terminate_flag[0] = 0; } + INDEX_T* const local_visited_hashmap_ptr = + visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes + _CLK_START(); + const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + uint32_t block_id = cta_id + (num_cta_per_query * query_id); + uint32_t num_blocks = num_cta_per_query * num_queries; + device::compute_distance_to_random_nodes( + result_indices_buffer, + result_distances_buffer, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_size, + dataset_ld, + result_buffer_size, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + hash_bitlen, + block_id, + num_blocks); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + uint32_t iter = 0; + while (1) { + // topk with bitonic sort + _CLK_START(); + topk_by_bitonic_sort(result_distances_buffer, + result_indices_buffer, + itopk_size + (num_parents * graph_degree), + itopk_size); + _CLK_REC(clk_topk); + + if (iter + 1 == max_iteration) { + __syncthreads(); + break; + } + + // pick up next parents + _CLK_START(); + pickup_next_parents( + parent_indices_buffer, num_parents, result_indices_buffer, itopk_size, terminate_flag); + _CLK_REC(clk_pickup_parents); + + __syncthreads(); + if (*terminate_flag && iter >= min_iteration) { break; } + + // compute the norms between child nodes and query node + _CLK_START(); + // constexpr unsigned max_n_frags = 16; + constexpr unsigned max_n_frags = 0; + device:: + compute_distance_to_child_nodes( + result_indices_buffer + itopk_size, + result_distances_buffer + itopk_size, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_ld, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + parent_indices_buffer, + num_parents); + _CLK_REC(clk_compute_distance); + __syncthreads(); + + iter++; + } + + for (uint32_t i = threadIdx.x; i < itopk_size; i += BLOCK_SIZE) { + uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + result_indices_ptr[j] = + result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit + } + + if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } + +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && (blockIdx.x == 0) && + ((query_id * 3) % gridDim.y < 3)) { + RAFT_LOG_DEBUG( + "query, %d, thread, %d" + ", init, %d" + ", 1st_distance, %lu" + ", topk, %lu" + ", pickup_parents, %lu" + ", distance, %lu" + "\n", + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_pickup_parents, + clk_compute_distance); + } +#endif +} +} // namespace multi_cta_search +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh new file mode 100644 index 0000000000..e003907292 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "search_multi_cta_kernel-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "search_multi_cta_kernel-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh index 82dff6d78d..122daa5ab2 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh @@ -976,6 +976,278 @@ instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, uint8_t, float instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, uint8_t, float, uint32_t, uint4); instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, uint8_t, float, uint32_t, uint4); +// search_single_cta_float_uint64_dim1024_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 1024, float, float, uint64_t, uint4); + +// search_single_cta_float_uint64_dim128_t8.cu +instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, float, float, uint64_t, uint4); + +// search_single_cta_float_uint64_dim256_t16.cu +instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, float, float, uint64_t, uint4); + +// search_single_cta_float_uint64_dim512_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, float, float, uint64_t, uint4); + #undef instantiate_single_cta_search_kernel } // namespace single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py new file mode 100644 index 0000000000..8a4ad4bad2 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py @@ -0,0 +1,113 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \\ + BLOCK_SIZE, \\ + BLOCK_COUNT, \\ + MAX_ELEMENTS, \\ + MAX_DATASET_DIM, \\ + DATA_T, \\ + DISTANCE_T, \\ + INDEX_T, \\ + LOAD_T) \\ + template __global__ void raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel(INDEX_T* const result_indices_ptr, \\ + DISTANCE_T* const result_distances_ptr, \\ + const DATA_T* const dataset_ptr, \\ + const size_t dataset_dim, \\ + const size_t dataset_size, \\ + const size_t dataset_ld, \\ + const DATA_T* const queries_ptr, \\ + const INDEX_T* const knn_graph, \\ + const uint32_t graph_degree, \\ + const unsigned num_distilation, \\ + const uint64_t rand_xor_mask, \\ + const INDEX_T* seed_ptr, \\ + const uint32_t num_seeds, \\ + INDEX_T* const visited_hashmap_ptr, \\ + const uint32_t hash_bitlen, \\ + const uint32_t itopk_size, \\ + const uint32_t num_parents, \\ + const uint32_t min_iteration, \\ + const uint32_t max_iteration, \\ + uint32_t* const num_executed_iterations); + +""" + +trailer = """ +#undef instantiate_multi_cta_search_kernel + +""" + +mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] +block = [(64, 16), (128, 8), (256, 4), (512, 2), (1024, 1)] +mxelem = [64, 128, 256] +load_types = ["uint4"] +search_types = dict( + float_uint32=("float", "uint32_t", "float"), # data_t, idx_t, distance_t + int8_uint32=("int8_t", "uint32_t", "float"), + uint8_uint32=("uint8_t", "uint32_t", "float"), + float_uint64=("float", "uint64_t", "float"), +) + +# knn +for type_path, (data_t, idx_t, distance_t) in search_types.items(): + for (mxdim, team) in mxdim_team: + path = f"search_multi_cta_{type_path}_dim{mxdim}_t{team}.cu" + with open(path, "w") as f: + f.write(header) + for load_t in load_types: + for b in block: + for elem in mxelem: + f.write( + f"instantiate_multi_cta_search_kernel({team}, {b[0]}, {b[1]}, {elem}, {mxdim},{data_t}, {distance_t}, {idx_t}, {load_t});\n" + ) + f.write(trailer) + # For pasting into CMakeLists.txt + print(f"src/neighbors/detail/cagra/{path}") diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu new file mode 100644 index 0000000000..7c3ef61886 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu new file mode 100644 index 0000000000..294f56f7b1 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, float, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu new file mode 100644 index 0000000000..8c38a55b25 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, float, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu new file mode 100644 index 0000000000..9a77350417 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, float, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, float, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu new file mode 100644 index 0000000000..bb1a67a735 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint64_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu new file mode 100644 index 0000000000..9fe6c24061 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, float, float, uint64_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu new file mode 100644 index 0000000000..9001701b7c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, float, float, uint64_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu new file mode 100644 index 0000000000..fcaa85d06c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, float, float, uint64_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu new file mode 100644 index 0000000000..f8a1ca47b3 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, int8_t, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu new file mode 100644 index 0000000000..3ce811f956 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, int8_t, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu new file mode 100644 index 0000000000..c8082f77ea --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, int8_t, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu new file mode 100644 index 0000000000..db66992318 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, int8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, int8_t, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu new file mode 100644 index 0000000000..3c805800c4 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, uint8_t, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu new file mode 100644 index 0000000000..0be714140d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, uint8_t, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu new file mode 100644 index 0000000000..ef6fe045da --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, uint8_t, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu new file mode 100644 index 0000000000..6d8203d910 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu @@ -0,0 +1,85 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python search_multi_cta_00_generate.py + * + */ + +#include + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + template __global__ void \ + raft::neighbors::experimental::cagra::detail::multi_cta_search::search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, uint8_t, float, uint32_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, uint8_t, float, uint32_t, uint4); + +#undef instantiate_multi_cta_search_kernel diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 04104c09db..77f571f705 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -319,10 +319,14 @@ if(BUILD_TESTS) test/neighbors/ann_cagra/test_int8_t_uint32_t.cu test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu test/neighbors/ann_cagra/test_float_int64_t.cu + src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu + src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu + src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu + src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu - # src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu - # src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu - # src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu + src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu + src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu + src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu diff --git a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu index d7405c166c..a0b3dd3f07 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu @@ -16,9 +16,6 @@ #include -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#undef RAFT_COMPILED - #include "../ann_cagra.cuh" namespace raft::neighbors::experimental::cagra { From 0a5da3cf3616387cbaa2fb59a604002f1564a900 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 4 Jul 2023 23:17:13 +0200 Subject: [PATCH 6/6] Move uint64 extern template declarations to separate header --- .../cagra/search_multi_cta_kernel-ext.cuh | 68 --- .../cagra/search_single_cta_kernel-ext.cuh | 272 ----------- .../ann_cagra/search_kernel_uint64_t.cuh | 454 ++++++++++++++++++ .../neighbors/ann_cagra/test_float_int64_t.cu | 1 + 4 files changed, 455 insertions(+), 340 deletions(-) create mode 100644 cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh index c6c9185b6d..d7ad4402c8 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh @@ -298,74 +298,6 @@ instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, uint8_t, float, uint32 instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, uint8_t, float, uint32_t, uint4); instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, uint8_t, float, uint32_t, uint4); -// search_multi_cta_float_uint64_dim1024_t32.cu -instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint64_t, uint4); - -// search_multi_cta_float_uint64_dim128_t8.cu -instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, float, float, uint64_t, uint4); - -// search_multi_cta_float_uint64_dim256_t16.cu -instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, float, float, uint64_t, uint4); - -// search_multi_cta_float_uint64_dim512_t32.cu -instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, float, float, uint64_t, uint4); -instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, float, float, uint64_t, uint4); - #undef instantiate_multi_cta_search_kernel } // namespace multi_cta_search } // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh index 122daa5ab2..82dff6d78d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh @@ -976,278 +976,6 @@ instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, uint8_t, float instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, uint8_t, float, uint32_t, uint4); instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, uint8_t, float, uint32_t, uint4); -// search_single_cta_float_uint64_dim1024_t32.cu -instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 1024, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 1024, float, float, uint64_t, uint4); - -// search_single_cta_float_uint64_dim128_t8.cu -instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, float, float, uint64_t, uint4); - -// search_single_cta_float_uint64_dim256_t16.cu -instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 128, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 256, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 512, 128, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 128, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 256, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 512, 256, 1, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, float, float, uint64_t, uint4); - -// search_single_cta_float_uint64_dim512_t32.cu -instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, float, float, uint64_t, uint4); -instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, float, float, uint64_t, uint4); - #undef instantiate_single_cta_search_kernel } // namespace single_cta_search diff --git a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh new file mode 100644 index 0000000000..f8bf6f2312 --- /dev/null +++ b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh @@ -0,0 +1,454 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // RAFT_EXPLICIT + +namespace raft::neighbors::experimental::cagra::detail { +namespace multi_cta_search { + +#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ELEMENTS, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + extern template __global__ void search_kernel(INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const DATA_T* const dataset_ptr, \ + const size_t dataset_dim, \ + const size_t dataset_size, \ + const size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const uint32_t hash_bitlen, \ + const uint32_t itopk_size, \ + const uint32_t num_parents, \ + const uint32_t min_iteration, \ + const uint32_t max_iteration, \ + uint32_t* const num_executed_iterations); + +// search_multi_cta_float_uint64_dim1024_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint64_t, uint4); + +// search_multi_cta_float_uint64_dim128_t8.cu +instantiate_multi_cta_search_kernel(8, 64, 16, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 64, 16, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 128, 8, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 256, 4, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 512, 2, 256, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 64, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 128, 128, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(8, 1024, 1, 256, 128, float, float, uint64_t, uint4); + +// search_multi_cta_float_uint64_dim256_t16.cu +instantiate_multi_cta_search_kernel(16, 64, 16, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 64, 16, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 128, 8, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 256, 4, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 512, 2, 256, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 64, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 128, 256, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(16, 1024, 1, 256, 256, float, float, uint64_t, uint4); + +// search_multi_cta_float_uint64_dim512_t32.cu +instantiate_multi_cta_search_kernel(32, 64, 16, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 64, 16, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 128, 8, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 256, 4, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 512, 2, 256, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 512, float, float, uint64_t, uint4); +instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 512, float, float, uint64_t, uint4); + +#undef instantiate_multi_cta_search_kernel +} // namespace multi_cta_search + +namespace single_cta_search { + +#define instantiate_single_cta_search_kernel(TEAM_SIZE, \ + BLOCK_SIZE, \ + BLOCK_COUNT, \ + MAX_ITOPK, \ + MAX_CANDIDATES, \ + TOPK_BY_BITONIC_SORT, \ + MAX_DATASET_DIM, \ + DATA_T, \ + DISTANCE_T, \ + INDEX_T, \ + LOAD_T) \ + extern template __global__ void search_kernel( \ + INDEX_T* const result_indices_ptr, \ + DISTANCE_T* const result_distances_ptr, \ + const std::uint32_t top_k, \ + const DATA_T* const dataset_ptr, \ + const std::size_t dataset_dim, \ + const std::size_t dataset_size, \ + const std::size_t dataset_ld, \ + const DATA_T* const queries_ptr, \ + const INDEX_T* const knn_graph, \ + const std::uint32_t graph_degree, \ + const unsigned num_distilation, \ + const uint64_t rand_xor_mask, \ + const INDEX_T* seed_ptr, \ + const uint32_t num_seeds, \ + INDEX_T* const visited_hashmap_ptr, \ + const std::uint32_t internal_topk, \ + const std::uint32_t num_parents, \ + const std::uint32_t min_iteration, \ + const std::uint32_t max_iteration, \ + std::uint32_t* const num_executed_iterations, \ + const std::uint32_t hash_bitlen, \ + const std::uint32_t small_hash_bitlen, \ + const std::uint32_t small_hash_reset_interval); + +// search_single_cta_float_uint64_dim1024_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 1024, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 1024, float, float, uint64_t, uint4); + +// search_single_cta_float_uint64_dim128_t8.cu +instantiate_single_cta_search_kernel(8, 64, 16, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 64, 16, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 128, 8, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 64, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 128, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 64, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 128, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 256, 1, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 256, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 256, 4, 512, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 256, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 512, 2, 512, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 256, 32, 0, 128, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(8, 1024, 1, 512, 32, 0, 128, float, float, uint64_t, uint4); + +// search_single_cta_float_uint64_dim256_t16.cu +instantiate_single_cta_search_kernel(16, 64, 16, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 64, 16, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 128, 8, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 64, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 128, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 64, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 128, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 256, 1, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 256, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 256, 4, 512, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 256, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 512, 2, 512, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 256, 32, 0, 256, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(16, 1024, 1, 512, 32, 0, 256, float, float, uint64_t, uint4); + +// search_single_cta_float_uint64_dim512_t32.cu +instantiate_single_cta_search_kernel(32, 64, 16, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 64, 16, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 128, 8, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 64, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 128, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 64, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 128, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 256, 1, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 256, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 256, 4, 512, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 256, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 512, 2, 512, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 256, 32, 0, 512, float, float, uint64_t, uint4); +instantiate_single_cta_search_kernel(32, 1024, 1, 512, 32, 0, 512, float, float, uint64_t, uint4); +#undef instantiate_single_cta_search_kernel + +} // namespace single_cta_search +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu index a0b3dd3f07..fa3d76d066 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu @@ -17,6 +17,7 @@ #include #include "../ann_cagra.cuh" +#include "search_kernel_uint64_t.cuh" namespace raft::neighbors::experimental::cagra {