From 565cf68404670ec65a5db26237dff6d03447315e Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 1 Jul 2021 11:52:45 +0200 Subject: [PATCH] RAFT handle update --- cpp/include/raft/handle.hpp | 52 +++++++++++++++++++---------------- python/raft/common/handle.pxd | 2 -- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/cpp/include/raft/handle.hpp b/cpp/include/raft/handle.hpp index dbe7e83189..1b1923abb7 100644 --- a/cpp/include/raft/handle.hpp +++ b/cpp/include/raft/handle.hpp @@ -36,9 +36,8 @@ #include #include #include -#include -#include #include +#include #include "cudart_utils.h" namespace raft { @@ -48,6 +47,9 @@ namespace raft { * necessary cuda kernels and/or libraries */ class handle_t { + using thrust_exec_policy_t = thrust::detail::execute_with_allocator< + rmm::mr::thrust_allocator, thrust::cuda_cub::execute_on_stream_base>; + private: static constexpr int kNumDefaultWorkerStreams = 0; @@ -63,9 +65,7 @@ class handle_t { CUDA_CHECK(cudaGetDevice(&cur_dev)); return cur_dev; }()), - streams_(n_streams), - device_allocator_(std::make_shared()), - host_allocator_(std::make_shared()) { + streams_(n_streams) { create_resources(); } @@ -86,8 +86,6 @@ class handle_t { "ERROR: the main handle must have at least one worker stream\n"); prop_ = other.get_device_properties(); device_prop_initialized_ = true; - device_allocator_ = other.get_device_allocator(); - host_allocator_ = other.get_host_allocator(); create_resources(); set_stream(other.get_internal_stream(stream_id)); } @@ -97,26 +95,15 @@ class handle_t { int get_device() const { return dev_id_; } - void set_stream(cudaStream_t stream) { user_stream_ = stream; } + void set_stream(cudaStream_t stream) { + thrust_policy_initialized_ = false; + user_stream_ = stream; + } cudaStream_t get_stream() const { return user_stream_; } rmm::cuda_stream_view get_stream_view() const { return rmm::cuda_stream_view(user_stream_); } - void set_device_allocator(std::shared_ptr allocator) { - device_allocator_ = allocator; - } - std::shared_ptr get_device_allocator() const { - return device_allocator_; - } - - void set_host_allocator(std::shared_ptr allocator) { - host_allocator_ = allocator; - } - std::shared_ptr get_host_allocator() const { - return host_allocator_; - } - cublasHandle_t get_cublas_handle() const { std::lock_guard _(mutex_); if (!cublas_initialized_) { @@ -153,6 +140,23 @@ class handle_t { return cusparse_handle_; } + thrust_exec_policy_t get_thrust_policy() const { + std::lock_guard _(mutex_); + if (!thrust_policy_initialized_) { + if (!thrust_policy_) { + thrust_policy_ = + (thrust_exec_policy_t*)malloc(sizeof(thrust_exec_policy_t)); + } + *thrust_policy_ = rmm::exec_policy(this->get_stream()); + thrust_policy_initialized_ = true; + } + return *thrust_policy_; + } + + thrust_exec_policy_t get_thrust_policy(cudaStream_t stream) const { + return rmm::exec_policy(stream); + } + // legacy compatibility for cuML cudaStream_t get_internal_stream(int sid) const { return streams_.get_stream(sid).value(); @@ -236,8 +240,8 @@ class handle_t { mutable bool cusolver_sp_initialized_{false}; mutable cusparseHandle_t cusparse_handle_; mutable bool cusparse_initialized_{false}; - std::shared_ptr device_allocator_; - std::shared_ptr host_allocator_; + mutable thrust_exec_policy_t* thrust_policy_{nullptr}; + mutable bool thrust_policy_initialized_{false}; cudaStream_t user_stream_{nullptr}; cudaEvent_t event_; mutable cudaDeviceProp prop_; diff --git a/python/raft/common/handle.pxd b/python/raft/common/handle.pxd index 6076640312..884d81bed1 100644 --- a/python/raft/common/handle.pxd +++ b/python/raft/common/handle.pxd @@ -34,7 +34,5 @@ cdef extern from "raft/handle.hpp" namespace "raft" nogil: handle_t() except + handle_t(int ns) except + void set_stream(_Stream s) except + - void set_device_allocator(shared_ptr[allocator] a) except + - shared_ptr[allocator] get_device_allocator() except + _Stream get_stream() except + int get_num_internal_streams() except +