Skip to content

Commit

Permalink
Merge pull request #3 from RayWang96/port-nn-descent
Browse files Browse the repository at this point in the history
Fix bugs in NN-Descent
  • Loading branch information
divyegala authored Aug 22, 2023
2 parents 558f849 + 40e1cf0 commit 1f1f32d
Showing 1 changed file with 50 additions and 39 deletions.
89 changes: 50 additions & 39 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <limits>
#include <thrust/execution_policy.h>
#include <thrust/fill.h>
#include <queue>

#include "../nn_descent_types.hpp"

Expand Down Expand Up @@ -372,7 +373,7 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec
const int lane_id) {
if constexpr (std::is_same_v<Data_t, float> or std::is_same_v<Data_t, uint8_t> or std::is_same_v<int8_t, float>) {
constexpr int num_load_elems_per_warp = WARP_SIZE;
for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) {
for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) {
int idx = step * num_load_elems_per_warp + lane_id;
if (idx < load_dims) {
vec_buffer[idx] = d_vec[idx];
Expand All @@ -381,12 +382,12 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec
}
}
}
if constexpr (std::is_same<Data_t, __half>::value) {
if ((size_t)vec_buffer % sizeof(float2) == 0 && load_dims % 4 == 0 &&
padding_dims % 4 == 0) {
if constexpr (std::is_same_v<Data_t, __half>) {
if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 &&
load_dims % 4 == 0 && padding_dims % 4 == 0) {
constexpr int num_load_elems_per_warp = WARP_SIZE * 4;
#pragma unroll
for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) {
for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) {
int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4;
if (idx_in_vec + 4 <= load_dims) {
*(float2 *)(vec_buffer + idx_in_vec) = *(float2 *)(d_vec + idx_in_vec);
Expand All @@ -396,7 +397,7 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec
}
} else {
constexpr int num_load_elems_per_warp = WARP_SIZE;
for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) {
for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) {
int idx = step * num_load_elems_per_warp + lane_id;
if (idx < load_dims) {
vec_buffer[idx] = d_vec[idx];
Expand All @@ -408,15 +409,15 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec
}
}

template <typename Data_t, typename Index_t>
__global__ void preprocess_data_kernel(const Data_t *input_data, __half *output_data, Index_t nrow,
int dim, DistData_t *l2_norms) {
template <typename Data_t>
__global__ void preprocess_data_kernel(const Data_t* input_data, __half* output_data, int dim,
DistData_t* l2_norms, size_t list_offset = 0) {
extern __shared__ char buffer[];
__shared__ float l2_norm;
Data_t *s_vec = (Data_t *)buffer;
size_t list_id = blockIdx.x;
size_t list_id = list_offset + blockIdx.x;

load_vec(s_vec, input_data + list_id * dim, dim, dim, threadIdx.x % WARP_SIZE);
load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % WARP_SIZE);
if (threadIdx.x == 0) {
l2_norm = 0;
}
Expand All @@ -443,9 +444,10 @@ __global__ void preprocess_data_kernel(const Data_t *input_data, __half *output_
int idx = step * WARP_SIZE + threadIdx.x;
if (idx < dim) {
if (l2_norms == nullptr) {
output_data[list_id * dim + idx] = (float)input_data[list_id * dim + idx] / sqrt(l2_norm);
output_data[list_id * dim + idx] =
(float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm);
} else {
output_data[list_id * dim + idx] = input_data[list_id * dim + idx];
output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx];
if (idx == 0) {
l2_norms[list_id] = l2_norm;
}
Expand Down Expand Up @@ -475,8 +477,7 @@ __global__ void add_rev_edges_kernel(const Index_t *graph, Index_t *rev_graph, i

template <typename Index_t, typename ID_t = InternalID_t<Index_t>>
__device__ void insert_to_global_graph(ResultItem<Index_t> elem, size_t list_id, ID_t *graph,
DistData_t *dists, int node_degree, int *locks,
bool new_new = true) {
DistData_t *dists, int node_degree, int *locks) {
int tx = threadIdx.x;
int lane_id = tx % WARP_SIZE;
size_t global_idx_base = list_id * node_degree;
Expand Down Expand Up @@ -760,8 +761,7 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 4)
if (idx_in_list >= list_new_size) continue;
auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances);
if (min_elem.id() < gridDim.x) {
insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks,
true);
insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks);
}
}

Expand Down Expand Up @@ -851,8 +851,7 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 4)
}

if (min_elem.id() < gridDim.x) {
insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks,
false);
insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks);
}
}
}
Expand Down Expand Up @@ -945,18 +944,13 @@ void GnndGraph<Index_t>::init_random_graph() {

#pragma omp parallel for
for (size_t i = 0; i < nrow; i++) {
for (size_t j = 0; j < NUM_SAMPLES; j++) {
size_t idx = i * NUM_SAMPLES + j;
for (size_t j = 0; j < node_degree; j++) {
size_t idx = i * node_degree + j;
Index_t id = rand_seq[idx % nrow];
if ((size_t)id == i) {
id = rand_seq[(idx + NUM_SAMPLES) % nrow];
id = rand_seq[(idx + node_degree) % nrow];
}
h_graph[i * node_degree + j].id_with_flag() = id;
}
for (size_t j = NUM_SAMPLES; j < node_degree; j++) {
h_graph[i * node_degree + j].id_with_flag() = std::numeric_limits<Index_t>::max();
}
for (size_t j = 0; j < node_degree; j++) {
h_dists[i * node_degree + j] = std::numeric_limits<DistData_t>::max();
}
}
Expand Down Expand Up @@ -1113,7 +1107,7 @@ void GNND<Data_t, Index_t>::add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev
list_sizes);
RAFT_CUDA_TRY(cudaMemcpyAsync(h_rev_graph_ptr, d_rev_graph_ptr,
sizeof(*h_rev_graph_ptr) * nrow_ * NUM_SAMPLES,
cudaMemcpyDeviceToHost, stream));
cudaMemcpyDefault, stream));
}

template <typename Data_t, typename Index_t>
Expand All @@ -1136,14 +1130,31 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
cudaPointerAttributes data_ptr_attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data));
if (data_ptr_attr.type == cudaMemoryTypeUnregistered) {
RAFT_CUDA_TRY(cudaHostRegister(const_cast<std::remove_const_t<Data_t>*>(data), sizeof(Data_t) * nrow * build_config_.dataset_dim,
cudaHostRegisterDefault));
typename std::remove_const<Data_t>::type* input_data;
size_t batch_size = 100000;
RAFT_CUDA_TRY(cudaMallocAsync(&input_data,
sizeof(Data_t) * batch_size * build_config_.dataset_dim, stream));
for (size_t step = 0; step < div_up(nrow_, batch_size); step++) {
size_t list_offset = step * batch_size;
size_t num_lists =
step != div_up(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset;
RAFT_CUDA_TRY(cudaMemcpyAsync(
input_data, data + list_offset * build_config_.dataset_dim,
sizeof(Data_t) * num_lists * build_config_.dataset_dim, cudaMemcpyDefault, stream));
preprocess_data_kernel<<<num_lists, WARP_SIZE,
sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) *
WARP_SIZE,
stream>>>(input_data, d_data_, build_config_.dataset_dim,
l2_norms_, list_offset);
}
RAFT_CUDA_TRY(cudaFreeAsync(input_data, stream));
} else {
preprocess_data_kernel<<<
nrow_, WARP_SIZE,
sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) * WARP_SIZE, stream>>>(
data, d_data_, build_config_.dataset_dim, l2_norms_);
}
preprocess_data_kernel<<<
nrow_, WARP_SIZE, sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) * WARP_SIZE,
stream>>>(data, d_data_, nrow_, build_config_.dataset_dim, l2_norms_);
thrust::fill(thrust::device.on(stream), (Index_t*)graph_buffer_,
(Index_t*)graph_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE,
std::numeric_limits<Index_t>::max());
Expand All @@ -1168,13 +1179,13 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
for (size_t it = 0; it < build_config_.max_iterations; it++) {
RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_, graph_.h_list_sizes_new,
sizeof(*d_list_sizes_new_) * nrow_,
cudaMemcpyHostToDevice, stream));
cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaMemcpyAsync(h_graph_old_, graph_.h_graph_old,
sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES,
cudaMemcpyHostToHost, stream));
cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_, graph_.h_list_sizes_old,
sizeof(*d_list_sizes_old_) * nrow_,
cudaMemcpyHostToDevice, stream));
cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
std::thread update_and_sample_thread(update_and_sample, it);
Expand All @@ -1201,11 +1212,11 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
RAFT_CUDA_TRY(cudaMemcpyAsync(graph_host_buffer_, graph_buffer_,
sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE,
cudaMemcpyDeviceToHost, stream));
cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
RAFT_CUDA_TRY(cudaMemcpyAsync(dists_host_buffer_, dists_buffer_,
sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE,
cudaMemcpyDeviceToHost, stream));
cudaMemcpyDefault, stream));
graph_.sample_graph_new(graph_host_buffer_, DEGREE_ON_DEVICE);
}
Expand Down

0 comments on commit 1f1f32d

Please sign in to comment.