diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index cd3c6f3947..cdfb9d9931 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -45,6 +45,7 @@ #include #include +#include #include #include @@ -217,6 +218,7 @@ struct BuildConfig { // If internal_node_degree == 0, the value of node_degree will be assigned to it size_t max_iterations{50}; float termination_threshold{0.0001}; + size_t output_graph_degree{32}; }; template @@ -345,7 +347,11 @@ class GNND { GNND(const GNND&) = delete; GNND& operator=(const GNND&) = delete; - void build(Data_t* data, const Index_t nrow, Index_t* output_graph); + void build(Data_t* data, + const Index_t nrow, + Index_t* output_graph, + bool return_distances, + DistData_t* output_distances); ~GNND() = default; using ID_t = InternalID_t; @@ -1212,7 +1218,11 @@ void GNND::local_join(cudaStream_t stream) } template -void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) +void GNND::build(Data_t* data, + const Index_t nrow, + Index_t* output_graph, + bool return_distances, + DistData_t* output_distances) { using input_t = typename std::remove_const::type; @@ -1338,6 +1348,16 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out // Reuse graph_.h_dists as the buffer for shrink the lists in graph static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t)); + + if (return_distances) { + for (size_t i = 0; i < (size_t)nrow_; i++) { + raft::copy(output_distances + i * build_config_.output_graph_degree, + graph_.h_dists.data_handle() + i * build_config_.node_degree, + build_config_.output_graph_degree, + raft::resource::get_cuda_stream(res)); + } + } + Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); #pragma omp parallel for @@ -1410,10 +1430,24 @@ void build(raft::resources const& res, .node_degree = extended_graph_degree, .internal_node_degree = extended_intermediate_degree, .max_iterations = params.max_iterations, - .termination_threshold = params.termination_threshold}; + .termination_threshold = params.termination_threshold, + .output_graph_degree = params.graph_degree}; GNND nnd(res, build_config); - nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); + + if (idx.distances().has_value() || !params.return_distances) { + nnd.build(dataset.data_handle(), + dataset.extent(0), + int_graph.data_handle(), + params.return_distances, + idx.distances() + .value_or(raft::make_device_matrix(res, 0, 0).view()) + .data_handle()); + } else { + RAFT_EXPECTS(!params.return_distances, + "Distance view not allocated. Using return_distances set to true requires " + "distance view to be allocated."); + } #pragma omp parallel for for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { @@ -1444,7 +1478,8 @@ index build(raft::resources const& res, graph_degree = intermediate_degree; } - index idx{res, dataset.extent(0), static_cast(graph_degree)}; + index idx{ + res, dataset.extent(0), static_cast(graph_degree), params.return_distances}; build(res, params, dataset, idx); diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index e1fc96878a..5d23ff2c2e 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -18,6 +18,8 @@ #include "ann_types.hpp" +#include +#include #include #include #include @@ -25,6 +27,8 @@ #include #include +#include + namespace raft::neighbors::experimental::nn_descent { /** * @ingroup nn-descent @@ -51,6 +55,7 @@ struct index_params : ann::index_params { size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. size_t max_iterations = 20; // Number of nn-descent iterations. float termination_threshold = 0.0001; // Termination threshold of nn-descent. + bool return_distances = false; // return distances if true }; /** @@ -79,14 +84,20 @@ struct index : ann::index { * @param res raft::resources is an object mangaging resources * @param n_rows number of rows in knn-graph * @param n_cols number of cols in knn-graph + * @param return_distances whether to allocate and get distances information */ - index(raft::resources const& res, int64_t n_rows, int64_t n_cols) + index(raft::resources const& res, int64_t n_rows, int64_t n_cols, bool return_distances = false) : ann::index(), res_{res}, metric_{raft::distance::DistanceType::L2Expanded}, graph_{raft::make_host_matrix(n_rows, n_cols)}, - graph_view_{graph_.view()} + graph_view_{graph_.view()}, + return_distances_(return_distances) { + if (return_distances) { + distances_ = raft::make_device_matrix(res_, n_rows, n_cols); + distances_view_ = distances_.value().view(); + } } /** @@ -98,14 +109,23 @@ struct index : ann::index { * * @param res raft::resources is an object mangaging resources * @param graph_view raft::host_matrix_view for storing knn-graph + * @param distances_view std::optional> for + * storing knn-graph distances + * @param return_distances whether to allocate and get distances information */ index(raft::resources const& res, - raft::host_matrix_view graph_view) + raft::host_matrix_view graph_view, + std::optional> distances_view = + std::nullopt, + bool return_distances = false) : ann::index(), res_{res}, metric_{raft::distance::DistanceType::L2Expanded}, graph_{raft::make_host_matrix(0, 0)}, - graph_view_{graph_view} + distances_{raft::make_device_matrix(res_, 0, 0)}, + graph_view_{graph_view}, + distances_view_(distances_view), + return_distances_(return_distances) { } @@ -133,6 +153,13 @@ struct index : ann::index { return graph_view_; } + /** neighborhood graph distances [size, graph-degree] */ + [[nodiscard]] inline auto distances() noexcept + -> std::optional> + { + return distances_view_; + } + // Don't allow copying the index for performance reasons (try avoiding copying data) index(const index&) = delete; index(index&&) = default; @@ -144,8 +171,11 @@ struct index : ann::index { raft::resources const& res_; raft::distance::DistanceType metric_; raft::host_matrix graph_; // graph to return for non-int IdxT + std::optional> distances_; raft::host_matrix_view graph_view_; // view of graph for user provided matrix + std::optional> distances_view_; + bool return_distances_; }; /** @} */ diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh index 495af081f1..f74cadb415 100644 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -15,11 +15,11 @@ */ #pragma once -#include "../test_utils.cuh" #include "ann_utils.cuh" #include #include +#include #include #include @@ -65,7 +65,9 @@ class AnnNNDescentTest : public ::testing::TestWithParam { { size_t queries_size = ps.n_rows * ps.graph_degree; std::vector indices_NNDescent(queries_size); + std::vector distances_NNDescent(queries_size); std::vector indices_naive(queries_size); + std::vector distances_naive(queries_size); { rmm::device_uvector distances_naive_dev(queries_size, stream_); @@ -81,6 +83,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam { ps.graph_degree, ps.metric); update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); resource::sync_stream(handle_); } @@ -91,6 +94,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam { index_params.graph_degree = ps.graph_degree; index_params.intermediate_graph_degree = 2 * ps.graph_degree; index_params.max_iterations = 100; + index_params.return_distances = true; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); @@ -102,20 +106,39 @@ class AnnNNDescentTest : public ::testing::TestWithParam { auto database_host_view = raft::make_host_matrix_view( (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); auto index = nn_descent::build(handle_, index_params, database_host_view); - update_host( + raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } + } else { auto index = nn_descent::build(handle_, index_params, database_view); - update_host( + raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } }; } resource::sync_stream(handle_); } double min_recall = ps.min_recall; - EXPECT_TRUE(eval_recall( - indices_naive, indices_NNDescent, ps.n_rows, ps.graph_degree, 0.001, min_recall)); + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_NNDescent, + distances_naive, + distances_NNDescent, + ps.n_rows, + ps.graph_degree, + 0.001, + min_recall)); } }