From 0327ce591a246499676427d03abbc9b42528b5a1 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Wed, 5 Jun 2024 23:44:35 +0000 Subject: [PATCH] change template param for index --- .../neighbors/detail/cagra/cagra_build.cuh | 2 +- .../raft/neighbors/detail/nn_descent.cuh | 6 +++--- cpp/include/raft/neighbors/nn_descent.cuh | 8 ++++---- .../raft/neighbors/nn_descent_types.hpp | 18 +++++++++--------- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 40dcf68e68..c558f90c84 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -240,7 +240,7 @@ void build_knn_graph(raft::resources const& res, raft::host_matrix_view knn_graph, experimental::nn_descent::index_params build_params) { - auto nn_descent_idx = experimental::nn_descent::index(res, knn_graph); + auto nn_descent_idx = experimental::nn_descent::index(res, knn_graph); experimental::nn_descent::build(res, build_params, dataset, nn_descent_idx); using internal_IdxT = typename std::make_unsigned::type; diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 333c12d303..a7aa956c51 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -1391,7 +1391,7 @@ template , row_major, Accessor> dataset, - index& idx) + index& idx) { RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, "The dataset size for GNND should be less than %d", @@ -1453,7 +1453,7 @@ template , memory_type::host>> -index build(raft::resources const& res, +index build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset) { @@ -1469,7 +1469,7 @@ index build(raft::resources const& res, graph_degree = intermediate_degree; } - index idx{ + index idx{ res, dataset.extent(0), static_cast(graph_degree), params.return_distances}; build(res, params, dataset, idx); diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index ceb5ae5643..62c03e2b19 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -56,7 +56,7 @@ namespace raft::neighbors::experimental::nn_descent { * @return index index containing all-neighbors knn graph in host memory */ template -index build(raft::resources const& res, +index build(raft::resources const& res, index_params const& params, raft::device_matrix_view dataset) { @@ -97,7 +97,7 @@ template void build(raft::resources const& res, index_params const& params, raft::device_matrix_view dataset, - index& idx) + index& idx) { detail::build(res, params, dataset, idx); } @@ -130,7 +130,7 @@ void build(raft::resources const& res, * @return index index containing all-neighbors knn graph in host memory */ template -index build(raft::resources const& res, +index build(raft::resources const& res, index_params const& params, raft::host_matrix_view dataset) { @@ -171,7 +171,7 @@ template void build(raft::resources const& res, index_params const& params, raft::host_matrix_view dataset, - index& idx) + index& idx) { detail::build(res, params, dataset, idx); } diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index efc773e275..b353d7fd02 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -30,7 +30,7 @@ #include namespace raft::neighbors::experimental::nn_descent { -using DistData_t = float; +// using DistData_t = float; /** * @ingroup nn-descent * @{ @@ -72,7 +72,7 @@ struct index_params : ann::index_params { * * @tparam IdxT dtype to be used for constructing knn-graph */ -template +template struct index : ann::index { public: /** @@ -95,7 +95,7 @@ struct index : ann::index { return_distances_(return_distances) { if (return_distances) { - distances_ = raft::make_device_matrix(res_, n_rows, n_cols); + distances_ = raft::make_device_matrix(res_, n_rows, n_cols); distances_view_ = distances_.value().view(); } } @@ -112,14 +112,14 @@ struct index : ann::index { */ index(raft::resources const& res, raft::host_matrix_view graph_view, - std::optional> distances_view = + std::optional> distances_view = std::nullopt, bool return_distances = false) : ann::index(), res_{res}, metric_{raft::distance::DistanceType::L2Expanded}, graph_{raft::make_host_matrix(0, 0)}, - distances_{raft::make_device_matrix(res_, 0, 0)}, + distances_{raft::make_device_matrix(res_, 0, 0)}, graph_view_{graph_view}, distances_view_(distances_view), return_distances_(return_distances) @@ -153,12 +153,12 @@ struct index : ann::index { return graph_view_; } - [[nodiscard]] inline auto distances() noexcept -> device_matrix_view + [[nodiscard]] inline auto distances() noexcept -> device_matrix_view { if (distances_view_.has_value()) { return distances_view_.value(); } else { - return raft::make_device_matrix(res_, 0, 0).view(); + return raft::make_device_matrix(res_, 0, 0).view(); } } @@ -173,10 +173,10 @@ struct index : ann::index { raft::resources const& res_; raft::distance::DistanceType metric_; raft::host_matrix graph_; // graph to return for non-int IdxT - std::optional> distances_; + std::optional> distances_; raft::host_matrix_view graph_view_; // view of graph for user provided matrix - std::optional> distances_view_; + std::optional> distances_view_; bool return_distances_; };