Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port nn descent #4

Merged
merged 2 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ void sort_knn_graph(raft::resources const& res,
const uint32_t input_graph_degree = knn_graph.extent(1);
IdxT* const input_graph_ptr = knn_graph.data_handle();

auto d_input_graph = raft::make_device_matrix<IdxT, IdxT>(res, graph_size, input_graph_degree);
auto d_input_graph = raft::make_device_matrix<IdxT, int64_t>(res, graph_size, input_graph_degree);

//
// Sorting kNN graph
Expand Down
41 changes: 25 additions & 16 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -927,11 +927,12 @@ int insert_to_ordered_list(InternalID_t<Index_t>* list,
if (dist > dist_list[width - 1]) { return width; }

int idx_insert = width;
bool position_found = false;
for (int i = 0; i < width; i++) {
if (list[i].id() == neighb_id.id()) { return width; }
if (dist_list[i] > dist) {
if (!position_found && dist_list[i] > dist) {
idx_insert = i;
break;
position_found = true;
}
}
if (idx_insert == width) return idx_insert;
Expand Down Expand Up @@ -1000,19 +1001,27 @@ void GnndGraph<Index_t>::sample_graph_new(InternalID_t<Index_t>* new_neighbors,
template <typename Index_t>
void GnndGraph<Index_t>::init_random_graph()
{
// random sequence (range: 0~nrow)
std::vector<Index_t> rand_seq(nrow);
std::iota(rand_seq.begin(), rand_seq.end(), 0);
std::random_shuffle(rand_seq.begin(), rand_seq.end());
for (size_t seg_idx = 0; seg_idx < static_cast<size_t>(num_segments); seg_idx++) {
// random sequence (range: 0~nrow)
// segment_x stores neighbors which id % num_segments == x
std::vector<Index_t> rand_seq(nrow / num_segments);
std::iota(rand_seq.begin(), rand_seq.end(), 0);
std::random_shuffle(rand_seq.begin(), rand_seq.end());

#pragma omp parallel for
for (size_t i = 0; i < nrow; i++) {
for (size_t j = 0; j < node_degree; j++) {
size_t idx = i * node_degree + j;
Index_t id = rand_seq[idx % nrow];
if ((size_t)id == i) { id = rand_seq[(idx + node_degree) % nrow]; }
h_graph[i * node_degree + j].id_with_flag() = id;
h_dists[i * node_degree + j] = std::numeric_limits<DistData_t>::max();
for (size_t i = 0; i < nrow; i++) {
size_t base_idx = i * node_degree + seg_idx * segment_size;
auto h_neighbor_list = h_graph + base_idx;
auto h_dist_list = h_dists + base_idx;
for (size_t j = 0; j < static_cast<size_t>(segment_size); j++) {
size_t idx = base_idx + j;
Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx;
if ((size_t)id == i) {
id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx;
}
h_neighbor_list[j].id_with_flag() = id;
h_dist_list[j] = std::numeric_limits<DistData_t>::max();
}
}
}
}
Expand Down Expand Up @@ -1064,9 +1073,9 @@ void GnndGraph<Index_t>::update_graph(const InternalID_t<Index_t>* new_neighbors
auto new_dist = new_dists[i * width + j];
if (new_dist == std::numeric_limits<DistData_t>::max()) break;
if ((size_t)new_neighb_id.id() == i) continue;
int idx_seg = new_neighb_id.id() % num_segments;
auto list = h_graph + i * node_degree + idx_seg * segment_size;
auto dist_list = h_dists + i * node_degree + idx_seg * segment_size;
int seg_idx = new_neighb_id.id() % num_segments;
auto list = h_graph + i * node_degree + seg_idx * segment_size;
auto dist_list = h_dists + i * node_degree + seg_idx * segment_size;
int insert_pos =
insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist);
if (i % counter_interval == 0 && insert_pos != segment_size) { update_counter++; }
Expand Down