From 3f49752e2937a0d7a4c98ddc7a9b95c6417245b1 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Thu, 6 Jun 2024 01:55:22 +0000 Subject: [PATCH] update test --- cpp/test/neighbors/ann_nn_descent.cuh | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh index edf0a890ad..dc8060c4d4 100644 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -17,6 +17,7 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" +#include "raft/util/cudart_utils.hpp" #include #include @@ -94,6 +95,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam { index_params.graph_degree = ps.graph_degree; index_params.intermediate_graph_degree = 2 * ps.graph_degree; index_params.max_iterations = 100; + index_params.return_distances = true; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); @@ -105,16 +107,12 @@ class AnnNNDescentTest : public ::testing::TestWithParam { auto database_host_view = raft::make_host_matrix_view( (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); auto index = nn_descent::build(handle_, index_params, database_host_view); - update_host( - indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); - update_host( - distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_); + raft::copy(indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + raft::copy(distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_); } else { auto index = nn_descent::build(handle_, index_params, database_view); - update_host( - indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); - update_host( - distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_); + raft::copy(indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + raft::copy(distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_); }; } resource::sync_stream(handle_);