Skip to content

Commit

Permalink
RAFT handle update
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Jul 1, 2021
1 parent 0a5cbc5 commit 565cf68
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
52 changes: 28 additions & 24 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@
#include <raft/linalg/cusolver_wrappers.h>
#include <raft/sparse/cusparse_wrappers.h>
#include <raft/comms/comms.hpp>
#include <raft/mr/device/allocator.hpp>
#include <raft/mr/host/allocator.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/exec_policy.hpp>
#include "cudart_utils.h"

namespace raft {
Expand All @@ -48,6 +47,9 @@ namespace raft {
* necessary cuda kernels and/or libraries
*/
class handle_t {
using thrust_exec_policy_t = thrust::detail::execute_with_allocator<
rmm::mr::thrust_allocator<char>, thrust::cuda_cub::execute_on_stream_base>;

private:
static constexpr int kNumDefaultWorkerStreams = 0;

Expand All @@ -63,9 +65,7 @@ class handle_t {
CUDA_CHECK(cudaGetDevice(&cur_dev));
return cur_dev;
}()),
streams_(n_streams),
device_allocator_(std::make_shared<mr::device::default_allocator>()),
host_allocator_(std::make_shared<mr::host::default_allocator>()) {
streams_(n_streams) {
create_resources();
}

Expand All @@ -86,8 +86,6 @@ class handle_t {
"ERROR: the main handle must have at least one worker stream\n");
prop_ = other.get_device_properties();
device_prop_initialized_ = true;
device_allocator_ = other.get_device_allocator();
host_allocator_ = other.get_host_allocator();
create_resources();
set_stream(other.get_internal_stream(stream_id));
}
Expand All @@ -97,26 +95,15 @@ class handle_t {

int get_device() const { return dev_id_; }

void set_stream(cudaStream_t stream) { user_stream_ = stream; }
void set_stream(cudaStream_t stream) {
thrust_policy_initialized_ = false;
user_stream_ = stream;
}
cudaStream_t get_stream() const { return user_stream_; }
rmm::cuda_stream_view get_stream_view() const {
return rmm::cuda_stream_view(user_stream_);
}

void set_device_allocator(std::shared_ptr<mr::device::allocator> allocator) {
device_allocator_ = allocator;
}
std::shared_ptr<mr::device::allocator> get_device_allocator() const {
return device_allocator_;
}

void set_host_allocator(std::shared_ptr<mr::host::allocator> allocator) {
host_allocator_ = allocator;
}
std::shared_ptr<mr::host::allocator> get_host_allocator() const {
return host_allocator_;
}

cublasHandle_t get_cublas_handle() const {
std::lock_guard<std::mutex> _(mutex_);
if (!cublas_initialized_) {
Expand Down Expand Up @@ -153,6 +140,23 @@ class handle_t {
return cusparse_handle_;
}

thrust_exec_policy_t get_thrust_policy() const {
std::lock_guard<std::mutex> _(mutex_);
if (!thrust_policy_initialized_) {
if (!thrust_policy_) {
thrust_policy_ =
(thrust_exec_policy_t*)malloc(sizeof(thrust_exec_policy_t));
}
*thrust_policy_ = rmm::exec_policy(this->get_stream());
thrust_policy_initialized_ = true;
}
return *thrust_policy_;
}

thrust_exec_policy_t get_thrust_policy(cudaStream_t stream) const {
return rmm::exec_policy(stream);
}

// legacy compatibility for cuML
cudaStream_t get_internal_stream(int sid) const {
return streams_.get_stream(sid).value();
Expand Down Expand Up @@ -236,8 +240,8 @@ class handle_t {
mutable bool cusolver_sp_initialized_{false};
mutable cusparseHandle_t cusparse_handle_;
mutable bool cusparse_initialized_{false};
std::shared_ptr<mr::device::allocator> device_allocator_;
std::shared_ptr<mr::host::allocator> host_allocator_;
mutable thrust_exec_policy_t* thrust_policy_{nullptr};
mutable bool thrust_policy_initialized_{false};
cudaStream_t user_stream_{nullptr};
cudaEvent_t event_;
mutable cudaDeviceProp prop_;
Expand Down
2 changes: 0 additions & 2 deletions python/raft/common/handle.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,5 @@ cdef extern from "raft/handle.hpp" namespace "raft" nogil:
handle_t() except +
handle_t(int ns) except +
void set_stream(_Stream s) except +
void set_device_allocator(shared_ptr[allocator] a) except +
shared_ptr[allocator] get_device_allocator() except +
_Stream get_stream() except +
int get_num_internal_streams() except +

0 comments on commit 565cf68

Please sign in to comment.