Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One cudaStream_t instance per raft::handle_t #291

Merged
merged 58 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
dc8ce65
checking in with handle changes
divyegala Jul 12, 2021
e9c88df
working handle cpp tests
divyegala Jul 14, 2021
ec5ec5d
working python handle
divyegala Jul 14, 2021
347b702
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jul 14, 2021
8b27ba2
styling changes
divyegala Jul 14, 2021
102ad4e
removing unnecessary TearDown from matrix gtest
divyegala Jul 14, 2021
6ddb1ff
renaming wrong variable name
divyegala Jul 14, 2021
52e775e
better doc for handle constructor according to review
divyegala Jul 15, 2021
54c67d4
review feedback
divyegala Jul 16, 2021
aedfa52
adjusting default handle stream to per thread
divyegala Jul 19, 2021
a502087
adjusting doc
divyegala Jul 19, 2021
8045f16
handle on knn detail API
divyegala Jul 19, 2021
8b2ab71
convenience function on handle to get stream from pool
divyegala Jul 19, 2021
fa320dd
correcting build
divyegala Jul 20, 2021
c25ab19
stream from pool at index
divyegala Jul 20, 2021
1ccc5cc
removing getting stream from pool functionality on handle
divyegala Jul 20, 2021
240fcf6
passing cpp tests
divyegala Sep 23, 2021
522c571
per-thread stream tests passing
divyegala Sep 23, 2021
89a23f6
solving pos argument
divyegala Oct 4, 2021
c24ecc8
merge upstream
divyegala Oct 4, 2021
e8a7856
passing tests
divyegala Oct 4, 2021
0c9871a
fix for failures in CI
divyegala Oct 4, 2021
5ab4f7c
Merge branch 'branch-21.12' of https://github.com/rapidsai/raft into …
divyegala Oct 13, 2021
830db09
review comments
divyegala Oct 14, 2021
2cf1e51
merging upstream
divyegala Oct 18, 2021
9a20bbf
resolving bad merge
divyegala Oct 18, 2021
7288978
changing sync method from cdef to def
divyegala Oct 22, 2021
ed6e4d8
removing cdef sync from handle pxd
divyegala Oct 28, 2021
9e83e9a
Merge branch 'branch-21.12' of https://github.com/rapidsai/raft into …
divyegala Oct 28, 2021
865fa7a
trying legacy stream
divyegala Nov 9, 2021
2044fb2
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 9, 2021
8bdbf81
back to default stream per thread
divyegala Nov 16, 2021
fe05b09
merging branch-22.02
divyegala Nov 16, 2021
d243eca
fixing bad merge
divyegala Nov 16, 2021
553453f
merge branch-21.12
divyegala Nov 16, 2021
5287e6e
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 17, 2021
2e60f56
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 17, 2021
480ba37
correcting legacy to per-thread
divyegala Nov 17, 2021
1877061
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 22, 2021
1be9fc7
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 22, 2021
0efbd91
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 22, 2021
ceac531
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 22, 2021
239a887
merging upstream
divyegala Dec 7, 2021
41d0694
merging upstream
divyegala Dec 7, 2021
a89ab29
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Dec 9, 2021
d106d6e
fixing compiler error
divyegala Dec 9, 2021
daadd95
merging upstream
divyegala Dec 9, 2021
7051e39
Reverting fused l2 changes. cuml CI still seems to be broken
cjnolet Dec 10, 2021
6bb7eeb
Fixing style
cjnolet Dec 10, 2021
3322ebe
merging corey's fused l2 knn bug revert
divyegala Dec 10, 2021
cbb0540
fixing macro name
divyegala Dec 10, 2021
ea97177
fixing typo with curly brace
divyegala Dec 10, 2021
8338dcc
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Dec 10, 2021
9659249
Adding no throw macro variants
cjnolet Dec 10, 2021
d12db1c
Fixing typo
cjnolet Dec 10, 2021
6186ead
pulling corey's macro updates
divyegala Dec 10, 2021
5ed4289
merging upstream
divyegala Dec 10, 2021
e97a938
merging upstream
divyegala Dec 13, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/include/raft/comms/test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <raft/comms/comms.hpp>
#include <raft/handle.hpp>
#include <raft/mr/device/buffer.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

Expand Down Expand Up @@ -513,7 +515,8 @@ bool test_commsplit(const handle_t &h, int n_colors) {
int color = rank % n_colors;
int key = rank / n_colors;

handle_t new_handle(1);
rmm::cuda_stream_pool stream_pool(1);
handle_t new_handle(rmm::cuda_stream_default, stream_pool);
auto shared_comm =
std::make_shared<comms_t>(communicator.comm_split(color, key));
new_handle.set_comms(shared_comm);
Expand Down
120 changes: 61 additions & 59 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <raft/mr/device/allocator.hpp>
#include <raft/mr/host/allocator.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/cuda_stream_view.hpp>
#include "cudart_utils.h"

namespace raft {
Expand All @@ -52,56 +53,47 @@ class handle_t {
static constexpr int kNumDefaultWorkerStreams = 0;
divyegala marked this conversation as resolved.
Show resolved Hide resolved

public:
// delete copy/move constructors and assignment operators as
// copying and moving underlying resources is unsafe
handle_t(const handle_t&) = delete;
handle_t& operator=(const handle_t&) = delete;
handle_t(handle_t&&) = delete;
handle_t& operator=(handle_t&&) = delete;

/**
* @brief Construct a handle with the specified number of worker streams
* @brief Construct a handle with a stream view and stream pool
*
* @param[in] n_streams number worker streams to be created
* @param[in] stream the default stream (which has the default value of nullptr if unspecified)
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] stream_pool the stream pool used (which has default pool of size 0 if unspecified)
*/
explicit handle_t(int n_streams = kNumDefaultWorkerStreams)
handle_t(rmm::cuda_stream_view stream = {},
const rmm::cuda_stream_pool& stream_pool = rmm::cuda_stream_pool{0})
: dev_id_([]() -> int {
int cur_dev = -1;
CUDA_CHECK(cudaGetDevice(&cur_dev));
return cur_dev;
}()),
streams_(n_streams),
device_allocator_(std::make_shared<mr::device::default_allocator>()),
host_allocator_(std::make_shared<mr::host::default_allocator>()) {
create_resources();
}

/**
* @brief Construct a light handle copy from another
* user stream, cuda handles, comms and worker pool are not copied
* The user_stream of the returned handle is set to the specified stream
* of the other handle worker pool
* @param[in] stream_id stream id in `other` worker streams
* to be set as user stream in the constructed handle
* @param[in] n_streams number worker streams to be created
*/
handle_t(const handle_t& other, int stream_id,
int n_streams = kNumDefaultWorkerStreams)
: dev_id_(other.get_device()), streams_(n_streams) {
RAFT_EXPECTS(
other.get_num_internal_streams() > 0,
"ERROR: the main handle must have at least one worker stream\n");
prop_ = other.get_device_properties();
device_prop_initialized_ = true;
device_allocator_ = other.get_device_allocator();
host_allocator_ = other.get_host_allocator();
host_allocator_(std::make_shared<mr::host::default_allocator>()),
stream_view_(stream),
stream_pool_(stream_pool) {
create_resources();
set_stream(other.get_internal_stream(stream_id));
}

/** Destroys all held-up resources */
virtual ~handle_t() { destroy_resources(); }

int get_device() const { return dev_id_; }

void set_stream(cudaStream_t stream) { user_stream_ = stream; }
cudaStream_t get_stream() const { return user_stream_; }
rmm::cuda_stream_view get_stream_view() const {
return rmm::cuda_stream_view(user_stream_);
}
/**
* @brief returns main stream on the handle
*/
const rmm::cuda_stream_view& get_stream() const { return stream_view_; }
divyegala marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief returns stream pool on the handle, could be 0 sized
*/
const rmm::cuda_stream_pool& get_stream_pool() const { return stream_pool_; }

void set_device_allocator(std::shared_ptr<mr::device::allocator> allocator) {
device_allocator_ = allocator;
Expand All @@ -121,6 +113,7 @@ class handle_t {
std::lock_guard<std::mutex> _(mutex_);
if (!cublas_initialized_) {
CUBLAS_CHECK(cublasCreate(&cublas_handle_));
CUBLAS_CHECK(cublasSetStream(cublas_handle_, stream_view_));
cublas_initialized_ = true;
}
return cublas_handle_;
Expand All @@ -130,6 +123,7 @@ class handle_t {
std::lock_guard<std::mutex> _(mutex_);
if (!cusolver_dn_initialized_) {
CUSOLVER_CHECK(cusolverDnCreate(&cusolver_dn_handle_));
CUSOLVER_CHECK(cusolverDnSetStream(cusolver_dn_handle_, stream_view_));
cusolver_dn_initialized_ = true;
}
return cusolver_dn_handle_;
Expand All @@ -139,6 +133,7 @@ class handle_t {
std::lock_guard<std::mutex> _(mutex_);
if (!cusolver_sp_initialized_) {
CUSOLVER_CHECK(cusolverSpCreate(&cusolver_sp_handle_));
CUSOLVER_CHECK(cusolverSpSetStream(cusolver_sp_handle_, stream_view_));
cusolver_sp_initialized_ = true;
}
return cusolver_sp_handle_;
Expand All @@ -148,40 +143,44 @@ class handle_t {
std::lock_guard<std::mutex> _(mutex_);
if (!cusparse_initialized_) {
CUSPARSE_CHECK(cusparseCreate(&cusparse_handle_));
CUSPARSE_CHECK(cusparseSetStream(cusparse_handle_, stream_view_));
cusparse_initialized_ = true;
}
return cusparse_handle_;
}

// legacy compatibility for cuML
cudaStream_t get_internal_stream(int sid) const {
return streams_.get_stream(sid).value();
}
// new accessor return rmm::cuda_stream_view
rmm::cuda_stream_view get_internal_stream_view(int sid) const {
return streams_.get_stream(sid);
}
/**
* @brief synchronize main stream on the handle
*/
void sync_stream() const { stream_view_.synchronize(); }

int get_num_internal_streams() const { return streams_.get_pool_size(); }
std::vector<cudaStream_t> get_internal_streams() const {
std::vector<cudaStream_t> int_streams_vec;
for (int i = 0; i < get_num_internal_streams(); i++) {
int_streams_vec.push_back(get_internal_stream(i));
/**
* @brief synchronize the stream pool on the handle
*/
void sync_stream_pool() const {
divyegala marked this conversation as resolved.
Show resolved Hide resolved
for (std::size_t i = 0; i < stream_pool_.get_pool_size(); i++) {
stream_pool_.get_stream(i).synchronize();
}
return int_streams_vec;
}

void wait_on_user_stream() const {
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
CUDA_CHECK(cudaEventRecord(event_, user_stream_));
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaStreamWaitEvent(get_internal_stream(i), event_, 0));
/**
* @brief synchronize subset of stream pool
*
* @param[in] stream_indices the indices of the streams in the stream pool to synchronize
*/
void sync_stream_pool(const std::vector<std::size_t> stream_indices) const {
for (const auto& stream_index : stream_indices) {
stream_pool_.get_stream(stream_index).synchronize();
}
}

void wait_on_internal_streams() const {
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaEventRecord(event_, get_internal_stream(i)));
CUDA_CHECK(cudaStreamWaitEvent(user_stream_, event_, 0));
/**
* @brief ask stream pool to wait on last event in main stream
*/
void wait_stream_pool_on_stream() const {
CUDA_CHECK(cudaEventRecord(event_, stream_view_));
for (std::size_t i = 0; i < stream_pool_.get_pool_size(); i++) {
CUDA_CHECK(cudaStreamWaitEvent(stream_pool_.get_stream(i), event_, 0));
}
}

Expand Down Expand Up @@ -227,7 +226,6 @@ class handle_t {
std::unordered_map<std::string, std::shared_ptr<comms::comms_t>> subcomms_;

const int dev_id_;
rmm::cuda_stream_pool streams_{0};
mutable cublasHandle_t cublas_handle_;
mutable bool cublas_initialized_{false};
mutable cusolverDnHandle_t cusolver_dn_handle_;
Expand All @@ -238,7 +236,8 @@ class handle_t {
mutable bool cusparse_initialized_{false};
std::shared_ptr<mr::device::allocator> device_allocator_;
std::shared_ptr<mr::host::allocator> host_allocator_;
cudaStream_t user_stream_{nullptr};
rmm::cuda_stream_view stream_view_;
const rmm::cuda_stream_pool& stream_pool_;
cudaEvent_t event_;
mutable cudaDeviceProp prop_;
mutable bool device_prop_initialized_{false};
Expand Down Expand Up @@ -277,9 +276,12 @@ class handle_t {
class stream_syncer {
public:
explicit stream_syncer(const handle_t& handle) : handle_(handle) {
handle_.wait_on_user_stream();
handle_.sync_stream();
}
~stream_syncer() {
handle_.wait_stream_pool_on_stream();
handle_.sync_stream_pool();
}
~stream_syncer() { handle_.wait_on_internal_streams(); }

stream_syncer(const stream_syncer& other) = delete;
stream_syncer& operator=(const stream_syncer& other) = delete;
Expand Down
18 changes: 10 additions & 8 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/cudart_utils.h>
#include <raft/cuda_utils.cuh>
#include <rmm/cuda_stream_pool.hpp>

#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/GpuResources.h>
Expand Down Expand Up @@ -200,9 +201,8 @@ void brute_force_knn_impl(std::vector<float *> &input, std::vector<int> &sizes,
int64_t *res_I, float *res_D, IntType k,
std::shared_ptr<deviceAllocator> allocator,
cudaStream_t userStream,
cudaStream_t *internalStreams = nullptr,
int n_int_streams = 0, bool rowMajorIndex = true,
bool rowMajorQuery = true,
const rmm::cuda_stream_pool &internalStreams,
bool rowMajorIndex = true, bool rowMajorQuery = true,
std::vector<int64_t> *translations = nullptr,
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Expanded,
Expand Down Expand Up @@ -263,14 +263,16 @@ void brute_force_knn_impl(std::vector<float *> &input, std::vector<int> &sizes,
}

// Sync user stream only if using other streams to parallelize query
if (n_int_streams > 0) CUDA_CHECK(cudaStreamSynchronize(userStream));
auto n_internal_streams = internalStreams.get_pool_size();
if (n_internal_streams > 0) CUDA_CHECK(cudaStreamSynchronize(userStream));

for (size_t i = 0; i < input.size(); i++) {
float *out_d_ptr = out_D + (i * k * n);
int64_t *out_i_ptr = out_I + (i * k * n);

cudaStream_t stream =
raft::select_stream(userStream, internalStreams, n_int_streams, i);
cudaStream_t stream = n_internal_streams > 0
? internalStreams.get_stream().value()
: userStream;

switch (metric) {
case raft::distance::DistanceType::Haversine:
Expand Down Expand Up @@ -318,8 +320,8 @@ void brute_force_knn_impl(std::vector<float *> &input, std::vector<int> &sizes,
// Sync internal streams if used. We don't need to
// sync the user stream because we'll already have
// fully serial execution.
for (int i = 0; i < n_int_streams; i++) {
CUDA_CHECK(cudaStreamSynchronize(internalStreams[i]));
for (std::size_t i = 0; i < internalStreams.get_pool_size(); i++) {
CUDA_CHECK(cudaStreamSynchronize(internalStreams.get_stream(i)));
}

if (input.size() > 1 || translations != nullptr) {
Expand Down
8 changes: 3 additions & 5 deletions cpp/include/raft/spatial/knn/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,11 @@ inline void brute_force_knn(
ASSERT(input.size() == sizes.size(),
"input and sizes vectors must be the same size");

std::vector<cudaStream_t> int_streams = handle.get_internal_streams();

detail::brute_force_knn_impl(input, sizes, D, search_items, n, res_I, res_D,
k, handle.get_device_allocator(),
handle.get_stream(), int_streams.data(),
handle.get_num_internal_streams(), rowMajorIndex,
rowMajorQuery, translations, metric, metric_arg);
handle.get_stream(), handle.get_stream_pool(),
rowMajorIndex, rowMajorQuery, translations,
metric, metric_arg);
}

} // namespace knn
Expand Down
1 change: 0 additions & 1 deletion cpp/test/cluster_solvers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ TEST(Raft, ModularitySolvers) {
using value_type = double;

handle_t h;
ASSERT_EQ(0, h.get_num_internal_streams());
ASSERT_EQ(0, h.get_device());

index_type neigvs{10};
Expand Down
2 changes: 0 additions & 2 deletions cpp/test/eigen_solvers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ TEST(Raft, EigenSolvers) {
using value_type = double;

handle_t h;
ASSERT_EQ(0, h.get_num_internal_streams());
ASSERT_EQ(0, h.get_device());

index_type* ro{nullptr};
Expand Down Expand Up @@ -73,7 +72,6 @@ TEST(Raft, SpectralSolvers) {
using value_type = double;

handle_t h;
ASSERT_EQ(0, h.get_num_internal_streams());
ASSERT_EQ(0, h.get_device());

index_type neigvs{10};
Expand Down
Loading