Skip to content

Commit

Permalink
Fix cagra graph opt bug (#2365)
Browse files Browse the repository at this point in the history
backport of rapidsai/cuvs#192

There is a bug in the current CAGRA graph rank-based neighbor reordering process. A low recall or illegal memory access can occur if there are many detourable nodes from a node to its neighbors, e.g. there is a small subgraph in the initial kNN graph. This PR fixes this.

Authors:
  - tsuki (https://github.com/enp1s0)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #2365
  • Loading branch information
enp1s0 authored Jun 19, 2024
1 parent 8fe2983 commit b86a5f9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ If citing the k-selection routines, please consider the following bibtex:
isbn = {9798400701092},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
location = {Denver, CO, USA}
location = {Denver, CO, USA},
series = {SC '23}
}
```
Expand Down
31 changes: 23 additions & 8 deletions cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -423,24 +423,39 @@ void optimize(raft::resources const& res,
const auto num_full = host_stats.data_handle()[1];

// Create pruned kNN graph
uint32_t max_detour = 0;
#pragma omp parallel for reduction(max : max_detour)
#pragma omp parallel for
for (uint64_t i = 0; i < graph_size; i++) {
uint64_t pk = 0;
for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) {
if (max_detour < num_detour) { max_detour = num_detour; /* stats */ }
// Find the `output_graph_degree` smallest detourable count nodes by checking the detourable
// count of the neighbors while increasing the target detourable count from zero.
uint64_t pk = 0;
uint32_t num_detour = 0;
while (pk < output_graph_degree) {
uint32_t next_num_detour = std::numeric_limits<uint32_t>::max();
for (uint64_t k = 0; k < input_graph_degree; k++) {
if (detour_count.data_handle()[k + (input_graph_degree * i)] != num_detour) { continue; }
const auto num_detour_k = detour_count.data_handle()[k + (input_graph_degree * i)];
// Find the detourable count to check in the next iteration
if (num_detour_k > num_detour) {
next_num_detour = std::min(static_cast<uint32_t>(num_detour_k), next_num_detour);
}

// Store the neighbor index if its detourable count is equal to `num_detour`.
if (num_detour_k != num_detour) { continue; }
output_graph_ptr[pk + (output_graph_degree * i)] =
input_graph_ptr[k + (input_graph_degree * i)];
pk += 1;
if (pk >= output_graph_degree) break;
}
if (pk >= output_graph_degree) break;

assert(next_num_detour != std::numeric_limits<uint32_t>::max());
num_detour = next_num_detour;
}
assert(pk == output_graph_degree);
RAFT_EXPECTS(pk == output_graph_degree,
"Couldn't find the output_graph_degree (%u) smallest detourable count nodes for "
"node %lu in the rank-based node reranking process",
output_graph_degree,
static_cast<uint64_t>(i));
}
// RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour);

const double time_prune_end = cur_time();
RAFT_LOG_DEBUG(
Expand Down

0 comments on commit b86a5f9

Please sign in to comment.