Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename the CAGRA prune function to optimize #1588

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ namespace raft::neighbors::experimental::cagra {
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
* auto pruned_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* cagra::prune(res, dataset, knn_graph.view(), pruned_graph.view());
* // Construct an index from dataset and pruned knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset, pruned_graph.view());
* auto optimized_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset,
* optimized_graph.view());
* @endcode
*
* @tparam T data element type
Expand Down Expand Up @@ -96,10 +97,10 @@ void build_knn_graph(raft::resources const& res,

/**
* @brief Sort a KNN graph index.
* Preprocessing step for `cagra::prune`: If a KNN graph is not built using
* Preprocessing step for `cagra::optimize`: If a KNN graph is not built using
* `cagra::build_knn_graph`, then it is necessary to call this function before calling
* `cagra::prune`. If the graph is built by `cagra::build_knn_graph`, it is already sorted and you
* do not need to call this function.
* `cagra::optimize`. If the graph is built by `cagra::build_knn_graph`, it is already sorted and
* you do not need to call this function.
*
* Usage example:
* @code{.cpp}
Expand All @@ -110,10 +111,11 @@ void build_knn_graph(raft::resources const& res,
* // build(knn_graph, dataset, ...);
* // sort graph index
* sort_knn_graph(res, dataset.view(), knn_graph.view());
* // prune graph
* cagra::prune(res, dataset, knn_graph.view(), pruned_graph.view());
* // Construct an index from dataset and pruned knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset, pruned_graph.view());
* // optimize graph
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset,
* optimized_graph.view());
* @endcode
*
* @tparam DataT type of the data in the source dataset
Expand Down Expand Up @@ -162,14 +164,14 @@ void sort_knn_graph(raft::resources const& res,
* @param[in] res raft resources
* @param[in] knn_graph a matrix view (host or device) of the input knn graph [n_rows,
* knn_graph_degree]
* @param[out] new_graph a host matrix view of the pruned knn graph [n_rows, graph_degree]
* @param[out] new_graph a host matrix view of the optimized knn graph [n_rows, graph_degree]
*/
template <typename IdxT = uint32_t,
typename g_accessor =
host_device_accessor<std::experimental::default_accessor<IdxT>, memory_type::host>>
void prune(raft::resources const& res,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
void optimize(raft::resources const& res,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
{
using internal_IdxT = typename std::make_unsigned<IdxT>::type;

Expand All @@ -186,21 +188,21 @@ void prune(raft::resources const& res,
knn_graph.extent(0),
knn_graph.extent(1));

detail::graph::prune(res, knn_graph_internal, new_graph_internal);
detail::graph::optimize(res, knn_graph_internal, new_graph_internal);
}

/**
* @brief Build the index from the dataset for efficient search.
*
* The build consist of two steps: build an intermediate knn-graph, and prune it to
* The build consist of two steps: build an intermediate knn-graph, and optimize it to
* create the final graph. The index_params struct controls the node degree of these
* graphs.
*
* It is required that dataset and the pruned graph fit the GPU memory.
* It is required that dataset and the optimized graph fit the GPU memory.
*
* To customize the parameters for knn-graph building and pruning, and to reuse the
* intermediate results, you could build the index in two steps using
* [cagra::build_knn_graph](#cagra::build_knn_graph) and [cagra::prune](#cagra::prune).
* [cagra::build_knn_graph](#cagra::build_knn_graph) and [cagra::optimize](#cagra::optimize).
*
* The following distance metrics are supported:
* - L2
Expand Down Expand Up @@ -260,9 +262,9 @@ index<T, IdxT> build(raft::resources const& res,

auto cagra_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), graph_degree);

prune<IdxT>(res, knn_graph.view(), cagra_graph.view());
optimize<IdxT>(res, knn_graph.view(), cagra_graph.view());

// Construct an index from dataset and pruned knn graph.
// Construct an index from dataset and optimized knn graph.
return index<T, IdxT>(res, params.metric, dataset, cagra_graph.view());
}

Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ void sort_knn_graph(raft::resources const& res,
template <typename IdxT = uint32_t,
typename g_accessor =
host_device_accessor<std::experimental::default_accessor<IdxT>, memory_type::host>>
void prune(raft::resources const& res,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
void optimize(raft::resources const& res,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
{
RAFT_LOG_DEBUG(
"# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1));
Expand Down