Skip to content

Commit

Permalink
Merge pull request #267 from rapidsai/branch-24.08
Browse files Browse the repository at this point in the history
Forward-merge branch-24.08 into branch-24.10
  • Loading branch information
GPUtester authored Jul 31, 2024
2 parents a49e0ba + e67caa5 commit 047b262
Show file tree
Hide file tree
Showing 4 changed files with 858 additions and 31 deletions.
5 changes: 5 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ struct index_params : cuvs::neighbors::index_params {
graph_build_params::ivf_pq_params,
graph_build_params::nn_descent_params>
graph_build_params;

/**
* Whether to use MST optimization to guarantee graph connectivity.
*/
bool guarantee_connectivity = false;
/**
* Whether to add the dataset content to the index, i.e.:
*
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,10 @@ template <
void optimize(
raft::resources const& res,
raft::mdspan<IdxT, raft::matrix_extent<int64_t>, raft::row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, int64_t, raft::row_major> new_graph)
raft::host_matrix_view<IdxT, int64_t, raft::row_major> new_graph,
const bool guarantee_connectivity = false)
{
detail::optimize(res, knn_graph, new_graph);
detail::optimize(res, knn_graph, new_graph, guarantee_connectivity);
}

template <typename T,
Expand Down
8 changes: 5 additions & 3 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ template <
void optimize(
raft::resources const& res,
raft::mdspan<IdxT, raft::matrix_extent<int64_t>, raft::row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, int64_t, raft::row_major> new_graph)
raft::host_matrix_view<IdxT, int64_t, raft::row_major> new_graph,
const bool guarantee_connectivity = false)
{
using internal_IdxT = typename std::make_unsigned<IdxT>::type;

Expand All @@ -400,7 +401,8 @@ void optimize(
knn_graph.extent(0),
knn_graph.extent(1));

cagra::detail::graph::optimize(res, knn_graph_internal, new_graph_internal);
cagra::detail::graph::optimize(
res, knn_graph_internal, new_graph_internal, guarantee_connectivity);
}

template <typename T,
Expand Down Expand Up @@ -476,7 +478,7 @@ index<T, IdxT> build(
auto cagra_graph = raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), graph_degree);

RAFT_LOG_INFO("optimizing graph");
optimize<IdxT>(res, knn_graph->view(), cagra_graph.view());
optimize<IdxT>(res, knn_graph->view(), cagra_graph.view(), params.guarantee_connectivity);

// free intermediate graph before trying to create the index
knn_graph.reset();
Expand Down
Loading

0 comments on commit 047b262

Please sign in to comment.