diff --git a/cpp/include/raft/handle.hpp b/cpp/include/raft/handle.hpp index 1b1923abb7..611868f55f 100644 --- a/cpp/include/raft/handle.hpp +++ b/cpp/include/raft/handle.hpp @@ -140,14 +140,10 @@ class handle_t { return cusparse_handle_; } - thrust_exec_policy_t get_thrust_policy() const { + rmm::exec_policy get_thrust_policy() const { std::lock_guard _(mutex_); if (!thrust_policy_initialized_) { - if (!thrust_policy_) { - thrust_policy_ = - (thrust_exec_policy_t*)malloc(sizeof(thrust_exec_policy_t)); - } - *thrust_policy_ = rmm::exec_policy(this->get_stream()); + thrust_policy_ = new rmm::exec_policy(get_stream()); thrust_policy_initialized_ = true; } return *thrust_policy_; @@ -240,7 +236,7 @@ class handle_t { mutable bool cusolver_sp_initialized_{false}; mutable cusparseHandle_t cusparse_handle_; mutable bool cusparse_initialized_{false}; - mutable thrust_exec_policy_t* thrust_policy_{nullptr}; + mutable rmm::exec_policy* thrust_policy_{nullptr}; mutable bool thrust_policy_initialized_{false}; cudaStream_t user_stream_{nullptr}; cudaEvent_t event_;