From a4ca4a54ca3706594378751249031ec0896eae7e Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Fri, 12 May 2023 13:51:11 +0900 Subject: [PATCH 01/18] Support 64-bit index data type in CAGRA search --- .../neighbors/detail/cagra/cagra_build.cuh | 10 +- .../neighbors/detail/cagra/graph_core.cuh | 841 +++++++----------- .../detail/cagra/search_multi_cta.cuh | 22 +- .../detail/cagra/search_multi_kernel.cuh | 6 +- .../neighbors/detail/cagra/search_plan.cuh | 2 +- .../detail/cagra/search_single_cta.cuh | 156 ++-- .../detail/cagra/topk_for_cagra/topk.h | 19 +- .../detail/cagra/topk_for_cagra/topk_core.cuh | 145 +-- .../raft/neighbors/detail/cagra/utils.hpp | 5 + 9 files changed, 510 insertions(+), 696 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 54c806ba13..1c5851f53c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -37,8 +37,6 @@ namespace raft::neighbors::experimental::cagra::detail { -using INDEX_T = std::uint32_t; - template void build_knn_graph(raft::device_resources const& res, mdspan, row_major, accessor> dataset, @@ -95,14 +93,14 @@ void build_knn_graph(raft::device_resources const& res, // search top (k + 1) neighbors // if (!search_params) { - search_params = ivf_pq::search_params{}; - search_params->n_probes = std::min(dataset.extent(1) * 2, build_params->n_lists); - search_params->lut_dtype = CUDA_R_8U; + search_params = ivf_pq::search_params{}; + search_params->n_probes = std::min(dataset.extent(1) * 2, build_params->n_lists); + search_params->lut_dtype = CUDA_R_8U; search_params->internal_distance_dtype = CUDA_R_32F; } const auto top_k = node_degree + 1; uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f); - gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); + gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); const auto num_queries = dataset.extent(0); const auto max_batch_size = 1024; RAFT_LOG_DEBUG( diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index a08c83677b..dfe02d4579 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -35,20 +35,8 @@ namespace raft::neighbors::experimental::cagra::detail { namespace graph { -template -__host__ __device__ float compute_norm2(const T* a, - const T* b, - const std::size_t dim, - const float scale) -{ - float sum = 0.f; - for (std::size_t j = 0; j < dim; j++) { - const auto diff = a[j] * scale - b[j] * scale; - sum += diff * diff; - } - return sum; -} - +// unnamed namespace to avoid multiple definition error +namespace { inline double cur_time(void) { struct timeval tv; @@ -76,25 +64,18 @@ __device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool a return false; } -template -__global__ void kern_sort( - DATA_T** dataset, // [num_gpus][dataset_chunk_size, dataset_dim] - uint32_t dataset_size, - uint32_t dataset_chunk_size, // (*) num_gpus * dataset_chunk_size >= dataset_size - uint32_t dataset_dim, - float scale, - uint32_t** knn_graph, // [num_gpus][graph_chunk_size, graph_degree] - uint32_t graph_size, - uint32_t graph_chunk_size, // (*) num_gpus * graph_chunk_size >= graph_size - uint32_t graph_degree, - int dev_id) +template +__global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dataset_dim] + const IdxT dataset_size, + const uint32_t dataset_dim, + IdxT* const knn_graph, // [graph_chunk_size, graph_degree] + const uint32_t graph_size, + const uint32_t graph_degree) { __shared__ float smem_keys[blockDim_x * numElementsPerThread]; - __shared__ uint32_t smem_vals[blockDim_x * numElementsPerThread]; + __shared__ IdxT smem_vals[blockDim_x * numElementsPerThread]; - uint64_t srcNode = blockIdx.x + ((uint64_t)graph_chunk_size * dev_id); - uint64_t srcNode_dev = srcNode / graph_chunk_size; - uint64_t srcNode_loc = srcNode % graph_chunk_size; + const IdxT srcNode = blockIdx.x; if (srcNode >= graph_size) { return; } const uint32_t num_warps = blockDim_x / 32; @@ -103,14 +84,13 @@ __global__ void kern_sort( // Compute distance from a src node to its neighbors for (int k = warp_id; k < graph_degree; k += num_warps) { - uint64_t dstNode = knn_graph[srcNode_dev][k + ((uint64_t)graph_degree * srcNode_loc)]; - uint64_t dstNode_dev = dstNode / graph_chunk_size; - uint64_t dstNode_loc = dstNode % graph_chunk_size; - float dist = 0.0; + const IdxT dstNode = knn_graph[k + ((uint64_t)graph_degree * srcNode)]; + float dist = 0.0; for (int d = lane_id; d < dataset_dim; d += 32) { - float diff = - (float)(dataset[srcNode_dev][d + ((uint64_t)dataset_dim * srcNode_loc)]) * scale - - (float)(dataset[dstNode_dev][d + ((uint64_t)dataset_dim * dstNode_loc)]) * scale; + float diff = spatial::knn::detail::utils::mapping{}( + dataset[d + ((uint64_t)dataset_dim * srcNode)]) - + spatial::knn::detail::utils::mapping{}( + dataset[d + ((uint64_t)dataset_dim * dstNode)]); dist += diff * diff; } dist += __shfl_xor_sync(0xffffffff, dist, 1); @@ -126,41 +106,41 @@ __global__ void kern_sort( __syncthreads(); float my_keys[numElementsPerThread]; - uint32_t my_vals[numElementsPerThread]; + IdxT my_vals[numElementsPerThread]; for (int i = 0; i < numElementsPerThread; i++) { - int k = i + (numElementsPerThread * threadIdx.x); + const int k = i + (numElementsPerThread * threadIdx.x); if (k < graph_degree) { my_keys[i] = smem_keys[k]; my_vals[i] = smem_vals[k]; } else { my_keys[i] = FLT_MAX; - my_vals[i] = 0xffffffffU; + my_vals[i] = ~static_cast(0); } } __syncthreads(); // Sorting by thread - uint32_t mask = 1; - bool ascending = ((threadIdx.x & mask) == 0); + uint32_t mask = 1; + const bool ascending = ((threadIdx.x & mask) == 0); for (int j = 0; j < numElementsPerThread; j += 2) { #pragma unroll for (int i = 0; i < numElementsPerThread; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } #pragma unroll for (int i = 1; i < numElementsPerThread - 1; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } } // Bitonic Sorting while (mask < blockDim_x) { - uint32_t next_mask = mask << 1; + const uint32_t next_mask = mask << 1; for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { - bool ascending = ((threadIdx.x & curr_mask) == 0) == ((threadIdx.x & next_mask) == 0); + const bool ascending = ((threadIdx.x & curr_mask) == 0) == ((threadIdx.x & next_mask) == 0); if (mask >= 32) { // inter warp __syncthreads(); @@ -172,29 +152,29 @@ __global__ void kern_sort( __syncthreads(); #pragma unroll for (int i = 0; i < numElementsPerThread; i++) { - float opp_key = smem_keys[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; - uint32_t opp_val = smem_vals[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + float opp_key = smem_keys[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; + IdxT opp_val = smem_vals[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } else { // intra warp #pragma unroll for (int i = 0; i < numElementsPerThread; i++) { - float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); - uint32_t opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); + IdxT opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } } - bool ascending = ((threadIdx.x & next_mask) == 0); + const bool ascending = ((threadIdx.x & next_mask) == 0); #pragma unroll for (uint32_t curr_mask = numElementsPerThread / 2; curr_mask > 0; curr_mask >>= 1) { #pragma unroll for (int i = 0; i < numElementsPerThread; i++) { int j = i ^ curr_mask; if (i > j) continue; - swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); + swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); } } mask = next_mask; @@ -202,54 +182,47 @@ __global__ void kern_sort( // Update knn_graph for (int i = 0; i < numElementsPerThread; i++) { - int k = i + (numElementsPerThread * threadIdx.x); + const int k = i + (numElementsPerThread * threadIdx.x); if (k < graph_degree) { - knn_graph[srcNode_dev][k + ((uint64_t)graph_degree * srcNode_loc)] = my_vals[i]; + knn_graph[k + (static_cast(graph_degree) * srcNode)] = my_vals[i]; } } } -template -__global__ void kern_prune( - uint32_t** knn_graph, // [num_gpus][graph_chunk_size, graph_degree] - uint32_t graph_size, - uint32_t graph_chunk_size, // (*) num_gpus * graph_chunk_size >= graph_size - uint32_t graph_degree, - uint32_t degree, - int dev_id, - uint32_t batch_size, - uint32_t batch_id, - uint8_t** detour_count, // [num_gpus][graph_chunk_size, graph_degree] - uint32_t** num_no_detour_edges, // [num_gpus][graph_size] - uint64_t* stats) +template +__global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] + const uint32_t graph_size, + const uint32_t graph_degree, + const uint32_t degree, + const uint32_t batch_size, + const uint32_t batch_id, + uint8_t* const detour_count, // [graph_chunk_size, graph_degree] + uint32_t* const num_no_detour_edges, // [graph_size] + uint64_t* const stats) { __shared__ uint32_t smem_num_detour[MAX_DEGREE]; - uint64_t* num_retain = stats; - uint64_t* num_full = stats + 1; + uint64_t* const num_retain = stats; + uint64_t* const num_full = stats + 1; - uint64_t nid = blockIdx.x + (batch_size * batch_id); - if (nid >= graph_chunk_size) { return; } + const uint64_t nid = blockIdx.x + (batch_size * batch_id); + if (nid >= graph_size) { return; } for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { smem_num_detour[k] = 0; } __syncthreads(); - uint64_t iA = nid + ((uint64_t)graph_chunk_size * dev_id); - uint64_t iA_dev = iA / graph_chunk_size; - uint64_t iA_loc = iA % graph_chunk_size; + const uint64_t iA = nid; if (iA >= graph_size) { return; } // count number of detours (A->D->B) for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - uint64_t iD = knn_graph[iA_dev][kAD + (graph_degree * iA_loc)]; - uint64_t iD_dev = iD / graph_chunk_size; - uint64_t iD_loc = iD % graph_chunk_size; + const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { - uint64_t iB_candidate = knn_graph[iD_dev][kDB + ((uint64_t)graph_degree * iD_loc)]; + const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { // if ( kDB < kAB ) { - uint64_t iB = knn_graph[iA_dev][kAB + (graph_degree * iA_loc)]; + const uint64_t iB = knn_graph[kAB + (graph_degree * iA)]; if (iB == iB_candidate) { atomicAdd(smem_num_detour + kAB, 1); break; @@ -262,7 +235,7 @@ __global__ void kern_prune( uint32_t num_edges_no_detour = 0; for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[iA_dev][k + (graph_degree * iA_loc)] = min(smem_num_detour[k], (uint32_t)255); + detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255); if (smem_num_detour[k] == 0) { num_edges_no_detour++; } } num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); @@ -273,119 +246,29 @@ __global__ void kern_prune( num_edges_no_detour = min(num_edges_no_detour, degree); if (threadIdx.x == 0) { - num_no_detour_edges[iA_dev][iA_loc] = num_edges_no_detour; + num_no_detour_edges[iA] = num_edges_no_detour; atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } } } -// unnamed namespace to avoid multiple definition error -namespace { -__global__ void kern_make_rev_graph(const uint32_t i_gpu, - const uint32_t* dest_nodes, // [global_graph_size] - const uint32_t global_graph_size, - uint32_t* rev_graph, // [graph_size, degree] - uint32_t* rev_graph_count, // [graph_size] +template +__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] + IdxT* const rev_graph, // [size, degree] + uint32_t* const rev_graph_count, // [graph_size] const uint32_t graph_size, const uint32_t degree) { const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); const uint32_t tnum = blockDim.x * gridDim.x; - for (uint32_t gl_src_id = tid; gl_src_id < global_graph_size; gl_src_id += tnum) { - uint32_t gl_dest_id = dest_nodes[gl_src_id]; - if (gl_dest_id < graph_size * i_gpu) continue; - if (gl_dest_id >= graph_size * (i_gpu + 1)) continue; - if (gl_dest_id >= global_graph_size) continue; - - uint32_t dest_id = gl_dest_id - (graph_size * i_gpu); - uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = gl_src_id; } - } -} -} // namespace -template -T*** mgpu_alloc(int n_gpus, uint32_t chunk, uint32_t nelems) -{ - T** arrays; // [n_gpus][chunk, nelems] - arrays = (T**)malloc(sizeof(T*) * n_gpus); /* h1 */ - size_t bsize = sizeof(T) * chunk * nelems; - // RAFT_LOG_DEBUG("[%s, %s, %d] n_gpus: %d, chunk: %u, nelems: %u, bsize: %lu (%lu MiB)\n", - // __FILE__, __func__, __LINE__, n_gpus, chunk, nelems, bsize, bsize / 1024 / 1024); - for (int i_gpu = 0; i_gpu < n_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMalloc(&(arrays[i_gpu]), bsize)); /* d1 */ - } - T*** d_arrays; // [n_gpus+1][n_gpus][chunk, nelems] - d_arrays = (T***)malloc(sizeof(T**) * (n_gpus + 1)); /* h2 */ - bsize = sizeof(T*) * n_gpus; - for (int i_gpu = 0; i_gpu < n_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMalloc(&(d_arrays[i_gpu]), bsize)); /* d2 */ - RAFT_CUDA_TRY(cudaMemcpy(d_arrays[i_gpu], arrays, bsize, cudaMemcpyDefault)); - } - RAFT_CUDA_TRY(cudaSetDevice(0)); - d_arrays[n_gpus] = arrays; - return d_arrays; -} - -template -void mgpu_free(T*** d_arrays, int n_gpus) -{ - for (int i_gpu = 0; i_gpu < n_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaFree(d_arrays[n_gpus][i_gpu])); /* d1 */ - RAFT_CUDA_TRY(cudaFree(d_arrays[i_gpu])); /* d2 */ - } - RAFT_CUDA_TRY(cudaSetDevice(0)); - free(d_arrays[n_gpus]); /* h1 */ - free(d_arrays); /* h2 */ -} - -template -void mgpu_H2D(T*** d_arrays, // [n_gpus+1][n_gpus][chunk, nelems] - const T* h_array, // [size, nelems] - int n_gpus, - uint32_t size, - uint32_t chunk, // (*) n_gpus * chunk >= size - uint32_t nelems) -{ -#pragma omp parallel num_threads(n_gpus) - { - int i_gpu = omp_get_thread_num(); - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - uint32_t _chunk = std::min(size - (chunk * i_gpu), chunk); - size_t bsize = sizeof(T) * _chunk * nelems; - RAFT_CUDA_TRY(cudaMemcpy(d_arrays[n_gpus][i_gpu], - h_array + ((uint64_t)chunk * nelems * i_gpu), - bsize, - cudaMemcpyDefault)); - } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_CUDA_TRY(cudaSetDevice(0)); -} + for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { + const IdxT dest_id = dest_nodes[src_id]; + if (dest_id >= graph_size) continue; -template -void mgpu_D2H(T*** d_arrays, // [n_gpus+1][n_gpus][chunk, nelems] - T* h_array, // [size, nelems] - int n_gpus, - uint32_t size, - uint32_t chunk, // (*) n_gpus * chunk >= size - uint32_t nelems) -{ -#pragma omp parallel num_threads(n_gpus) - { - int i_gpu = omp_get_thread_num(); - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - uint32_t _chunk = std::min(size - (chunk * i_gpu), chunk); - size_t bsize = sizeof(T) * _chunk * nelems; - RAFT_CUDA_TRY(cudaMemcpy(h_array + ((uint64_t)chunk * nelems * i_gpu), - d_arrays[n_gpus][i_gpu], - bsize, - cudaMemcpyDefault)); + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_CUDA_TRY(cudaSetDevice(0)); } template @@ -404,6 +287,7 @@ void shift_array(T* array, uint64_t num) array[i] = array[i - 1]; } } +} // namespace template (num_gpus, graph_chunk_size, input_graph_degree); + auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); - DataT*** d_dataset_ptr = NULL; // [num_gpus+1][...][...] - const uint32_t dataset_chunk_size = (dataset_size + num_gpus - 1) / num_gpus; - assert(dataset_chunk_size == graph_chunk_size); - d_dataset_ptr = mgpu_alloc(num_gpus, dataset_chunk_size, dataset_dim); + // + // Sorting kNN graph + // + const double time_sort_start = cur_time(); + RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); - const float scale = 1.0f / raft::spatial::knn::detail::utils::config::kDivisor; + auto d_dataset = raft::make_device_matrix(res, dataset_size, dataset_dim); + raft::copy(d_dataset.data_handle(), dataset_ptr, dataset_size * dataset_dim, res.get_stream()); - mgpu_H2D( - d_dataset_ptr, dataset_ptr, num_gpus, dataset_size, dataset_chunk_size, dataset_dim); + raft::copy(d_input_graph.data_handle(), + input_graph_ptr, + graph_size * input_graph_degree, + res.get_stream()); - double time_sort_start = cur_time(); - RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); - mgpu_H2D(d_input_graph_ptr, - input_graph_ptr, - num_gpus, - dataset_size, - graph_chunk_size, - input_graph_degree); void (*kernel_sort)( - DataT**, uint32_t, uint32_t, uint32_t, float, uint32_t**, uint32_t, uint32_t, uint32_t, int); + const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t); constexpr int numElementsPerThread = 4; dim3 threads_sort(1, 1, 1); if (input_graph_degree <= numElementsPerThread * 32) { constexpr int blockDim_x = 32; - kernel_sort = kern_sort; + kernel_sort = kern_sort; threads_sort.x = blockDim_x; } else if (input_graph_degree <= numElementsPerThread * 64) { constexpr int blockDim_x = 64; - kernel_sort = kern_sort; + kernel_sort = kern_sort; threads_sort.x = blockDim_x; } else if (input_graph_degree <= numElementsPerThread * 128) { constexpr int blockDim_x = 128; - kernel_sort = kern_sort; + kernel_sort = kern_sort; threads_sort.x = blockDim_x; } else if (input_graph_degree <= numElementsPerThread * 256) { constexpr int blockDim_x = 256; - kernel_sort = kern_sort; + kernel_sort = kern_sort; threads_sort.x = blockDim_x; } else { - fprintf(stderr, - "[ERROR] The degree of input knn graph is too large (%u). " - "It must be equal to or small than %d.\n", - input_graph_degree, - numElementsPerThread * 256); + RAFT_LOG_ERROR( + "[ERROR] The degree of input knn graph is too large (%u). " + "It must be equal to or small than %d.\n", + input_graph_degree, + numElementsPerThread * 256); exit(-1); } - dim3 blocks_sort(graph_chunk_size, 1, 1); - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_LOG_DEBUG("."); - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - kernel_sort<<>>(d_dataset_ptr[i_gpu], - dataset_size, - dataset_chunk_size, - dataset_dim, - scale, - d_input_graph_ptr[i_gpu], - dataset_size, - graph_chunk_size, - input_graph_degree, - i_gpu); - } - RAFT_CUDA_TRY(cudaSetDevice(0)); - RAFT_CUDA_TRY(cudaDeviceSynchronize()); + dim3 blocks_sort(graph_size, 1, 1); + RAFT_LOG_DEBUG("."); + kernel_sort<<>>(d_dataset.data_handle(), + dataset_size, + dataset_dim, + d_input_graph.data_handle(), + graph_size, + input_graph_degree); + res.sync_stream(); RAFT_LOG_DEBUG("."); - mgpu_D2H(d_input_graph_ptr, - input_graph_ptr, - num_gpus, - dataset_size, - graph_chunk_size, - input_graph_degree); + raft::copy(input_graph_ptr, + d_input_graph.data_handle(), + graph_size * input_graph_degree, + res.get_stream()); RAFT_LOG_DEBUG("\n"); - double time_sort_end = cur_time(); - RAFT_LOG_DEBUG("# Sorting kNN graph time: %.1lf sec\n", time_sort_end - time_sort_start); - mgpu_free(d_dataset_ptr, num_gpus); + const double time_sort_end = cur_time(); + RAFT_LOG_DEBUG("# Sorting kNN graph time: %.1lf sec\n", time_sort_end - time_sort_start); } -/** Input arrays can be both host and device*/ template , memory_type::host>> @@ -538,308 +389,252 @@ void prune(raft::device_resources const& res, "output graph cannot have more columns than input graph"); const uint32_t input_graph_degree = knn_graph.extent(1); const uint32_t output_graph_degree = new_graph.extent(1); - uint32_t* input_graph_ptr = (uint32_t*)knn_graph.data_handle(); - uint32_t* output_graph_ptr = new_graph.data_handle(); - const std::size_t graph_size = new_graph.extent(0); - size_t array_size; - - // Setup GPUs - int num_gpus = 0; - - // Setup GPUs - RAFT_CUDA_TRY(cudaGetDeviceCount(&num_gpus)); - RAFT_LOG_DEBUG("# num_gpus: %d\n", num_gpus); - for (int self = 0; self < num_gpus; self++) { - RAFT_CUDA_TRY(cudaSetDevice(self)); - for (int peer = 0; peer < num_gpus; peer++) { - if (self == peer) { continue; } - RAFT_CUDA_TRY(cudaDeviceEnablePeerAccess(peer, 0)); - } - } - RAFT_CUDA_TRY(cudaSetDevice(0)); - - uint32_t graph_chunk_size = graph_size; - uint32_t*** d_input_graph_ptr = NULL; // [...][num_gpus][graph_chunk_size, input_graph_degree] - graph_chunk_size = (graph_size + num_gpus - 1) / num_gpus; - d_input_graph_ptr = mgpu_alloc(num_gpus, graph_chunk_size, input_graph_degree); - - // - uint8_t* detour_count; // [graph_size, input_graph_degree] - array_size = sizeof(uint8_t) * graph_size * input_graph_degree; - detour_count = (uint8_t*)malloc(array_size); - memset(detour_count, 0xff, array_size); - - uint8_t*** d_detour_count = NULL; // [...][num_gpus][graph_chunk_size, input_graph_degree] - d_detour_count = mgpu_alloc(num_gpus, graph_chunk_size, input_graph_degree); - mgpu_H2D( - d_detour_count, detour_count, num_gpus, graph_size, graph_chunk_size, input_graph_degree); - - // - uint32_t* num_no_detour_edges; // [graph_size] - array_size = sizeof(uint32_t) * graph_size; - num_no_detour_edges = (uint32_t*)malloc(array_size); - memset(num_no_detour_edges, 0, array_size); + auto input_graph_ptr = knn_graph.data_handle(); + auto output_graph_ptr = new_graph.data_handle(); + const IdxT graph_size = new_graph.extent(0); - uint32_t*** d_num_no_detour_edges = NULL; // [...][num_gpus][graph_chunk_size] - d_num_no_detour_edges = mgpu_alloc(num_gpus, graph_chunk_size, 1); - mgpu_H2D( - d_num_no_detour_edges, num_no_detour_edges, num_gpus, graph_size, graph_chunk_size, 1); + auto pruned_graph = raft::make_host_matrix(graph_size, output_graph_degree); - // - uint64_t** dev_stats = NULL; // [num_gpus][2] - uint64_t** host_stats = NULL; // [num_gpus][2] - dev_stats = (uint64_t**)malloc(sizeof(uint64_t*) * num_gpus); - host_stats = (uint64_t**)malloc(sizeof(uint64_t*) * num_gpus); - array_size = sizeof(uint64_t) * 2; - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMalloc(&(dev_stats[i_gpu]), array_size)); - host_stats[i_gpu] = (uint64_t*)malloc(array_size); - } - RAFT_CUDA_TRY(cudaSetDevice(0)); + { + // + // Prune kNN graph + // + auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); + + auto detour_count = raft::make_host_matrix(graph_size, input_graph_degree); + auto d_detour_count = + raft::make_device_matrix(res, graph_size, input_graph_degree); + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), + 0xff, + graph_size * input_graph_degree * sizeof(uint8_t), + res.get_stream())); + + auto d_num_no_detour_edges = raft::make_device_vector(res, graph_size); + RAFT_CUDA_TRY(cudaMemsetAsync( + d_num_no_detour_edges.data_handle(), 0x00, graph_size * sizeof(uint32_t), res.get_stream())); + + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); + + // + // Prune unimportant edges. + // + // The edge to be retained is determined without explicitly considering + // distance or angle. Suppose the edge is the k-th edge of some node-A to + // node-B (A->B). Among the edges originating at node-A, there are k-1 edges + // shorter than the edge A->B. Each of these k-1 edges are connected to a + // different k-1 nodes. Among these k-1 nodes, count the number of nodes with + // edges to node-B, which is the number of 2-hop detours for the edge A->B. + // Once the number of 2-hop detours has been counted for all edges, the + // specified number of edges are picked up for each node, starting with the + // edge with the lowest number of 2-hop detours. + // + const double time_prune_start = cur_time(); + RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); + + raft::copy(d_input_graph.data_handle(), + input_graph_ptr, + graph_size * input_graph_degree, + res.get_stream()); + void (*kernel_prune)(const IdxT* const, + const uint32_t, + const uint32_t, + const uint32_t, + const uint32_t, + const uint32_t, + uint8_t* const, + uint32_t* const, + uint64_t* const); - // - // Prune unimportant edges. - // - // The edge to be retained is determined without explicitly considering - // distance or angle. Suppose the edge is the k-th edge of some node-A to - // node-B (A->B). Among the edges originating at node-A, there are k-1 edges - // shorter than the edge A->B. Each of these k-1 edges are connected to a - // different k-1 nodes. Among these k-1 nodes, count the number of nodes with - // edges to node-B, which is the number of 2-hop detours for the edge A->B. - // Once the number of 2-hop detours has been counted for all edges, the - // specified number of edges are picked up for each node, starting with the - // edge with the lowest number of 2-hop detours. - // - double time_prune_start = cur_time(); - uint64_t num_keep = 0; - uint64_t num_full = 0; - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - mgpu_H2D( - d_input_graph_ptr, input_graph_ptr, num_gpus, graph_size, graph_chunk_size, input_graph_degree); - void (*kernel_prune)(uint32_t**, - uint32_t, - uint32_t, - uint32_t, - uint32_t, - int, - uint32_t, - uint32_t, - uint8_t**, - uint32_t**, - uint64_t*); - if (input_graph_degree <= 1024) { constexpr int MAX_DEGREE = 1024; - kernel_prune = kern_prune; - } else { - fprintf(stderr, - "[ERROR] The degree of input knn graph is too large (%u). " - "It must be equal to or small than %d.\n", - input_graph_degree, - 1024); - exit(-1); - } - uint32_t batch_size = std::min(graph_chunk_size, (uint32_t)256 * 1024); - uint32_t num_batch = (graph_chunk_size + batch_size - 1) / batch_size; - dim3 threads_prune(32, 1, 1); - dim3 blocks_prune(batch_size, 1, 1); - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMemset(dev_stats[i_gpu], 0, sizeof(uint64_t) * 2)); - } - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - kernel_prune<<>>(d_input_graph_ptr[i_gpu], - graph_size, - graph_chunk_size, - input_graph_degree, - output_graph_degree, - i_gpu, - batch_size, - i_batch, - d_detour_count[i_gpu], - d_num_no_detour_edges[i_gpu], - dev_stats[i_gpu]); + if (input_graph_degree <= MAX_DEGREE) { + kernel_prune = kern_prune; + } else { + RAFT_LOG_ERROR( + "[ERROR] The degree of input knn graph is too large (%u). " + "It must be equal to or small than %d.\n", + input_graph_degree, + 1024); + exit(-1); } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - fprintf( - stderr, - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_chunk_size) / graph_chunk_size * 100); - } - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); + const uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + const dim3 threads_prune(32, 1, 1); + const dim3 blocks_prune(batch_size, 1, 1); + RAFT_CUDA_TRY( - cudaMemcpy(host_stats[i_gpu], dev_stats[i_gpu], sizeof(uint64_t) * 2, cudaMemcpyDefault)); - num_keep += host_stats[i_gpu][0]; - num_full += host_stats[i_gpu][1]; - } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_CUDA_TRY(cudaSetDevice(0)); - RAFT_LOG_DEBUG("\n"); + cudaMemsetAsync(dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, res.get_stream())); + + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + kernel_prune<<>>( + d_input_graph.data_handle(), + graph_size, + input_graph_degree, + output_graph_degree, + batch_size, + i_batch, + d_detour_count.data_handle(), + d_num_no_detour_edges.data_handle(), + dev_stats.data_handle()); + res.sync_stream(); + RAFT_LOG_DEBUG( + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); + } + res.sync_stream(); + RAFT_LOG_DEBUG("\n"); - mgpu_D2H( - d_detour_count, detour_count, num_gpus, graph_size, graph_chunk_size, input_graph_degree); - mgpu_D2H( - d_num_no_detour_edges, num_no_detour_edges, num_gpus, graph_size, graph_chunk_size, 1); + raft::copy(detour_count.data_handle(), + d_detour_count.data_handle(), + graph_size * input_graph_degree, + res.get_stream()); - mgpu_free(d_input_graph_ptr, num_gpus); - mgpu_free(d_detour_count, num_gpus); - mgpu_free(d_num_no_detour_edges, num_gpus); + raft::copy(host_stats.data_handle(), dev_stats.data_handle(), 2, res.get_stream()); + const auto num_keep = host_stats.data_handle()[0]; + const auto num_full = host_stats.data_handle()[1]; - // Create pruned kNN graph - array_size = sizeof(uint32_t) * graph_size * output_graph_degree; - uint32_t* pruned_graph_ptr = (uint32_t*)malloc(array_size); - uint32_t max_detour = 0; + // Create pruned kNN graph + uint32_t max_detour = 0; #pragma omp parallel for reduction(max : max_detour) - for (uint64_t i = 0; i < graph_size; i++) { - uint64_t pk = 0; - for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) { - if (max_detour < num_detour) { max_detour = num_detour; /* stats */ } - for (uint64_t k = 0; k < input_graph_degree; k++) { - if (detour_count[k + (input_graph_degree * i)] != num_detour) { continue; } - pruned_graph_ptr[pk + (output_graph_degree * i)] = - input_graph_ptr[k + (input_graph_degree * i)]; - pk += 1; + for (uint64_t i = 0; i < graph_size; i++) { + uint64_t pk = 0; + for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) { + if (max_detour < num_detour) { max_detour = num_detour; /* stats */ } + for (uint64_t k = 0; k < input_graph_degree; k++) { + if (detour_count.data_handle()[k + (input_graph_degree * i)] != num_detour) { continue; } + pruned_graph.data_handle()[pk + (output_graph_degree * i)] = + input_graph_ptr[k + (input_graph_degree * i)]; + pk += 1; + if (pk >= output_graph_degree) break; + } if (pk >= output_graph_degree) break; } - if (pk >= output_graph_degree) break; + assert(pk == output_graph_degree); } - assert(pk == output_graph_degree); - } - // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); - - double time_prune_end = cur_time(); - fprintf(stderr, - "# Pruning time: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%\n", - time_prune_end - time_prune_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); + // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); - // - // Make reverse graph - // - double time_make_start = cur_time(); - - array_size = sizeof(uint32_t) * graph_size * output_graph_degree; - uint32_t* rev_graph_ptr = (uint32_t*)malloc(array_size); - memset(rev_graph_ptr, 0xff, array_size); - - uint32_t*** d_rev_graph_ptr; // [...][num_gpus][graph_chunk_size, output_graph_degree] - d_rev_graph_ptr = mgpu_alloc(num_gpus, graph_chunk_size, output_graph_degree); - mgpu_H2D( - d_rev_graph_ptr, rev_graph_ptr, num_gpus, graph_size, graph_chunk_size, output_graph_degree); - - array_size = sizeof(uint32_t) * graph_size; - uint32_t* rev_graph_count = (uint32_t*)malloc(array_size); - memset(rev_graph_count, 0, array_size); - - uint32_t*** d_rev_graph_count; // [...][num_gpus][graph_chunk_size, 1] - d_rev_graph_count = mgpu_alloc(num_gpus, graph_chunk_size, 1); - mgpu_H2D(d_rev_graph_count, rev_graph_count, num_gpus, graph_size, graph_chunk_size, 1); - - uint32_t* dest_nodes; // [graph_size] - dest_nodes = (uint32_t*)malloc(sizeof(uint32_t) * graph_size); - uint32_t** d_dest_nodes; // [num_gpus][graph_size] - d_dest_nodes = (uint32_t**)malloc(sizeof(uint32_t*) * num_gpus); - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMalloc(&(d_dest_nodes[i_gpu]), sizeof(uint32_t) * graph_size)); + const double time_prune_end = cur_time(); + RAFT_LOG_DEBUG( + "# Pruning time: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%\n", + time_prune_end - time_prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); } - for (uint64_t k = 0; k < output_graph_degree; k++) { + auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); + auto rev_graph_count = raft::make_host_vector(graph_size); + + { + // + // Make reverse graph + // + const double time_make_start = cur_time(); + + auto d_rev_graph = raft::make_device_matrix(res, graph_size, output_graph_degree); + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), + 0xff, + graph_size * output_graph_degree * sizeof(IdxT), + res.get_stream())); + + auto d_rev_graph_count = raft::make_device_vector(res, graph_size); + RAFT_CUDA_TRY(cudaMemsetAsync( + d_rev_graph_count.data_handle(), 0x00, graph_size * sizeof(uint32_t), res.get_stream())); + + auto dest_nodes = raft::make_host_vector(graph_size); + auto d_dest_nodes = raft::make_device_vector(res, graph_size); + + for (uint64_t k = 0; k < output_graph_degree; k++) { #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - dest_nodes[i] = pruned_graph_ptr[k + (output_graph_degree * i)]; - } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); -#pragma omp parallel num_threads(num_gpus) - { - int i_gpu = omp_get_thread_num(); - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMemcpy( - d_dest_nodes[i_gpu], dest_nodes, sizeof(uint32_t) * graph_size, cudaMemcpyHostToDevice)); + for (uint64_t i = 0; i < graph_size; i++) { + dest_nodes.data_handle()[i] = pruned_graph.data_handle()[k + (output_graph_degree * i)]; + } + res.sync_stream(); + + raft::copy( + d_dest_nodes.data_handle(), dest_nodes.data_handle(), graph_size, res.get_stream()); + dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>(i_gpu, - d_dest_nodes[i_gpu], - graph_size, - d_rev_graph_ptr[num_gpus][i_gpu], - d_rev_graph_count[num_gpus][i_gpu], - graph_chunk_size, - output_graph_degree); + kern_make_rev_graph<<>>(d_dest_nodes.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); } - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); - } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_CUDA_TRY(cudaSetDevice(0)); - RAFT_LOG_DEBUG("\n"); - mgpu_D2H( - d_rev_graph_ptr, rev_graph_ptr, num_gpus, graph_size, graph_chunk_size, output_graph_degree); - mgpu_D2H(d_rev_graph_count, rev_graph_count, num_gpus, graph_size, graph_chunk_size, 1); - mgpu_free(d_rev_graph_ptr, num_gpus); - mgpu_free(d_rev_graph_count, num_gpus); + res.sync_stream(); + RAFT_LOG_DEBUG("\n"); - double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf sec", time_make_end - time_make_start); + raft::copy(rev_graph.data_handle(), + d_rev_graph.data_handle(), + graph_size * output_graph_degree, + res.get_stream()); + raft::copy( + rev_graph_count.data_handle(), d_rev_graph_count.data_handle(), graph_size, res.get_stream()); - // - // Replace some edges with reverse edges - // - double time_replace_start = cur_time(); + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf sec", time_make_end - time_make_start); + } + + { + // + // Replace some edges with reverse edges + // + const double time_replace_start = cur_time(); - uint64_t num_protected_edges = output_graph_degree / 2; - RAFT_LOG_DEBUG("# num_protected_edges: %lu", num_protected_edges); + const uint64_t num_protected_edges = output_graph_degree / 2; + RAFT_LOG_DEBUG("# num_protected_edges: %lu", num_protected_edges); - array_size = sizeof(uint32_t) * graph_size * output_graph_degree; - memcpy(output_graph_ptr, pruned_graph_ptr, array_size); + memcpy(output_graph_ptr, + pruned_graph.data_handle(), + sizeof(uint32_t) * graph_size * output_graph_degree); - constexpr int _omp_chunk = 1024; + constexpr int _omp_chunk = 1024; #pragma omp parallel for schedule(dynamic, _omp_chunk) - for (uint64_t j = 0; j < graph_size; j++) { - for (uint64_t _k = 0; _k < rev_graph_count[j]; _k++) { - uint64_t k = rev_graph_count[j] - 1 - _k; - uint64_t i = rev_graph_ptr[k + (output_graph_degree * j)]; - - uint64_t pos = pos_in_array( - i, output_graph_ptr + (output_graph_degree * j), output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos == output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } - shift_array(output_graph_ptr + num_protected_edges + (output_graph_degree * j), - num_shift); - output_graph_ptr[num_protected_edges + (output_graph_degree * j)] = i; - } - if ((omp_get_thread_num() == 0) && ((j % _omp_chunk) == 0)) { - RAFT_LOG_DEBUG("# Replacing reverse edges: %lu / %lu ", j, graph_size); + for (uint64_t j = 0; j < graph_size; j++) { + for (uint64_t _k = 0; _k < rev_graph_count.data_handle()[j]; _k++) { + uint64_t k = rev_graph_count.data_handle()[j] - 1 - _k; + uint64_t i = rev_graph.data_handle()[k + (output_graph_degree * j)]; + + uint64_t pos = + pos_in_array(i, output_graph_ptr + (output_graph_degree * j), output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos == output_graph_degree) { + num_shift = output_graph_degree - num_protected_edges - 1; + } + shift_array(output_graph_ptr + num_protected_edges + (output_graph_degree * j), + num_shift); + output_graph_ptr[num_protected_edges + (output_graph_degree * j)] = i; + } + if ((omp_get_thread_num() == 0) && ((j % _omp_chunk) == 0)) { + RAFT_LOG_DEBUG("# Replacing reverse edges: %lu / %lu ", j, graph_size); + } } - } - RAFT_LOG_DEBUG("\n"); - free(rev_graph_ptr); - free(rev_graph_count); + RAFT_LOG_DEBUG("\n"); - double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf sec", time_replace_end - time_replace_start); + const double time_replace_end = cur_time(); + RAFT_LOG_DEBUG("# Replacing edges time: %.1lf sec", time_replace_end - time_replace_start); - /* stats */ - uint64_t num_replaced_edges = 0; + /* stats */ + uint64_t num_replaced_edges = 0; #pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - uint64_t j = pruned_graph_ptr[k + (output_graph_degree * i)]; - uint64_t pos = pos_in_array( - j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = pruned_graph.data_handle()[k + (output_graph_degree * i)]; + const uint64_t pos = + pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); + if (pos == output_graph_degree) { num_replaced_edges += 1; } + } } + RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", + (double)num_replaced_edges / graph_size); } - fprintf(stderr, - "# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); } } // namespace graph diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 99553632ac..dc3607016b 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -82,9 +82,9 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } } -template +template __device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] - uint32_t* indices, // [num_elements] + INDEX_T* indices, // [num_elements] const uint32_t num_elements, const uint32_t num_itopk // num_itopk <= num_elements ) @@ -94,7 +94,7 @@ __device__ inline void topk_by_bitonic_sort(float* distances, // [num_el const unsigned lane_id = threadIdx.x % 32; constexpr unsigned N = (MAX_ELEMENTS + 31) / 32; float key[N]; - uint32_t val[N]; + INDEX_T val[N]; for (unsigned i = 0; i < N; i++) { unsigned j = lane_id + (32 * i); if (j < num_elements) { @@ -102,11 +102,11 @@ __device__ inline void topk_by_bitonic_sort(float* distances, // [num_el val[i] = indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Warp Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); /* Store itopk sorted results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; @@ -192,7 +192,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( auto result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto parent_indices_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); + reinterpret_cast(result_distances_buffer + result_buffer_size_32); auto terminate_flag = reinterpret_cast(parent_indices_buffer + num_parents); #if 0 @@ -244,10 +244,10 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( while (1) { // topk with bitonic sort _CLK_START(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (num_parents * graph_degree), - itopk_size); + topk_by_bitonic_sort(result_distances_buffer, + result_indices_buffer, + itopk_size + (num_parents * graph_degree), + itopk_size); _CLK_REC(clk_topk); if (iter + 1 == max_iteration) { @@ -454,7 +454,7 @@ struct search : public search_plan_impl { using search_plan_impl::num_seeds; uint32_t num_cta_per_query; - rmm::device_uvector intermediate_indices; + rmm::device_uvector intermediate_indices; rmm::device_uvector intermediate_distances; size_t topk_workspace_size; rmm::device_uvector topk_workspace; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index e3e9c8a655..5a16163d97 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -536,9 +536,9 @@ struct search : search_plan_impl { using search_plan_impl::num_seeds; size_t result_buffer_allocation_size; - rmm::device_uvector result_indices; // results_indices_buffer - rmm::device_uvector result_distances; // result_distances_buffer - rmm::device_uvector parent_node_list; + rmm::device_uvector result_indices; // results_indices_buffer + rmm::device_uvector result_distances; // result_distances_buffer + rmm::device_uvector parent_node_list; rmm::device_uvector topk_hint; rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; rmm::device_uvector topk_workspace; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 09d5e71254..1ef35ae97c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -82,7 +82,7 @@ struct search_plan_impl : public search_plan_impl_base { rmm::device_uvector hashmap; rmm::device_uvector num_executed_iterations; // device or managed? - rmm::device_uvector dev_seed; // IdxT + rmm::device_uvector dev_seed; // IdxT search_plan_impl(raft::device_resources const& res, search_params params, diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 531b30ba85..0e4ec10abd 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -91,49 +91,52 @@ struct topk_by_radix_sort_base { static constexpr std::uint32_t state_bit_lenght = 0; static constexpr std::uint32_t vecLen = 2; // TODO }; -template +template struct topk_by_radix_sort : topk_by_radix_sort_base {}; -template +template struct topk_by_radix_sort> : topk_by_radix_sort_base { __device__ void operator()(uint32_t topk, uint32_t batch_size, uint32_t len_x, const uint32_t* _x, - const uint32_t* _in_vals, + const IdxT* _in_vals, uint32_t* _y, - uint32_t* _out_vals, + IdxT* _out_vals, uint32_t* work, uint32_t* _hints, bool sort, uint32_t* _smem) { - std::uint8_t* state = (std::uint8_t*)work; + std::uint8_t* const state = reinterpret_cast(work); topk_cta_11_core::state_bit_lenght, topk_by_radix_sort_base::vecLen, 64, - 32>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); + 32, + IdxT>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); } }; #define TOP_FUNC_PARTIAL_SPECIALIZATION(V) \ - template \ + template \ struct topk_by_radix_sort< \ MAX_INTERNAL_TOPK, \ BLOCK_SIZE, \ + IdxT, \ std::enable_if_t<((MAX_INTERNAL_TOPK <= V) && (2 * MAX_INTERNAL_TOPK > V))>> \ : topk_by_radix_sort_base { \ __device__ void operator()(uint32_t topk, \ uint32_t batch_size, \ uint32_t len_x, \ const uint32_t* _x, \ - const uint32_t* _in_vals, \ + const IdxT* _in_vals, \ uint32_t* _y, \ - uint32_t* _out_vals, \ + IdxT* _out_vals, \ uint32_t* work, \ uint32_t* _hints, \ bool sort, \ @@ -145,7 +148,8 @@ struct topk_by_radix_sort::state_bit_lenght, \ topk_by_radix_sort_base::vecLen, \ V, \ - V / 4>( \ + V / 4, \ + IdxT>( \ topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); \ } \ }; @@ -154,12 +158,11 @@ TOP_FUNC_PARTIAL_SPECIALIZATION(256); TOP_FUNC_PARTIAL_SPECIALIZATION(512); TOP_FUNC_PARTIAL_SPECIALIZATION(1024); -template -__device__ inline void topk_by_bitonic_sort_1st( - float* candidate_distances, // [num_candidates] - std::uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - const std::uint32_t num_itopk) +template +__device__ inline void topk_by_bitonic_sort_1st(float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) { const unsigned lane_id = threadIdx.x % 32; const unsigned warp_id = threadIdx.x / 32; @@ -167,7 +170,7 @@ __device__ inline void topk_by_bitonic_sort_1st( if (warp_id > 0) { return; } constexpr unsigned N = (MAX_CANDIDATES + 31) / 32; float key[N]; - std::uint32_t val[N]; + IdxT val[N]; /* Candidates -> Reg */ for (unsigned i = 0; i < N; i++) { unsigned j = lane_id + (32 * i); @@ -176,11 +179,11 @@ __device__ inline void topk_by_bitonic_sort_1st( val[i] = candidate_indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); /* Reg -> Temp_itopk */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; @@ -194,7 +197,7 @@ __device__ inline void topk_by_bitonic_sort_1st( constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; constexpr unsigned N = (max_candidates_per_warp + 31) / 32; float key[N]; - std::uint32_t val[N]; + IdxT val[N]; if (warp_id < 2) { /* Candidates -> Reg */ for (unsigned i = 0; i < N; i++) { @@ -205,11 +208,11 @@ __device__ inline void topk_by_bitonic_sort_1st( val[i] = candidate_indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); /* Reg -> Temp_candidates */ for (unsigned i = 0; i < N; i++) { unsigned jl = (N * lane_id) + i; @@ -242,7 +245,7 @@ __device__ inline void topk_by_bitonic_sort_1st( if (num_warps_used > 1) { __syncthreads(); } if (warp_id < num_warps_used) { /* Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, 32); /* Reg -> Temp_itopk */ for (unsigned i = 0; i < N; i++) { unsigned jl = (N * lane_id) + i; @@ -257,16 +260,15 @@ __device__ inline void topk_by_bitonic_sort_1st( } } -template -__device__ inline void topk_by_bitonic_sort_2nd( - float* itopk_distances, // [num_itopk] - std::uint32_t* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - std::uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first) +template +__device__ inline void topk_by_bitonic_sort_2nd(float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) { const unsigned lane_id = threadIdx.x % 32; const unsigned warp_id = threadIdx.x / 32; @@ -274,7 +276,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( if (warp_id > 0) { return; } constexpr unsigned N = (MAX_ITOPK + 31) / 32; float key[N]; - std::uint32_t val[N]; + IdxT val[N]; if (first) { /* Load itopk results */ for (unsigned i = 0; i < N; i++) { @@ -284,11 +286,11 @@ __device__ inline void topk_by_bitonic_sort_2nd( val[i] = itopk_indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Warp Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); } else { /* Load itopk results */ for (unsigned i = 0; i < N; i++) { @@ -298,7 +300,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( val[i] = itopk_indices[device::swizzling(j)]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } } @@ -314,7 +316,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( } } /* Warp Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, 32); /* Store new itopk results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; @@ -328,7 +330,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; constexpr unsigned N = (max_itopk_per_warp + 31) / 32; float key[N]; - std::uint32_t val[N]; + IdxT val[N]; if (first) { /* Load itop results (not sorted) */ if (warp_id < 2) { @@ -339,11 +341,11 @@ __device__ inline void topk_by_bitonic_sort_2nd( val[i] = itopk_indices[j]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Warp Sort */ - bitonic::warp_sort(key, val); + bitonic::warp_sort(key, val); /* Store intermedidate results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * threadIdx.x) + i; @@ -367,7 +369,7 @@ __device__ inline void topk_by_bitonic_sort_2nd( } } /* Warp Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, 32); } __syncthreads(); /* Store itopk results (sorted) */ @@ -412,8 +414,8 @@ __device__ inline void topk_by_bitonic_sort_2nd( if (key_0 > key_1) { itopk_distances[device::swizzling(j)] = key_1; itopk_distances[device::swizzling(k)] = key_0; - std::uint32_t val_0 = itopk_indices[device::swizzling(j)]; - std::uint32_t val_1 = itopk_indices[device::swizzling(k)]; + IdxT val_0 = itopk_indices[device::swizzling(j)]; + IdxT val_1 = itopk_indices[device::swizzling(k)]; itopk_indices[device::swizzling(j)] = val_1; itopk_indices[device::swizzling(k)] = val_0; atomicMin(work_buf + 0, j); @@ -445,11 +447,11 @@ __device__ inline void topk_by_bitonic_sort_2nd( val[i] = itopk_indices[device::swizzling(k)]; } else { key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); } } /* Warp Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, 32); /* Store new itopk results */ for (unsigned i = 0; i < N; i++) { const unsigned j = (N * lane_id) + i; @@ -466,30 +468,31 @@ __device__ inline void topk_by_bitonic_sort_2nd( template -__device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] - std::uint32_t* itopk_indices, // [num_itopk] + unsigned MULTI_WARPS_2, + class IdxT> +__device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - std::uint32_t* candidate_indices, // [num_candidates] + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] const std::uint32_t num_candidates, std::uint32_t* work_buf, const bool first) { // The results in candidate_distances/indices are sorted by bitonic sort. - topk_by_bitonic_sort_1st( + topk_by_bitonic_sort_1st( candidate_distances, candidate_indices, num_candidates, num_itopk); // The results sorted above are merged with the internal intermediate top-k // results so far using bitonic merge. - topk_by_bitonic_sort_2nd(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); + topk_by_bitonic_sort_2nd(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); } template @@ -586,7 +589,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto visited_hash_buffer = reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); + auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); auto topk_ws = reinterpret_cast(parent_list_buffer + num_parents); auto terminate_flag = reinterpret_cast(topk_ws + 3); auto smem_working_ptr = reinterpret_cast(terminate_flag + 1); @@ -691,7 +694,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ } else { _CLK_START(); // topk with radix block sort - topk_by_radix_sort{}( + topk_by_radix_sort{}( internal_topk, gridDim.x, result_buffer_size, @@ -997,17 +1000,18 @@ struct search : search_plan_impl { const std::uint32_t topk_ws_size = 3; const std::uint32_t base_smem_size = sizeof(float) * max_dim + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + - sizeof(std::uint32_t) * hashmap::get_size(small_hash_bitlen) + - sizeof(std::uint32_t) * num_parents + sizeof(std::uint32_t) * topk_ws_size + - sizeof(std::uint32_t); + sizeof(std::uint32_t) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * num_parents + + sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t); smem_size = base_smem_size; if (num_itopk_candidates > 256) { // Tentatively calculate the required share memory size when radix // sort based topk is used, assuming the block size is the maximum. if (itopk_size <= 256) { - smem_size += topk_by_radix_sort<256, max_block_size>::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort<256, max_block_size, INDEX_T>::smem_size * sizeof(std::uint32_t); } else { - smem_size += topk_by_radix_sort<512, max_block_size>::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort<512, max_block_size, INDEX_T>::smem_size * sizeof(std::uint32_t); } } @@ -1078,25 +1082,31 @@ struct search : search_plan_impl { constexpr unsigned MAX_ITOPK = 256; if (block_size == 256) { constexpr unsigned BLOCK_SIZE = 256; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } else if (block_size == 512) { constexpr unsigned BLOCK_SIZE = 512; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } else { constexpr unsigned BLOCK_SIZE = 1024; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } } else { constexpr unsigned MAX_ITOPK = 512; if (block_size == 256) { constexpr unsigned BLOCK_SIZE = 256; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } else if (block_size == 512) { constexpr unsigned BLOCK_SIZE = 512; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } else { constexpr unsigned BLOCK_SIZE = 1024; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); + smem_size += + topk_by_radix_sort::smem_size * sizeof(std::uint32_t); } } } diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h index ccb65fd0ea..2896dba1f3 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h @@ -27,17 +27,18 @@ size_t _cuann_find_topk_bufferSize(uint32_t topK, cudaDataType_t sampleDtype = CUDA_R_32F); // +template void _cuann_find_topk(uint32_t topK, uint32_t sizeBatch, uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const uint32_t* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - uint32_t* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK + const float* inputKeys, // [sizeBatch, ldIK,] + uint32_t ldIK, // (*) ldIK >= numElements + const ValT* inputVals, // [sizeBatch, ldIV,] + uint32_t ldIV, // (*) ldIV >= numElements + float* outputKeys, // [sizeBatch, ldOK,] + uint32_t ldOK, // (*) ldOK >= topK + ValT* outputVals, // [sizeBatch, ldOV,] + uint32_t ldOV, // (*) ldOV >= topK void* workspace, bool sort = false, uint32_t* hint = NULL, @@ -54,4 +55,4 @@ CUDA_DEVICE_HOST_FUNC inline size_t _cuann_aligned(size_t size, size_t unit = 12 if (size % unit) { size += unit - (size % unit); } return size; } -} // namespace raft::neighbors::experimental::cagra::detail \ No newline at end of file +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh index 072593550e..eddaa9cea8 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh @@ -493,14 +493,14 @@ __device__ __host__ inline uint32_t get_state_size(uint32_t len_x) } // -template +template __device__ inline void topk_cta_11_core(uint32_t topk, uint32_t len_x, - const uint32_t* _x, // [size_batch, ld_x,] - const uint32_t* _in_vals, // [size_batch, ld_iv,] - uint32_t* _y, // [size_batch, ld_y,] - uint32_t* _out_vals, // [size_batch, ld_ov,] - uint8_t* _state, // [size_batch, ...,] + const uint32_t* _x, // [size_batch, ld_x,] + const ValT* _in_vals, // [size_batch, ld_iv,] + uint32_t* _y, // [size_batch, ld_y,] + ValT* _out_vals, // [size_batch, ld_ov,] + uint8_t* _state, // [size_batch, ...,] uint32_t* _hint, bool sort, uint32_t* _smem) @@ -514,11 +514,11 @@ __device__ inline void topk_cta_11_core(uint32_t topk, const uint32_t thread_id = threadIdx.x; uint32_t nx = len_x; const uint32_t* x = _x; - const uint32_t* in_vals = NULL; + const ValT* in_vals = NULL; if (_in_vals) { in_vals = _in_vals; } uint32_t* y = NULL; if (_y) { y = _y; } - uint32_t* out_vals = NULL; + ValT* out_vals = NULL; if (_out_vals) { out_vals = _out_vals; } uint8_t* state = _state; uint32_t hint = (_hint == NULL ? ~0u : *_hint); @@ -616,7 +616,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, constexpr int numTopkPerThread = maxTopk / numSortThreads; float my_keys[numTopkPerThread]; - uint32_t my_vals[numTopkPerThread]; + ValT my_vals[numTopkPerThread]; // Read keys and values to registers if (thread_id < numSortThreads) { @@ -632,7 +632,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, } } else { my_keys[i] = FLT_MAX; - my_vals[i] = 0xffffffffU; + my_vals[i] = ~static_cast(0); } } } @@ -641,21 +641,21 @@ __device__ inline void topk_cta_11_core(uint32_t topk, // Sorting by thread if (thread_id < numSortThreads) { - bool ascending = ((thread_id & mask) == 0); + const bool ascending = ((thread_id & mask) == 0); if (numTopkPerThread == 3) { - swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); - swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); - swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); + swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); + swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); + swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); } else { for (int j = 0; j < numTopkPerThread / 2; j += 1) { #pragma unroll for (int i = 0; i < numTopkPerThread; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } #pragma unroll for (int i = 1; i < numTopkPerThread - 1; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } } @@ -667,11 +667,12 @@ __device__ inline void topk_cta_11_core(uint32_t topk, uint32_t next_mask = mask << 1; for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { - bool ascending = ((thread_id & curr_mask) == 0) == ((thread_id & next_mask) == 0); + const bool ascending = ((thread_id & curr_mask) == 0) == ((thread_id & next_mask) == 0); if (curr_mask >= 32) { // inter warp - uint32_t* smem_vals = _smem; // [numTopkPerThread, numSortThreads] - float* smem_keys = (float*)(_smem + numTopkPerThread * numSortThreads); + ValT* smem_vals = reinterpret_cast(_smem); // [maxTopk] + float* smem_keys = + reinterpret_cast(smem_vals + maxTopk); // [numTopkPerThread, numSortThreads] __syncthreads(); if (thread_id < numSortThreads) { #pragma unroll @@ -684,9 +685,9 @@ __device__ inline void topk_cta_11_core(uint32_t topk, if (thread_id < numSortThreads) { #pragma unroll for (int i = 0; i < numTopkPerThread; i++) { - float opp_key = smem_keys[(thread_id ^ curr_mask) + (numSortThreads * i)]; - uint32_t opp_val = smem_vals[(thread_id ^ curr_mask) + (numSortThreads * i)]; - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + float opp_key = smem_keys[(thread_id ^ curr_mask) + (numSortThreads * i)]; + ValT opp_val = smem_vals[(thread_id ^ curr_mask) + (numSortThreads * i)]; + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } } else { @@ -694,9 +695,9 @@ __device__ inline void topk_cta_11_core(uint32_t topk, if (thread_id < numSortThreads) { #pragma unroll for (int i = 0; i < numTopkPerThread; i++) { - float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); - uint32_t opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); + ValT opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } } @@ -705,9 +706,9 @@ __device__ inline void topk_cta_11_core(uint32_t topk, if (thread_id < numSortThreads) { bool ascending = ((thread_id & next_mask) == 0); if (numTopkPerThread == 3) { - swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); - swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); - swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); + swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); + swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); + swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); } else { #pragma unroll for (uint32_t curr_mask = numTopkPerThread / 2; curr_mask > 0; curr_mask >>= 1) { @@ -715,8 +716,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, for (int i = 0; i < numTopkPerThread; i++) { int j = i ^ curr_mask; if (i > j) continue; - swap_if_needed( - my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); + swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); } } } @@ -729,7 +729,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, for (int i = 0; i < numTopkPerThread; i++) { int k = i + (numTopkPerThread * thread_id); if (k < topk) { - if (y) { y[k] = ((uint32_t*)my_keys)[i]; } + if (y) { y[k] = reinterpret_cast(my_keys)[i]; } if (out_vals) { out_vals[k] = my_vals[i]; } } } @@ -755,28 +755,32 @@ int _get_vecLen(uint32_t maxSamples, int maxVecLen = MAX_VEC_LENGTH) } } // unnamed namespace -template +template __launch_bounds__(1024, 1) __global__ void kern_topk_cta_11(uint32_t topk, uint32_t size_batch, uint32_t len_x, - const uint32_t* _x, // [size_batch, ld_x,] + const uint32_t* _x, // [size_batch, ld_x,] uint32_t ld_x, - const uint32_t* _in_vals, // [size_batch, ld_iv,] + const ValT* _in_vals, // [size_batch, ld_iv,] uint32_t ld_iv, - uint32_t* _y, // [size_batch, ld_y,] + uint32_t* _y, // [size_batch, ld_y,] uint32_t ld_y, - uint32_t* _out_vals, // [size_batch, ld_ov,] + ValT* _out_vals, // [size_batch, ld_ov,] uint32_t ld_ov, - uint8_t* _state, // [size_batch, ...,] - uint32_t* _hints, // [size_batch,] + uint8_t* _state, // [size_batch, ...,] + uint32_t* _hints, // [size_batch,] bool sort) { uint32_t i_batch = blockIdx.x; if (i_batch >= size_batch) return; - __shared__ uint32_t _smem[2 * maxTopk + 2048 + 8]; - topk_cta_11_core( + constexpr uint32_t smem_len = 2 * maxTopk + 2048 + 8; + static_assert(maxTopk * (1 + utils::size_of() / utils::size_of()) <= smem_len, + "maxTopk * sizeof(ValT) must be smaller or equal to 8192 byte"); + __shared__ uint32_t _smem[smem_len]; + + topk_cta_11_core( topk, len_x, (_x == NULL ? NULL : _x + i_batch * ld_x), @@ -809,17 +813,18 @@ size_t inline _cuann_find_topk_bufferSize(uint32_t topK, return workspaceSize; } +template inline void _cuann_find_topk(uint32_t topK, uint32_t sizeBatch, uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const uint32_t* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - uint32_t* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK + const float* inputKeys, // [sizeBatch, ldIK,] + uint32_t ldIK, // (*) ldIK >= numElements + const ValT* inputVals, // [sizeBatch, ldIV,] + uint32_t ldIV, // (*) ldIV >= numElements + float* outputKeys, // [sizeBatch, ldOK,] + uint32_t ldOK, // (*) ldOK >= topK + ValT* outputVals, // [sizeBatch, ldOV,] + uint32_t ldOV, // (*) ldOV >= topK void* workspace, bool sort, uint32_t* hints, @@ -845,48 +850,48 @@ inline void _cuann_find_topk(uint32_t topK, uint32_t, const uint32_t*, uint32_t, - const uint32_t*, + const ValT*, uint32_t, uint32_t*, uint32_t, - uint32_t*, + ValT*, uint32_t, uint8_t*, uint32_t*, bool) = nullptr; // V:vecLen, K:maxTopk, T:numSortThreads -#define SET_KERNEL_VKT(V, K, T) \ - do { \ - assert(numThreads >= T); \ - assert((K % T) == 0); \ - assert((K / T) <= 4); \ - cta_kernel = kern_topk_cta_11; \ +#define SET_KERNEL_VKT(V, K, T, ValT) \ + do { \ + assert(numThreads >= T); \ + assert((K % T) == 0); \ + assert((K / T) <= 4); \ + cta_kernel = kern_topk_cta_11; \ } while (0) // V: vecLen -#define SET_KERNEL_V(V) \ +#define SET_KERNEL_V(V, ValT) \ do { \ if (topK <= 32) { \ - SET_KERNEL_VKT(V, 32, 32); \ + SET_KERNEL_VKT(V, 32, 32, ValT); \ } else if (topK <= 64) { \ - SET_KERNEL_VKT(V, 64, 32); \ + SET_KERNEL_VKT(V, 64, 32, ValT); \ } else if (topK <= 96) { \ - SET_KERNEL_VKT(V, 96, 32); \ + SET_KERNEL_VKT(V, 96, 32, ValT); \ } else if (topK <= 128) { \ - SET_KERNEL_VKT(V, 128, 32); \ + SET_KERNEL_VKT(V, 128, 32, ValT); \ } else if (topK <= 192) { \ - SET_KERNEL_VKT(V, 192, 64); \ + SET_KERNEL_VKT(V, 192, 64, ValT); \ } else if (topK <= 256) { \ - SET_KERNEL_VKT(V, 256, 64); \ + SET_KERNEL_VKT(V, 256, 64, ValT); \ } else if (topK <= 384) { \ - SET_KERNEL_VKT(V, 384, 128); \ + SET_KERNEL_VKT(V, 384, 128, ValT); \ } else if (topK <= 512) { \ - SET_KERNEL_VKT(V, 512, 128); \ + SET_KERNEL_VKT(V, 512, 128, ValT); \ } else if (topK <= 768) { \ - SET_KERNEL_VKT(V, 768, 256); \ + SET_KERNEL_VKT(V, 768, 256, ValT); \ } else if (topK <= 1024) { \ - SET_KERNEL_VKT(V, 1024, 256); \ + SET_KERNEL_VKT(V, 1024, 256, ValT); \ } \ /* else if (topK <= 1536) { SET_KERNEL_VKT(V, 1536, 512); } */ \ /* else if (topK <= 2048) { SET_KERNEL_VKT(V, 2048, 512); } */ \ @@ -901,9 +906,9 @@ inline void _cuann_find_topk(uint32_t topK, int _vecLen = _get_vecLen(ldIK, 2); if (_vecLen == 2) { - SET_KERNEL_V(2); + SET_KERNEL_V(2, ValT); } else if (_vecLen == 1) { - SET_KERNEL_V(1); + SET_KERNEL_V(1, ValT); } cta_kernel<<>>(topK, @@ -923,4 +928,4 @@ inline void _cuann_find_topk(uint32_t topK, return; } -} // namespace raft::neighbors::experimental::cagra::detail \ No newline at end of file +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp index 3e329c9239..14187c6d31 100644 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -128,6 +128,11 @@ _RAFT_HOST_DEVICE inline std::uint32_t get_max_value() { return 0xffffffffu; }; +template <> +_RAFT_HOST_DEVICE inline std::uint64_t get_max_value() +{ + return 0xfffffffffffffffflu; +}; template struct constexpr_max { From fd0291ffcf530ea201ae3f8b09ab1cf5bf3b27c7 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Fri, 12 May 2023 13:51:49 +0900 Subject: [PATCH 02/18] Add CAGRA test for 64-bit index data type --- cpp/test/CMakeLists.txt | 1 + .../ann_cagra/test_float_uint64_t.cu | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 cpp/test/neighbors/ann_cagra/test_float_uint64_t.cu diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 88ad7772c2..aad32ddc51 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -316,6 +316,7 @@ if(BUILD_TESTS) test/neighbors/ann_cagra/test_float_uint32_t.cu test/neighbors/ann_cagra/test_int8_t_uint32_t.cu test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu + test/neighbors/ann_cagra/test_float_uint64_t.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint64_t.cu new file mode 100644 index 0000000000..3fceb58918 --- /dev/null +++ b/cpp/test/neighbors/ann_cagra/test_float_uint64_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "../ann_cagra.cuh" + +namespace raft::neighbors::experimental::cagra { + +typedef AnnCagraTest AnnCagraTestF_I64; +TEST_P(AnnCagraTestF_I64, AnnCagra) { this->testCagra(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_I64, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::cagra From 4dec1438d532a4f26b2de6d50f958c46b43a18b6 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Fri, 12 May 2023 16:08:59 +0900 Subject: [PATCH 03/18] Fix a bug in CAGRA topk_cta_11_core --- .../raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh index eddaa9cea8..d5015c7ec7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh @@ -729,7 +729,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, for (int i = 0; i < numTopkPerThread; i++) { int k = i + (numTopkPerThread * thread_id); if (k < topk) { - if (y) { y[k] = reinterpret_cast(my_keys)[i]; } + if (y) { y[k] = reinterpret_cast(my_keys)[i]; } if (out_vals) { out_vals[k] = my_vals[i]; } } } From 566f96d3540002aa9ac5cd9caa4aff69d541af18 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Fri, 12 May 2023 18:00:21 +0900 Subject: [PATCH 04/18] Add const --- .../detail/cagra/topk_for_cagra/topk_core.cuh | 96 +++++++++---------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh index d5015c7ec7..d6aca93b57 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh @@ -237,7 +237,7 @@ __device__ inline void update_histogram(int itr, } #pragma unroll for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { - int iv = i + (num_threads * v); + const int iv = i + (num_threads * v); if (iv >= nx) break; struct u32_vector x_u32_vec; @@ -249,7 +249,7 @@ __device__ inline void update_histogram(int itr, } #pragma unroll for (int u = 0; u < vecLen; u++) { - int ivu = iv + u; + const int ivu = iv + u; if (ivu >= nx) break; uint8_t mask = (uint8_t)0x1 << (v + u); @@ -270,7 +270,7 @@ __device__ inline void update_histogram(int itr, iState |= mask; } } else { - uint32_t k = (xi - threshold) >> shift; // 0 <= k + const uint32_t k = (xi - threshold) >> shift; // 0 <= k if (k >= num_bins) { if (stateBitLen == 8) { iState |= mask; } } else if (k + 1 < num_bins) { @@ -287,15 +287,15 @@ __device__ inline void update_histogram(int itr, // template -__device__ inline void select_best_index_for_next_threshold(uint32_t topk, - uint32_t threshold, - uint32_t max_threshold, - uint32_t nx_below_threshold, - uint32_t num_bins, - uint32_t shift, - const uint32_t* hist, // [num_bins] - uint32_t* best_index, - uint32_t* best_csum) +__device__ inline void select_best_index_for_next_threshold(const uint32_t topk, + const uint32_t threshold, + const uint32_t max_threshold, + const uint32_t nx_below_threshold, + const uint32_t num_bins, + const uint32_t shift, + const uint32_t* const hist, // [num_bins] + uint32_t* const best_index, + uint32_t* const best_csum) { // Scan the histogram ('hist') and compute csum. Then, find the largest // index under the condition that the sum of the number of elements found @@ -311,7 +311,7 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, if (threadIdx.x < num_bins) { csum = hist[threadIdx.x]; } BlockScanT(temp_storage).InclusiveSum(csum, csum); if (threadIdx.x < num_bins) { - uint32_t index = threadIdx.x; + const uint32_t index = threadIdx.x; if ((nx_below_threshold + csum <= topk) && (threshold + (index << shift) <= max_threshold)) { my_index = index; my_csum = csum; @@ -327,7 +327,7 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, BlockScanT(temp_storage).InclusiveSum(csum, csum); for (int i = n_data - 1; i >= 0; i--) { if (nx_below_threshold + csum[i] > topk) continue; - uint32_t index = i + (n_data * threadIdx.x); + const uint32_t index = i + (n_data * threadIdx.x); if (threshold + (index << shift) > max_threshold) continue; my_index = index; my_csum = csum[i]; @@ -342,7 +342,7 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, BlockScanT(temp_storage).InclusiveSum(csum, csum); for (int i = n_data - 1; i >= 0; i--) { if (nx_below_threshold + csum[i] > topk) continue; - uint32_t index = i + (n_data * threadIdx.x); + const uint32_t index = i + (n_data * threadIdx.x); if (threshold + (index << shift) > max_threshold) continue; my_index = index; my_csum = csum[i]; @@ -351,9 +351,9 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, } } if (threadIdx.x < num_bins) { - int laneid = 31 - __clz(__ballot_sync(0xffffffff, (my_index != 0xffffffff))); + const int laneid = 31 - __clz(__ballot_sync(0xffffffff, (my_index != 0xffffffff))); if ((threadIdx.x & 0x1f) == laneid) { - uint32_t old_index = atomicMax(best_index, my_index); + const uint32_t old_index = atomicMax(best_index, my_index); if (old_index < my_index) { atomicMax(best_csum, my_csum); } } } @@ -362,17 +362,17 @@ __device__ inline void select_best_index_for_next_threshold(uint32_t topk, // template -__device__ inline void output_index_below_threshold(uint32_t topk, - uint32_t thread_id, - uint32_t num_threads, - uint32_t threshold, - uint32_t nx_below_threshold, - const T* x, // [nx,] - uint32_t nx, +__device__ inline void output_index_below_threshold(const uint32_t topk, + const uint32_t thread_id, + const uint32_t num_threads, + const uint32_t threshold, + const uint32_t nx_below_threshold, + const T* const x, // [nx,] + const uint32_t nx, const uint8_t* state, - uint32_t* output, // [topk] - uint32_t* output_count, - uint32_t* output_count_eq) + uint32_t* const output, // [topk] + uint32_t* const output_count, + uint32_t* const output_count_eq) { int ii = 0; for (int i = thread_id * vecLen; i < nx; i += num_threads * max(vecLen, stateBitLen), ii++) { @@ -383,7 +383,7 @@ __device__ inline void output_index_below_threshold(uint32_t topk, } #pragma unroll for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { - int iv = i + (num_threads * v); + const int iv = i + (num_threads * v); if (iv >= nx) break; struct u32_vector u32_vec; @@ -395,10 +395,10 @@ __device__ inline void output_index_below_threshold(uint32_t topk, } #pragma unroll for (int u = 0; u < vecLen; u++) { - int ivu = iv + u; + const int ivu = iv + u; if (ivu >= nx) break; - uint8_t mask = (uint8_t)0x1 << (v + u); + const uint8_t mask = (uint8_t)0x1 << (v + u); if ((stateBitLen == 8) && (iState & mask)) continue; uint32_t xi; @@ -425,7 +425,7 @@ __device__ inline void output_index_below_threshold(uint32_t topk, template __device__ inline void swap(T& val1, T& val2) { - T val0 = val1; + const T val0 = val1; val1 = val2; val2 = val0; } @@ -505,15 +505,15 @@ __device__ inline void topk_cta_11_core(uint32_t topk, bool sort, uint32_t* _smem) { - uint32_t* smem_out_vals = _smem; - uint32_t* hist = &(_smem[2 * maxTopk]); - uint32_t* best_index = &(_smem[2 * maxTopk + 2048]); - uint32_t* best_csum = &(_smem[2 * maxTopk + 2048 + 3]); + uint32_t* const smem_out_vals = _smem; + uint32_t* const hist = &(_smem[2 * maxTopk]); + uint32_t* const best_index = &(_smem[2 * maxTopk + 2048]); + uint32_t* const best_csum = &(_smem[2 * maxTopk + 2048 + 3]); const uint32_t num_threads = blockDim_x; const uint32_t thread_id = threadIdx.x; uint32_t nx = len_x; - const uint32_t* x = _x; + const uint32_t* const x = _x; const ValT* in_vals = NULL; if (_in_vals) { in_vals = _in_vals; } uint32_t* y = NULL; @@ -521,14 +521,14 @@ __device__ inline void topk_cta_11_core(uint32_t topk, ValT* out_vals = NULL; if (_out_vals) { out_vals = _out_vals; } uint8_t* state = _state; - uint32_t hint = (_hint == NULL ? ~0u : *_hint); + const uint32_t hint = (_hint == NULL ? ~0u : *_hint); // Initialize shared memory for (int i = 2 * maxTopk + thread_id; i < 2 * maxTopk + 2048 + 8; i += num_threads) { _smem[i] = 0; } - uint32_t* output_count = &(_smem[2 * maxTopk + 2048 + 6]); - uint32_t* output_count_eq = &(_smem[2 * maxTopk + 2048 + 7]); + uint32_t* const output_count = &(_smem[2 * maxTopk + 2048 + 6]); + uint32_t* const output_count_eq = &(_smem[2 * maxTopk + 2048 + 7]); uint32_t threshold = 0; uint32_t nx_below_threshold = 0; __syncthreads(); @@ -601,7 +601,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, if (!sort) { for (int k = thread_id; k < topk; k += blockDim_x) { - uint32_t i = smem_out_vals[k]; + const uint32_t i = smem_out_vals[k]; if (y) { y[k] = x[i]; } if (out_vals) { if (in_vals) { @@ -621,9 +621,9 @@ __device__ inline void topk_cta_11_core(uint32_t topk, // Read keys and values to registers if (thread_id < numSortThreads) { for (int i = 0; i < numTopkPerThread; i++) { - int k = thread_id + (numSortThreads * i); + const int k = thread_id + (numSortThreads * i); if (k < topk) { - int j = smem_out_vals[k]; + const int j = smem_out_vals[k]; my_keys[i] = ((float*)x)[j]; if (in_vals) { my_vals[i] = in_vals[j]; @@ -670,8 +670,8 @@ __device__ inline void topk_cta_11_core(uint32_t topk, const bool ascending = ((thread_id & curr_mask) == 0) == ((thread_id & next_mask) == 0); if (curr_mask >= 32) { // inter warp - ValT* smem_vals = reinterpret_cast(_smem); // [maxTopk] - float* smem_keys = + ValT* const smem_vals = reinterpret_cast(_smem); // [maxTopk] + float* const smem_keys = reinterpret_cast(smem_vals + maxTopk); // [numTopkPerThread, numSortThreads] __syncthreads(); if (thread_id < numSortThreads) { @@ -704,7 +704,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, } if (thread_id < numSortThreads) { - bool ascending = ((thread_id & next_mask) == 0); + const bool ascending = ((thread_id & next_mask) == 0); if (numTopkPerThread == 3) { swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); @@ -714,7 +714,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, for (uint32_t curr_mask = numTopkPerThread / 2; curr_mask > 0; curr_mask >>= 1) { #pragma unroll for (int i = 0; i < numTopkPerThread; i++) { - int j = i ^ curr_mask; + const int j = i ^ curr_mask; if (i > j) continue; swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); } @@ -727,7 +727,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, // Write sorted keys and values if (thread_id < numSortThreads) { for (int i = 0; i < numTopkPerThread; i++) { - int k = i + (numTopkPerThread * thread_id); + const int k = i + (numTopkPerThread * thread_id); if (k < topk) { if (y) { y[k] = reinterpret_cast(my_keys)[i]; } if (out_vals) { out_vals[k] = my_vals[i]; } @@ -772,7 +772,7 @@ __launch_bounds__(1024, 1) __global__ uint32_t* _hints, // [size_batch,] bool sort) { - uint32_t i_batch = blockIdx.x; + const uint32_t i_batch = blockIdx.x; if (i_batch >= size_batch) return; constexpr uint32_t smem_len = 2 * maxTopk + 2048 + 8; From 76194ae7ff93e101e87279be915cebf9aa80575e Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Mon, 15 May 2023 14:21:05 +0900 Subject: [PATCH 05/18] Fix the data type of seed index in CAGRA::random_pickup --- cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index 52e5c62169..a738de05f0 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -79,7 +79,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( DISTANCE_T best_norm2_team_local = utils::get_max_value(); for (uint32_t j = 0; j < num_distilation; j++) { // Select a node randomly and compute the distance to it - uint32_t seed_index; + INDEX_T seed_index; DISTANCE_T norm2 = 0.0; if (valid_i) { // uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id))); From 3cbd6242f7e7d0dedeeb15a40e20a4d301524196 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Mon, 15 May 2023 14:42:58 +0900 Subject: [PATCH 06/18] Fix indent --- .../detail/cagra/topk_for_cagra/topk_core.cuh | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh index d6aca93b57..5bc4b70791 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh @@ -287,15 +287,16 @@ __device__ inline void update_histogram(int itr, // template -__device__ inline void select_best_index_for_next_threshold(const uint32_t topk, - const uint32_t threshold, - const uint32_t max_threshold, - const uint32_t nx_below_threshold, - const uint32_t num_bins, - const uint32_t shift, - const uint32_t* const hist, // [num_bins] - uint32_t* const best_index, - uint32_t* const best_csum) +__device__ inline void select_best_index_for_next_threshold( + const uint32_t topk, + const uint32_t threshold, + const uint32_t max_threshold, + const uint32_t nx_below_threshold, + const uint32_t num_bins, + const uint32_t shift, + const uint32_t* const hist, // [num_bins] + uint32_t* const best_index, + uint32_t* const best_csum) { // Scan the histogram ('hist') and compute csum. Then, find the largest // index under the condition that the sum of the number of elements found @@ -426,8 +427,8 @@ template __device__ inline void swap(T& val1, T& val2) { const T val0 = val1; - val1 = val2; - val2 = val0; + val1 = val2; + val2 = val0; } // @@ -520,17 +521,17 @@ __device__ inline void topk_cta_11_core(uint32_t topk, if (_y) { y = _y; } ValT* out_vals = NULL; if (_out_vals) { out_vals = _out_vals; } - uint8_t* state = _state; - const uint32_t hint = (_hint == NULL ? ~0u : *_hint); + uint8_t* state = _state; + const uint32_t hint = (_hint == NULL ? ~0u : *_hint); // Initialize shared memory for (int i = 2 * maxTopk + thread_id; i < 2 * maxTopk + 2048 + 8; i += num_threads) { _smem[i] = 0; } - uint32_t* const output_count = &(_smem[2 * maxTopk + 2048 + 6]); - uint32_t* const output_count_eq = &(_smem[2 * maxTopk + 2048 + 7]); - uint32_t threshold = 0; - uint32_t nx_below_threshold = 0; + uint32_t* const output_count = &(_smem[2 * maxTopk + 2048 + 6]); + uint32_t* const output_count_eq = &(_smem[2 * maxTopk + 2048 + 7]); + uint32_t threshold = 0; + uint32_t nx_below_threshold = 0; __syncthreads(); // @@ -623,8 +624,8 @@ __device__ inline void topk_cta_11_core(uint32_t topk, for (int i = 0; i < numTopkPerThread; i++) { const int k = thread_id + (numSortThreads * i); if (k < topk) { - const int j = smem_out_vals[k]; - my_keys[i] = ((float*)x)[j]; + const int j = smem_out_vals[k]; + my_keys[i] = ((float*)x)[j]; if (in_vals) { my_vals[i] = in_vals[j]; } else { @@ -672,7 +673,7 @@ __device__ inline void topk_cta_11_core(uint32_t topk, // inter warp ValT* const smem_vals = reinterpret_cast(_smem); // [maxTopk] float* const smem_keys = - reinterpret_cast(smem_vals + maxTopk); // [numTopkPerThread, numSortThreads] + reinterpret_cast(smem_vals + maxTopk); // [numTopkPerThread, numSortThreads] __syncthreads(); if (thread_id < numSortThreads) { #pragma unroll From 8a0cbfe78435ba0c5b3ac85aa2d3f21d51a8cdaa Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Mon, 15 May 2023 18:21:53 +0900 Subject: [PATCH 07/18] Update hashmap to support uint64 --- .../detail/cagra/compute_distance.hpp | 6 +- .../raft/neighbors/detail/cagra/hashmap.hpp | 27 +++---- .../detail/cagra/search_multi_cta.cuh | 10 +-- .../detail/cagra/search_multi_kernel.cuh | 76 +++++++++---------- .../neighbors/detail/cagra/search_plan.cuh | 4 +- .../detail/cagra/search_single_cta.cuh | 16 ++-- 6 files changed, 70 insertions(+), 69 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index a738de05f0..fd66735cf6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -59,9 +59,9 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( const std::size_t num_pickup, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_seeds] + const INDEX_T* const seed_ptr, // [num_seeds] const uint32_t num_seeds, - uint32_t* const visited_hash_ptr, + INDEX_T* const visited_hash_ptr, const uint32_t hash_bitlen, const uint32_t block_id = 0, const uint32_t num_blocks = 1) @@ -150,7 +150,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in const INDEX_T* const knn_graph, const std::uint32_t knn_k, // hashmap - std::uint32_t* const visited_hashmap_ptr, + INDEX_T* const visited_hashmap_ptr, const std::uint32_t hash_bitlen, const INDEX_T* const parent_indices, const std::uint32_t num_parents) diff --git a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp index 18f4006367..cd2c8ec491 100644 --- a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp @@ -27,32 +27,33 @@ namespace hashmap { _RAFT_HOST_DEVICE inline uint32_t get_size(const uint32_t bitlen) { return 1U << bitlen; } -template -_RAFT_DEVICE inline void init(uint32_t* table, const uint32_t bitlen) +template +_RAFT_DEVICE inline void init(IdxT* const table, const unsigned bitlen) { if (threadIdx.x < FIRST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) { - table[i] = utils::get_max_value(); + table[i] = utils::get_max_value(); } } -template -_RAFT_DEVICE inline void init(uint32_t* table, const uint32_t bitlen) +template +_RAFT_DEVICE inline void init(IdxT* const table, const uint32_t bitlen) { if ((FIRST_TID > 0 && threadIdx.x < FIRST_TID) || threadIdx.x >= LAST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += LAST_TID - FIRST_TID) { - table[i] = utils::get_max_value(); + table[i] = utils::get_max_value(); } } -_RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, const uint32_t key) +template +_RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) { // Open addressing is used for collision resolution const uint32_t size = get_size(bitlen); const uint32_t bit_mask = size - 1; #if 1 // Linear probing - uint32_t index = (key ^ (key >> bitlen)) & bit_mask; + IdxT index = (key ^ (key >> bitlen)) & bit_mask; constexpr uint32_t stride = 1; #else // Double hashing @@ -60,8 +61,8 @@ _RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, cons const uint32_t stride = (key >> bitlen) * 2 + 1; #endif for (unsigned i = 0; i < size; i++) { - const uint32_t old = atomicCAS(&table[index], ~0u, key); - if (old == ~0u) { + const IdxT old = atomicCAS(&table[index], ~static_cast(0), key); + if (old == ~static_cast(0)) { return 1; } else if (old == key) { return 0; @@ -71,10 +72,10 @@ _RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, cons return 0; } -template -_RAFT_DEVICE inline uint32_t insert(uint32_t* table, const uint32_t bitlen, const uint32_t key) +template +_RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) { - uint32_t ret = 0; + IdxT ret = 0; if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); } for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) { ret |= __shfl_xor_sync(0xffffffff, ret, offset); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index dc3607016b..4a6571df1d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -140,9 +140,9 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( const uint32_t graph_degree, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, - uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const uint32_t hash_bitlen, const uint32_t itopk_size, const uint32_t num_parents, @@ -213,7 +213,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( } } if (threadIdx.x == 0) { terminate_flag[0] = 0; } - uint32_t* local_visited_hashmap_ptr = + INDEX_T* const local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); __syncthreads(); _CLK_REC(clk_init); @@ -366,7 +366,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const uint32_t hash_bitlen, \ const uint32_t itopk_size, \ const uint32_t num_parents, \ @@ -581,7 +581,7 @@ struct search : public search_plan_impl { // Initialize hash table const uint32_t hash_size = hashmap::get_size(hash_bitlen); set_value_batch( - hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); + hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); dim3 block_dims(block_size, 1, 1); dim3 grid_dims(num_cta_per_query, num_queries, 1); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index 5a16163d97..8ac27fbf98 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -96,12 +96,12 @@ __global__ void random_pickup_kernel( const std::size_t num_pickup, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, - INDEX_T* const result_indices_ptr, // [num_queries, ldr] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::uint32_t ldr, // (*) ldr >= num_pickup - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + INDEX_T* const result_indices_ptr, // [num_queries, ldr] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] + const std::uint32_t ldr, // (*) ldr >= num_pickup + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] const std::uint32_t hash_bitlen) { const auto ldb = hashmap::get_size(hash_bitlen); @@ -167,12 +167,12 @@ void random_pickup(const DATA_T* const dataset_ptr, // [dataset_size, dataset_d const std::size_t num_pickup, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, - INDEX_T* const result_indices_ptr, // [num_queries, ldr] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::size_t ldr, // (*) ldr >= num_pickup - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + INDEX_T* const result_indices_ptr, // [num_queries, ldr] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] + const std::size_t ldr, // (*) ldr >= num_pickup + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] const std::uint32_t hash_bitlen, cudaStream_t const cuda_stream = 0) { @@ -203,7 +203,7 @@ __global__ void pickup_next_parents_kernel( INDEX_T* const parent_candidates_ptr, // [num_queries, lds] const std::size_t lds, // (*) lds >= parent_candidates_size const std::uint32_t parent_candidates_size, // - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::size_t hash_bitlen, const std::uint32_t small_hash_bitlen, INDEX_T* const parent_list_ptr, // [num_queries, ldd] @@ -262,19 +262,18 @@ __global__ void pickup_next_parents_kernel( } template -void pickup_next_parents( - INDEX_T* const parent_candidates_ptr, // [num_queries, lds] - const std::size_t lds, // (*) lds >= parent_candidates_size - const std::size_t parent_candidates_size, // - const std::size_t num_queries, - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::size_t hash_bitlen, - const std::size_t small_hash_bitlen, - INDEX_T* const parent_list_ptr, // [num_queries, ldd] - const std::size_t ldd, // (*) ldd >= parent_list_size - const std::size_t parent_list_size, // - std::uint32_t* const terminate_flag, - cudaStream_t cuda_stream = 0) +void pickup_next_parents(INDEX_T* const parent_candidates_ptr, // [num_queries, lds] + const std::size_t lds, // (*) lds >= parent_candidates_size + const std::size_t parent_candidates_size, // + const std::size_t num_queries, + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::size_t hash_bitlen, + const std::size_t small_hash_bitlen, + INDEX_T* const parent_list_ptr, // [num_queries, ldd] + const std::size_t ldd, // (*) ldd >= parent_list_size + const std::size_t parent_list_size, // + std::uint32_t* const terminate_flag, + cudaStream_t cuda_stream = 0) { std::uint32_t block_size = 32; if (small_hash_bitlen) { @@ -308,14 +307,14 @@ __global__ void compute_distance_to_child_nodes_kernel( const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, const std::uint32_t dataset_size, - const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] const std::uint32_t graph_degree, - const DATA_T* query_ptr, // [num_queries, data_dim] - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const DATA_T* query_ptr, // [num_queries, data_dim] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::uint32_t hash_bitlen, - INDEX_T* const result_indices_ptr, // [num_queries, ldd] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd // (*) ldd >= num_parents * graph_degree + INDEX_T* const result_indices_ptr, // [num_queries, ldd] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd // (*) ldd >= num_parents * graph_degree ) { const uint32_t ldb = hashmap::get_size(hash_bitlen); @@ -333,7 +332,8 @@ __global__ void compute_distance_to_child_nodes_kernel( const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; - if (hashmap::insert(visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id)) { + if (hashmap::insert( + visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id)) { device::fragment frag_target; device::load_vector_sync(frag_target, dataset_ptr + (data_dim * child_id), data_dim); @@ -367,15 +367,15 @@ void compute_distance_to_child_nodes( const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, const std::uint32_t dataset_size, - const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] const std::uint32_t graph_degree, - const DATA_T* query_ptr, // [num_queries, data_dim] + const DATA_T* query_ptr, // [num_queries, data_dim] const std::uint32_t num_queries, - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::uint32_t hash_bitlen, - INDEX_T* const result_indices_ptr, // [num_queries, ldd] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd, // (*) ldd >= num_parents * graph_degree + INDEX_T* const result_indices_ptr, // [num_queries, ldd] + DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd, // (*) ldd >= num_parents * graph_degree cudaStream_t cuda_stream = 0) { const auto block_size = 128; @@ -597,7 +597,7 @@ struct search : search_plan_impl { cudaStream_t stream = res.get_stream(); const uint32_t hash_size = hashmap::get_size(hash_bitlen); set_value_batch( - hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); + hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); // Init topk_hint if (topk_hint.size() > 0) { set_value(topk_hint.data(), 0xffffffffu, num_queries, stream); } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 1ef35ae97c..8a45d35b26 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -80,7 +80,7 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t topk; uint32_t num_seeds; - rmm::device_uvector hashmap; + rmm::device_uvector hashmap; rmm::device_uvector num_executed_iterations; // device or managed? rmm::device_uvector dev_seed; // IdxT @@ -242,7 +242,7 @@ struct search_plan_impl : public search_plan_impl_base { if (small_hash_bitlen > 0) { RAFT_LOG_DEBUG("# small_hash_reset_interval = %lu", small_hash_reset_interval); } - hashmap_size = sizeof(std::uint32_t) * max_queries * hashmap::get_size(hash_bitlen); + hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); RAFT_LOG_DEBUG("# hashmap size: %lu", hashmap_size); if (hashmap_size >= 1024 * 1024 * 1024) { RAFT_LOG_DEBUG(" (%.2f GiB)", (double)hashmap_size / (1024 * 1024 * 1024)); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 0e4ec10abd..df9cc22114 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -496,7 +496,7 @@ __device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] } template -__device__ inline void hashmap_restore(uint32_t* hashmap_ptr, +__device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr, const size_t hashmap_bitlen, const INDEX_T* itopk_indices, uint32_t itopk_size) @@ -540,9 +540,9 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ const std::uint32_t graph_degree, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, - std::uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::uint32_t internal_topk, const std::uint32_t num_parents, const std::uint32_t min_iteration, @@ -588,7 +588,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ auto result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto visited_hash_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); + reinterpret_cast(result_distances_buffer + result_buffer_size_32); auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); auto topk_ws = reinterpret_cast(parent_list_buffer + num_parents); auto terminate_flag = reinterpret_cast(topk_ws + 3); @@ -609,7 +609,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ } // Init hashmap - uint32_t* local_visited_hashmap_ptr; + INDEX_T* local_visited_hashmap_ptr; if (small_hash_bitlen) { local_visited_hashmap_ptr = visited_hash_buffer; } else { @@ -869,7 +869,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ const uint64_t rand_xor_mask, \ const INDEX_T* seed_ptr, \ const uint32_t num_seeds, \ - std::uint32_t* const visited_hashmap_ptr, \ + INDEX_T* const visited_hashmap_ptr, \ const std::uint32_t itopk_size, \ const std::uint32_t num_parents, \ const std::uint32_t min_iteration, \ @@ -1000,7 +1000,7 @@ struct search : search_plan_impl { const std::uint32_t topk_ws_size = 3; const std::uint32_t base_smem_size = sizeof(float) * max_dim + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + - sizeof(std::uint32_t) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * num_parents + + sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * num_parents + sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t); smem_size = base_smem_size; if (num_itopk_candidates > 256) { @@ -1113,7 +1113,7 @@ struct search : search_plan_impl { RAFT_LOG_DEBUG("# smem_size: %u", smem_size); hashmap_size = 0; if (small_hash_bitlen == 0) { - hashmap_size = sizeof(uint32_t) * max_queries * hashmap::get_size(hash_bitlen); + hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); hashmap.resize(hashmap_size, res.get_stream()); } RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size); From fb2cd93afbb2e2efa4fe446b31942b7809f56637 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Mon, 15 May 2023 18:30:55 +0900 Subject: [PATCH 08/18] Update "mottainai" bit to support uint64 --- .../neighbors/detail/cagra/search_multi_cta.cuh | 12 +++++++++--- .../detail/cagra/search_multi_kernel.cuh | 15 ++++++++++----- .../neighbors/detail/cagra/search_single_cta.cuh | 16 ++++++++++++---- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 4a6571df1d..14432770ad 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -50,6 +50,8 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num const size_t num_itopk, uint32_t* const terminate_flag) { + constexpr INDEX_T index_msb_1_mask = static_cast(1) + << (utils::size_of() * 8 - 1); const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } const unsigned lane_id = threadIdx.x % 32; @@ -64,7 +66,7 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num int new_parent = 0; if (j < num_itopk) { index = itopk_indices[j]; - if ((index & 0x80000000) == 0) { // check if most significant bit is set + if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set new_parent = 1; } } @@ -73,7 +75,7 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; if (i < num_parents) { next_parent_indices[i] = index; - itopk_indices[j] |= 0x80000000; // set most significant bit as used node + itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node } } num_new_parents += __popc(ballot_mask); @@ -290,7 +292,11 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( for (uint32_t i = threadIdx.x; i < itopk_size; i += BLOCK_SIZE) { uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } - result_indices_ptr[j] = result_indices_buffer[i] & ~0x80000000; // clear most significant bit + + constexpr INDEX_T index_msb_1_mask = static_cast(1) + << (utils::size_of() * 8 - 1); + result_indices_ptr[j] = + result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit } if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index 8ac27fbf98..40bb8e8010 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -211,6 +211,8 @@ __global__ void pickup_next_parents_kernel( const std::size_t parent_list_size, // std::uint32_t* const terminate_flag) { + constexpr INDEX_T index_msb_1_mask = static_cast(1) + << (utils::size_of() * 8 - 1); const std::size_t ldb = hashmap::get_size(hash_bitlen); const uint32_t query_id = blockIdx.x; if (threadIdx.x < 32) { @@ -228,7 +230,7 @@ __global__ void pickup_next_parents_kernel( int new_parent = 0; if (j < parent_candidates_size) { index = parent_candidates_ptr[j + (lds * query_id)]; - if ((index & 0x80000000) == 0) { // check most significant bit + if ((index & index_msb_1_mask) == 0) { // check most significant bit new_parent = 1; } } @@ -238,7 +240,7 @@ __global__ void pickup_next_parents_kernel( if (i < parent_list_size) { parent_list_ptr[i + (ldd * query_id)] = index; parent_candidates_ptr[j + (lds * query_id)] |= - 0x80000000; // set most significant bit as used node + index_msb_1_mask; // set most significant bit as used node } } num_new_parents += __popc(ballot_mask); @@ -254,8 +256,8 @@ __global__ void pickup_next_parents_kernel( __syncthreads(); // insert internal-topk indices into small-hash for (unsigned i = threadIdx.x; i < parent_candidates_size; i += blockDim.x) { - auto key = - parent_candidates_ptr[i + (lds * query_id)] & ~0x80000000; // clear most significant bit + auto key = parent_candidates_ptr[i + (lds * query_id)] & + ~index_msb_1_mask; // clear most significant bit hashmap::insert(visited_hashmap_ptr + (ldb * query_id), hash_bitlen, key); } } @@ -404,11 +406,14 @@ __global__ void remove_parent_bit_kernel(const std::uint32_t num_queries, INDEX_T* const topk_indices_ptr, // [ld, num_queries] const std::uint32_t ld) { + constexpr INDEX_T index_msb_1_mask = static_cast(1) + << (utils::size_of() * 8 - 1); + uint32_t i_query = blockIdx.x; if (i_query >= num_queries) return; for (unsigned i = threadIdx.x; i < num_topk; i += blockDim.x) { - topk_indices_ptr[i + (ld * i_query)] &= ~0x80000000; // clear most significant bit + topk_indices_ptr[i + (ld * i_query)] &= ~index_msb_1_mask; // clear most significant bit } } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index df9cc22114..260fb276af 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -51,6 +51,8 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, const std::size_t dataset_size, const std::uint32_t num_parents) { + constexpr INDEX_T index_msb_1_mask = static_cast(1) + << (utils::size_of() * 8 - 1); // if (threadIdx.x >= 32) return; for (std::uint32_t i = threadIdx.x; i < num_parents; i += 32) { @@ -66,7 +68,7 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, int new_parent = 0; if (j < internal_topk_size) { index = internal_topk_indices[jj]; - if ((index & 0x80000000) == 0) { // check if most significant bit is set + if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set new_parent = 1; } } @@ -76,7 +78,7 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, if (i < num_parents) { next_parent_indices[i] = index; // set most significant bit as used node - internal_topk_indices[jj] |= 0x80000000; + internal_topk_indices[jj] |= index_msb_1_mask; } } num_new_parents += __popc(ballot_mask); @@ -501,9 +503,11 @@ __device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr, const INDEX_T* itopk_indices, uint32_t itopk_size) { + constexpr INDEX_T index_msb_1_mask = static_cast(1) + << (utils::size_of() * 8 - 1); if (threadIdx.x < FIRST_TID || threadIdx.x >= LAST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < itopk_size; i += LAST_TID - FIRST_TID) { - auto key = itopk_indices[i] & ~0x80000000; // clear most significant bit + auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit hashmap::insert(hashmap_ptr, hashmap_bitlen, key); } } @@ -769,7 +773,11 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ unsigned ii = i; if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } - result_indices_ptr[j] = result_indices_buffer[ii] & ~0x80000000; // clear most significant bit + + constexpr INDEX_T index_msb_1_mask = static_cast(1) + << (utils::size_of() * 8 - 1); + result_indices_ptr[j] = + result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit } if (threadIdx.x == 0 && num_executed_iterations != nullptr) { num_executed_iterations[query_id] = iter + 1; From 57791b74790d8db2cc50f4275040227e5572a17f Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Mon, 15 May 2023 19:37:03 +0900 Subject: [PATCH 09/18] Fix cagra::prune --- cpp/include/raft/neighbors/detail/cagra/graph_core.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index dfe02d4579..3e9d64bb78 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -592,7 +592,7 @@ void prune(raft::device_resources const& res, memcpy(output_graph_ptr, pruned_graph.data_handle(), - sizeof(uint32_t) * graph_size * output_graph_degree); + sizeof(IdxT) * graph_size * output_graph_degree); constexpr int _omp_chunk = 1024; #pragma omp parallel for schedule(dynamic, _omp_chunk) From 9448cba02f5ab413c0adc0081712fa3e3c093648 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Mon, 15 May 2023 20:15:57 +0900 Subject: [PATCH 10/18] Update CAGRA tests for uint64 index data type --- cpp/test/CMakeLists.txt | 2 ++ .../ann_cagra/test_float_uint32_t.cu | 12 +++---- .../ann_cagra/test_int8_t_uint32_t.cu | 12 +++---- .../ann_cagra/test_int8_t_uint64_t.cu | 32 ++++++++++++++++++ .../ann_cagra/test_uint8_t_uint32_t.cu | 12 +++---- .../ann_cagra/test_uint8_t_uint64_t.cu | 33 +++++++++++++++++++ 6 files changed, 85 insertions(+), 18 deletions(-) create mode 100644 cpp/test/neighbors/ann_cagra/test_int8_t_uint64_t.cu create mode 100644 cpp/test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index aad32ddc51..f0abd30ac3 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -317,6 +317,8 @@ if(BUILD_TESTS) test/neighbors/ann_cagra/test_int8_t_uint32_t.cu test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu test/neighbors/ann_cagra/test_float_uint64_t.cu + test/neighbors/ann_cagra/test_int8_t_uint64_t.cu + test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index adb44a9264..dbaf4dedd9 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -20,13 +20,13 @@ namespace raft::neighbors::experimental::cagra { -typedef AnnCagraTest AnnCagraTestF; -TEST_P(AnnCagraTestF, AnnCagra) { this->testCagra(); } +typedef AnnCagraTest AnnCagraTestF_U32; +TEST_P(AnnCagraTestF_U32, AnnCagra) { this->testCagra(); } -typedef AnnCagraSortTest AnnCagraSortTestF; -TEST_P(AnnCagraSortTestF, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraSortTest AnnCagraSortTestF_U32; +TEST_P(AnnCagraSortTestF_U32, AnnCagraSort) { this->testCagraSort(); } -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF_U32, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index 11c986c189..ba60131677 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -20,12 +20,12 @@ namespace raft::neighbors::experimental::cagra { -typedef AnnCagraTest AnnCagraTestI8; -TEST_P(AnnCagraTestI8, AnnCagra) { this->testCagra(); } -typedef AnnCagraSortTest AnnCagraSortTestI8; -TEST_P(AnnCagraSortTestI8, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraTest AnnCagraTestI8_U32; +TEST_P(AnnCagraTestI8_U32, AnnCagra) { this->testCagra(); } +typedef AnnCagraSortTest AnnCagraSortTestI8_U32; +TEST_P(AnnCagraSortTestI8_U32, AnnCagraSort) { this->testCagraSort(); } -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U32, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint64_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint64_t.cu new file mode 100644 index 0000000000..c80db1c2a9 --- /dev/null +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint64_t.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "../ann_cagra.cuh" + +namespace raft::neighbors::experimental::cagra { + +typedef AnnCagraTest AnnCagraTestI8_U64; +TEST_P(AnnCagraTestI8_U64, AnnCagra) { this->testCagra(); } +typedef AnnCagraSortTest AnnCagraSortTestI8_U64; +TEST_P(AnnCagraSortTestI8_U64, AnnCagraSort) { this->testCagraSort(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U64, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U64, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index 51d4feeed2..cc172e4833 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -20,13 +20,13 @@ namespace raft::neighbors::experimental::cagra { -typedef AnnCagraTest AnnCagraTestU8; -TEST_P(AnnCagraTestU8, AnnCagra) { this->testCagra(); } +typedef AnnCagraTest AnnCagraTestU8_U32; +TEST_P(AnnCagraTestU8_U32, AnnCagra) { this->testCagra(); } -typedef AnnCagraSortTest AnnCagraSortTestU8; -TEST_P(AnnCagraSortTestU8, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraSortTest AnnCagraSortTestU8_U32; +TEST_P(AnnCagraSortTestU8_U32, AnnCagraSort) { this->testCagraSort(); } -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U32, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu new file mode 100644 index 0000000000..624990810f --- /dev/null +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "../ann_cagra.cuh" + +namespace raft::neighbors::experimental::cagra { + +typedef AnnCagraTest AnnCagraTestU8_U64; +TEST_P(AnnCagraTestU8_U64, AnnCagra) { this->testCagra(); } + +typedef AnnCagraSortTest AnnCagraSortTestU8_U64; +TEST_P(AnnCagraSortTestU8_U64, AnnCagraSort) { this->testCagraSort(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U64, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U64, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::cagra From ef5381d6f998b416783c25b189ecfd292c0858e7 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Mon, 15 May 2023 20:22:38 +0900 Subject: [PATCH 11/18] Remove a comment --- cpp/include/raft/neighbors/detail/cagra/search_plan.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 8a45d35b26..92708ecc08 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -82,7 +82,7 @@ struct search_plan_impl : public search_plan_impl_base { rmm::device_uvector hashmap; rmm::device_uvector num_executed_iterations; // device or managed? - rmm::device_uvector dev_seed; // IdxT + rmm::device_uvector dev_seed; search_plan_impl(raft::device_resources const& res, search_params params, From 681c7d1c1e960c3b893cbe55b8534508c295b5df Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Wed, 17 May 2023 12:41:12 +0900 Subject: [PATCH 12/18] Add gen_index_msb_1_mask --- .../raft/neighbors/detail/cagra/graph_core.cuh | 4 +++- .../raft/neighbors/detail/cagra/search_single_cta.cuh | 11 +++++------ cpp/include/raft/neighbors/detail/cagra/utils.hpp | 5 +++++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 3e9d64bb78..339d99e8db 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -32,6 +32,8 @@ #include +#include "utils.hpp" + namespace raft::neighbors::experimental::cagra::detail { namespace graph { @@ -114,7 +116,7 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, my_vals[i] = smem_vals[k]; } else { my_keys[i] = FLT_MAX; - my_vals[i] = ~static_cast(0); + my_vals[i] = utils::get_max_value(); } } __syncthreads(); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 260fb276af..fd2e301653 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -51,8 +51,7 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, const std::size_t dataset_size, const std::uint32_t num_parents) { - constexpr INDEX_T index_msb_1_mask = static_cast(1) - << (utils::size_of() * 8 - 1); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; // if (threadIdx.x >= 32) return; for (std::uint32_t i = threadIdx.x; i < num_parents; i += 32) { @@ -503,8 +502,8 @@ __device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr, const INDEX_T* itopk_indices, uint32_t itopk_size) { - constexpr INDEX_T index_msb_1_mask = static_cast(1) - << (utils::size_of() * 8 - 1); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + if (threadIdx.x < FIRST_TID || threadIdx.x >= LAST_TID) return; for (unsigned i = threadIdx.x - FIRST_TID; i < itopk_size; i += LAST_TID - FIRST_TID) { auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit @@ -774,8 +773,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } - constexpr INDEX_T index_msb_1_mask = static_cast(1) - << (utils::size_of() * 8 - 1); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + result_indices_ptr[j] = result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit } diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp index 14187c6d31..934e84d4d5 100644 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -143,6 +143,11 @@ template struct constexpr_max A), bool>> { static const int value = B; }; + +template +struct gen_index_msb_1_mask { + static constexpr IdxT value = static_cast(1) << (utils::size_of() * 8 - 1); +}; } // namespace utils } // namespace raft::neighbors::experimental::cagra::detail From 0a68b43847191fb027fa010c21bd5a18c019d29e Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Wed, 17 May 2023 17:08:14 +0900 Subject: [PATCH 13/18] Use gen_index_msb_1_mask in multi_cta::search_kernel --- cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 1d2e60d13d..485423b8f5 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -295,8 +295,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } - constexpr INDEX_T index_msb_1_mask = static_cast(1) - << (utils::size_of() * 8 - 1); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + result_indices_ptr[j] = result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit } From b2bd012ba65649af39d3c476fb17a9576d0db331 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Wed, 17 May 2023 17:13:04 +0900 Subject: [PATCH 14/18] Use gen_index_msb_1_mask in multi_cta::search_kernel --- .../raft/neighbors/detail/cagra/search_multi_cta.cuh | 3 +-- .../raft/neighbors/detail/cagra/search_multi_kernel.cuh | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 485423b8f5..8d4ec550ca 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -52,8 +52,7 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num const size_t num_itopk, uint32_t* const terminate_flag) { - constexpr INDEX_T index_msb_1_mask = static_cast(1) - << (utils::size_of() * 8 - 1); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } const unsigned lane_id = threadIdx.x % 32; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index c8dc1df354..8fbd5d8f03 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -212,8 +212,8 @@ __global__ void pickup_next_parents_kernel( const std::size_t parent_list_size, // std::uint32_t* const terminate_flag) { - constexpr INDEX_T index_msb_1_mask = static_cast(1) - << (utils::size_of() * 8 - 1); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const std::size_t ldb = hashmap::get_size(hash_bitlen); const uint32_t query_id = blockIdx.x; if (threadIdx.x < 32) { @@ -407,8 +407,7 @@ __global__ void remove_parent_bit_kernel(const std::uint32_t num_queries, INDEX_T* const topk_indices_ptr, // [ld, num_queries] const std::uint32_t ld) { - constexpr INDEX_T index_msb_1_mask = static_cast(1) - << (utils::size_of() * 8 - 1); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; uint32_t i_query = blockIdx.x; if (i_query >= num_queries) return; From 1c2610989206b9e8340f5342acf2b4ddb56b4c98 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Wed, 17 May 2023 17:21:10 +0900 Subject: [PATCH 15/18] Fix code format --- cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 8d4ec550ca..f9a0fef2fe 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -53,7 +53,7 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num uint32_t* const terminate_flag) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const unsigned warp_id = threadIdx.x / 32; + const unsigned warp_id = threadIdx.x / 32; if (warp_id > 0) { return; } const unsigned lane_id = threadIdx.x % 32; for (uint32_t i = lane_id; i < num_parents; i += 32) { From 01473b6f661c683f626dd43d08ae4cf311b650d5 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Wed, 17 May 2023 22:28:52 +0900 Subject: [PATCH 16/18] Remove some CAGRA IdxT=uint64 tests --- cpp/test/CMakeLists.txt | 4 +-- ...loat_uint64_t.cu => test_float_int64_t.cu} | 2 +- .../ann_cagra/test_int8_t_uint64_t.cu | 32 ------------------ .../ann_cagra/test_uint8_t_uint64_t.cu | 33 ------------------- 4 files changed, 2 insertions(+), 69 deletions(-) rename cpp/test/neighbors/ann_cagra/{test_float_uint64_t.cu => test_float_int64_t.cu} (93%) delete mode 100644 cpp/test/neighbors/ann_cagra/test_int8_t_uint64_t.cu delete mode 100644 cpp/test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index f0abd30ac3..8cdac73d15 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -316,9 +316,7 @@ if(BUILD_TESTS) test/neighbors/ann_cagra/test_float_uint32_t.cu test/neighbors/ann_cagra/test_int8_t_uint32_t.cu test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu - test/neighbors/ann_cagra/test_float_uint64_t.cu - test/neighbors/ann_cagra/test_int8_t_uint64_t.cu - test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu + test/neighbors/ann_cagra/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu similarity index 93% rename from cpp/test/neighbors/ann_cagra/test_float_uint64_t.cu rename to cpp/test/neighbors/ann_cagra/test_float_int64_t.cu index 3fceb58918..e473a72b2b 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint64_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu @@ -21,7 +21,7 @@ namespace raft::neighbors::experimental::cagra { -typedef AnnCagraTest AnnCagraTestF_I64; +typedef AnnCagraTest AnnCagraTestF_I64; TEST_P(AnnCagraTestF_I64, AnnCagra) { this->testCagra(); } INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_I64, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint64_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint64_t.cu deleted file mode 100644 index c80db1c2a9..0000000000 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "../ann_cagra.cuh" - -namespace raft::neighbors::experimental::cagra { - -typedef AnnCagraTest AnnCagraTestI8_U64; -TEST_P(AnnCagraTestI8_U64, AnnCagra) { this->testCagra(); } -typedef AnnCagraSortTest AnnCagraSortTestI8_U64; -TEST_P(AnnCagraSortTestI8_U64, AnnCagraSort) { this->testCagraSort(); } - -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U64, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U64, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu deleted file mode 100644 index 624990810f..0000000000 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "../ann_cagra.cuh" - -namespace raft::neighbors::experimental::cagra { - -typedef AnnCagraTest AnnCagraTestU8_U64; -TEST_P(AnnCagraTestU8_U64, AnnCagra) { this->testCagra(); } - -typedef AnnCagraSortTest AnnCagraSortTestU8_U64; -TEST_P(AnnCagraSortTestU8_U64, AnnCagraSort) { this->testCagraSort(); } - -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U64, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U64, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::experimental::cagra From c91541df8babba3a522bf0f6214fb7d61ffb4a47 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Thu, 18 May 2023 09:31:01 +0900 Subject: [PATCH 17/18] Add support for int64 using uint64 kernels --- cpp/include/raft/neighbors/cagra.cuh | 59 +++++++++++++++++-- .../neighbors/detail/cagra/cagra_search.cuh | 36 +++++++---- 2 files changed, 78 insertions(+), 17 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 19f65baf1a..32275dfcdb 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -81,7 +81,17 @@ void build_knn_graph(raft::resources const& res, std::optional build_params = std::nullopt, std::optional search_params = std::nullopt) { - detail::build_knn_graph(res, dataset, knn_graph, refine_rate, build_params, search_params); + using internal_IdxT = typename std::make_unsigned::type; + + auto knn_graph_internal = make_host_matrix_view( + reinterpret_cast(knn_graph.data_handle()), + knn_graph.extent(0), + knn_graph.extent(1)); + auto dataset_internal = mdspan, row_major, accessor>( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + + detail::build_knn_graph( + res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); } /** @@ -124,7 +134,20 @@ void sort_knn_graph(raft::resources const& res, mdspan, row_major, d_accessor> dataset, mdspan, row_major, g_accessor> knn_graph) { - detail::graph::sort_knn_graph(res, dataset, knn_graph); + using internal_IdxT = typename std::make_unsigned::type; + + using g_accessor_internal = + host_device_accessor, memory_type::host>; + auto knn_graph_internal = + mdspan, row_major, g_accessor_internal>( + reinterpret_cast(knn_graph.data_handle()), + knn_graph.extent(0), + knn_graph.extent(1)); + + auto dataset_internal = mdspan, row_major, d_accessor>( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + + detail::graph::sort_knn_graph(res, dataset_internal, knn_graph_internal); } /** @@ -148,7 +171,22 @@ void prune(raft::resources const& res, mdspan, row_major, g_accessor> knn_graph, raft::host_matrix_view new_graph) { - detail::graph::prune(res, knn_graph, new_graph); + using internal_IdxT = typename std::make_unsigned::type; + + auto new_graph_internal = raft::make_host_matrix_view( + reinterpret_cast(new_graph.data_handle()), + new_graph.extent(0), + new_graph.extent(1)); + + using g_accessor_internal = + host_device_accessor, memory_type::host>; + auto knn_graph_internal = + mdspan, row_major, g_accessor_internal>( + reinterpret_cast(knn_graph.data_handle()), + knn_graph.extent(0), + knn_graph.extent(1)); + + detail::graph::prune(res, knn_graph_internal, new_graph_internal); } /** @@ -200,7 +238,7 @@ index build(raft::resources const& res, mdspan, row_major, Accessor> dataset) { size_t degree = params.intermediate_graph_degree; - if (degree >= dataset.extent(0)) { + if (degree >= static_cast(dataset.extent(0))) { RAFT_LOG_WARN( "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", dataset.extent(0)); @@ -256,7 +294,18 @@ void search(raft::resources const& res, RAFT_EXPECTS(queries.extent(1) == idx.dim(), "Number of query dimensions should equal number of dimensions in the index."); - detail::search_main(res, params, idx, queries, neighbors, distances); + using internal_IdxT = typename std::make_unsigned::type; + auto queries_internal = raft::make_device_matrix_view( + queries.data_handle(), queries.extent(0), queries.extent(1)); + auto neighbors_internal = raft::make_device_matrix_view( + reinterpret_cast(neighbors.data_handle()), + neighbors.extent(0), + neighbors.extent(1)); + auto distances_internal = raft::make_device_matrix_view( + distances.data_handle(), distances.extent(0), distances.extent(1)); + + detail::search_main( + res, params, idx, queries_internal, neighbors_internal, distances_internal); } /** @} */ // end group cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 0073f66d0b..d3b24dc861 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -52,13 +52,13 @@ namespace raft::neighbors::experimental::cagra::detail { * k] */ -template +template void search_main(raft::resources const& res, search_params params, const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) { RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", static_cast(index.dataset().extent(0)), @@ -69,8 +69,9 @@ void search_main(raft::resources const& res, RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match"); uint32_t topk = neighbors.extent(1); - std::unique_ptr> plan = - factory::create(res, params, index.dim(), index.graph_degree(), topk); + std::unique_ptr> plan = + factory::create( + res, params, index.dim(), index.graph_degree(), topk); plan->check(neighbors.extent(1)); @@ -79,18 +80,29 @@ void search_main(raft::resources const& res, uint32_t query_dim = queries.extent(1); for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) { - const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); - IdxT* _topk_indices_ptr = neighbors.data_handle() + (topk * qid); + const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); + internal_IdxT* _topk_indices_ptr = + reinterpret_cast(neighbors.data_handle()) + (topk * qid); DistanceT* _topk_distances_ptr = distances.data_handle() + (topk * qid); // todo(tfeher): one could keep distances optional and pass nullptr const T* _query_ptr = queries.data_handle() + (query_dim * qid); - const IdxT* _seed_ptr = - plan->num_seeds > 0 ? plan->dev_seed.data() + (plan->num_seeds * qid) : nullptr; + const internal_IdxT* _seed_ptr = + plan->num_seeds > 0 + ? reinterpret_cast(plan->dev_seed.data()) + (plan->num_seeds * qid) + : nullptr; uint32_t* _num_executed_iterations = nullptr; + auto dataset_internal = raft::make_device_matrix_view( + index.dataset().data_handle(), index.dataset().extent(0), index.dataset().extent(1)); + auto graph_internal = + raft::make_device_matrix_view( + reinterpret_cast(index.graph().data_handle()), + index.graph().extent(0), + index.graph().extent(1)); + (*plan)(res, - index.dataset(), - index.graph(), + dataset_internal, + graph_internal, _topk_indices_ptr, _topk_distances_ptr, _query_ptr, From cfbb29e3517c04252fe2234e4f2200f5e671431b Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Fri, 19 May 2023 22:38:27 +0900 Subject: [PATCH 18/18] Update memory_type of knn graph in sort_knn_graph Co-authored-by: Tamas Bela Feher --- cpp/include/raft/neighbors/cagra.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 32275dfcdb..9905f2abae 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -137,7 +137,7 @@ void sort_knn_graph(raft::resources const& res, using internal_IdxT = typename std::make_unsigned::type; using g_accessor_internal = - host_device_accessor, memory_type::host>; + host_device_accessor, g_accessor::mem_type>; auto knn_graph_internal = mdspan, row_major, g_accessor_internal>( reinterpret_cast(knn_graph.data_handle()),