diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 19f65baf1a..9905f2abae 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -81,7 +81,17 @@ void build_knn_graph(raft::resources const& res, std::optional build_params = std::nullopt, std::optional search_params = std::nullopt) { - detail::build_knn_graph(res, dataset, knn_graph, refine_rate, build_params, search_params); + using internal_IdxT = typename std::make_unsigned::type; + + auto knn_graph_internal = make_host_matrix_view( + reinterpret_cast(knn_graph.data_handle()), + knn_graph.extent(0), + knn_graph.extent(1)); + auto dataset_internal = mdspan, row_major, accessor>( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + + detail::build_knn_graph( + res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); } /** @@ -124,7 +134,20 @@ void sort_knn_graph(raft::resources const& res, mdspan, row_major, d_accessor> dataset, mdspan, row_major, g_accessor> knn_graph) { - detail::graph::sort_knn_graph(res, dataset, knn_graph); + using internal_IdxT = typename std::make_unsigned::type; + + using g_accessor_internal = + host_device_accessor, g_accessor::mem_type>; + auto knn_graph_internal = + mdspan, row_major, g_accessor_internal>( + reinterpret_cast(knn_graph.data_handle()), + knn_graph.extent(0), + knn_graph.extent(1)); + + auto dataset_internal = mdspan, row_major, d_accessor>( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + + detail::graph::sort_knn_graph(res, dataset_internal, knn_graph_internal); } /** @@ -148,7 +171,22 @@ void prune(raft::resources const& res, mdspan, row_major, g_accessor> knn_graph, raft::host_matrix_view new_graph) { - detail::graph::prune(res, knn_graph, new_graph); + using internal_IdxT = typename std::make_unsigned::type; + + auto new_graph_internal = raft::make_host_matrix_view( + reinterpret_cast(new_graph.data_handle()), + new_graph.extent(0), + new_graph.extent(1)); + + using g_accessor_internal = + host_device_accessor, memory_type::host>; + auto knn_graph_internal = + mdspan, row_major, g_accessor_internal>( + reinterpret_cast(knn_graph.data_handle()), + knn_graph.extent(0), + knn_graph.extent(1)); + + detail::graph::prune(res, knn_graph_internal, new_graph_internal); } /** @@ -200,7 +238,7 @@ index build(raft::resources const& res, mdspan, row_major, Accessor> dataset) { size_t degree = params.intermediate_graph_degree; - if (degree >= dataset.extent(0)) { + if (degree >= static_cast(dataset.extent(0))) { RAFT_LOG_WARN( "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", dataset.extent(0)); @@ -256,7 +294,18 @@ void search(raft::resources const& res, RAFT_EXPECTS(queries.extent(1) == idx.dim(), "Number of query dimensions should equal number of dimensions in the index."); - detail::search_main(res, params, idx, queries, neighbors, distances); + using internal_IdxT = typename std::make_unsigned::type; + auto queries_internal = raft::make_device_matrix_view( + queries.data_handle(), queries.extent(0), queries.extent(1)); + auto neighbors_internal = raft::make_device_matrix_view( + reinterpret_cast(neighbors.data_handle()), + neighbors.extent(0), + neighbors.extent(1)); + auto distances_internal = raft::make_device_matrix_view( + distances.data_handle(), distances.extent(0), distances.extent(1)); + + detail::search_main( + res, params, idx, queries_internal, neighbors_internal, distances_internal); } /** @} */ // end group cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index f0eeb2b36c..d88aaa245a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -38,8 +38,6 @@ namespace raft::neighbors::experimental::cagra::detail { -using INDEX_T = std::uint32_t; - template void build_knn_graph(raft::resources const& res, mdspan, row_major, accessor> dataset, @@ -96,14 +94,14 @@ void build_knn_graph(raft::resources const& res, // search top (k + 1) neighbors // if (!search_params) { - search_params = ivf_pq::search_params{}; - search_params->n_probes = std::min(dataset.extent(1) * 2, build_params->n_lists); - search_params->lut_dtype = CUDA_R_8U; + search_params = ivf_pq::search_params{}; + search_params->n_probes = std::min(dataset.extent(1) * 2, build_params->n_lists); + search_params->lut_dtype = CUDA_R_8U; search_params->internal_distance_dtype = CUDA_R_32F; } const auto top_k = node_degree + 1; uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f); - gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); + gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); const auto num_queries = dataset.extent(0); const auto max_batch_size = 1024; RAFT_LOG_DEBUG( diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 0073f66d0b..d3b24dc861 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -52,13 +52,13 @@ namespace raft::neighbors::experimental::cagra::detail { * k] */ -template +template void search_main(raft::resources const& res, search_params params, const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) { RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", static_cast(index.dataset().extent(0)), @@ -69,8 +69,9 @@ void search_main(raft::resources const& res, RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match"); uint32_t topk = neighbors.extent(1); - std::unique_ptr> plan = - factory::create(res, params, index.dim(), index.graph_degree(), topk); + std::unique_ptr> plan = + factory::create( + res, params, index.dim(), index.graph_degree(), topk); plan->check(neighbors.extent(1)); @@ -79,18 +80,29 @@ void search_main(raft::resources const& res, uint32_t query_dim = queries.extent(1); for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) { - const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); - IdxT* _topk_indices_ptr = neighbors.data_handle() + (topk * qid); + const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); + internal_IdxT* _topk_indices_ptr = + reinterpret_cast(neighbors.data_handle()) + (topk * qid); DistanceT* _topk_distances_ptr = distances.data_handle() + (topk * qid); // todo(tfeher): one could keep distances optional and pass nullptr const T* _query_ptr = queries.data_handle() + (query_dim * qid); - const IdxT* _seed_ptr = - plan->num_seeds > 0 ? plan->dev_seed.data() + (plan->num_seeds * qid) : nullptr; + const internal_IdxT* _seed_ptr = + plan->num_seeds > 0 + ? reinterpret_cast(plan->dev_seed.data()) + (plan->num_seeds * qid) + : nullptr; uint32_t* _num_executed_iterations = nullptr; + auto dataset_internal = raft::make_device_matrix_view( + index.dataset().data_handle(), index.dataset().extent(0), index.dataset().extent(1)); + auto graph_internal = + raft::make_device_matrix_view( + reinterpret_cast(index.graph().data_handle()), + index.graph().extent(0), + index.graph().extent(1)); + (*plan)(res, - index.dataset(), - index.graph(), + dataset_internal, + graph_internal, _topk_indices_ptr, _topk_distances_ptr, _query_ptr, diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index 52e5c62169..fd66735cf6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -59,9 +59,9 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( const std::size_t num_pickup, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_seeds] + const INDEX_T* const seed_ptr, // [num_seeds] const uint32_t num_seeds, - uint32_t* const visited_hash_ptr, + INDEX_T* const visited_hash_ptr, const uint32_t hash_bitlen, const uint32_t block_id = 0, const uint32_t num_blocks = 1) @@ -79,7 +79,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( DISTANCE_T best_norm2_team_local = utils::get_max_value(); for (uint32_t j = 0; j < num_distilation; j++) { // Select a node randomly and compute the distance to it - uint32_t seed_index; + INDEX_T seed_index; DISTANCE_T norm2 = 0.0; if (valid_i) { // uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id))); @@ -150,7 +150,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in const INDEX_T* const knn_graph, const std::uint32_t knn_k, // hashmap - std::uint32_t* const visited_hashmap_ptr, + INDEX_T* const visited_hashmap_ptr, const std::uint32_t hash_bitlen, const INDEX_T* const parent_indices, const std::uint32_t num_parents) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index aa3f7dd29f..feb9b76b2d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -33,6 +33,8 @@ #include +#include "utils.hpp" + namespace raft::neighbors::experimental::cagra::detail { namespace graph { @@ -115,7 +117,7 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, my_vals[i] = smem_vals[k]; } else { my_keys[i] = FLT_MAX; - my_vals[i] = ~static_cast(0); + my_vals[i] = utils::get_max_value(); } } __syncthreads(); @@ -607,7 +609,7 @@ void prune(raft::resources const& res, memcpy(output_graph_ptr, pruned_graph.data_handle(), - sizeof(uint32_t) * graph_size * output_graph_degree); + sizeof(IdxT) * graph_size * output_graph_degree); constexpr int _omp_chunk = 1024; #pragma omp parallel for schedule(dynamic, _omp_chunk) @@ -616,15 +618,15 @@ void prune(raft::resources const& res, uint64_t k = rev_graph_count.data_handle()[j] - 1 - _k; uint64_t i = rev_graph.data_handle()[k + (output_graph_degree * j)]; - uint64_t pos = pos_in_array( - i, output_graph_ptr + (output_graph_degree * j), output_graph_degree); + uint64_t pos = + pos_in_array(i, output_graph_ptr + (output_graph_degree * j), output_graph_degree); if (pos < num_protected_edges) { continue; } uint64_t num_shift = pos - num_protected_edges; if (pos == output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } - shift_array(output_graph_ptr + num_protected_edges + (output_graph_degree * j), - num_shift); + shift_array(output_graph_ptr + num_protected_edges + (output_graph_degree * j), + num_shift); output_graph_ptr[num_protected_edges + (output_graph_degree * j)] = i; } if ((omp_get_thread_num() == 0) && ((j % _omp_chunk) == 0)) { @@ -641,9 +643,9 @@ void prune(raft::resources const& res, #pragma omp parallel for reduction(+ : num_replaced_edges) for (uint64_t i = 0; i < graph_size; i++) { for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = pruned_graph.data_handle()[k + (output_graph_degree * i)]; - const uint64_t pos = pos_in_array( - j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); + const uint64_t j = pruned_graph.data_handle()[k + (output_graph_degree * i)]; + const uint64_t pos = + pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); if (pos == output_graph_degree) { num_replaced_edges += 1; } } } diff --git a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp index 18f4006367..cd2c8ec491 100644 --- a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp @@ -27,32 +27,33 @@ namespace hashmap { _RAFT_HOST_DEVICE inline uint32_t get_size(const uint32_t bitlen) { return 1U << bitlen; } -template -_RAFT_DEVICE inline void init(uint32_t* table, const uint32_t bitlen) +template +_RAFT_DEVICE inline void init(IdxT* const table, const unsigned bitlen) { if (threadIdx.x < FIRST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) { - table[i] = utils::get_max_value(); + table[i] = utils::get_max_value(); } } -template -_RAFT_DEVICE inline void init(uint32_t* table, const uint32_t bitlen) +template +_RAFT_DEVICE inline void init(IdxT* const table, const uint32_t bitlen) { if ((FIRST_TID > 0 && threadIdx.x < FIRST_TID) || threadIdx.x >= LAST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += LAST_TID - FIRST_TID) { - table[i] = utils::get_max_value(); + table[i] = utils::get_max_value(); } } -_RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, const uint32_t key) +template +_RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) { // Open addressing is used for collision resolution const uint32_t size = get_size(bitlen); const uint32_t bit_mask = size - 1; #if 1 // Linear probing - uint32_t index = (key ^ (key >> bitlen)) & bit_mask; + IdxT index = (key ^ (key >> bitlen)) & bit_mask; constexpr uint32_t stride = 1; #else // Double hashing @@ -60,8 +61,8 @@ _RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, cons const uint32_t stride = (key >> bitlen) * 2 + 1; #endif for (unsigned i = 0; i < size; i++) { - const uint32_t old = atomicCAS(&table[index], ~0u, key); - if (old == ~0u) { + const IdxT old = atomicCAS(&table[index], ~static_cast(0), key); + if (old == ~static_cast(0)) { return 1; } else if (old == key) { return 0; @@ -71,10 +72,10 @@ _RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, cons return 0; } -template -_RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, const uint32_t key) +template +_RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) { - uint32_t ret = 0; + IdxT ret = 0; if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); } for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) { ret |= __shfl_xor_sync(0xffffffff, ret, offset); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 4cccc36a23..f9a0fef2fe 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -52,7 +52,8 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num const size_t num_itopk, uint32_t* const terminate_flag) { - const unsigned warp_id = threadIdx.x / 32; + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } const unsigned lane_id = threadIdx.x % 32; for (uint32_t i = lane_id; i < num_parents; i += 32) { @@ -66,7 +67,7 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num int new_parent = 0; if (j < num_itopk) { index = itopk_indices[j]; - if ((index & 0x80000000) == 0) { // check if most significant bit is set + if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set new_parent = 1; } } @@ -75,7 +76,7 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; if (i < num_parents) { next_parent_indices[i] = index; - itopk_indices[j] |= 0x80000000; // set most significant bit as used node + itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node } } num_new_parents += __popc(ballot_mask); @@ -84,9 +85,9 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } } -template +template __device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] - uint32_t* indices, // [num_elements] + INDEX_T* indices, // [num_elements] const uint32_t num_elements, const uint32_t num_itopk // num_itopk <= num_elements ) @@ -96,7 +97,7 @@ __device__ inline void topk_by_bitonic_sort(float* distances, // [num_el const unsigned lane_id = threadIdx.x % 32; constexpr unsigned N = (MAX_ELEMENTS + 31) / 32; float key[N]; - uint32_t val[N]; + INDEX_T val[N]; for (unsigned i = 0; i < N; i++) { unsigned j = lane_id + (32 * i); if (j < num_elements) { @@ -104,11 +105,11 @@ __device__ inline void topk_by_bitonic_sort(float* distances, // [num_el val[i] = indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Warp Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); /* Store itopk sorted results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; @@ -142,9 +143,9 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( const uint32_t graph_degree, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, - uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const uint32_t hash_bitlen, const uint32_t itopk_size, const uint32_t num_parents, @@ -194,7 +195,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( auto result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto parent_indices_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); + reinterpret_cast(result_distances_buffer + result_buffer_size_32); auto terminate_flag = reinterpret_cast(parent_indices_buffer + num_parents); #if 0 @@ -215,7 +216,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( } } if (threadIdx.x == 0) { terminate_flag[0] = 0; } - uint32_t* local_visited_hashmap_ptr = + INDEX_T* const local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); __syncthreads(); _CLK_REC(clk_init); @@ -246,10 +247,10 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( while (1) { // topk with bitonic sort _CLK_START(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (num_parents * graph_degree), - itopk_size); + topk_by_bitonic_sort(result_distances_buffer, + result_indices_buffer, + itopk_size + (num_parents * graph_degree), + itopk_size); _CLK_REC(clk_topk); if (iter + 1 == max_iteration) { @@ -292,7 +293,11 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( for (uint32_t i = threadIdx.x; i < itopk_size; i += BLOCK_SIZE) { uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } - result_indices_ptr[j] = result_indices_buffer[i] & ~0x80000000; // clear most significant bit + + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + result_indices_ptr[j] = + result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit } if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { @@ -368,7 +373,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const uint32_t hash_bitlen, \ const uint32_t itopk_size, \ const uint32_t num_parents, \ @@ -456,7 +461,7 @@ struct search : public search_plan_impl { using search_plan_impl::num_seeds; uint32_t num_cta_per_query; - rmm::device_uvector intermediate_indices; + rmm::device_uvector intermediate_indices; rmm::device_uvector intermediate_distances; size_t topk_workspace_size; rmm::device_uvector topk_workspace; @@ -583,7 +588,7 @@ struct search : public search_plan_impl { // Initialize hash table const uint32_t hash_size = hashmap::get_size(hash_bitlen); set_value_batch( - hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); + hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); dim3 block_dims(block_size, 1, 1); dim3 grid_dims(num_cta_per_query, num_queries, 1); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index 439ebd563b..8fbd5d8f03 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -97,12 +97,12 @@ __global__ void random_pickup_kernel( const std::size_t num_pickup, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, - INDEX_T* const result_indices_ptr, // [num_queries, ldr] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::uint32_t ldr, // (*) ldr >= num_pickup - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + INDEX_T* const result_indices_ptr, // [num_queries, ldr] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] + const std::uint32_t ldr, // (*) ldr >= num_pickup + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] const std::uint32_t hash_bitlen) { const auto ldb = hashmap::get_size(hash_bitlen); @@ -168,12 +168,12 @@ void random_pickup(const DATA_T* const dataset_ptr, // [dataset_size, dataset_d const std::size_t num_pickup, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, - INDEX_T* const result_indices_ptr, // [num_queries, ldr] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::size_t ldr, // (*) ldr >= num_pickup - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + INDEX_T* const result_indices_ptr, // [num_queries, ldr] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] + const std::size_t ldr, // (*) ldr >= num_pickup + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] const std::uint32_t hash_bitlen, cudaStream_t const cuda_stream = 0) { @@ -204,7 +204,7 @@ __global__ void pickup_next_parents_kernel( INDEX_T* const parent_candidates_ptr, // [num_queries, lds] const std::size_t lds, // (*) lds >= parent_candidates_size const std::uint32_t parent_candidates_size, // - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::size_t hash_bitlen, const std::uint32_t small_hash_bitlen, INDEX_T* const parent_list_ptr, // [num_queries, ldd] @@ -212,6 +212,8 @@ __global__ void pickup_next_parents_kernel( const std::size_t parent_list_size, // std::uint32_t* const terminate_flag) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const std::size_t ldb = hashmap::get_size(hash_bitlen); const uint32_t query_id = blockIdx.x; if (threadIdx.x < 32) { @@ -229,7 +231,7 @@ __global__ void pickup_next_parents_kernel( int new_parent = 0; if (j < parent_candidates_size) { index = parent_candidates_ptr[j + (lds * query_id)]; - if ((index & 0x80000000) == 0) { // check most significant bit + if ((index & index_msb_1_mask) == 0) { // check most significant bit new_parent = 1; } } @@ -239,7 +241,7 @@ __global__ void pickup_next_parents_kernel( if (i < parent_list_size) { parent_list_ptr[i + (ldd * query_id)] = index; parent_candidates_ptr[j + (lds * query_id)] |= - 0x80000000; // set most significant bit as used node + index_msb_1_mask; // set most significant bit as used node } } num_new_parents += __popc(ballot_mask); @@ -255,27 +257,26 @@ __global__ void pickup_next_parents_kernel( __syncthreads(); // insert internal-topk indices into small-hash for (unsigned i = threadIdx.x; i < parent_candidates_size; i += blockDim.x) { - auto key = - parent_candidates_ptr[i + (lds * query_id)] & ~0x80000000; // clear most significant bit + auto key = parent_candidates_ptr[i + (lds * query_id)] & + ~index_msb_1_mask; // clear most significant bit hashmap::insert(visited_hashmap_ptr + (ldb * query_id), hash_bitlen, key); } } } template -void pickup_next_parents( - INDEX_T* const parent_candidates_ptr, // [num_queries, lds] - const std::size_t lds, // (*) lds >= parent_candidates_size - const std::size_t parent_candidates_size, // - const std::size_t num_queries, - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::size_t hash_bitlen, - const std::size_t small_hash_bitlen, - INDEX_T* const parent_list_ptr, // [num_queries, ldd] - const std::size_t ldd, // (*) ldd >= parent_list_size - const std::size_t parent_list_size, // - std::uint32_t* const terminate_flag, - cudaStream_t cuda_stream = 0) +void pickup_next_parents(INDEX_T* const parent_candidates_ptr, // [num_queries, lds] + const std::size_t lds, // (*) lds >= parent_candidates_size + const std::size_t parent_candidates_size, // + const std::size_t num_queries, + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::size_t hash_bitlen, + const std::size_t small_hash_bitlen, + INDEX_T* const parent_list_ptr, // [num_queries, ldd] + const std::size_t ldd, // (*) ldd >= parent_list_size + const std::size_t parent_list_size, // + std::uint32_t* const terminate_flag, + cudaStream_t cuda_stream = 0) { std::uint32_t block_size = 32; if (small_hash_bitlen) { @@ -309,14 +310,14 @@ __global__ void compute_distance_to_child_nodes_kernel( const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, const std::uint32_t dataset_size, - const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] const std::uint32_t graph_degree, - const DATA_T* query_ptr, // [num_queries, data_dim] - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const DATA_T* query_ptr, // [num_queries, data_dim] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::uint32_t hash_bitlen, - INDEX_T* const result_indices_ptr, // [num_queries, ldd] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd // (*) ldd >= num_parents * graph_degree + INDEX_T* const result_indices_ptr, // [num_queries, ldd] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd // (*) ldd >= num_parents * graph_degree ) { const uint32_t ldb = hashmap::get_size(hash_bitlen); @@ -334,7 +335,8 @@ __global__ void compute_distance_to_child_nodes_kernel( const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; - if (hashmap::insert(visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id)) { + if (hashmap::insert( + visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id)) { device::fragment frag_target; device::load_vector_sync(frag_target, dataset_ptr + (data_dim * child_id), data_dim); @@ -368,15 +370,15 @@ void compute_distance_to_child_nodes( const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, const std::uint32_t dataset_size, - const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] const std::uint32_t graph_degree, - const DATA_T* query_ptr, // [num_queries, data_dim] + const DATA_T* query_ptr, // [num_queries, data_dim] const std::uint32_t num_queries, - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::uint32_t hash_bitlen, - INDEX_T* const result_indices_ptr, // [num_queries, ldd] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd, // (*) ldd >= num_parents * graph_degree + INDEX_T* const result_indices_ptr, // [num_queries, ldd] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd, // (*) ldd >= num_parents * graph_degree cudaStream_t cuda_stream = 0) { const auto block_size = 128; @@ -405,11 +407,13 @@ __global__ void remove_parent_bit_kernel(const std::uint32_t num_queries, INDEX_T* const topk_indices_ptr, // [ld, num_queries] const std::uint32_t ld) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + uint32_t i_query = blockIdx.x; if (i_query >= num_queries) return; for (unsigned i = threadIdx.x; i < num_topk; i += blockDim.x) { - topk_indices_ptr[i + (ld * i_query)] &= ~0x80000000; // clear most significant bit + topk_indices_ptr[i + (ld * i_query)] &= ~index_msb_1_mask; // clear most significant bit } } @@ -537,9 +541,9 @@ struct search : search_plan_impl { using search_plan_impl::num_seeds; size_t result_buffer_allocation_size; - rmm::device_uvector result_indices; // results_indices_buffer - rmm::device_uvector result_distances; // result_distances_buffer - rmm::device_uvector parent_node_list; + rmm::device_uvector result_indices; // results_indices_buffer + rmm::device_uvector result_distances; // result_distances_buffer + rmm::device_uvector parent_node_list; rmm::device_uvector topk_hint; rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; rmm::device_uvector topk_workspace; @@ -600,7 +604,7 @@ struct search : search_plan_impl { cudaStream_t stream = resource::get_cuda_stream(res); const uint32_t hash_size = hashmap::get_size(hash_bitlen); set_value_batch( - hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); + hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); // Init topk_hint if (topk_hint.size() > 0) { set_value(topk_hint.data(), 0xffffffffu, num_queries, stream); } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index b573d7d7ca..3bed100a70 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -81,9 +81,9 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t topk; uint32_t num_seeds; - rmm::device_uvector hashmap; + rmm::device_uvector hashmap; rmm::device_uvector num_executed_iterations; // device or managed? - rmm::device_uvector dev_seed; // IdxT + rmm::device_uvector dev_seed; search_plan_impl(raft::resources const& res, search_params params, @@ -243,7 +243,7 @@ struct search_plan_impl : public search_plan_impl_base { if (small_hash_bitlen > 0) { RAFT_LOG_DEBUG("# small_hash_reset_interval = %lu", small_hash_reset_interval); } - hashmap_size = sizeof(std::uint32_t) * max_queries * hashmap::get_size(hash_bitlen); + hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); RAFT_LOG_DEBUG("# hashmap size: %lu", hashmap_size); if (hashmap_size >= 1024 * 1024 * 1024) { RAFT_LOG_DEBUG(" (%.2f GiB)", (double)hashmap_size / (1024 * 1024 * 1024)); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index d64afb0d11..9400a16c36 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -53,6 +53,7 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, const std::size_t dataset_size, const std::uint32_t num_parents) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; // if (threadIdx.x >= 32) return; for (std::uint32_t i = threadIdx.x; i < num_parents; i += 32) { @@ -68,7 +69,7 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, int new_parent = 0; if (j < internal_topk_size) { index = internal_topk_indices[jj]; - if ((index & 0x80000000) == 0) { // check if most significant bit is set + if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set new_parent = 1; } } @@ -78,7 +79,7 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, if (i < num_parents) { next_parent_indices[i] = index; // set most significant bit as used node - internal_topk_indices[jj] |= 0x80000000; + internal_topk_indices[jj] |= index_msb_1_mask; } } num_new_parents += __popc(ballot_mask); @@ -93,49 +94,52 @@ struct topk_by_radix_sort_base { static constexpr std::uint32_t state_bit_lenght = 0; static constexpr std::uint32_t vecLen = 2; // TODO }; -template +template struct topk_by_radix_sort : topk_by_radix_sort_base {}; -template +template struct topk_by_radix_sort> : topk_by_radix_sort_base { __device__ void operator()(uint32_t topk, uint32_t batch_size, uint32_t len_x, const uint32_t* _x, - const uint32_t* _in_vals, + const IdxT* _in_vals, uint32_t* _y, - uint32_t* _out_vals, + IdxT* _out_vals, uint32_t* work, uint32_t* _hints, bool sort, uint32_t* _smem) { - std::uint8_t* state = (std::uint8_t*)work; + std::uint8_t* const state = reinterpret_cast(work); topk_cta_11_core::state_bit_lenght, topk_by_radix_sort_base::vecLen, 64, - 32>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); + 32, + IdxT>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); } }; #define TOP_FUNC_PARTIAL_SPECIALIZATION(V) \ - template \ + template \ struct topk_by_radix_sort< \ MAX_INTERNAL_TOPK, \ BLOCK_SIZE, \ + IdxT, \ std::enable_if_t<((MAX_INTERNAL_TOPK <= V) && (2 * MAX_INTERNAL_TOPK > V))>> \ : topk_by_radix_sort_base { \ __device__ void operator()(uint32_t topk, \ uint32_t batch_size, \ uint32_t len_x, \ const uint32_t* _x, \ - const uint32_t* _in_vals, \ + const IdxT* _in_vals, \ uint32_t* _y, \ - uint32_t* _out_vals, \ + IdxT* _out_vals, \ uint32_t* work, \ uint32_t* _hints, \ bool sort, \ @@ -147,7 +151,8 @@ struct topk_by_radix_sort::state_bit_lenght, \ topk_by_radix_sort_base::vecLen, \ V, \ - V / 4>( \ + V / 4, \ + IdxT>( \ topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); \ } \ }; @@ -156,12 +161,11 @@ TOP_FUNC_PARTIAL_SPECIALIZATION(256); TOP_FUNC_PARTIAL_SPECIALIZATION(512); TOP_FUNC_PARTIAL_SPECIALIZATION(1024); -template -__device__ inline void topk_by_bitonic_sort_1st( - float* candidate_distances, // [num_candidates] - std::uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - const std::uint32_t num_itopk) +template +__device__ inline void topk_by_bitonic_sort_1st(float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) { const unsigned lane_id = threadIdx.x % 32; const unsigned warp_id = threadIdx.x / 32; @@ -169,7 +173,7 @@ __device__ inline void topk_by_bitonic_sort_1st( if (warp_id > 0) { return; } constexpr unsigned N = (MAX_CANDIDATES + 31) / 32; float key[N]; - std::uint32_t val[N]; + IdxT val[N]; /* Candidates -> Reg */ for (unsigned i = 0; i < N; i++) { unsigned j = lane_id + (32 * i); @@ -178,11 +182,11 @@ __device__ inline void topk_by_bitonic_sort_1st( val[i] = candidate_indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); /* Reg -> Temp_itopk */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; @@ -196,7 +200,7 @@ __device__ inline void topk_by_bitonic_sort_1st( constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; constexpr unsigned N = (max_candidates_per_warp + 31) / 32; float key[N]; - std::uint32_t val[N]; + IdxT val[N]; if (warp_id < 2) { /* Candidates -> Reg */ for (unsigned i = 0; i < N; i++) { @@ -207,11 +211,11 @@ __device__ inline void topk_by_bitonic_sort_1st( val[i] = candidate_indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); /* Reg -> Temp_candidates */ for (unsigned i = 0; i < N; i++) { unsigned jl = (N * lane_id) + i; @@ -244,7 +248,7 @@ __device__ inline void topk_by_bitonic_sort_1st( if (num_warps_used > 1) { __syncthreads(); } if (warp_id < num_warps_used) { /* Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, 32); /* Reg -> Temp_itopk */ for (unsigned i = 0; i < N; i++) { unsigned jl = (N * lane_id) + i; @@ -259,16 +263,15 @@ __device__ inline void topk_by_bitonic_sort_1st( } } -template -__device__ inline void topk_by_bitonic_sort_2nd( - float* itopk_distances, // [num_itopk] - std::uint32_t* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - std::uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first) +template +__device__ inline void topk_by_bitonic_sort_2nd(float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) { const unsigned lane_id = threadIdx.x % 32; const unsigned warp_id = threadIdx.x / 32; @@ -276,7 +279,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( if (warp_id > 0) { return; } constexpr unsigned N = (MAX_ITOPK + 31) / 32; float key[N]; - std::uint32_t val[N]; + IdxT val[N]; if (first) { /* Load itopk results */ for (unsigned i = 0; i < N; i++) { @@ -286,11 +289,11 @@ __device__ inline void topk_by_bitonic_sort_2nd( val[i] = itopk_indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Warp Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); } else { /* Load itopk results */ for (unsigned i = 0; i < N; i++) { @@ -300,7 +303,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( val[i] = itopk_indices[device::swizzling(j)]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } } @@ -316,7 +319,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( } } /* Warp Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, 32); /* Store new itopk results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; @@ -330,7 +333,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; constexpr unsigned N = (max_itopk_per_warp + 31) / 32; float key[N]; - std::uint32_t val[N]; + IdxT val[N]; if (first) { /* Load itop results (not sorted) */ if (warp_id < 2) { @@ -341,11 +344,11 @@ __device__ inline void topk_by_bitonic_sort_2nd( val[i] = itopk_indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Warp Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); /* Store intermedidate results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * threadIdx.x) + i; @@ -369,7 +372,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( } } /* Warp Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, 32); } __syncthreads(); /* Store itopk results (sorted) */ @@ -414,8 +417,8 @@ __device__ inline void topk_by_bitonic_sort_2nd( if (key_0 > key_1) { itopk_distances[device::swizzling(j)] = key_1; itopk_distances[device::swizzling(k)] = key_0; - std::uint32_t val_0 = itopk_indices[device::swizzling(j)]; - std::uint32_t val_1 = itopk_indices[device::swizzling(k)]; + IdxT val_0 = itopk_indices[device::swizzling(j)]; + IdxT val_1 = itopk_indices[device::swizzling(k)]; itopk_indices[device::swizzling(j)] = val_1; itopk_indices[device::swizzling(k)] = val_0; atomicMin(work_buf + 0, j); @@ -447,11 +450,11 @@ __device__ inline void topk_by_bitonic_sort_2nd( val[i] = itopk_indices[device::swizzling(k)]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Warp Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, 32); /* Store new itopk results */ for (unsigned i = 0; i < N; i++) { const unsigned j = (N * lane_id) + i; @@ -468,41 +471,44 @@ __device__ inline void topk_by_bitonic_sort_2nd( template -__device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] - std::uint32_t* itopk_indices, // [num_itopk] + unsigned MULTI_WARPS_2, + class IdxT> +__device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - std::uint32_t* candidate_indices, // [num_candidates] + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] const std::uint32_t num_candidates, std::uint32_t* work_buf, const bool first) { // The results in candidate_distances/indices are sorted by bitonic sort. - topk_by_bitonic_sort_1st( + topk_by_bitonic_sort_1st( candidate_distances, candidate_indices, num_candidates, num_itopk); // The results sorted above are merged with the internal intermediate top-k // results so far using bitonic merge. - topk_by_bitonic_sort_2nd(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); + topk_by_bitonic_sort_2nd(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); } template -__device__ inline void hashmap_restore(uint32_t* hashmap_ptr, +__device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr, const size_t hashmap_bitlen, const INDEX_T* itopk_indices, uint32_t itopk_size) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + if (threadIdx.x < FIRST_TID || threadIdx.x >= LAST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < itopk_size; i += LAST_TID - FIRST_TID) { - auto key = itopk_indices[i] & ~0x80000000; // clear most significant bit + auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit hashmap::insert(hashmap_ptr, hashmap_bitlen, key); } } @@ -539,9 +545,9 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ const std::uint32_t graph_degree, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::uint32_t internal_topk, const std::uint32_t num_parents, const std::uint32_t min_iteration, @@ -587,8 +593,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ auto result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto visited_hash_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); auto topk_ws = reinterpret_cast(parent_list_buffer + num_parents); auto terminate_flag = reinterpret_cast(topk_ws + 3); auto smem_working_ptr = reinterpret_cast(terminate_flag + 1); @@ -608,7 +614,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ } // Init hashmap - uint32_t* local_visited_hashmap_ptr; + INDEX_T* local_visited_hashmap_ptr; if (small_hash_bitlen) { local_visited_hashmap_ptr = visited_hash_buffer; } else { @@ -693,7 +699,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ } else { _CLK_START(); // topk with radix block sort - topk_by_radix_sort{}( + topk_by_radix_sort{}( internal_topk, gridDim.x, result_buffer_size, @@ -768,7 +774,11 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ unsigned ii = i; if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } - result_indices_ptr[j] = result_indices_buffer[ii] & ~0x80000000; // clear most significant bit + + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + result_indices_ptr[j] = + result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit } if (threadIdx.x == 0 && num_executed_iterations != nullptr) { num_executed_iterations[query_id] = iter + 1; @@ -868,7 +878,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t itopk_size, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ @@ -999,17 +1009,18 @@ struct search : search_plan_impl { const std::uint32_t topk_ws_size = 3; const std::uint32_t base_smem_size = sizeof(float) * max_dim + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + - sizeof(std::uint32_t) * hashmap::get_size(small_hash_bitlen) + - sizeof(std::uint32_t) * num_parents + sizeof(std::uint32_t) * topk_ws_size + - sizeof(std::uint32_t); + sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * num_parents + + sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t); smem_size = base_smem_size; if (num_itopk_candidates > 256) { // Tentatively calculate the required share memory size when radix // sort based topk is used, assuming the block size is the maximum. if (itopk_size <= 256) { - smem_size += topk_by_radix_sort<256, max_block_size>::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort<256, max_block_size, INDEX_T>::smem_size * sizeof(std::uint32_t); } else { - smem_size += topk_by_radix_sort<512, max_block_size>::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort<512, max_block_size, INDEX_T>::smem_size * sizeof(std::uint32_t); } } @@ -1080,32 +1091,38 @@ struct search : search_plan_impl { constexpr unsigned MAX_ITOPK = 256; if (block_size == 256) { constexpr unsigned BLOCK_SIZE = 256; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } else if (block_size == 512) { constexpr unsigned BLOCK_SIZE = 512; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } else { constexpr unsigned BLOCK_SIZE = 1024; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } } else { constexpr unsigned MAX_ITOPK = 512; if (block_size == 256) { constexpr unsigned BLOCK_SIZE = 256; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } else if (block_size == 512) { constexpr unsigned BLOCK_SIZE = 512; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } else { constexpr unsigned BLOCK_SIZE = 1024; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } } } RAFT_LOG_DEBUG("# smem_size: %u", smem_size); hashmap_size = 0; if (small_hash_bitlen == 0) { - hashmap_size = sizeof(uint32_t) * max_queries * hashmap::get_size(hash_bitlen); + hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); } RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size); diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h index ccb65fd0ea..2896dba1f3 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h @@ -27,17 +27,18 @@ size_t _cuann_find_topk_bufferSize(uint32_t topK, cudaDataType_t sampleDtype = CUDA_R_32F); // +template void _cuann_find_topk(uint32_t topK, uint32_t sizeBatch, uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const uint32_t* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - uint32_t* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK + const float* inputKeys, // [sizeBatch, ldIK,] + uint32_t ldIK, // (*) ldIK >= numElements + const ValT* inputVals, // [sizeBatch, ldIV,] + uint32_t ldIV, // (*) ldIV >= numElements + float* outputKeys, // [sizeBatch, ldOK,] + uint32_t ldOK, // (*) ldOK >= topK + ValT* outputVals, // [sizeBatch, ldOV,] + uint32_t ldOV, // (*) ldOV >= topK void* workspace, bool sort = false, uint32_t* hint = NULL, @@ -54,4 +55,4 @@ CUDA_DEVICE_HOST_FUNC inline size_t _cuann_aligned(size_t size, size_t unit = 12 if (size % unit) { size += unit - (size % unit); } return size; } -} // namespace raft::neighbors::experimental::cagra::detail \ No newline at end of file +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh index 072593550e..5bc4b70791 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh @@ -237,7 +237,7 @@ __device__ inline void update_histogram(int itr, } #pragma unroll for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { - int iv = i + (num_threads * v); + const int iv = i + (num_threads * v); if (iv >= nx) break; struct u32_vector x_u32_vec; @@ -249,7 +249,7 @@ __device__ inline void update_histogram(int itr, } #pragma unroll for (int u = 0; u < vecLen; u++) { - int ivu = iv + u; + const int ivu = iv + u; if (ivu >= nx) break; uint8_t mask = (uint8_t)0x1 << (v + u); @@ -270,7 +270,7 @@ __device__ inline void update_histogram(int itr, iState |= mask; } } else { - uint32_t k = (xi - threshold) >> shift; // 0 <= k + const uint32_t k = (xi - threshold) >> shift; // 0 <= k if (k >= num_bins) { if (stateBitLen == 8) { iState |= mask; } } else if (k + 1 < num_bins) { @@ -287,15 +287,16 @@ __device__ inline void update_histogram(int itr, // template -__device__ inline void select_best_index_for_next_threshold(uint32_t topk, - uint32_t threshold, - uint32_t max_threshold, - uint32_t nx_below_threshold, - uint32_t num_bins, - uint32_t shift, - const uint32_t* hist, // [num_bins] - uint32_t* best_index, - uint32_t* best_csum) +__device__ inline void select_best_index_for_next_threshold( + const uint32_t topk, + const uint32_t threshold, + const uint32_t max_threshold, + const uint32_t nx_below_threshold, + const uint32_t num_bins, + const uint32_t shift, + const uint32_t* const hist, // [num_bins] + uint32_t* const best_index, + uint32_t* const best_csum) { // Scan the histogram ('hist') and compute csum. Then, find the largest // index under the condition that the sum of the number of elements found @@ -311,7 +312,7 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, if (threadIdx.x < num_bins) { csum = hist[threadIdx.x]; } BlockScanT(temp_storage).InclusiveSum(csum, csum); if (threadIdx.x < num_bins) { - uint32_t index = threadIdx.x; + const uint32_t index = threadIdx.x; if ((nx_below_threshold + csum <= topk) && (threshold + (index << shift) <= max_threshold)) { my_index = index; my_csum = csum; @@ -327,7 +328,7 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, BlockScanT(temp_storage).InclusiveSum(csum, csum); for (int i = n_data - 1; i >= 0; i--) { if (nx_below_threshold + csum[i] > topk) continue; - uint32_t index = i + (n_data * threadIdx.x); + const uint32_t index = i + (n_data * threadIdx.x); if (threshold + (index << shift) > max_threshold) continue; my_index = index; my_csum = csum[i]; @@ -342,7 +343,7 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, BlockScanT(temp_storage).InclusiveSum(csum, csum); for (int i = n_data - 1; i >= 0; i--) { if (nx_below_threshold + csum[i] > topk) continue; - uint32_t index = i + (n_data * threadIdx.x); + const uint32_t index = i + (n_data * threadIdx.x); if (threshold + (index << shift) > max_threshold) continue; my_index = index; my_csum = csum[i]; @@ -351,9 +352,9 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, } } if (threadIdx.x < num_bins) { - int laneid = 31 - __clz(__ballot_sync(0xffffffff, (my_index != 0xffffffff))); + const int laneid = 31 - __clz(__ballot_sync(0xffffffff, (my_index != 0xffffffff))); if ((threadIdx.x & 0x1f) == laneid) { - uint32_t old_index = atomicMax(best_index, my_index); + const uint32_t old_index = atomicMax(best_index, my_index); if (old_index < my_index) { atomicMax(best_csum, my_csum); } } } @@ -362,17 +363,17 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, // template -__device__ inline void output_index_below_threshold(uint32_t topk, - uint32_t thread_id, - uint32_t num_threads, - uint32_t threshold, - uint32_t nx_below_threshold, - const T* x, // [nx,] - uint32_t nx, +__device__ inline void output_index_below_threshold(const uint32_t topk, + const uint32_t thread_id, + const uint32_t num_threads, + const uint32_t threshold, + const uint32_t nx_below_threshold, + const T* const x, // [nx,] + const uint32_t nx, const uint8_t* state, - uint32_t* output, // [topk] - uint32_t* output_count, - uint32_t* output_count_eq) + uint32_t* const output, // [topk] + uint32_t* const output_count, + uint32_t* const output_count_eq) { int ii = 0; for (int i = thread_id * vecLen; i < nx; i += num_threads * max(vecLen, stateBitLen), ii++) { @@ -383,7 +384,7 @@ __device__ inline void output_index_below_threshold(uint32_t topk, } #pragma unroll for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { - int iv = i + (num_threads * v); + const int iv = i + (num_threads * v); if (iv >= nx) break; struct u32_vector u32_vec; @@ -395,10 +396,10 @@ __device__ inline void output_index_below_threshold(uint32_t topk, } #pragma unroll for (int u = 0; u < vecLen; u++) { - int ivu = iv + u; + const int ivu = iv + u; if (ivu >= nx) break; - uint8_t mask = (uint8_t)0x1 << (v + u); + const uint8_t mask = (uint8_t)0x1 << (v + u); if ((stateBitLen == 8) && (iState & mask)) continue; uint32_t xi; @@ -425,9 +426,9 @@ __device__ inline void output_index_below_threshold(uint32_t topk, template __device__ inline void swap(T& val1, T& val2) { - T val0 = val1; - val1 = val2; - val2 = val0; + const T val0 = val1; + val1 = val2; + val2 = val0; } // @@ -493,44 +494,44 @@ __device__ __host__ inline uint32_t get_state_size(uint32_t len_x) } // -template +template __device__ inline void topk_cta_11_core(uint32_t topk, uint32_t len_x, - const uint32_t* _x, // [size_batch, ld_x,] - const uint32_t* _in_vals, // [size_batch, ld_iv,] - uint32_t* _y, // [size_batch, ld_y,] - uint32_t* _out_vals, // [size_batch, ld_ov,] - uint8_t* _state, // [size_batch, ...,] + const uint32_t* _x, // [size_batch, ld_x,] + const ValT* _in_vals, // [size_batch, ld_iv,] + uint32_t* _y, // [size_batch, ld_y,] + ValT* _out_vals, // [size_batch, ld_ov,] + uint8_t* _state, // [size_batch, ...,] uint32_t* _hint, bool sort, uint32_t* _smem) { - uint32_t* smem_out_vals = _smem; - uint32_t* hist = &(_smem[2 * maxTopk]); - uint32_t* best_index = &(_smem[2 * maxTopk + 2048]); - uint32_t* best_csum = &(_smem[2 * maxTopk + 2048 + 3]); + uint32_t* const smem_out_vals = _smem; + uint32_t* const hist = &(_smem[2 * maxTopk]); + uint32_t* const best_index = &(_smem[2 * maxTopk + 2048]); + uint32_t* const best_csum = &(_smem[2 * maxTopk + 2048 + 3]); const uint32_t num_threads = blockDim_x; const uint32_t thread_id = threadIdx.x; uint32_t nx = len_x; - const uint32_t* x = _x; - const uint32_t* in_vals = NULL; + const uint32_t* const x = _x; + const ValT* in_vals = NULL; if (_in_vals) { in_vals = _in_vals; } uint32_t* y = NULL; if (_y) { y = _y; } - uint32_t* out_vals = NULL; + ValT* out_vals = NULL; if (_out_vals) { out_vals = _out_vals; } - uint8_t* state = _state; - uint32_t hint = (_hint == NULL ? ~0u : *_hint); + uint8_t* state = _state; + const uint32_t hint = (_hint == NULL ? ~0u : *_hint); // Initialize shared memory for (int i = 2 * maxTopk + thread_id; i < 2 * maxTopk + 2048 + 8; i += num_threads) { _smem[i] = 0; } - uint32_t* output_count = &(_smem[2 * maxTopk + 2048 + 6]); - uint32_t* output_count_eq = &(_smem[2 * maxTopk + 2048 + 7]); - uint32_t threshold = 0; - uint32_t nx_below_threshold = 0; + uint32_t* const output_count = &(_smem[2 * maxTopk + 2048 + 6]); + uint32_t* const output_count_eq = &(_smem[2 * maxTopk + 2048 + 7]); + uint32_t threshold = 0; + uint32_t nx_below_threshold = 0; __syncthreads(); // @@ -601,7 +602,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, if (!sort) { for (int k = thread_id; k < topk; k += blockDim_x) { - uint32_t i = smem_out_vals[k]; + const uint32_t i = smem_out_vals[k]; if (y) { y[k] = x[i]; } if (out_vals) { if (in_vals) { @@ -616,15 +617,15 @@ __device__ inline void topk_cta_11_core(uint32_t topk, constexpr int numTopkPerThread = maxTopk / numSortThreads; float my_keys[numTopkPerThread]; - uint32_t my_vals[numTopkPerThread]; + ValT my_vals[numTopkPerThread]; // Read keys and values to registers if (thread_id < numSortThreads) { for (int i = 0; i < numTopkPerThread; i++) { - int k = thread_id + (numSortThreads * i); + const int k = thread_id + (numSortThreads * i); if (k < topk) { - int j = smem_out_vals[k]; - my_keys[i] = ((float*)x)[j]; + const int j = smem_out_vals[k]; + my_keys[i] = ((float*)x)[j]; if (in_vals) { my_vals[i] = in_vals[j]; } else { @@ -632,7 +633,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, } } else { my_keys[i] = FLT_MAX; - my_vals[i] = 0xffffffffU; + my_vals[i] = ~static_cast(0); } } } @@ -641,21 +642,21 @@ __device__ inline void topk_cta_11_core(uint32_t topk, // Sorting by thread if (thread_id < numSortThreads) { - bool ascending = ((thread_id & mask) == 0); + const bool ascending = ((thread_id & mask) == 0); if (numTopkPerThread == 3) { - swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); - swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); - swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); + swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); + swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); + swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); } else { for (int j = 0; j < numTopkPerThread / 2; j += 1) { #pragma unroll for (int i = 0; i < numTopkPerThread; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } #pragma unroll for (int i = 1; i < numTopkPerThread - 1; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } } @@ -667,11 +668,12 @@ __device__ inline void topk_cta_11_core(uint32_t topk, uint32_t next_mask = mask << 1; for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { - bool ascending = ((thread_id & curr_mask) == 0) == ((thread_id & next_mask) == 0); + const bool ascending = ((thread_id & curr_mask) == 0) == ((thread_id & next_mask) == 0); if (curr_mask >= 32) { // inter warp - uint32_t* smem_vals = _smem; // [numTopkPerThread, numSortThreads] - float* smem_keys = (float*)(_smem + numTopkPerThread * numSortThreads); + ValT* const smem_vals = reinterpret_cast(_smem); // [maxTopk] + float* const smem_keys = + reinterpret_cast(smem_vals + maxTopk); // [numTopkPerThread, numSortThreads] __syncthreads(); if (thread_id < numSortThreads) { #pragma unroll @@ -684,9 +686,9 @@ __device__ inline void topk_cta_11_core(uint32_t topk, if (thread_id < numSortThreads) { #pragma unroll for (int i = 0; i < numTopkPerThread; i++) { - float opp_key = smem_keys[(thread_id ^ curr_mask) + (numSortThreads * i)]; - uint32_t opp_val = smem_vals[(thread_id ^ curr_mask) + (numSortThreads * i)]; - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + float opp_key = smem_keys[(thread_id ^ curr_mask) + (numSortThreads * i)]; + ValT opp_val = smem_vals[(thread_id ^ curr_mask) + (numSortThreads * i)]; + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } } else { @@ -694,29 +696,28 @@ __device__ inline void topk_cta_11_core(uint32_t topk, if (thread_id < numSortThreads) { #pragma unroll for (int i = 0; i < numTopkPerThread; i++) { - float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); - uint32_t opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); + ValT opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } } } if (thread_id < numSortThreads) { - bool ascending = ((thread_id & next_mask) == 0); + const bool ascending = ((thread_id & next_mask) == 0); if (numTopkPerThread == 3) { - swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); - swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); - swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); + swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); + swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); + swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); } else { #pragma unroll for (uint32_t curr_mask = numTopkPerThread / 2; curr_mask > 0; curr_mask >>= 1) { #pragma unroll for (int i = 0; i < numTopkPerThread; i++) { - int j = i ^ curr_mask; + const int j = i ^ curr_mask; if (i > j) continue; - swap_if_needed( - my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); + swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); } } } @@ -727,9 +728,9 @@ __device__ inline void topk_cta_11_core(uint32_t topk, // Write sorted keys and values if (thread_id < numSortThreads) { for (int i = 0; i < numTopkPerThread; i++) { - int k = i + (numTopkPerThread * thread_id); + const int k = i + (numTopkPerThread * thread_id); if (k < topk) { - if (y) { y[k] = ((uint32_t*)my_keys)[i]; } + if (y) { y[k] = reinterpret_cast(my_keys)[i]; } if (out_vals) { out_vals[k] = my_vals[i]; } } } @@ -755,28 +756,32 @@ int _get_vecLen(uint32_t maxSamples, int maxVecLen = MAX_VEC_LENGTH) } } // unnamed namespace -template +template __launch_bounds__(1024, 1) __global__ void kern_topk_cta_11(uint32_t topk, uint32_t size_batch, uint32_t len_x, - const uint32_t* _x, // [size_batch, ld_x,] + const uint32_t* _x, // [size_batch, ld_x,] uint32_t ld_x, - const uint32_t* _in_vals, // [size_batch, ld_iv,] + const ValT* _in_vals, // [size_batch, ld_iv,] uint32_t ld_iv, - uint32_t* _y, // [size_batch, ld_y,] + uint32_t* _y, // [size_batch, ld_y,] uint32_t ld_y, - uint32_t* _out_vals, // [size_batch, ld_ov,] + ValT* _out_vals, // [size_batch, ld_ov,] uint32_t ld_ov, - uint8_t* _state, // [size_batch, ...,] - uint32_t* _hints, // [size_batch,] + uint8_t* _state, // [size_batch, ...,] + uint32_t* _hints, // [size_batch,] bool sort) { - uint32_t i_batch = blockIdx.x; + const uint32_t i_batch = blockIdx.x; if (i_batch >= size_batch) return; - __shared__ uint32_t _smem[2 * maxTopk + 2048 + 8]; - topk_cta_11_core( + constexpr uint32_t smem_len = 2 * maxTopk + 2048 + 8; + static_assert(maxTopk * (1 + utils::size_of() / utils::size_of()) <= smem_len, + "maxTopk * sizeof(ValT) must be smaller or equal to 8192 byte"); + __shared__ uint32_t _smem[smem_len]; + + topk_cta_11_core( topk, len_x, (_x == NULL ? NULL : _x + i_batch * ld_x), @@ -809,17 +814,18 @@ size_t inline _cuann_find_topk_bufferSize(uint32_t topK, return workspaceSize; } +template inline void _cuann_find_topk(uint32_t topK, uint32_t sizeBatch, uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const uint32_t* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - uint32_t* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK + const float* inputKeys, // [sizeBatch, ldIK,] + uint32_t ldIK, // (*) ldIK >= numElements + const ValT* inputVals, // [sizeBatch, ldIV,] + uint32_t ldIV, // (*) ldIV >= numElements + float* outputKeys, // [sizeBatch, ldOK,] + uint32_t ldOK, // (*) ldOK >= topK + ValT* outputVals, // [sizeBatch, ldOV,] + uint32_t ldOV, // (*) ldOV >= topK void* workspace, bool sort, uint32_t* hints, @@ -845,48 +851,48 @@ inline void _cuann_find_topk(uint32_t topK, uint32_t, const uint32_t*, uint32_t, - const uint32_t*, + const ValT*, uint32_t, uint32_t*, uint32_t, - uint32_t*, + ValT*, uint32_t, uint8_t*, uint32_t*, bool) = nullptr; // V:vecLen, K:maxTopk, T:numSortThreads -#define SET_KERNEL_VKT(V, K, T) \ - do { \ - assert(numThreads >= T); \ - assert((K % T) == 0); \ - assert((K / T) <= 4); \ - cta_kernel = kern_topk_cta_11; \ +#define SET_KERNEL_VKT(V, K, T, ValT) \ + do { \ + assert(numThreads >= T); \ + assert((K % T) == 0); \ + assert((K / T) <= 4); \ + cta_kernel = kern_topk_cta_11; \ } while (0) // V: vecLen -#define SET_KERNEL_V(V) \ +#define SET_KERNEL_V(V, ValT) \ do { \ if (topK <= 32) { \ - SET_KERNEL_VKT(V, 32, 32); \ + SET_KERNEL_VKT(V, 32, 32, ValT); \ } else if (topK <= 64) { \ - SET_KERNEL_VKT(V, 64, 32); \ + SET_KERNEL_VKT(V, 64, 32, ValT); \ } else if (topK <= 96) { \ - SET_KERNEL_VKT(V, 96, 32); \ + SET_KERNEL_VKT(V, 96, 32, ValT); \ } else if (topK <= 128) { \ - SET_KERNEL_VKT(V, 128, 32); \ + SET_KERNEL_VKT(V, 128, 32, ValT); \ } else if (topK <= 192) { \ - SET_KERNEL_VKT(V, 192, 64); \ + SET_KERNEL_VKT(V, 192, 64, ValT); \ } else if (topK <= 256) { \ - SET_KERNEL_VKT(V, 256, 64); \ + SET_KERNEL_VKT(V, 256, 64, ValT); \ } else if (topK <= 384) { \ - SET_KERNEL_VKT(V, 384, 128); \ + SET_KERNEL_VKT(V, 384, 128, ValT); \ } else if (topK <= 512) { \ - SET_KERNEL_VKT(V, 512, 128); \ + SET_KERNEL_VKT(V, 512, 128, ValT); \ } else if (topK <= 768) { \ - SET_KERNEL_VKT(V, 768, 256); \ + SET_KERNEL_VKT(V, 768, 256, ValT); \ } else if (topK <= 1024) { \ - SET_KERNEL_VKT(V, 1024, 256); \ + SET_KERNEL_VKT(V, 1024, 256, ValT); \ } \ /* else if (topK <= 1536) { SET_KERNEL_VKT(V, 1536, 512); } */ \ /* else if (topK <= 2048) { SET_KERNEL_VKT(V, 2048, 512); } */ \ @@ -901,9 +907,9 @@ inline void _cuann_find_topk(uint32_t topK, int _vecLen = _get_vecLen(ldIK, 2); if (_vecLen == 2) { - SET_KERNEL_V(2); + SET_KERNEL_V(2, ValT); } else if (_vecLen == 1) { - SET_KERNEL_V(1); + SET_KERNEL_V(1, ValT); } cta_kernel<<>>(topK, @@ -923,4 +929,4 @@ inline void _cuann_find_topk(uint32_t topK, return; } -} // namespace raft::neighbors::experimental::cagra::detail \ No newline at end of file +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp index 3e329c9239..934e84d4d5 100644 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -128,6 +128,11 @@ _RAFT_HOST_DEVICE inline std::uint32_t get_max_value() { return 0xffffffffu; }; +template <> +_RAFT_HOST_DEVICE inline std::uint64_t get_max_value() +{ + return 0xfffffffffffffffflu; +}; template struct constexpr_max { @@ -138,6 +143,11 @@ template struct constexpr_max A), bool>> { static const int value = B; }; + +template +struct gen_index_msb_1_mask { + static constexpr IdxT value = static_cast(1) << (utils::size_of() * 8 - 1); +}; } // namespace utils } // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 98ce8ac5bd..1b4d269d1b 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -316,6 +316,7 @@ if(BUILD_TESTS) test/neighbors/ann_cagra/test_float_uint32_t.cu test/neighbors/ann_cagra/test_int8_t_uint32_t.cu test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu + test/neighbors/ann_cagra/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu diff --git a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu new file mode 100644 index 0000000000..e473a72b2b --- /dev/null +++ b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "../ann_cagra.cuh" + +namespace raft::neighbors::experimental::cagra { + +typedef AnnCagraTest AnnCagraTestF_I64; +TEST_P(AnnCagraTestF_I64, AnnCagra) { this->testCagra(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_I64, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index adb44a9264..dbaf4dedd9 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -20,13 +20,13 @@ namespace raft::neighbors::experimental::cagra { -typedef AnnCagraTest AnnCagraTestF; -TEST_P(AnnCagraTestF, AnnCagra) { this->testCagra(); } +typedef AnnCagraTest AnnCagraTestF_U32; +TEST_P(AnnCagraTestF_U32, AnnCagra) { this->testCagra(); } -typedef AnnCagraSortTest AnnCagraSortTestF; -TEST_P(AnnCagraSortTestF, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraSortTest AnnCagraSortTestF_U32; +TEST_P(AnnCagraSortTestF_U32, AnnCagraSort) { this->testCagraSort(); } -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF_U32, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index 11c986c189..ba60131677 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -20,12 +20,12 @@ namespace raft::neighbors::experimental::cagra { -typedef AnnCagraTest AnnCagraTestI8; -TEST_P(AnnCagraTestI8, AnnCagra) { this->testCagra(); } -typedef AnnCagraSortTest AnnCagraSortTestI8; -TEST_P(AnnCagraSortTestI8, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraTest AnnCagraTestI8_U32; +TEST_P(AnnCagraTestI8_U32, AnnCagra) { this->testCagra(); } +typedef AnnCagraSortTest AnnCagraSortTestI8_U32; +TEST_P(AnnCagraSortTestI8_U32, AnnCagraSort) { this->testCagraSort(); } -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U32, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index 51d4feeed2..cc172e4833 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -20,13 +20,13 @@ namespace raft::neighbors::experimental::cagra { -typedef AnnCagraTest AnnCagraTestU8; -TEST_P(AnnCagraTestU8, AnnCagra) { this->testCagra(); } +typedef AnnCagraTest AnnCagraTestU8_U32; +TEST_P(AnnCagraTestU8_U32, AnnCagra) { this->testCagra(); } -typedef AnnCagraSortTest AnnCagraSortTestU8; -TEST_P(AnnCagraSortTestU8, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraSortTest AnnCagraSortTestU8_U32; +TEST_P(AnnCagraSortTestU8_U32, AnnCagraSort) { this->testCagraSort(); } -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U32, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::experimental::cagra