From 8d4d1a21eae37f8f54f5777a78845723bd5e07e4 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 2 Oct 2024 23:00:56 -0700 Subject: [PATCH] add more syncs, use thrust_policy --- cpp/src/neighbors/detail/nn_descent.cuh | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index d416bc686..6a2e5b6a2 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include #include // raft::util::arch::SM_* @@ -1162,15 +1163,19 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build { static_assert(NUM_SAMPLES <= 32); - thrust::fill(thrust::device, + thrust::fill(raft::resource::get_thrust_policy(res), dists_buffer_.data_handle(), dists_buffer_.data_handle() + dists_buffer_.size(), std::numeric_limits::max()); - thrust::fill(thrust::device, + thrust::fill(raft::resource::get_thrust_policy(res), reinterpret_cast(graph_buffer_.data_handle()), reinterpret_cast(graph_buffer_.data_handle()) + graph_buffer_.size(), std::numeric_limits::max()); - thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); + thrust::fill(raft::resource::get_thrust_policy(res), + d_locks_.data_handle(), + d_locks_.data_handle() + d_locks_.size(), + 0); + raft::resource::sync_stream(res); }; template @@ -1190,7 +1195,7 @@ void GNND::add_reverse_edges(Index_t* graph_ptr, template void GNND::local_join(cudaStream_t stream) { - thrust::fill(thrust::device.on(stream), + thrust::fill(raft::resource::get_thrust_policy(res), dists_buffer_.data_handle(), dists_buffer_.data_handle() + dists_buffer_.size(), std::numeric_limits::max()); @@ -1209,6 +1214,7 @@ void GNND::local_join(cudaStream_t stream) DEGREE_ON_DEVICE, d_locks_.data_handle(), l2_norms_.data_handle()); + raft::resource::sync_stream(res); } template @@ -1240,10 +1246,11 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out batch.offset()); } - thrust::fill(thrust::device.on(stream), + thrust::fill(raft::resource::get_thrust_policy(res), (Index_t*)graph_buffer_.data_handle(), (Index_t*)graph_buffer_.data_handle() + graph_buffer_.size(), std::numeric_limits::max()); + raft::resource::sync_stream(res); graph_.clear(); graph_.init_random_graph(); @@ -1330,6 +1337,7 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out graph_.sample_graph_new(thrust::raw_pointer_cast(graph_host_buffer_.data()), DEGREE_ON_DEVICE); } + raft::resource::sync_stream(res); graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), thrust::raw_pointer_cast(dists_host_buffer_.data()), DEGREE_ON_DEVICE, @@ -1415,6 +1423,7 @@ void build(raft::resources const& res, GNND nnd(res, build_config); nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); + raft::resource::sync_stream(res); #pragma omp parallel for for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) {