From 6672be47ccd05579e6db7b862e686e9f163fe3af Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Mon, 17 Jun 2024 19:27:48 +0900 Subject: [PATCH] Fix a CAGRA graph opt bug --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 2e90eed64..acc0043d8 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -430,24 +430,39 @@ void optimize( const auto num_full = host_stats.data_handle()[1]; // Create pruned kNN graph - uint32_t max_detour = 0; -#pragma omp parallel for reduction(max : max_detour) +#pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { - uint64_t pk = 0; - for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) { - if (max_detour < num_detour) { max_detour = num_detour; /* stats */ } + // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable + // count of the neighbors while increasing the target detourable count from zero. + uint64_t pk = 0; + uint32_t num_detour = 0; + while (pk < output_graph_degree) { + uint32_t next_num_detour = std::numeric_limits::max(); for (uint64_t k = 0; k < input_graph_degree; k++) { - if (detour_count.data_handle()[k + (input_graph_degree * i)] != num_detour) { continue; } + const auto num_detour_k = detour_count.data_handle()[k + (input_graph_degree * i)]; + // Find the detourable count to check in the next iteration + if (num_detour_k > num_detour) { + next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); + } + + // Store the neighbor index if its detourable count is equal to `num_detour`. + if (num_detour_k != num_detour) { continue; } output_graph_ptr[pk + (output_graph_degree * i)] = input_graph_ptr[k + (input_graph_degree * i)]; pk += 1; if (pk >= output_graph_degree) break; } if (pk >= output_graph_degree) break; + + assert(next_num_detour != std::numeric_limits::max()); + num_detour = next_num_detour; } - assert(pk == output_graph_degree); + RAFT_EXPECTS(pk == output_graph_degree, + "Couldn't find the output_graph_degree (%u) smallest detourable count nodes for " + "node %lu in the rank-based node reranking process", + output_graph_degree, + static_cast(i)); } - // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); const double time_prune_end = cur_time(); RAFT_LOG_DEBUG(