Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a CAGRA graph opt bug #192

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions cpp/src/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -435,24 +435,39 @@ void optimize(
const auto num_full = host_stats.data_handle()[1];

// Create pruned kNN graph
uint32_t max_detour = 0;
#pragma omp parallel for reduction(max : max_detour)
#pragma omp parallel for
for (uint64_t i = 0; i < graph_size; i++) {
uint64_t pk = 0;
for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) {
if (max_detour < num_detour) { max_detour = num_detour; /* stats */ }
// Find the `output_graph_degree` smallest detourable count nodes by checking the detourable
// count of the neighbors while increasing the target detourable count from zero.
uint64_t pk = 0;
uint32_t num_detour = 0;
while (pk < output_graph_degree) {
uint32_t next_num_detour = std::numeric_limits<uint32_t>::max();
for (uint64_t k = 0; k < input_graph_degree; k++) {
if (detour_count.data_handle()[k + (input_graph_degree * i)] != num_detour) { continue; }
const auto num_detour_k = detour_count.data_handle()[k + (input_graph_degree * i)];
// Find the detourable count to check in the next iteration
if (num_detour_k > num_detour) {
next_num_detour = std::min(static_cast<uint32_t>(num_detour_k), next_num_detour);
}

// Store the neighbor index if its detourable count is equal to `num_detour`.
if (num_detour_k != num_detour) { continue; }
output_graph_ptr[pk + (output_graph_degree * i)] =
input_graph_ptr[k + (input_graph_degree * i)];
pk += 1;
if (pk >= output_graph_degree) break;
}
if (pk >= output_graph_degree) break;

assert(next_num_detour != std::numeric_limits<uint32_t>::max());
num_detour = next_num_detour;
}
assert(pk == output_graph_degree);
RAFT_EXPECTS(pk == output_graph_degree,
"Couldn't find the output_graph_degree (%u) smallest detourable count nodes for "
"node %lu in the rank-based node reranking process",
output_graph_degree,
static_cast<uint64_t>(i));
}
// RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour);

const double time_prune_end = cur_time();
RAFT_LOG_DEBUG(
Expand Down
Loading