Skip to content

Commit

Permalink
Update CAGRA knn_graph_sort to use Raft::bitonic_sort (#1550)
Browse files Browse the repository at this point in the history
This PR changes CAGRA `knn_graph_sort` function to use `raft::util::bitnic_sort` instead of a custom sorting function.

Rel: #1503 (comment)

Authors:
  - tsuki (https://github.com/enp1s0)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1550
  • Loading branch information
enp1s0 authored Jun 2, 2023
1 parent e804156 commit cfe27ec
Showing 1 changed file with 46 additions and 111 deletions.
157 changes: 46 additions & 111 deletions cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <random>
#include <sys/time.h>

#include <raft/util/bitonic_sort.cuh>
#include <raft/util/cuda_rt_essentials.hpp>

#include "utils.hpp"
Expand Down Expand Up @@ -67,125 +68,56 @@ __device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool a
return false;
}

template <class DATA_T, class IdxT, int blockDim_x, int numElementsPerThread>
template <class DATA_T, class IdxT, int numElementsPerThread>
__global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dataset_dim]
const IdxT dataset_size,
const uint32_t dataset_dim,
IdxT* const knn_graph, // [graph_chunk_size, graph_degree]
const uint32_t graph_size,
const uint32_t graph_degree)
{
__shared__ float smem_keys[blockDim_x * numElementsPerThread];
__shared__ IdxT smem_vals[blockDim_x * numElementsPerThread];

const IdxT srcNode = blockIdx.x;
const IdxT srcNode = (blockDim.x * blockIdx.x + threadIdx.x) / raft::WarpSize;
if (srcNode >= graph_size) { return; }

const uint32_t num_warps = blockDim_x / 32;
const uint32_t warp_id = threadIdx.x / 32;
const uint32_t lane_id = threadIdx.x % 32;
const uint32_t lane_id = threadIdx.x % raft::WarpSize;

float my_keys[numElementsPerThread];
IdxT my_vals[numElementsPerThread];

// Compute distance from a src node to its neighbors
for (int k = warp_id; k < graph_degree; k += num_warps) {
const IdxT dstNode = knn_graph[k + ((uint64_t)graph_degree * srcNode)];
for (int k = 0; k < graph_degree; k++) {
const IdxT dstNode = knn_graph[k + static_cast<uint64_t>(graph_degree) * srcNode];
float dist = 0.0;
for (int d = lane_id; d < dataset_dim; d += 32) {
for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
float diff = spatial::knn::detail::utils::mapping<float>{}(
dataset[d + ((uint64_t)dataset_dim * srcNode)]) -
dataset[d + static_cast<uint64_t>(dataset_dim) * srcNode]) -
spatial::knn::detail::utils::mapping<float>{}(
dataset[d + ((uint64_t)dataset_dim * dstNode)]);
dataset[d + static_cast<uint64_t>(dataset_dim) * dstNode]);
dist += diff * diff;
}
dist += __shfl_xor_sync(0xffffffff, dist, 1);
dist += __shfl_xor_sync(0xffffffff, dist, 2);
dist += __shfl_xor_sync(0xffffffff, dist, 4);
dist += __shfl_xor_sync(0xffffffff, dist, 8);
dist += __shfl_xor_sync(0xffffffff, dist, 16);
if (lane_id == 0) {
smem_keys[k] = dist;
smem_vals[k] = dstNode;
}
}
__syncthreads();

float my_keys[numElementsPerThread];
IdxT my_vals[numElementsPerThread];
for (int i = 0; i < numElementsPerThread; i++) {
const int k = i + (numElementsPerThread * threadIdx.x);
if (k < graph_degree) {
my_keys[i] = smem_keys[k];
my_vals[i] = smem_vals[k];
} else {
my_keys[i] = FLT_MAX;
my_vals[i] = utils::get_max_value<IdxT>();
if (lane_id == (k % raft::WarpSize)) {
my_keys[k / raft::WarpSize] = dist;
my_vals[k / raft::WarpSize] = dstNode;
}
}
__syncthreads();

// Sorting by thread
uint32_t mask = 1;
const bool ascending = ((threadIdx.x & mask) == 0);
for (int j = 0; j < numElementsPerThread; j += 2) {
#pragma unroll
for (int i = 0; i < numElementsPerThread; i += 2) {
swap_if_needed<float, IdxT>(
my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending);
}
#pragma unroll
for (int i = 1; i < numElementsPerThread - 1; i += 2) {
swap_if_needed<float, IdxT>(
my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending);
for (int k = graph_degree; k < raft::WarpSize * numElementsPerThread; k++) {
if (lane_id == k % raft::WarpSize) {
my_keys[k / raft::WarpSize] = utils::get_max_value<float>();
my_vals[k / raft::WarpSize] = utils::get_max_value<IdxT>();
}
}

// Bitonic Sorting
while (mask < blockDim_x) {
const uint32_t next_mask = mask << 1;

for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) {
const bool ascending = ((threadIdx.x & curr_mask) == 0) == ((threadIdx.x & next_mask) == 0);
if (mask >= 32) {
// inter warp
__syncthreads();
#pragma unroll
for (int i = 0; i < numElementsPerThread; i++) {
smem_keys[threadIdx.x + (blockDim_x * i)] = my_keys[i];
smem_vals[threadIdx.x + (blockDim_x * i)] = my_vals[i];
}
__syncthreads();
#pragma unroll
for (int i = 0; i < numElementsPerThread; i++) {
float opp_key = smem_keys[(threadIdx.x ^ curr_mask) + (blockDim_x * i)];
IdxT opp_val = smem_vals[(threadIdx.x ^ curr_mask) + (blockDim_x * i)];
swap_if_needed<float, IdxT>(my_keys[i], opp_key, my_vals[i], opp_val, ascending);
}
} else {
// intra warp
#pragma unroll
for (int i = 0; i < numElementsPerThread; i++) {
float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask);
IdxT opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask);
swap_if_needed<float, IdxT>(my_keys[i], opp_key, my_vals[i], opp_val, ascending);
}
}
}

const bool ascending = ((threadIdx.x & next_mask) == 0);
#pragma unroll
for (uint32_t curr_mask = numElementsPerThread / 2; curr_mask > 0; curr_mask >>= 1) {
#pragma unroll
for (int i = 0; i < numElementsPerThread; i++) {
int j = i ^ curr_mask;
if (i > j) continue;
swap_if_needed<float, IdxT>(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending);
}
}
mask = next_mask;
}
// Sort by RAFT bitonic sort
raft::util::bitonic<numElementsPerThread>(true).sort(my_keys, my_vals);

// Update knn_graph
for (int i = 0; i < numElementsPerThread; i++) {
const int k = i + (numElementsPerThread * threadIdx.x);
const int k = i * raft::WarpSize + lane_id;
if (k < graph_degree) {
knn_graph[k + (static_cast<uint64_t>(graph_degree) * srcNode)] = my_vals[i];
}
Expand Down Expand Up @@ -333,35 +265,38 @@ void sort_knn_graph(raft::resources const& res,

void (*kernel_sort)(
const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t);
constexpr int numElementsPerThread = 4;
dim3 threads_sort(1, 1, 1);
if (input_graph_degree <= numElementsPerThread * 32) {
constexpr int blockDim_x = 32;
kernel_sort = kern_sort<DataT, IdxT, blockDim_x, numElementsPerThread>;
threads_sort.x = blockDim_x;
} else if (input_graph_degree <= numElementsPerThread * 64) {
constexpr int blockDim_x = 64;
kernel_sort = kern_sort<DataT, IdxT, blockDim_x, numElementsPerThread>;
threads_sort.x = blockDim_x;
} else if (input_graph_degree <= numElementsPerThread * 128) {
constexpr int blockDim_x = 128;
kernel_sort = kern_sort<DataT, IdxT, blockDim_x, numElementsPerThread>;
threads_sort.x = blockDim_x;
} else if (input_graph_degree <= numElementsPerThread * 256) {
constexpr int blockDim_x = 256;
kernel_sort = kern_sort<DataT, IdxT, blockDim_x, numElementsPerThread>;
threads_sort.x = blockDim_x;
if (input_graph_degree <= 32) {
constexpr int numElementsPerThread = 1;
kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
} else if (input_graph_degree <= 64) {
constexpr int numElementsPerThread = 2;
kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
} else if (input_graph_degree <= 128) {
constexpr int numElementsPerThread = 4;
kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
} else if (input_graph_degree <= 256) {
constexpr int numElementsPerThread = 8;
kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
} else if (input_graph_degree <= 512) {
constexpr int numElementsPerThread = 16;
kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
} else if (input_graph_degree <= 1024) {
constexpr int numElementsPerThread = 32;
kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
} else {
RAFT_LOG_ERROR(
"[ERROR] The degree of input knn graph is too large (%u). "
"It must be equal to or small than %d.\n",
input_graph_degree,
numElementsPerThread * 256);
1024);
exit(-1);
}
dim3 blocks_sort(graph_size, 1, 1);
const auto block_size = 256;
const auto num_warps_per_block = block_size / raft::WarpSize;
const auto grid_size = (graph_size + num_warps_per_block - 1) / num_warps_per_block;

RAFT_LOG_DEBUG(".");
kernel_sort<<<blocks_sort, threads_sort, 0, resource::get_cuda_stream(res)>>>(
kernel_sort<<<grid_size, block_size, 0, resource::get_cuda_stream(res)>>>(
d_dataset.data_handle(),
dataset_size,
dataset_dim,
Expand Down

0 comments on commit cfe27ec

Please sign in to comment.