Skip to content

Commit

Permalink
Rename the CAGRA prune function to optimize (#1588)
Browse files Browse the repository at this point in the history
This PR renames the `cagra::prune` function to `cagra::optimize` since it adds reverse edges other than pruning unimportant edges.

Authors:
  - tsuki (https://github.com/enp1s0)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1588
  • Loading branch information
enp1s0 authored Jun 22, 2023
1 parent 6f0abae commit 28b61c4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
44 changes: 23 additions & 21 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ namespace raft::neighbors::experimental::cagra {
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
* auto pruned_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* cagra::prune(res, dataset, knn_graph.view(), pruned_graph.view());
* // Construct an index from dataset and pruned knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset, pruned_graph.view());
* auto optimized_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset,
* optimized_graph.view());
* @endcode
*
* @tparam T data element type
Expand Down Expand Up @@ -96,10 +97,10 @@ void build_knn_graph(raft::resources const& res,

/**
* @brief Sort a KNN graph index.
* Preprocessing step for `cagra::prune`: If a KNN graph is not built using
* Preprocessing step for `cagra::optimize`: If a KNN graph is not built using
* `cagra::build_knn_graph`, then it is necessary to call this function before calling
* `cagra::prune`. If the graph is built by `cagra::build_knn_graph`, it is already sorted and you
* do not need to call this function.
* `cagra::optimize`. If the graph is built by `cagra::build_knn_graph`, it is already sorted and
* you do not need to call this function.
*
* Usage example:
* @code{.cpp}
Expand All @@ -110,10 +111,11 @@ void build_knn_graph(raft::resources const& res,
* // build(knn_graph, dataset, ...);
* // sort graph index
* sort_knn_graph(res, dataset.view(), knn_graph.view());
* // prune graph
* cagra::prune(res, dataset, knn_graph.view(), pruned_graph.view());
* // Construct an index from dataset and pruned knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset, pruned_graph.view());
* // optimize graph
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset,
* optimized_graph.view());
* @endcode
*
* @tparam DataT type of the data in the source dataset
Expand Down Expand Up @@ -162,14 +164,14 @@ void sort_knn_graph(raft::resources const& res,
* @param[in] res raft resources
* @param[in] knn_graph a matrix view (host or device) of the input knn graph [n_rows,
* knn_graph_degree]
* @param[out] new_graph a host matrix view of the pruned knn graph [n_rows, graph_degree]
* @param[out] new_graph a host matrix view of the optimized knn graph [n_rows, graph_degree]
*/
template <typename IdxT = uint32_t,
typename g_accessor =
host_device_accessor<std::experimental::default_accessor<IdxT>, memory_type::host>>
void prune(raft::resources const& res,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
void optimize(raft::resources const& res,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
{
using internal_IdxT = typename std::make_unsigned<IdxT>::type;

Expand All @@ -186,21 +188,21 @@ void prune(raft::resources const& res,
knn_graph.extent(0),
knn_graph.extent(1));

detail::graph::prune(res, knn_graph_internal, new_graph_internal);
detail::graph::optimize(res, knn_graph_internal, new_graph_internal);
}

/**
* @brief Build the index from the dataset for efficient search.
*
* The build consist of two steps: build an intermediate knn-graph, and prune it to
* The build consist of two steps: build an intermediate knn-graph, and optimize it to
* create the final graph. The index_params struct controls the node degree of these
* graphs.
*
* It is required that dataset and the pruned graph fit the GPU memory.
* It is required that dataset and the optimized graph fit the GPU memory.
*
* To customize the parameters for knn-graph building and pruning, and to reuse the
* intermediate results, you could build the index in two steps using
* [cagra::build_knn_graph](#cagra::build_knn_graph) and [cagra::prune](#cagra::prune).
* [cagra::build_knn_graph](#cagra::build_knn_graph) and [cagra::optimize](#cagra::optimize).
*
* The following distance metrics are supported:
* - L2
Expand Down Expand Up @@ -260,9 +262,9 @@ index<T, IdxT> build(raft::resources const& res,

auto cagra_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), graph_degree);

prune<IdxT>(res, knn_graph.view(), cagra_graph.view());
optimize<IdxT>(res, knn_graph.view(), cagra_graph.view());

// Construct an index from dataset and pruned knn graph.
// Construct an index from dataset and optimized knn graph.
return index<T, IdxT>(res, params.metric, dataset, cagra_graph.view());
}

Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ void sort_knn_graph(raft::resources const& res,
template <typename IdxT = uint32_t,
typename g_accessor =
host_device_accessor<std::experimental::default_accessor<IdxT>, memory_type::host>>
void prune(raft::resources const& res,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
void optimize(raft::resources const& res,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
{
RAFT_LOG_DEBUG(
"# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1));
Expand Down

0 comments on commit 28b61c4

Please sign in to comment.