diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index 496c65d91f..856ecc96d7 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -121,7 +121,7 @@ class device_resources : public resources { cusparseHandle_t get_cusparse_handle() const { return resource::get_cusparse_handle(*this); } - rmm::exec_policy& get_thrust_policy() const { return resource::get_thrust_policy(*this); } + rmm::exec_policy_nosync& get_thrust_policy() const { return resource::get_thrust_policy(*this); } /** * @brief synchronize a stream on the current container diff --git a/cpp/include/raft/core/resource/thrust_policy.hpp b/cpp/include/raft/core/resource/thrust_policy.hpp index f81898be8a..c728f0a00e 100644 --- a/cpp/include/raft/core/resource/thrust_policy.hpp +++ b/cpp/include/raft/core/resource/thrust_policy.hpp @@ -24,7 +24,7 @@ namespace raft::resource { class thrust_policy_resource : public resource { public: thrust_policy_resource(rmm::cuda_stream_view stream_view) - : thrust_policy_(std::make_unique(stream_view)) + : thrust_policy_(std::make_unique(stream_view)) { } void* get_resource() override { return thrust_policy_.get(); } @@ -32,7 +32,7 @@ class thrust_policy_resource : public resource { ~thrust_policy_resource() override {} private: - std::unique_ptr thrust_policy_; + std::unique_ptr thrust_policy_; }; /** @@ -60,13 +60,13 @@ class thrust_policy_resource_factory : public resource_factory { * @param res raft res object for managing resources * @return thrust execution policy */ -inline rmm::exec_policy& get_thrust_policy(resources const& res) +inline rmm::exec_policy_nosync& get_thrust_policy(resources const& res) { if (!res.has_resource_factory(resource_type::THRUST_POLICY)) { rmm::cuda_stream_view stream = get_cuda_stream(res); res.add_resource_factory(std::make_shared(stream)); } - return *res.get_resource(resource_type::THRUST_POLICY); + return *res.get_resource(resource_type::THRUST_POLICY); }; /** diff --git a/cpp/include/raft/spectral/detail/matrix_wrappers.hpp b/cpp/include/raft/spectral/detail/matrix_wrappers.hpp index 30dd6e5e69..1fe078bd32 100644 --- a/cpp/include/raft/spectral/detail/matrix_wrappers.hpp +++ b/cpp/include/raft/spectral/detail/matrix_wrappers.hpp @@ -129,7 +129,7 @@ class vector_t { private: using thrust_exec_policy_t = thrust::detail::execute_with_allocator, - thrust::cuda_cub::execute_on_stream_base>; + thrust::cuda_cub::execute_on_stream_nosync_base>; rmm::device_uvector buffer_; const thrust_exec_policy_t thrust_policy; };