Skip to content

Commit

Permalink
Move single and multi cta kernels into separate files
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed May 14, 2023
1 parent b32cd60 commit 95428b3
Show file tree
Hide file tree
Showing 5 changed files with 1,148 additions and 1,033 deletions.
279 changes: 1 addition & 278 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "compute_distance.hpp"
#include "device_common.hpp"
#include "hashmap.hpp"
#include "search_multi_cta_kernel-inl.cuh"
#include "search_plan.cuh"
#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible
#include "utils.hpp"
Expand All @@ -41,284 +42,6 @@
namespace raft::neighbors::experimental::cagra::detail {
namespace multi_cta_search {

// #define _CLK_BREAKDOWN

template <class INDEX_T>
__device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num_parents]
const uint32_t num_parents,
INDEX_T* const itopk_indices, // [num_itopk]
const size_t num_itopk,
uint32_t* const terminate_flag)
{
const unsigned warp_id = threadIdx.x / 32;
if (warp_id > 0) { return; }
const unsigned lane_id = threadIdx.x % 32;
for (uint32_t i = lane_id; i < num_parents; i += 32) {
next_parent_indices[i] = utils::get_max_value<INDEX_T>();
}
uint32_t max_itopk = num_itopk;
if (max_itopk % 32) { max_itopk += 32 - (max_itopk % 32); }
uint32_t num_new_parents = 0;
for (uint32_t j = lane_id; j < max_itopk; j += 32) {
INDEX_T index;
int new_parent = 0;
if (j < num_itopk) {
index = itopk_indices[j];
if ((index & 0x80000000) == 0) { // check if most significant bit is set
new_parent = 1;
}
}
const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent);
if (new_parent) {
const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents;
if (i < num_parents) {
next_parent_indices[i] = index;
itopk_indices[j] |= 0x80000000; // set most significant bit as used node
}
}
num_new_parents += __popc(ballot_mask);
if (num_new_parents >= num_parents) { break; }
}
if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; }
}

template <unsigned MAX_ELEMENTS>
__device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements]
uint32_t* indices, // [num_elements]
const uint32_t num_elements,
const uint32_t num_itopk // num_itopk <= num_elements
)
{
const unsigned warp_id = threadIdx.x / 32;
if (warp_id > 0) { return; }
const unsigned lane_id = threadIdx.x % 32;
constexpr unsigned N = (MAX_ELEMENTS + 31) / 32;
float key[N];
uint32_t val[N];
for (unsigned i = 0; i < N; i++) {
unsigned j = lane_id + (32 * i);
if (j < num_elements) {
key[i] = distances[j];
val[i] = indices[j];
} else {
key[i] = utils::get_max_value<float>();
val[i] = utils::get_max_value<uint32_t>();
}
}
/* Warp Sort */
bitonic::warp_sort<float, uint32_t, N>(key, val);
/* Store itopk sorted results */
for (unsigned i = 0; i < N; i++) {
unsigned j = (N * lane_id) + i;
if (j < num_itopk) {
distances[j] = key[i];
indices[j] = val[i];
}
}
}

//
// multiple CTAs per single query
//
template <unsigned TEAM_SIZE,
unsigned BLOCK_SIZE,
unsigned BLOCK_COUNT,
unsigned MAX_ELEMENTS,
unsigned MAX_DATASET_DIM,
class DATA_T,
class DISTANCE_T,
class INDEX_T,
class LOAD_T>
__launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size]
DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size]
const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim]
const size_t dataset_dim,
const size_t dataset_size,
const DATA_T* const queries_ptr, // [num_queries, dataset_dim]
const INDEX_T* const knn_graph, // [dataset_size, graph_degree]
const uint32_t graph_degree,
const unsigned num_distilation,
const uint64_t rand_xor_mask,
const INDEX_T* seed_ptr, // [num_queries, num_seeds]
const uint32_t num_seeds,
uint32_t* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen]
const uint32_t hash_bitlen,
const uint32_t itopk_size,
const uint32_t num_parents,
const uint32_t min_iteration,
const uint32_t max_iteration,
uint32_t* const num_executed_iterations /* stats */
)
{
assert(blockDim.x == BLOCK_SIZE);
assert(dataset_dim <= MAX_DATASET_DIM);

// const auto num_queries = gridDim.y;
const auto query_id = blockIdx.y;
const auto num_cta_per_query = gridDim.x;
const auto cta_id = blockIdx.x; // local CTA ID

#ifdef _CLK_BREAKDOWN
uint64_t clk_init = 0;
uint64_t clk_compute_1st_distance = 0;
uint64_t clk_topk = 0;
uint64_t clk_pickup_parents = 0;
uint64_t clk_compute_distance = 0;
uint64_t clk_start;
#define _CLK_START() clk_start = clock64()
#define _CLK_REC(V) V += clock64() - clk_start;
#else
#define _CLK_START()
#define _CLK_REC(V)
#endif
_CLK_START();

extern __shared__ uint32_t smem[];

// Layout of result_buffer
// +----------------+------------------------------+---------+
// | internal_top_k | neighbors of parent nodes | padding |
// | <itopk_size> | <num_parents * graph_degree> | upto 32 |
// +----------------+------------------------------+---------+
// |<--- result_buffer_size --->|
uint32_t result_buffer_size = itopk_size + (num_parents * graph_degree);
uint32_t result_buffer_size_32 = result_buffer_size;
if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); }
assert(result_buffer_size_32 <= MAX_ELEMENTS);

auto query_buffer = reinterpret_cast<float*>(smem);
auto result_indices_buffer = reinterpret_cast<INDEX_T*>(query_buffer + MAX_DATASET_DIM);
auto result_distances_buffer =
reinterpret_cast<DISTANCE_T*>(result_indices_buffer + result_buffer_size_32);
auto parent_indices_buffer =
reinterpret_cast<uint32_t*>(result_distances_buffer + result_buffer_size_32);
auto terminate_flag = reinterpret_cast<uint32_t*>(parent_indices_buffer + num_parents);

#if 0
/* debug */
for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += BLOCK_SIZE) {
result_indices_buffer[i] = utils::get_max_value<INDEX_T>();
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>();
}
#endif

const DATA_T* const query_ptr = queries_ptr + (dataset_dim * query_id);
for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) {
unsigned j = device::swizzling(i);
if (i < dataset_dim) {
query_buffer[j] = spatial::knn::detail::utils::mapping<float>{}(query_ptr[i]);
} else {
query_buffer[j] = 0.0;
}
}
if (threadIdx.x == 0) { terminate_flag[0] = 0; }
uint32_t* local_visited_hashmap_ptr =
visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id);
__syncthreads();
_CLK_REC(clk_init);

// compute distance to randomly selecting nodes
_CLK_START();
const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr;
device::compute_distance_to_random_nodes<TEAM_SIZE, MAX_DATASET_DIM, LOAD_T>(
result_indices_buffer,
result_distances_buffer,
query_buffer,
dataset_ptr,
dataset_dim,
dataset_size,
result_buffer_size,
num_distilation,
rand_xor_mask,
local_seed_ptr,
num_seeds,
local_visited_hashmap_ptr,
hash_bitlen,
cta_id,
num_cta_per_query);
__syncthreads();
_CLK_REC(clk_compute_1st_distance);

uint32_t iter = 0;
while (1) {
// topk with bitonic sort
_CLK_START();
topk_by_bitonic_sort<MAX_ELEMENTS>(result_distances_buffer,
result_indices_buffer,
itopk_size + (num_parents * graph_degree),
itopk_size);
_CLK_REC(clk_topk);

if (iter + 1 == max_iteration) {
__syncthreads();
break;
}

// pick up next parents
_CLK_START();
pickup_next_parents<INDEX_T>(
parent_indices_buffer, num_parents, result_indices_buffer, itopk_size, terminate_flag);
_CLK_REC(clk_pickup_parents);

__syncthreads();
if (*terminate_flag && iter >= min_iteration) { break; }

// compute the norms between child nodes and query node
_CLK_START();
// constexpr unsigned max_n_frags = 16;
constexpr unsigned max_n_frags = 0;
device::
compute_distance_to_child_nodes<TEAM_SIZE, BLOCK_SIZE, MAX_DATASET_DIM, max_n_frags, LOAD_T>(
result_indices_buffer + itopk_size,
result_distances_buffer + itopk_size,
query_buffer,
dataset_ptr,
dataset_dim,
knn_graph,
graph_degree,
local_visited_hashmap_ptr,
hash_bitlen,
parent_indices_buffer,
num_parents);
_CLK_REC(clk_compute_distance);
__syncthreads();

iter++;
}

for (uint32_t i = threadIdx.x; i < itopk_size; i += BLOCK_SIZE) {
uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id)));
if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; }
result_indices_ptr[j] = result_indices_buffer[i] & ~0x80000000; // clear most significant bit
}

if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) {
num_executed_iterations[query_id] = iter + 1;
}

#ifdef _CLK_BREAKDOWN
if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && (blockIdx.x == 0) &&
((query_id * 3) % gridDim.y < 3)) {
RAFT_LOG_DEBUG(
"query, %d, thread, %d"
", init, %d"
", 1st_distance, %lu"
", topk, %lu"
", pickup_parents, %lu"
", distance, %lu"
"\n",
query_id,
threadIdx.x,
clk_init,
clk_compute_1st_distance,
clk_topk,
clk_pickup_parents,
clk_compute_distance);
}
#endif
}

#define SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS, LOAD_T) \
kernel = search_kernel<TEAM_SIZE, \
BLOCK_SIZE, \
Expand Down
Loading

0 comments on commit 95428b3

Please sign in to comment.