Skip to content

Commit

Permalink
change template param for index
Browse files Browse the repository at this point in the history
  • Loading branch information
jinsolp committed Jun 5, 2024
1 parent c71e97a commit 0327ce5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ void build_knn_graph(raft::resources const& res,
raft::host_matrix_view<IdxT, int64_t, row_major> knn_graph,
experimental::nn_descent::index_params build_params)
{
auto nn_descent_idx = experimental::nn_descent::index<IdxT>(res, knn_graph);
auto nn_descent_idx = experimental::nn_descent::index<float, IdxT>(res, knn_graph);
experimental::nn_descent::build<DataT, IdxT>(res, build_params, dataset, nn_descent_idx);

using internal_IdxT = typename std::make_unsigned<IdxT>::type;
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ template <typename T,
void build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
index<IdxT>& idx)
index<DistData_t, IdxT>& idx)
{
RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits<int>::max() - 1,
"The dataset size for GNND should be less than %d",
Expand Down Expand Up @@ -1453,7 +1453,7 @@ template <typename T,
typename IdxT = uint32_t,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
index<IdxT> build(raft::resources const& res,
index<DistData_t, IdxT> build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset)
{
Expand All @@ -1469,7 +1469,7 @@ index<IdxT> build(raft::resources const& res,
graph_degree = intermediate_degree;
}
index<IdxT> idx{
index<DistData_t, IdxT> idx{
res, dataset.extent(0), static_cast<int64_t>(graph_degree), params.return_distances};
build(res, params, dataset, idx);
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/neighbors/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace raft::neighbors::experimental::nn_descent {
* @return index<IdxT> index containing all-neighbors knn graph in host memory
*/
template <typename T, typename IdxT = uint32_t>
index<IdxT> build(raft::resources const& res,
index<detail::DistData_t, IdxT> build(raft::resources const& res,
index_params const& params,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
Expand Down Expand Up @@ -97,7 +97,7 @@ template <typename T, typename IdxT = uint32_t>
void build(raft::resources const& res,
index_params const& params,
raft::device_matrix_view<const T, int64_t, row_major> dataset,
index<IdxT>& idx)
index<detail::DistData_t, IdxT>& idx)
{
detail::build<T, IdxT>(res, params, dataset, idx);
}
Expand Down Expand Up @@ -130,7 +130,7 @@ void build(raft::resources const& res,
* @return index<IdxT> index containing all-neighbors knn graph in host memory
*/
template <typename T, typename IdxT = uint32_t>
index<IdxT> build(raft::resources const& res,
index<detail::DistData_t, IdxT> build(raft::resources const& res,
index_params const& params,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
Expand Down Expand Up @@ -171,7 +171,7 @@ template <typename T, typename IdxT = uint32_t>
void build(raft::resources const& res,
index_params const& params,
raft::host_matrix_view<const T, int64_t, row_major> dataset,
index<IdxT>& idx)
index<detail::DistData_t, IdxT>& idx)
{
detail::build<T, IdxT>(res, params, dataset, idx);
}
Expand Down
18 changes: 9 additions & 9 deletions cpp/include/raft/neighbors/nn_descent_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <optional>

namespace raft::neighbors::experimental::nn_descent {
using DistData_t = float;
// using DistData_t = float;
/**
* @ingroup nn-descent
* @{
Expand Down Expand Up @@ -72,7 +72,7 @@ struct index_params : ann::index_params {
*
* @tparam IdxT dtype to be used for constructing knn-graph
*/
template <typename IdxT>
template <typename T, typename IdxT>
struct index : ann::index {
public:
/**
Expand All @@ -95,7 +95,7 @@ struct index : ann::index {
return_distances_(return_distances)
{
if (return_distances) {
distances_ = raft::make_device_matrix<DistData_t, int64_t>(res_, n_rows, n_cols);
distances_ = raft::make_device_matrix<T, int64_t>(res_, n_rows, n_cols);
distances_view_ = distances_.value().view();
}
}
Expand All @@ -112,14 +112,14 @@ struct index : ann::index {
*/
index(raft::resources const& res,
raft::host_matrix_view<IdxT, int64_t, raft::row_major> graph_view,
std::optional<raft::device_matrix_view<DistData_t, int64_t, row_major>> distances_view =
std::optional<raft::device_matrix_view<T, int64_t, row_major>> distances_view =
std::nullopt,
bool return_distances = false)
: ann::index(),
res_{res},
metric_{raft::distance::DistanceType::L2Expanded},
graph_{raft::make_host_matrix<IdxT, int64_t, row_major>(0, 0)},
distances_{raft::make_device_matrix<DistData_t, int64_t>(res_, 0, 0)},
distances_{raft::make_device_matrix<T, int64_t>(res_, 0, 0)},
graph_view_{graph_view},
distances_view_(distances_view),
return_distances_(return_distances)
Expand Down Expand Up @@ -153,12 +153,12 @@ struct index : ann::index {
return graph_view_;
}

[[nodiscard]] inline auto distances() noexcept -> device_matrix_view<DistData_t, int64_t, row_major>
[[nodiscard]] inline auto distances() noexcept -> device_matrix_view<T, int64_t, row_major>
{
if (distances_view_.has_value()) {
return distances_view_.value();
} else {
return raft::make_device_matrix<DistData_t, int64_t>(res_, 0, 0).view();
return raft::make_device_matrix<T, int64_t>(res_, 0, 0).view();
}
}

Expand All @@ -173,10 +173,10 @@ struct index : ann::index {
raft::resources const& res_;
raft::distance::DistanceType metric_;
raft::host_matrix<IdxT, int64_t, row_major> graph_; // graph to return for non-int IdxT
std::optional<raft::device_matrix<DistData_t, int64_t, row_major>> distances_;
std::optional<raft::device_matrix<T, int64_t, row_major>> distances_;
raft::host_matrix_view<IdxT, int64_t, row_major>
graph_view_; // view of graph for user provided matrix
std::optional<raft::device_matrix_view<DistData_t, int64_t, row_major>> distances_view_;
std::optional<raft::device_matrix_view<T, int64_t, row_major>> distances_view_;
bool return_distances_;
};

Expand Down

0 comments on commit 0327ce5

Please sign in to comment.