diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index feb9b76b2d..e56198ebb7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -31,6 +31,7 @@ #include #include +#include #include #include "utils.hpp" @@ -67,7 +68,7 @@ __device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool a return false; } -template +template __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dataset_dim] const IdxT dataset_size, const uint32_t dataset_dim, @@ -75,25 +76,23 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, const uint32_t graph_size, const uint32_t graph_degree) { - __shared__ float smem_keys[blockDim_x * numElementsPerThread]; - __shared__ IdxT smem_vals[blockDim_x * numElementsPerThread]; - - const IdxT srcNode = blockIdx.x; + const IdxT srcNode = (blockDim.x * blockIdx.x + threadIdx.x) / raft::WarpSize; if (srcNode >= graph_size) { return; } - const uint32_t num_warps = blockDim_x / 32; - const uint32_t warp_id = threadIdx.x / 32; - const uint32_t lane_id = threadIdx.x % 32; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; + + float my_keys[numElementsPerThread]; + IdxT my_vals[numElementsPerThread]; // Compute distance from a src node to its neighbors - for (int k = warp_id; k < graph_degree; k += num_warps) { - const IdxT dstNode = knn_graph[k + ((uint64_t)graph_degree * srcNode)]; + for (int k = 0; k < graph_degree; k++) { + const IdxT dstNode = knn_graph[k + static_cast(graph_degree) * srcNode]; float dist = 0.0; - for (int d = lane_id; d < dataset_dim; d += 32) { + for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) { float diff = spatial::knn::detail::utils::mapping{}( - dataset[d + ((uint64_t)dataset_dim * srcNode)]) - + dataset[d + static_cast(dataset_dim) * srcNode]) - spatial::knn::detail::utils::mapping{}( - dataset[d + ((uint64_t)dataset_dim * dstNode)]); + dataset[d + static_cast(dataset_dim) * dstNode]); dist += diff * diff; } dist += __shfl_xor_sync(0xffffffff, dist, 1); @@ -101,91 +100,24 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dist += __shfl_xor_sync(0xffffffff, dist, 4); dist += __shfl_xor_sync(0xffffffff, dist, 8); dist += __shfl_xor_sync(0xffffffff, dist, 16); - if (lane_id == 0) { - smem_keys[k] = dist; - smem_vals[k] = dstNode; - } - } - __syncthreads(); - - float my_keys[numElementsPerThread]; - IdxT my_vals[numElementsPerThread]; - for (int i = 0; i < numElementsPerThread; i++) { - const int k = i + (numElementsPerThread * threadIdx.x); - if (k < graph_degree) { - my_keys[i] = smem_keys[k]; - my_vals[i] = smem_vals[k]; - } else { - my_keys[i] = FLT_MAX; - my_vals[i] = utils::get_max_value(); + if (lane_id == (k % raft::WarpSize)) { + my_keys[k / raft::WarpSize] = dist; + my_vals[k / raft::WarpSize] = dstNode; } } - __syncthreads(); - - // Sorting by thread - uint32_t mask = 1; - const bool ascending = ((threadIdx.x & mask) == 0); - for (int j = 0; j < numElementsPerThread; j += 2) { -#pragma unroll - for (int i = 0; i < numElementsPerThread; i += 2) { - swap_if_needed( - my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); - } -#pragma unroll - for (int i = 1; i < numElementsPerThread - 1; i += 2) { - swap_if_needed( - my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); + for (int k = graph_degree; k < raft::WarpSize * numElementsPerThread; k++) { + if (lane_id == k % raft::WarpSize) { + my_keys[k / raft::WarpSize] = utils::get_max_value(); + my_vals[k / raft::WarpSize] = utils::get_max_value(); } } - // Bitonic Sorting - while (mask < blockDim_x) { - const uint32_t next_mask = mask << 1; - - for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { - const bool ascending = ((threadIdx.x & curr_mask) == 0) == ((threadIdx.x & next_mask) == 0); - if (mask >= 32) { - // inter warp - __syncthreads(); -#pragma unroll - for (int i = 0; i < numElementsPerThread; i++) { - smem_keys[threadIdx.x + (blockDim_x * i)] = my_keys[i]; - smem_vals[threadIdx.x + (blockDim_x * i)] = my_vals[i]; - } - __syncthreads(); -#pragma unroll - for (int i = 0; i < numElementsPerThread; i++) { - float opp_key = smem_keys[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; - IdxT opp_val = smem_vals[(threadIdx.x ^ curr_mask) + (blockDim_x * i)]; - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); - } - } else { -// intra warp -#pragma unroll - for (int i = 0; i < numElementsPerThread; i++) { - float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); - IdxT opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); - } - } - } - - const bool ascending = ((threadIdx.x & next_mask) == 0); -#pragma unroll - for (uint32_t curr_mask = numElementsPerThread / 2; curr_mask > 0; curr_mask >>= 1) { -#pragma unroll - for (int i = 0; i < numElementsPerThread; i++) { - int j = i ^ curr_mask; - if (i > j) continue; - swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); - } - } - mask = next_mask; - } + // Sort by RAFT bitonic sort + raft::util::bitonic(true).sort(my_keys, my_vals); // Update knn_graph for (int i = 0; i < numElementsPerThread; i++) { - const int k = i + (numElementsPerThread * threadIdx.x); + const int k = i * raft::WarpSize + lane_id; if (k < graph_degree) { knn_graph[k + (static_cast(graph_degree) * srcNode)] = my_vals[i]; } @@ -333,35 +265,38 @@ void sort_knn_graph(raft::resources const& res, void (*kernel_sort)( const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t); - constexpr int numElementsPerThread = 4; - dim3 threads_sort(1, 1, 1); - if (input_graph_degree <= numElementsPerThread * 32) { - constexpr int blockDim_x = 32; - kernel_sort = kern_sort; - threads_sort.x = blockDim_x; - } else if (input_graph_degree <= numElementsPerThread * 64) { - constexpr int blockDim_x = 64; - kernel_sort = kern_sort; - threads_sort.x = blockDim_x; - } else if (input_graph_degree <= numElementsPerThread * 128) { - constexpr int blockDim_x = 128; - kernel_sort = kern_sort; - threads_sort.x = blockDim_x; - } else if (input_graph_degree <= numElementsPerThread * 256) { - constexpr int blockDim_x = 256; - kernel_sort = kern_sort; - threads_sort.x = blockDim_x; + if (input_graph_degree <= 32) { + constexpr int numElementsPerThread = 1; + kernel_sort = kern_sort; + } else if (input_graph_degree <= 64) { + constexpr int numElementsPerThread = 2; + kernel_sort = kern_sort; + } else if (input_graph_degree <= 128) { + constexpr int numElementsPerThread = 4; + kernel_sort = kern_sort; + } else if (input_graph_degree <= 256) { + constexpr int numElementsPerThread = 8; + kernel_sort = kern_sort; + } else if (input_graph_degree <= 512) { + constexpr int numElementsPerThread = 16; + kernel_sort = kern_sort; + } else if (input_graph_degree <= 1024) { + constexpr int numElementsPerThread = 32; + kernel_sort = kern_sort; } else { RAFT_LOG_ERROR( "[ERROR] The degree of input knn graph is too large (%u). " "It must be equal to or small than %d.\n", input_graph_degree, - numElementsPerThread * 256); + 1024); exit(-1); } - dim3 blocks_sort(graph_size, 1, 1); + const auto block_size = 256; + const auto num_warps_per_block = block_size / raft::WarpSize; + const auto grid_size = (graph_size + num_warps_per_block - 1) / num_warps_per_block; + RAFT_LOG_DEBUG("."); - kernel_sort<<>>( + kernel_sort<<>>( d_dataset.data_handle(), dataset_size, dataset_dim,