Skip to content

Commit

Permalink
Fix a CAGRA graph opt bug (#192)
Browse files Browse the repository at this point in the history
There is a bug in the current CAGRA graph rank-based neighbor reordering process. A low recall or illegal memory access can occur if there are many detourable nodes from a node to its neighbors, e.g. there is a small subgraph in the initial kNN graph. This PR fixes this.

Authors:
  - tsuki (https://github.com/enp1s0)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #192
  • Loading branch information
enp1s0 authored Jun 18, 2024
1 parent 9dc3a4d commit 581c9cc
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions cpp/src/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -435,24 +435,39 @@ void optimize(
const auto num_full = host_stats.data_handle()[1];

// Create pruned kNN graph
uint32_t max_detour = 0;
#pragma omp parallel for reduction(max : max_detour)
#pragma omp parallel for
for (uint64_t i = 0; i < graph_size; i++) {
uint64_t pk = 0;
for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) {
if (max_detour < num_detour) { max_detour = num_detour; /* stats */ }
// Find the `output_graph_degree` smallest detourable count nodes by checking the detourable
// count of the neighbors while increasing the target detourable count from zero.
uint64_t pk = 0;
uint32_t num_detour = 0;
while (pk < output_graph_degree) {
uint32_t next_num_detour = std::numeric_limits<uint32_t>::max();
for (uint64_t k = 0; k < input_graph_degree; k++) {
if (detour_count.data_handle()[k + (input_graph_degree * i)] != num_detour) { continue; }
const auto num_detour_k = detour_count.data_handle()[k + (input_graph_degree * i)];
// Find the detourable count to check in the next iteration
if (num_detour_k > num_detour) {
next_num_detour = std::min(static_cast<uint32_t>(num_detour_k), next_num_detour);
}

// Store the neighbor index if its detourable count is equal to `num_detour`.
if (num_detour_k != num_detour) { continue; }
output_graph_ptr[pk + (output_graph_degree * i)] =
input_graph_ptr[k + (input_graph_degree * i)];
pk += 1;
if (pk >= output_graph_degree) break;
}
if (pk >= output_graph_degree) break;

assert(next_num_detour != std::numeric_limits<uint32_t>::max());
num_detour = next_num_detour;
}
assert(pk == output_graph_degree);
RAFT_EXPECTS(pk == output_graph_degree,
"Couldn't find the output_graph_degree (%u) smallest detourable count nodes for "
"node %lu in the rank-based node reranking process",
output_graph_degree,
static_cast<uint64_t>(i));
}
// RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour);

const double time_prune_end = cur_time();
RAFT_LOG_DEBUG(
Expand Down

0 comments on commit 581c9cc

Please sign in to comment.