From 591bc9d624f7659e81ac026f347f73c21542c385 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Wed, 10 May 2023 15:57:58 +0900 Subject: [PATCH 1/6] Update CAGRA prune not to use mgpu_alloc and use rmm allocator instead --- .../neighbors/detail/cagra/graph_core.cuh | 847 +++++++----------- 1 file changed, 326 insertions(+), 521 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 02055f2a4d..c3ea5236d4 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -35,20 +35,8 @@ namespace raft::neighbors::experimental::cagra::detail { namespace graph { -template -__host__ __device__ float compute_norm2(const T* a, - const T* b, - const std::size_t dim, - const float scale) -{ - float sum = 0.f; - for (std::size_t j = 0; j < dim; j++) { - const auto diff = a[j] * scale - b[j] * scale; - sum += diff * diff; - } - return sum; -} - +// unnamed namespace to avoid multiple definition error +namespace { inline double cur_time(void) { struct timeval tv; @@ -76,25 +64,19 @@ __device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool a return false; } -template -__global__ void kern_sort( - DATA_T** dataset, // [num_gpus][dataset_chunk_size, dataset_dim] - uint32_t dataset_size, - uint32_t dataset_chunk_size, // (*) num_gpus * dataset_chunk_size >= dataset_size - uint32_t dataset_dim, - float scale, - uint32_t** knn_graph, // [num_gpus][graph_chunk_size, graph_degree] - uint32_t graph_size, - uint32_t graph_chunk_size, // (*) num_gpus * graph_chunk_size >= graph_size - uint32_t graph_degree, - int dev_id) +template +__global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] + uint32_t dataset_size, + uint32_t dataset_dim, + float scale, + INDEX_T* knn_graph, // [graph_chunk_size, graph_degree] + uint32_t graph_size, + uint32_t graph_degree) { __shared__ float smem_keys[blockDim_x * numElementsPerThread]; - __shared__ uint32_t smem_vals[blockDim_x * numElementsPerThread]; + __shared__ INDEX_T smem_vals[blockDim_x * numElementsPerThread]; - uint64_t srcNode = blockIdx.x + ((uint64_t)graph_chunk_size * dev_id); - uint64_t srcNode_dev = srcNode / graph_chunk_size; - uint64_t srcNode_loc = srcNode % graph_chunk_size; + uint64_t srcNode = blockIdx.x; if (srcNode >= graph_size) { return; } const uint32_t num_warps = blockDim_x / 32; @@ -103,14 +85,11 @@ __global__ void kern_sort( // Compute distance from a src node to its neighbors for (int k = warp_id; k < graph_degree; k += num_warps) { - uint64_t dstNode = knn_graph[srcNode_dev][k + ((uint64_t)graph_degree * srcNode_loc)]; - uint64_t dstNode_dev = dstNode / graph_chunk_size; - uint64_t dstNode_loc = dstNode % graph_chunk_size; - float dist = 0.0; + uint64_t dstNode = knn_graph[k + ((uint64_t)graph_degree * srcNode)]; + float dist = 0.0; for (int d = lane_id; d < dataset_dim; d += 32) { - float diff = - (float)(dataset[srcNode_dev][d + ((uint64_t)dataset_dim * srcNode_loc)]) * scale - - (float)(dataset[dstNode_dev][d + ((uint64_t)dataset_dim * dstNode_loc)]) * scale; + float diff = (float)(dataset[d + ((uint64_t)dataset_dim * srcNode)]) * scale - + (float)(dataset[d + ((uint64_t)dataset_dim * dstNode)]) * scale; dist += diff * diff; } dist += __shfl_xor_sync(0xffffffff, dist, 1); @@ -126,7 +105,7 @@ __global__ void kern_sort( __syncthreads(); float my_keys[numElementsPerThread]; - uint32_t my_vals[numElementsPerThread]; + INDEX_T my_vals[numElementsPerThread]; for (int i = 0; i < numElementsPerThread; i++) { int k = i + (numElementsPerThread * threadIdx.x); if (k < graph_degree) { @@ -134,23 +113,23 @@ __global__ void kern_sort( my_vals[i] = smem_vals[k]; } else { my_keys[i] = FLT_MAX; - my_vals[i] = 0xffffffffU; + my_vals[i] = ~static_cast(0); } } __syncthreads(); // Sorting by thread - uint32_t mask = 1; - bool ascending = ((threadIdx.x & mask) == 0); + uint32_t mask = 1; + const bool ascending = ((threadIdx.x & mask) == 0); for (int j = 0; j < numElementsPerThread; j += 2) { #pragma unroll for (int i = 0; i < numElementsPerThread; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } #pragma unroll for (int i = 1; i < numElementsPerThread - 1; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } } @@ -160,7 +139,7 @@ __global__ void kern_sort( uint32_t next_mask = mask << 1; for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { - bool ascending = ((threadIdx.x & curr_mask) == 0) == ((threadIdx.x & next_mask) == 0); + const bool ascending = ((threadIdx.x & curr_mask) == 0) == ((threadIdx.x & next_mask) == 0); if (mask >= 32) { // inter warp __syncthreads(); @@ -174,7 +153,7 @@ __global__ void kern_sort( for (int i = 0; i < numElementsPerThread; i++) { float opp_key = smem_keys[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; uint32_t opp_val = smem_vals[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } else { // intra warp @@ -182,19 +161,19 @@ __global__ void kern_sort( for (int i = 0; i < numElementsPerThread; i++) { float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); uint32_t opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } } - bool ascending = ((threadIdx.x & next_mask) == 0); + const bool ascending = ((threadIdx.x & next_mask) == 0); #pragma unroll for (uint32_t curr_mask = numElementsPerThread / 2; curr_mask > 0; curr_mask >>= 1) { #pragma unroll for (int i = 0; i < numElementsPerThread; i++) { int j = i ^ curr_mask; if (i > j) continue; - swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); + swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); } } mask = next_mask; @@ -203,53 +182,44 @@ __global__ void kern_sort( // Update knn_graph for (int i = 0; i < numElementsPerThread; i++) { int k = i + (numElementsPerThread * threadIdx.x); - if (k < graph_degree) { - knn_graph[srcNode_dev][k + ((uint64_t)graph_degree * srcNode_loc)] = my_vals[i]; - } + if (k < graph_degree) { knn_graph[k + ((uint64_t)graph_degree * srcNode)] = my_vals[i]; } } } template -__global__ void kern_prune( - uint32_t** knn_graph, // [num_gpus][graph_chunk_size, graph_degree] - uint32_t graph_size, - uint32_t graph_chunk_size, // (*) num_gpus * graph_chunk_size >= graph_size - uint32_t graph_degree, - uint32_t degree, - int dev_id, - uint32_t batch_size, - uint32_t batch_id, - uint8_t** detour_count, // [num_gpus][graph_chunk_size, graph_degree] - uint32_t** num_no_detour_edges, // [num_gpus][graph_size] - uint64_t* stats) +__global__ void kern_prune(uint32_t* knn_graph, // [graph_chunk_size, graph_degree] + uint32_t graph_size, + uint32_t graph_degree, + uint32_t degree, + uint32_t batch_size, + uint32_t batch_id, + uint8_t* detour_count, // [graph_chunk_size, graph_degree] + uint32_t* num_no_detour_edges, // [graph_size] + uint64_t* stats) { __shared__ uint32_t smem_num_detour[MAX_DEGREE]; - uint64_t* num_retain = stats; - uint64_t* num_full = stats + 1; + uint64_t* const num_retain = stats; + uint64_t* const num_full = stats + 1; - uint64_t nid = blockIdx.x + (batch_size * batch_id); - if (nid >= graph_chunk_size) { return; } + const uint64_t nid = blockIdx.x + (batch_size * batch_id); + if (nid >= graph_size) { return; } for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { smem_num_detour[k] = 0; } __syncthreads(); - uint64_t iA = nid + ((uint64_t)graph_chunk_size * dev_id); - uint64_t iA_dev = iA / graph_chunk_size; - uint64_t iA_loc = iA % graph_chunk_size; + const uint64_t iA = nid; if (iA >= graph_size) { return; } // count number of detours (A->D->B) for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - uint64_t iD = knn_graph[iA_dev][kAD + (graph_degree * iA_loc)]; - uint64_t iD_dev = iD / graph_chunk_size; - uint64_t iD_loc = iD % graph_chunk_size; + const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { - uint64_t iB_candidate = knn_graph[iD_dev][kDB + ((uint64_t)graph_degree * iD_loc)]; + const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { // if ( kDB < kAB ) { - uint64_t iB = knn_graph[iA_dev][kAB + (graph_degree * iA_loc)]; + const uint64_t iB = knn_graph[kAB + (graph_degree * iA)]; if (iB == iB_candidate) { atomicAdd(smem_num_detour + kAB, 1); break; @@ -262,7 +232,7 @@ __global__ void kern_prune( uint32_t num_edges_no_detour = 0; for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[iA_dev][k + (graph_degree * iA_loc)] = min(smem_num_detour[k], (uint32_t)255); + detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255); if (smem_num_detour[k] == 0) { num_edges_no_detour++; } } num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); @@ -273,18 +243,14 @@ __global__ void kern_prune( num_edges_no_detour = min(num_edges_no_detour, degree); if (threadIdx.x == 0) { - num_no_detour_edges[iA_dev][iA_loc] = num_edges_no_detour; + num_no_detour_edges[iA] = num_edges_no_detour; atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } } } -// unnamed namespace to avoid multiple definition error -namespace { -__global__ void kern_make_rev_graph(const uint32_t i_gpu, - const uint32_t* dest_nodes, // [global_graph_size] - const uint32_t global_graph_size, - uint32_t* rev_graph, // [graph_size, degree] +__global__ void kern_make_rev_graph(const uint32_t* dest_nodes, // [graph_size] + uint32_t* rev_graph, // [size, degree] uint32_t* rev_graph_count, // [graph_size] const uint32_t graph_size, const uint32_t degree) @@ -292,100 +258,13 @@ __global__ void kern_make_rev_graph(const uint32_t i_gpu, const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); const uint32_t tnum = blockDim.x * gridDim.x; - for (uint32_t gl_src_id = tid; gl_src_id < global_graph_size; gl_src_id += tnum) { - uint32_t gl_dest_id = dest_nodes[gl_src_id]; - if (gl_dest_id < graph_size * i_gpu) continue; - if (gl_dest_id >= graph_size * (i_gpu + 1)) continue; - if (gl_dest_id >= global_graph_size) continue; - - uint32_t dest_id = gl_dest_id - (graph_size * i_gpu); - uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = gl_src_id; } - } -} -} // namespace -template -T*** mgpu_alloc(int n_gpus, uint32_t chunk, uint32_t nelems) -{ - T** arrays; // [n_gpus][chunk, nelems] - arrays = (T**)malloc(sizeof(T*) * n_gpus); /* h1 */ - size_t bsize = sizeof(T) * chunk * nelems; - // RAFT_LOG_DEBUG("[%s, %s, %d] n_gpus: %d, chunk: %u, nelems: %u, bsize: %lu (%lu MiB)\n", - // __FILE__, __func__, __LINE__, n_gpus, chunk, nelems, bsize, bsize / 1024 / 1024); - for (int i_gpu = 0; i_gpu < n_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMalloc(&(arrays[i_gpu]), bsize)); /* d1 */ - } - T*** d_arrays; // [n_gpus+1][n_gpus][chunk, nelems] - d_arrays = (T***)malloc(sizeof(T**) * (n_gpus + 1)); /* h2 */ - bsize = sizeof(T*) * n_gpus; - for (int i_gpu = 0; i_gpu < n_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMalloc(&(d_arrays[i_gpu]), bsize)); /* d2 */ - RAFT_CUDA_TRY(cudaMemcpy(d_arrays[i_gpu], arrays, bsize, cudaMemcpyDefault)); - } - RAFT_CUDA_TRY(cudaSetDevice(0)); - d_arrays[n_gpus] = arrays; - return d_arrays; -} - -template -void mgpu_free(T*** d_arrays, int n_gpus) -{ - for (int i_gpu = 0; i_gpu < n_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaFree(d_arrays[n_gpus][i_gpu])); /* d1 */ - RAFT_CUDA_TRY(cudaFree(d_arrays[i_gpu])); /* d2 */ - } - RAFT_CUDA_TRY(cudaSetDevice(0)); - free(d_arrays[n_gpus]); /* h1 */ - free(d_arrays); /* h2 */ -} - -template -void mgpu_H2D(T*** d_arrays, // [n_gpus+1][n_gpus][chunk, nelems] - const T* h_array, // [size, nelems] - int n_gpus, - uint32_t size, - uint32_t chunk, // (*) n_gpus * chunk >= size - uint32_t nelems) -{ -#pragma omp parallel num_threads(n_gpus) - { - int i_gpu = omp_get_thread_num(); - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - uint32_t _chunk = std::min(size - (chunk * i_gpu), chunk); - size_t bsize = sizeof(T) * _chunk * nelems; - RAFT_CUDA_TRY(cudaMemcpy(d_arrays[n_gpus][i_gpu], - h_array + ((uint64_t)chunk * nelems * i_gpu), - bsize, - cudaMemcpyDefault)); - } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_CUDA_TRY(cudaSetDevice(0)); -} + for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { + const uint32_t dest_id = dest_nodes[src_id]; + if (dest_id >= graph_size) continue; -template -void mgpu_D2H(T*** d_arrays, // [n_gpus+1][n_gpus][chunk, nelems] - T* h_array, // [size, nelems] - int n_gpus, - uint32_t size, - uint32_t chunk, // (*) n_gpus * chunk >= size - uint32_t nelems) -{ -#pragma omp parallel num_threads(n_gpus) - { - int i_gpu = omp_get_thread_num(); - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - uint32_t _chunk = std::min(size - (chunk * i_gpu), chunk); - size_t bsize = sizeof(T) * _chunk * nelems; - RAFT_CUDA_TRY(cudaMemcpy(h_array + ((uint64_t)chunk * nelems * i_gpu), - d_arrays[n_gpus][i_gpu], - bsize, - cudaMemcpyDefault)); + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_CUDA_TRY(cudaSetDevice(0)); } template @@ -404,6 +283,7 @@ void shift_array(T* array, uint64_t num) array[i] = array[i - 1]; } } +} // namespace /** Input arrays can be both host and device*/ template ::kDivisor; - const std::size_t graph_size = dataset_size; - size_t array_size; - - // Setup GPUs - int num_gpus = 0; - - // Setup GPUs - RAFT_CUDA_TRY(cudaGetDeviceCount(&num_gpus)); - RAFT_LOG_DEBUG("# num_gpus: %d\n", num_gpus); - for (int self = 0; self < num_gpus; self++) { - RAFT_CUDA_TRY(cudaSetDevice(self)); - for (int peer = 0; peer < num_gpus; peer++) { - if (self == peer) { continue; } - RAFT_CUDA_TRY(cudaDeviceEnablePeerAccess(peer, 0)); + const DATA_T* const dataset_ptr = dataset.data_handle(); + uint32_t* const input_graph_ptr = (uint32_t*)knn_graph.data_handle(); + uint32_t* const output_graph_ptr = new_graph.data_handle(); + const float scale = 1.0f / raft::spatial::knn::detail::utils::config::kDivisor; + const IdxT graph_size = dataset_size; + + auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); + + { + // + // Sorting kNN graph + // + const double time_sort_start = cur_time(); + RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); + + auto d_dataset = raft::make_device_matrix(res, dataset_size, dataset_dim); + raft::copy(d_dataset.data_handle(), dataset_ptr, dataset_size * dataset_dim, res.get_stream()); + + raft::copy(d_input_graph.data_handle(), + input_graph_ptr, + graph_size * input_graph_degree, + res.get_stream()); + + void (*kernel_sort)(DATA_T*, uint32_t, uint32_t, float, IdxT*, uint32_t, uint32_t); + constexpr int numElementsPerThread = 4; + dim3 threads_sort(1, 1, 1); + if (input_graph_degree <= numElementsPerThread * 32) { + constexpr int blockDim_x = 32; + kernel_sort = kern_sort; + threads_sort.x = blockDim_x; + } else if (input_graph_degree <= numElementsPerThread * 64) { + constexpr int blockDim_x = 64; + kernel_sort = kern_sort; + threads_sort.x = blockDim_x; + } else if (input_graph_degree <= numElementsPerThread * 128) { + constexpr int blockDim_x = 128; + kernel_sort = kern_sort; + threads_sort.x = blockDim_x; + } else if (input_graph_degree <= numElementsPerThread * 256) { + constexpr int blockDim_x = 256; + kernel_sort = kern_sort; + threads_sort.x = blockDim_x; + } else { + fprintf(stderr, + "[ERROR] The degree of input knn graph is too large (%u). " + "It must be equal to or small than %d.\n", + input_graph_degree, + numElementsPerThread * 256); + exit(-1); } - } - RAFT_CUDA_TRY(cudaSetDevice(0)); - - uint32_t graph_chunk_size = graph_size; - uint32_t*** d_input_graph_ptr = NULL; // [...][num_gpus][graph_chunk_size, input_graph_degree] - graph_chunk_size = (graph_size + num_gpus - 1) / num_gpus; - d_input_graph_ptr = mgpu_alloc(num_gpus, graph_chunk_size, input_graph_degree); - - uint32_t dataset_chunk_size = dataset_size; - DATA_T*** d_dataset_ptr = NULL; // [num_gpus+1][...][...] - dataset_chunk_size = (dataset_size + num_gpus - 1) / num_gpus; - assert(dataset_chunk_size == graph_chunk_size); - d_dataset_ptr = mgpu_alloc(num_gpus, dataset_chunk_size, dataset_dim); - - mgpu_H2D( - d_dataset_ptr, dataset_ptr, num_gpus, dataset_size, dataset_chunk_size, dataset_dim); - - // - // Sorting kNN graph - // - double time_sort_start = cur_time(); - RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); - mgpu_H2D( - d_input_graph_ptr, input_graph_ptr, num_gpus, graph_size, graph_chunk_size, input_graph_degree); - void (*kernel_sort)( - DATA_T**, uint32_t, uint32_t, uint32_t, float, uint32_t**, uint32_t, uint32_t, uint32_t, int); - constexpr int numElementsPerThread = 4; - dim3 threads_sort(1, 1, 1); - if (input_graph_degree <= numElementsPerThread * 32) { - constexpr int blockDim_x = 32; - kernel_sort = kern_sort; - threads_sort.x = blockDim_x; - } else if (input_graph_degree <= numElementsPerThread * 64) { - constexpr int blockDim_x = 64; - kernel_sort = kern_sort; - threads_sort.x = blockDim_x; - } else if (input_graph_degree <= numElementsPerThread * 128) { - constexpr int blockDim_x = 128; - kernel_sort = kern_sort; - threads_sort.x = blockDim_x; - } else if (input_graph_degree <= numElementsPerThread * 256) { - constexpr int blockDim_x = 256; - kernel_sort = kern_sort; - threads_sort.x = blockDim_x; - } else { - fprintf(stderr, - "[ERROR] The degree of input knn graph is too large (%u). " - "It must be equal to or small than %d.\n", - input_graph_degree, - numElementsPerThread * 256); - exit(-1); - } - dim3 blocks_sort(graph_chunk_size, 1, 1); - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { + dim3 blocks_sort(graph_size, 1, 1); RAFT_LOG_DEBUG("."); - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - kernel_sort<<>>(d_dataset_ptr[i_gpu], - dataset_size, - dataset_chunk_size, - dataset_dim, - scale, - d_input_graph_ptr[i_gpu], - graph_size, - graph_chunk_size, - input_graph_degree, - i_gpu); - } - RAFT_CUDA_TRY(cudaSetDevice(0)); - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_LOG_DEBUG("."); - mgpu_D2H( - d_input_graph_ptr, input_graph_ptr, num_gpus, graph_size, graph_chunk_size, input_graph_degree); - RAFT_LOG_DEBUG("\n"); - double time_sort_end = cur_time(); - RAFT_LOG_DEBUG("# Sorting kNN graph time: %.1lf sec\n", time_sort_end - time_sort_start); - - mgpu_free(d_dataset_ptr, num_gpus); - - // - uint8_t* detour_count; // [graph_size, input_graph_degree] - array_size = sizeof(uint8_t) * graph_size * input_graph_degree; - detour_count = (uint8_t*)malloc(array_size); - memset(detour_count, 0xff, array_size); - - uint8_t*** d_detour_count = NULL; // [...][num_gpus][graph_chunk_size, input_graph_degree] - d_detour_count = mgpu_alloc(num_gpus, graph_chunk_size, input_graph_degree); - mgpu_H2D( - d_detour_count, detour_count, num_gpus, graph_size, graph_chunk_size, input_graph_degree); - - // - uint32_t* num_no_detour_edges; // [graph_size] - array_size = sizeof(uint32_t) * graph_size; - num_no_detour_edges = (uint32_t*)malloc(array_size); - memset(num_no_detour_edges, 0, array_size); - - uint32_t*** d_num_no_detour_edges = NULL; // [...][num_gpus][graph_chunk_size] - d_num_no_detour_edges = mgpu_alloc(num_gpus, graph_chunk_size, 1); - mgpu_H2D( - d_num_no_detour_edges, num_no_detour_edges, num_gpus, graph_size, graph_chunk_size, 1); - - // - uint64_t** dev_stats = NULL; // [num_gpus][2] - uint64_t** host_stats = NULL; // [num_gpus][2] - dev_stats = (uint64_t**)malloc(sizeof(uint64_t*) * num_gpus); - host_stats = (uint64_t**)malloc(sizeof(uint64_t*) * num_gpus); - array_size = sizeof(uint64_t) * 2; - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMalloc(&(dev_stats[i_gpu]), array_size)); - host_stats[i_gpu] = (uint64_t*)malloc(array_size); + kernel_sort<<>>(d_dataset.data_handle(), + dataset_size, + dataset_dim, + scale, + d_input_graph.data_handle(), + graph_size, + input_graph_degree); + res.sync_stream(); + RAFT_LOG_DEBUG("."); + raft::copy(input_graph_ptr, + d_input_graph.data_handle(), + graph_size * input_graph_degree, + res.get_stream()); + RAFT_LOG_DEBUG("\n"); + + const double time_sort_end = cur_time(); + RAFT_LOG_DEBUG("# Sorting kNN graph time: %.1lf sec\n", time_sort_end - time_sort_start); } - RAFT_CUDA_TRY(cudaSetDevice(0)); - - // - // Prune unimportant edges. - // - // The edge to be retained is determined without explicitly considering - // distance or angle. Suppose the edge is the k-th edge of some node-A to - // node-B (A->B). Among the edges originating at node-A, there are k-1 edges - // shorter than the edge A->B. Each of these k-1 edges are connected to a - // different k-1 nodes. Among these k-1 nodes, count the number of nodes with - // edges to node-B, which is the number of 2-hop detours for the edge A->B. - // Once the number of 2-hop detours has been counted for all edges, the - // specified number of edges are picked up for each node, starting with the - // edge with the lowest number of 2-hop detours. - // - double time_prune_start = cur_time(); - uint64_t num_keep = 0; - uint64_t num_full = 0; - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - mgpu_H2D( - d_input_graph_ptr, input_graph_ptr, num_gpus, graph_size, graph_chunk_size, input_graph_degree); - void (*kernel_prune)(uint32_t**, - uint32_t, - uint32_t, - uint32_t, - uint32_t, - int, - uint32_t, - uint32_t, - uint8_t**, - uint32_t**, - uint64_t*); - if (input_graph_degree <= 1024) { + + auto pruned_graph = raft::make_host_matrix(graph_size, output_graph_degree); + + { + // + // Prune kNN graph + // + auto detour_count = raft::make_host_matrix(graph_size, input_graph_degree); + auto d_detour_count = + raft::make_device_matrix(res, graph_size, input_graph_degree); + RAFT_CUDA_TRY(cudaMemset( + d_detour_count.data_handle(), 0xff, graph_size * input_graph_degree * sizeof(uint8_t))); + + auto d_num_no_detour_edges = raft::make_device_vector(res, graph_size); + RAFT_CUDA_TRY( + cudaMemset(d_num_no_detour_edges.data_handle(), 0x00, graph_size * sizeof(uint32_t))); + + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); + + // + // Prune unimportant edges. + // + // The edge to be retained is determined without explicitly considering + // distance or angle. Suppose the edge is the k-th edge of some node-A to + // node-B (A->B). Among the edges originating at node-A, there are k-1 edges + // shorter than the edge A->B. Each of these k-1 edges are connected to a + // different k-1 nodes. Among these k-1 nodes, count the number of nodes with + // edges to node-B, which is the number of 2-hop detours for the edge A->B. + // Once the number of 2-hop detours has been counted for all edges, the + // specified number of edges are picked up for each node, starting with the + // edge with the lowest number of 2-hop detours. + // + const double time_prune_start = cur_time(); + RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); + + raft::copy(d_input_graph.data_handle(), + input_graph_ptr, + graph_size * input_graph_degree, + res.get_stream()); + void (*kernel_prune)( + uint32_t*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint8_t*, uint32_t*, uint64_t*); + constexpr int MAX_DEGREE = 1024; - kernel_prune = kern_prune; - } else { - fprintf(stderr, - "[ERROR] The degree of input knn graph is too large (%u). " - "It must be equal to or small than %d.\n", - input_graph_degree, - 1024); - exit(-1); - } - uint32_t batch_size = std::min(graph_chunk_size, (uint32_t)256 * 1024); - uint32_t num_batch = (graph_chunk_size + batch_size - 1) / batch_size; - dim3 threads_prune(32, 1, 1); - dim3 blocks_prune(batch_size, 1, 1); - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMemset(dev_stats[i_gpu], 0, sizeof(uint64_t) * 2)); - } - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - kernel_prune<<>>(d_input_graph_ptr[i_gpu], - graph_size, - graph_chunk_size, - input_graph_degree, - output_graph_degree, - i_gpu, - batch_size, - i_batch, - d_detour_count[i_gpu], - d_num_no_detour_edges[i_gpu], - dev_stats[i_gpu]); + if (input_graph_degree <= MAX_DEGREE) { + kernel_prune = kern_prune; + } else { + fprintf(stderr, + "[ERROR] The degree of input knn graph is too large (%u). " + "It must be equal to or small than %d.\n", + input_graph_degree, + 1024); + exit(-1); } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - fprintf( - stderr, - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_chunk_size) / graph_chunk_size * 100); - } - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY( - cudaMemcpy(host_stats[i_gpu], dev_stats[i_gpu], sizeof(uint64_t) * 2, cudaMemcpyDefault)); - num_keep += host_stats[i_gpu][0]; - num_full += host_stats[i_gpu][1]; - } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_CUDA_TRY(cudaSetDevice(0)); - RAFT_LOG_DEBUG("\n"); - - mgpu_D2H( - d_detour_count, detour_count, num_gpus, graph_size, graph_chunk_size, input_graph_degree); - mgpu_D2H( - d_num_no_detour_edges, num_no_detour_edges, num_gpus, graph_size, graph_chunk_size, 1); - - mgpu_free(d_input_graph_ptr, num_gpus); - mgpu_free(d_detour_count, num_gpus); - mgpu_free(d_num_no_detour_edges, num_gpus); - - // Create pruned kNN graph - array_size = sizeof(uint32_t) * graph_size * output_graph_degree; - uint32_t* pruned_graph_ptr = (uint32_t*)malloc(array_size); - uint32_t max_detour = 0; + const uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + const dim3 threads_prune(32, 1, 1); + const dim3 blocks_prune(batch_size, 1, 1); + + RAFT_CUDA_TRY(cudaMemset(dev_stats.data_handle(), 0, sizeof(uint64_t) * 2)); + + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + kernel_prune<<>>( + d_input_graph.data_handle(), + graph_size, + input_graph_degree, + output_graph_degree, + batch_size, + i_batch, + d_detour_count.data_handle(), + d_num_no_detour_edges.data_handle(), + dev_stats.data_handle()); + res.sync_stream(); + fprintf(stderr, + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); + } + res.sync_stream(); + RAFT_LOG_DEBUG("\n"); + + raft::copy(detour_count.data_handle(), + d_detour_count.data_handle(), + graph_size * input_graph_degree, + res.get_stream()); + + raft::copy(host_stats.data_handle(), dev_stats.data_handle(), 2, res.get_stream()); + const auto num_keep = host_stats.data_handle()[0]; + const auto num_full = host_stats.data_handle()[1]; + + // Create pruned kNN graph + uint32_t max_detour = 0; #pragma omp parallel for reduction(max : max_detour) - for (uint64_t i = 0; i < graph_size; i++) { - uint64_t pk = 0; - for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) { - if (max_detour < num_detour) { max_detour = num_detour; /* stats */ } - for (uint64_t k = 0; k < input_graph_degree; k++) { - if (detour_count[k + (input_graph_degree * i)] != num_detour) { continue; } - pruned_graph_ptr[pk + (output_graph_degree * i)] = - input_graph_ptr[k + (input_graph_degree * i)]; - pk += 1; + for (uint64_t i = 0; i < graph_size; i++) { + uint64_t pk = 0; + for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) { + if (max_detour < num_detour) { max_detour = num_detour; /* stats */ } + for (uint64_t k = 0; k < input_graph_degree; k++) { + if (detour_count.data_handle()[k + (input_graph_degree * i)] != num_detour) { continue; } + pruned_graph.data_handle()[pk + (output_graph_degree * i)] = + input_graph_ptr[k + (input_graph_degree * i)]; + pk += 1; + if (pk >= output_graph_degree) break; + } if (pk >= output_graph_degree) break; } - if (pk >= output_graph_degree) break; + assert(pk == output_graph_degree); } - assert(pk == output_graph_degree); - } - // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); - - double time_prune_end = cur_time(); - fprintf(stderr, - "# Pruning time: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%\n", - time_prune_end - time_prune_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); - - // - // Make reverse graph - // - double time_make_start = cur_time(); - - array_size = sizeof(uint32_t) * graph_size * output_graph_degree; - uint32_t* rev_graph_ptr = (uint32_t*)malloc(array_size); - memset(rev_graph_ptr, 0xff, array_size); - - uint32_t*** d_rev_graph_ptr; // [...][num_gpus][graph_chunk_size, output_graph_degree] - d_rev_graph_ptr = mgpu_alloc(num_gpus, graph_chunk_size, output_graph_degree); - mgpu_H2D( - d_rev_graph_ptr, rev_graph_ptr, num_gpus, graph_size, graph_chunk_size, output_graph_degree); - - array_size = sizeof(uint32_t) * graph_size; - uint32_t* rev_graph_count = (uint32_t*)malloc(array_size); - memset(rev_graph_count, 0, array_size); - - uint32_t*** d_rev_graph_count; // [...][num_gpus][graph_chunk_size, 1] - d_rev_graph_count = mgpu_alloc(num_gpus, graph_chunk_size, 1); - mgpu_H2D(d_rev_graph_count, rev_graph_count, num_gpus, graph_size, graph_chunk_size, 1); - - uint32_t* dest_nodes; // [graph_size] - dest_nodes = (uint32_t*)malloc(sizeof(uint32_t) * graph_size); - uint32_t** d_dest_nodes; // [num_gpus][graph_size] - d_dest_nodes = (uint32_t**)malloc(sizeof(uint32_t*) * num_gpus); - for (int i_gpu = 0; i_gpu < num_gpus; i_gpu++) { - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMalloc(&(d_dest_nodes[i_gpu]), sizeof(uint32_t) * graph_size)); + // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); + + const double time_prune_end = cur_time(); + fprintf(stderr, + "# Pruning time: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%\n", + time_prune_end - time_prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); } - for (uint64_t k = 0; k < output_graph_degree; k++) { + auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); + auto rev_graph_count = raft::make_host_vector(graph_size); + + { + // + // Make reverse graph + // + const double time_make_start = cur_time(); + + auto d_rev_graph = raft::make_device_matrix(res, graph_size, output_graph_degree); + RAFT_CUDA_TRY( + cudaMemset(d_rev_graph.data_handle(), 0xff, graph_size * output_graph_degree * sizeof(IdxT))); + + auto d_rev_graph_count = raft::make_device_vector(res, graph_size); + RAFT_CUDA_TRY(cudaMemset(d_rev_graph_count.data_handle(), 0x00, graph_size * sizeof(uint32_t))); + + auto dest_nodes = raft::make_host_vector(graph_size); + auto d_dest_nodes = raft::make_device_vector(res, graph_size); + + for (uint64_t k = 0; k < output_graph_degree; k++) { #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - dest_nodes[i] = pruned_graph_ptr[k + (output_graph_degree * i)]; - } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); -#pragma omp parallel num_threads(num_gpus) - { - int i_gpu = omp_get_thread_num(); - RAFT_CUDA_TRY(cudaSetDevice(i_gpu)); - RAFT_CUDA_TRY(cudaMemcpy( - d_dest_nodes[i_gpu], dest_nodes, sizeof(uint32_t) * graph_size, cudaMemcpyHostToDevice)); + for (uint64_t i = 0; i < graph_size; i++) { + dest_nodes.data_handle()[i] = pruned_graph.data_handle()[k + (output_graph_degree * i)]; + } + res.sync_stream(); + + raft::copy( + d_dest_nodes.data_handle(), dest_nodes.data_handle(), graph_size, res.get_stream()); + dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>(i_gpu, - d_dest_nodes[i_gpu], - graph_size, - d_rev_graph_ptr[num_gpus][i_gpu], - d_rev_graph_count[num_gpus][i_gpu], - graph_chunk_size, - output_graph_degree); + kern_make_rev_graph<<>>(d_dest_nodes.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); } - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); - } - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_CUDA_TRY(cudaSetDevice(0)); - RAFT_LOG_DEBUG("\n"); - mgpu_D2H( - d_rev_graph_ptr, rev_graph_ptr, num_gpus, graph_size, graph_chunk_size, output_graph_degree); - mgpu_D2H(d_rev_graph_count, rev_graph_count, num_gpus, graph_size, graph_chunk_size, 1); - mgpu_free(d_rev_graph_ptr, num_gpus); - mgpu_free(d_rev_graph_count, num_gpus); + res.sync_stream(); + RAFT_LOG_DEBUG("\n"); - double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf sec", time_make_end - time_make_start); + raft::copy(rev_graph.data_handle(), + d_rev_graph.data_handle(), + graph_size * output_graph_degree, + res.get_stream()); + raft::copy( + rev_graph_count.data_handle(), d_rev_graph_count.data_handle(), graph_size, res.get_stream()); - // - // Replace some edges with reverse edges - // - double time_replace_start = cur_time(); + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf sec", time_make_end - time_make_start); + } - uint64_t num_protected_edges = output_graph_degree / 2; - RAFT_LOG_DEBUG("# num_protected_edges: %lu", num_protected_edges); + { + // + // Replace some edges with reverse edges + // + const double time_replace_start = cur_time(); + + const uint64_t num_protected_edges = output_graph_degree / 2; + RAFT_LOG_DEBUG("# num_protected_edges: %lu", num_protected_edges); - array_size = sizeof(uint32_t) * graph_size * output_graph_degree; - memcpy(output_graph_ptr, pruned_graph_ptr, array_size); + memcpy(output_graph_ptr, + pruned_graph.data_handle(), + sizeof(uint32_t) * graph_size * output_graph_degree); - constexpr int _omp_chunk = 1024; + constexpr int _omp_chunk = 1024; #pragma omp parallel for schedule(dynamic, _omp_chunk) - for (uint64_t j = 0; j < graph_size; j++) { - for (uint64_t _k = 0; _k < rev_graph_count[j]; _k++) { - uint64_t k = rev_graph_count[j] - 1 - _k; - uint64_t i = rev_graph_ptr[k + (output_graph_degree * j)]; - - uint64_t pos = pos_in_array( - i, output_graph_ptr + (output_graph_degree * j), output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos == output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } - shift_array(output_graph_ptr + num_protected_edges + (output_graph_degree * j), - num_shift); - output_graph_ptr[num_protected_edges + (output_graph_degree * j)] = i; - } - if ((omp_get_thread_num() == 0) && ((j % _omp_chunk) == 0)) { - RAFT_LOG_DEBUG("# Replacing reverse edges: %lu / %lu ", j, graph_size); + for (uint64_t j = 0; j < graph_size; j++) { + for (uint64_t _k = 0; _k < rev_graph_count.data_handle()[j]; _k++) { + uint64_t k = rev_graph_count.data_handle()[j] - 1 - _k; + uint64_t i = rev_graph.data_handle()[k + (output_graph_degree * j)]; + + uint64_t pos = pos_in_array( + i, output_graph_ptr + (output_graph_degree * j), output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos == output_graph_degree) { + num_shift = output_graph_degree - num_protected_edges - 1; + } + shift_array(output_graph_ptr + num_protected_edges + (output_graph_degree * j), + num_shift); + output_graph_ptr[num_protected_edges + (output_graph_degree * j)] = i; + } + if ((omp_get_thread_num() == 0) && ((j % _omp_chunk) == 0)) { + RAFT_LOG_DEBUG("# Replacing reverse edges: %lu / %lu ", j, graph_size); + } } - } - RAFT_LOG_DEBUG("\n"); - free(rev_graph_ptr); - free(rev_graph_count); + RAFT_LOG_DEBUG("\n"); - double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf sec", time_replace_end - time_replace_start); + const double time_replace_end = cur_time(); + RAFT_LOG_DEBUG("# Replacing edges time: %.1lf sec", time_replace_end - time_replace_start); - /* stats */ - uint64_t num_replaced_edges = 0; + /* stats */ + uint64_t num_replaced_edges = 0; #pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - uint64_t j = pruned_graph_ptr[k + (output_graph_degree * i)]; - uint64_t pos = pos_in_array( - j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = pruned_graph.data_handle()[k + (output_graph_degree * i)]; + const uint64_t pos = pos_in_array( + j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); + if (pos == output_graph_degree) { num_replaced_edges += 1; } + } } + fprintf(stderr, + "# Average number of replaced edges per node: %.2f", + (double)num_replaced_edges / graph_size); } - fprintf(stderr, - "# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); } } // namespace graph From 657f272261c19f3e3f09448866a9245d83a96eb4 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Wed, 10 May 2023 17:24:07 +0900 Subject: [PATCH 2/6] Update CAGRA sorting kernel to use spatial::knn::detail::utils::mapping --- .../raft/neighbors/detail/cagra/graph_core.cuh | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index c3ea5236d4..536a71388d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -68,7 +68,6 @@ template __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] uint32_t dataset_size, uint32_t dataset_dim, - float scale, INDEX_T* knn_graph, // [graph_chunk_size, graph_degree] uint32_t graph_size, uint32_t graph_degree) @@ -88,8 +87,10 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] uint64_t dstNode = knn_graph[k + ((uint64_t)graph_degree * srcNode)]; float dist = 0.0; for (int d = lane_id; d < dataset_dim; d += 32) { - float diff = (float)(dataset[d + ((uint64_t)dataset_dim * srcNode)]) * scale - - (float)(dataset[d + ((uint64_t)dataset_dim * dstNode)]) * scale; + float diff = spatial::knn::detail::utils::mapping{}( + dataset[d + ((uint64_t)dataset_dim * srcNode)]) - + spatial::knn::detail::utils::mapping{}( + dataset[d + ((uint64_t)dataset_dim * dstNode)]); dist += diff * diff; } dist += __shfl_xor_sync(0xffffffff, dist, 1); @@ -312,8 +313,7 @@ void prune(raft::device_resources const& res, const DATA_T* const dataset_ptr = dataset.data_handle(); uint32_t* const input_graph_ptr = (uint32_t*)knn_graph.data_handle(); uint32_t* const output_graph_ptr = new_graph.data_handle(); - const float scale = 1.0f / raft::spatial::knn::detail::utils::config::kDivisor; - const IdxT graph_size = dataset_size; + const IdxT graph_size = dataset_size; auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); @@ -332,7 +332,7 @@ void prune(raft::device_resources const& res, graph_size * input_graph_degree, res.get_stream()); - void (*kernel_sort)(DATA_T*, uint32_t, uint32_t, float, IdxT*, uint32_t, uint32_t); + void (*kernel_sort)(DATA_T*, uint32_t, uint32_t, IdxT*, uint32_t, uint32_t); constexpr int numElementsPerThread = 4; dim3 threads_sort(1, 1, 1); if (input_graph_degree <= numElementsPerThread * 32) { @@ -364,7 +364,6 @@ void prune(raft::device_resources const& res, kernel_sort<<>>(d_dataset.data_handle(), dataset_size, dataset_dim, - scale, d_input_graph.data_handle(), graph_size, input_graph_degree); From 8efacd6866f86dd2256637bf2b9bc4cede1a9f76 Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Wed, 10 May 2023 19:44:16 +0900 Subject: [PATCH 3/6] Support uint64 index type in CAGRA prune --- .../neighbors/detail/cagra/graph_core.cuh | 52 +++++++++++-------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 536a71388d..89607caaf0 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -64,16 +64,16 @@ __device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool a return false; } -template +template __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] uint32_t dataset_size, uint32_t dataset_dim, - INDEX_T* knn_graph, // [graph_chunk_size, graph_degree] + IdxT* knn_graph, // [graph_chunk_size, graph_degree] uint32_t graph_size, uint32_t graph_degree) { __shared__ float smem_keys[blockDim_x * numElementsPerThread]; - __shared__ INDEX_T smem_vals[blockDim_x * numElementsPerThread]; + __shared__ IdxT smem_vals[blockDim_x * numElementsPerThread]; uint64_t srcNode = blockIdx.x; if (srcNode >= graph_size) { return; } @@ -106,7 +106,7 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] __syncthreads(); float my_keys[numElementsPerThread]; - INDEX_T my_vals[numElementsPerThread]; + IdxT my_vals[numElementsPerThread]; for (int i = 0; i < numElementsPerThread; i++) { int k = i + (numElementsPerThread * threadIdx.x); if (k < graph_degree) { @@ -114,7 +114,7 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] my_vals[i] = smem_vals[k]; } else { my_keys[i] = FLT_MAX; - my_vals[i] = ~static_cast(0); + my_vals[i] = ~static_cast(0); } } __syncthreads(); @@ -125,12 +125,12 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] for (int j = 0; j < numElementsPerThread; j += 2) { #pragma unroll for (int i = 0; i < numElementsPerThread; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } #pragma unroll for (int i = 1; i < numElementsPerThread - 1; i += 2) { - swap_if_needed( + swap_if_needed( my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); } } @@ -154,7 +154,7 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] for (int i = 0; i < numElementsPerThread; i++) { float opp_key = smem_keys[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; uint32_t opp_val = smem_vals[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } else { // intra warp @@ -162,7 +162,7 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] for (int i = 0; i < numElementsPerThread; i++) { float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); uint32_t opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); + swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } } @@ -174,7 +174,7 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] for (int i = 0; i < numElementsPerThread; i++) { int j = i ^ curr_mask; if (i > j) continue; - swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); + swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); } } mask = next_mask; @@ -187,16 +187,16 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] } } -template -__global__ void kern_prune(uint32_t* knn_graph, // [graph_chunk_size, graph_degree] +template +__global__ void kern_prune(IdxT* const knn_graph, // [graph_chunk_size, graph_degree] uint32_t graph_size, uint32_t graph_degree, uint32_t degree, uint32_t batch_size, uint32_t batch_id, - uint8_t* detour_count, // [graph_chunk_size, graph_degree] - uint32_t* num_no_detour_edges, // [graph_size] - uint64_t* stats) + uint8_t* const detour_count, // [graph_chunk_size, graph_degree] + uint32_t* const num_no_detour_edges, // [graph_size] + uint64_t* const stats) { __shared__ uint32_t smem_num_detour[MAX_DEGREE]; uint64_t* const num_retain = stats; @@ -250,9 +250,10 @@ __global__ void kern_prune(uint32_t* knn_graph, // [graph_chunk_size, graph_deg } } -__global__ void kern_make_rev_graph(const uint32_t* dest_nodes, // [graph_size] - uint32_t* rev_graph, // [size, degree] - uint32_t* rev_graph_count, // [graph_size] +template +__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] + IdxT* const rev_graph, // [size, degree] + uint32_t* const rev_graph_count, // [graph_size] const uint32_t graph_size, const uint32_t degree) { @@ -260,7 +261,7 @@ __global__ void kern_make_rev_graph(const uint32_t* dest_nodes, // [graph_size] const uint32_t tnum = blockDim.x * gridDim.x; for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { - const uint32_t dest_id = dest_nodes[src_id]; + const IdxT dest_id = dest_nodes[src_id]; if (dest_id >= graph_size) continue; const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); @@ -418,12 +419,19 @@ void prune(raft::device_resources const& res, input_graph_ptr, graph_size * input_graph_degree, res.get_stream()); - void (*kernel_prune)( - uint32_t*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint8_t*, uint32_t*, uint64_t*); + void (*kernel_prune)(IdxT* const, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint8_t* const, + uint32_t* const, + uint64_t* const); constexpr int MAX_DEGREE = 1024; if (input_graph_degree <= MAX_DEGREE) { - kernel_prune = kern_prune; + kernel_prune = kern_prune; } else { fprintf(stderr, "[ERROR] The degree of input knn graph is too large (%u). " From e9bfe51cd3ec1eab24678bac340253df9fcf6e2f Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Thu, 11 May 2023 13:04:03 +0900 Subject: [PATCH 4/6] Fix data types in CAGRA::prune and sort --- .../neighbors/detail/cagra/graph_core.cuh | 63 ++++++++++--------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 70b430ef51..6911e98ed2 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -65,17 +65,17 @@ __device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool a } template -__global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] - IdxT dataset_size, - uint32_t dataset_dim, +__global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dataset_dim] + const IdxT dataset_size, + const uint32_t dataset_dim, IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - uint32_t graph_size, - uint32_t graph_degree) + const uint32_t graph_size, + const uint32_t graph_degree) { __shared__ float smem_keys[blockDim_x * numElementsPerThread]; __shared__ IdxT smem_vals[blockDim_x * numElementsPerThread]; - uint64_t srcNode = blockIdx.x; + const IdxT srcNode = blockIdx.x; if (srcNode >= graph_size) { return; } const uint32_t num_warps = blockDim_x / 32; @@ -84,8 +84,8 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] // Compute distance from a src node to its neighbors for (int k = warp_id; k < graph_degree; k += num_warps) { - uint64_t dstNode = knn_graph[k + ((uint64_t)graph_degree * srcNode)]; - float dist = 0.0; + const IdxT dstNode = knn_graph[k + ((uint64_t)graph_degree * srcNode)]; + float dist = 0.0; for (int d = lane_id; d < dataset_dim; d += 32) { float diff = spatial::knn::detail::utils::mapping{}( dataset[d + ((uint64_t)dataset_dim * srcNode)]) - @@ -108,7 +108,7 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] float my_keys[numElementsPerThread]; IdxT my_vals[numElementsPerThread]; for (int i = 0; i < numElementsPerThread; i++) { - int k = i + (numElementsPerThread * threadIdx.x); + const int k = i + (numElementsPerThread * threadIdx.x); if (k < graph_degree) { my_keys[i] = smem_keys[k]; my_vals[i] = smem_vals[k]; @@ -137,7 +137,7 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] // Bitonic Sorting while (mask < blockDim_x) { - uint32_t next_mask = mask << 1; + const uint32_t next_mask = mask << 1; for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { const bool ascending = ((threadIdx.x & curr_mask) == 0) == ((threadIdx.x & next_mask) == 0); @@ -152,16 +152,16 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] __syncthreads(); #pragma unroll for (int i = 0; i < numElementsPerThread; i++) { - float opp_key = smem_keys[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; - uint32_t opp_val = smem_vals[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; + float opp_key = smem_keys[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; + IdxT opp_val = smem_vals[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } else { // intra warp #pragma unroll for (int i = 0; i < numElementsPerThread; i++) { - float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); - uint32_t opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); + float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); + IdxT opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); } } @@ -182,18 +182,20 @@ __global__ void kern_sort(DATA_T* dataset, // [dataset_chunk_size, dataset_dim] // Update knn_graph for (int i = 0; i < numElementsPerThread; i++) { - int k = i + (numElementsPerThread * threadIdx.x); - if (k < graph_degree) { knn_graph[k + ((uint64_t)graph_degree * srcNode)] = my_vals[i]; } + const int k = i + (numElementsPerThread * threadIdx.x); + if (k < graph_degree) { + knn_graph[k + (static_cast(graph_degree) * srcNode)] = my_vals[i]; + } } } template -__global__ void kern_prune(IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - uint32_t graph_size, - uint32_t graph_degree, - uint32_t degree, - uint32_t batch_size, - uint32_t batch_id, +__global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] + const uint32_t graph_size, + const uint32_t graph_degree, + const uint32_t degree, + const uint32_t batch_size, + const uint32_t batch_id, uint8_t* const detour_count, // [graph_chunk_size, graph_degree] uint32_t* const num_no_detour_edges, // [graph_size] uint64_t* const stats) @@ -323,7 +325,8 @@ void sort_knn_graph(raft::device_resources const& res, graph_size * input_graph_degree, res.get_stream()); - void (*kernel_sort)(DataT*, IdxT, uint32_t, IdxT* const, uint32_t, uint32_t); + void (*kernel_sort)( + const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t); constexpr int numElementsPerThread = 4; dim3 threads_sort(1, 1, 1); if (input_graph_degree <= numElementsPerThread * 32) { @@ -431,12 +434,12 @@ void prune(raft::device_resources const& res, input_graph_ptr, graph_size * input_graph_degree, res.get_stream()); - void (*kernel_prune)(IdxT* const, - uint32_t, - uint32_t, - uint32_t, - uint32_t, - uint32_t, + void (*kernel_prune)(const IdxT* const, + const uint32_t, + const uint32_t, + const uint32_t, + const uint32_t, + const uint32_t, uint8_t* const, uint32_t* const, uint64_t* const); @@ -630,4 +633,4 @@ void prune(raft::device_resources const& res, } } // namespace graph -} // namespace raft::neighbors::experimental::cagra::detail \ No newline at end of file +} // namespace raft::neighbors::experimental::cagra::detail From 7fd67363aaeaed19a86f7f26363aa1680c857b9b Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Thu, 11 May 2023 13:13:09 +0900 Subject: [PATCH 5/6] Use RAFT_LOG_ERROR and cudaMemsetAsync in CAGRA::prune and sort --- .../neighbors/detail/cagra/graph_core.cuh | 69 ++++++++++--------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 6911e98ed2..44443bfa6f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -346,11 +346,11 @@ void sort_knn_graph(raft::device_resources const& res, kernel_sort = kern_sort; threads_sort.x = blockDim_x; } else { - fprintf(stderr, - "[ERROR] The degree of input knn graph is too large (%u). " - "It must be equal to or small than %d.\n", - input_graph_degree, - numElementsPerThread * 256); + RAFT_LOG_ERROR( + "[ERROR] The degree of input knn graph is too large (%u). " + "It must be equal to or small than %d.\n", + input_graph_degree, + numElementsPerThread * 256); exit(-1); } dim3 blocks_sort(graph_size, 1, 1); @@ -404,12 +404,14 @@ void prune(raft::device_resources const& res, auto detour_count = raft::make_host_matrix(graph_size, input_graph_degree); auto d_detour_count = raft::make_device_matrix(res, graph_size, input_graph_degree); - RAFT_CUDA_TRY(cudaMemset( - d_detour_count.data_handle(), 0xff, graph_size * input_graph_degree * sizeof(uint8_t))); + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), + 0xff, + graph_size * input_graph_degree * sizeof(uint8_t), + res.get_stream())); auto d_num_no_detour_edges = raft::make_device_vector(res, graph_size); - RAFT_CUDA_TRY( - cudaMemset(d_num_no_detour_edges.data_handle(), 0x00, graph_size * sizeof(uint32_t))); + RAFT_CUDA_TRY(cudaMemsetAsync( + d_num_no_detour_edges.data_handle(), 0x00, graph_size * sizeof(uint32_t), res.get_stream())); auto dev_stats = raft::make_device_vector(res, 2); auto host_stats = raft::make_host_vector(2); @@ -448,11 +450,11 @@ void prune(raft::device_resources const& res, if (input_graph_degree <= MAX_DEGREE) { kernel_prune = kern_prune; } else { - fprintf(stderr, - "[ERROR] The degree of input knn graph is too large (%u). " - "It must be equal to or small than %d.\n", - input_graph_degree, - 1024); + RAFT_LOG_ERROR( + "[ERROR] The degree of input knn graph is too large (%u). " + "It must be equal to or small than %d.\n", + input_graph_degree, + 1024); exit(-1); } const uint32_t batch_size = @@ -461,7 +463,8 @@ void prune(raft::device_resources const& res, const dim3 threads_prune(32, 1, 1); const dim3 blocks_prune(batch_size, 1, 1); - RAFT_CUDA_TRY(cudaMemset(dev_stats.data_handle(), 0, sizeof(uint64_t) * 2)); + RAFT_CUDA_TRY( + cudaMemsetAsync(dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, res.get_stream())); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { kernel_prune<<>>( @@ -475,9 +478,9 @@ void prune(raft::device_resources const& res, d_num_no_detour_edges.data_handle(), dev_stats.data_handle()); res.sync_stream(); - fprintf(stderr, - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); + RAFT_LOG_ERROR( + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); } res.sync_stream(); RAFT_LOG_DEBUG("\n"); @@ -512,14 +515,14 @@ void prune(raft::device_resources const& res, // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); const double time_prune_end = cur_time(); - fprintf(stderr, - "# Pruning time: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%\n", - time_prune_end - time_prune_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); + RAFT_LOG_ERROR( + "# Pruning time: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%\n", + time_prune_end - time_prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); } auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); @@ -532,11 +535,14 @@ void prune(raft::device_resources const& res, const double time_make_start = cur_time(); auto d_rev_graph = raft::make_device_matrix(res, graph_size, output_graph_degree); - RAFT_CUDA_TRY( - cudaMemset(d_rev_graph.data_handle(), 0xff, graph_size * output_graph_degree * sizeof(IdxT))); + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), + 0xff, + graph_size * output_graph_degree * sizeof(IdxT), + res.get_stream())); auto d_rev_graph_count = raft::make_device_vector(res, graph_size); - RAFT_CUDA_TRY(cudaMemset(d_rev_graph_count.data_handle(), 0x00, graph_size * sizeof(uint32_t))); + RAFT_CUDA_TRY(cudaMemsetAsync( + d_rev_graph_count.data_handle(), 0x00, graph_size * sizeof(uint32_t), res.get_stream())); auto dest_nodes = raft::make_host_vector(graph_size); auto d_dest_nodes = raft::make_device_vector(res, graph_size); @@ -626,9 +632,8 @@ void prune(raft::device_resources const& res, if (pos == output_graph_degree) { num_replaced_edges += 1; } } } - fprintf(stderr, - "# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); + RAFT_LOG_ERROR("# Average number of replaced edges per node: %.2f", + (double)num_replaced_edges / graph_size); } } From c20b49b965605136208b79085a2b7ebc313e75ea Mon Sep 17 00:00:00 2001 From: Hiroyuki Ootomo Date: Thu, 11 May 2023 13:16:33 +0900 Subject: [PATCH 6/6] Fix log level in CAGRA::prune --- cpp/include/raft/neighbors/detail/cagra/graph_core.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 44443bfa6f..b7fffb4eaa 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -478,7 +478,7 @@ void prune(raft::device_resources const& res, d_num_no_detour_edges.data_handle(), dev_stats.data_handle()); res.sync_stream(); - RAFT_LOG_ERROR( + RAFT_LOG_DEBUG( "# Pruning kNN Graph on GPUs (%.1lf %%)\r", (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); } @@ -515,7 +515,7 @@ void prune(raft::device_resources const& res, // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); const double time_prune_end = cur_time(); - RAFT_LOG_ERROR( + RAFT_LOG_DEBUG( "# Pruning time: %.1lf sec, " "avg_no_detour_edges_per_node: %.2lf/%u, " "nodes_with_no_detour_at_all_edges: %.1lf%%\n", @@ -632,7 +632,7 @@ void prune(raft::device_resources const& res, if (pos == output_graph_degree) { num_replaced_edges += 1; } } } - RAFT_LOG_ERROR("# Average number of replaced edges per node: %.2f", + RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", (double)num_replaced_edges / graph_size); } }