Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One cudaStream_t instance per raft::handle_t #291

Merged
merged 58 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
dc8ce65
checking in with handle changes
divyegala Jul 12, 2021
e9c88df
working handle cpp tests
divyegala Jul 14, 2021
ec5ec5d
working python handle
divyegala Jul 14, 2021
347b702
Merge branch 'branch-21.08' of https://github.com/rapidsai/raft into …
divyegala Jul 14, 2021
8b27ba2
styling changes
divyegala Jul 14, 2021
102ad4e
removing unnecessary TearDown from matrix gtest
divyegala Jul 14, 2021
6ddb1ff
renaming wrong variable name
divyegala Jul 14, 2021
52e775e
better doc for handle constructor according to review
divyegala Jul 15, 2021
54c67d4
review feedback
divyegala Jul 16, 2021
aedfa52
adjusting default handle stream to per thread
divyegala Jul 19, 2021
a502087
adjusting doc
divyegala Jul 19, 2021
8045f16
handle on knn detail API
divyegala Jul 19, 2021
8b2ab71
convenience function on handle to get stream from pool
divyegala Jul 19, 2021
fa320dd
correcting build
divyegala Jul 20, 2021
c25ab19
stream from pool at index
divyegala Jul 20, 2021
1ccc5cc
removing getting stream from pool functionality on handle
divyegala Jul 20, 2021
240fcf6
passing cpp tests
divyegala Sep 23, 2021
522c571
per-thread stream tests passing
divyegala Sep 23, 2021
89a23f6
solving pos argument
divyegala Oct 4, 2021
c24ecc8
merge upstream
divyegala Oct 4, 2021
e8a7856
passing tests
divyegala Oct 4, 2021
0c9871a
fix for failures in CI
divyegala Oct 4, 2021
5ab4f7c
Merge branch 'branch-21.12' of https://github.com/rapidsai/raft into …
divyegala Oct 13, 2021
830db09
review comments
divyegala Oct 14, 2021
2cf1e51
merging upstream
divyegala Oct 18, 2021
9a20bbf
resolving bad merge
divyegala Oct 18, 2021
7288978
changing sync method from cdef to def
divyegala Oct 22, 2021
ed6e4d8
removing cdef sync from handle pxd
divyegala Oct 28, 2021
9e83e9a
Merge branch 'branch-21.12' of https://github.com/rapidsai/raft into …
divyegala Oct 28, 2021
865fa7a
trying legacy stream
divyegala Nov 9, 2021
2044fb2
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 9, 2021
8bdbf81
back to default stream per thread
divyegala Nov 16, 2021
fe05b09
merging branch-22.02
divyegala Nov 16, 2021
d243eca
fixing bad merge
divyegala Nov 16, 2021
553453f
merge branch-21.12
divyegala Nov 16, 2021
5287e6e
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 17, 2021
2e60f56
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 17, 2021
480ba37
correcting legacy to per-thread
divyegala Nov 17, 2021
1877061
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 22, 2021
1be9fc7
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 22, 2021
0efbd91
Merge remote-tracking branch 'upstream/branch-21.12' into imp-21.10-h…
divyegala Nov 22, 2021
ceac531
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Nov 22, 2021
239a887
merging upstream
divyegala Dec 7, 2021
41d0694
merging upstream
divyegala Dec 7, 2021
a89ab29
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Dec 9, 2021
d106d6e
fixing compiler error
divyegala Dec 9, 2021
daadd95
merging upstream
divyegala Dec 9, 2021
7051e39
Reverting fused l2 changes. cuml CI still seems to be broken
cjnolet Dec 10, 2021
6bb7eeb
Fixing style
cjnolet Dec 10, 2021
3322ebe
merging corey's fused l2 knn bug revert
divyegala Dec 10, 2021
cbb0540
fixing macro name
divyegala Dec 10, 2021
ea97177
fixing typo with curly brace
divyegala Dec 10, 2021
8338dcc
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.10-h…
divyegala Dec 10, 2021
9659249
Adding no throw macro variants
cjnolet Dec 10, 2021
d12db1c
Fixing typo
cjnolet Dec 10, 2021
6186ead
pulling corey's macro updates
divyegala Dec 10, 2021
5ed4289
merging upstream
divyegala Dec 10, 2021
e97a938
merging upstream
divyegala Dec 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# raft 22.02.00 (Date TBD)

Please see https://github.com/rapidsai/raft/releases/tag/v22.02.00a for the latest changes to this development branch.

# raft 21.12.00 (Date TBD)

Please see https://github.com/rapidsai/raft/releases/tag/v21.12.00a for the latest changes to this development branch.
Expand Down
4 changes: 2 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#=============================================================================

cmake_minimum_required(VERSION 3.20.1 FATAL_ERROR)
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-21.12/RAPIDS.cmake
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-22.02/RAPIDS.cmake
${CMAKE_BINARY_DIR}/RAPIDS.cmake)
include(${CMAKE_BINARY_DIR}/RAPIDS.cmake)
include(rapids-cmake)
Expand All @@ -26,7 +26,7 @@ include(rapids-find)

rapids_cuda_init_architectures(RAFT)

project(RAFT VERSION 21.12.00 LANGUAGES CXX CUDA)
project(RAFT VERSION 22.02.00 LANGUAGES CXX CUDA)

##############################################################################
# - build type ---------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion cpp/cmake/modules/ConfigureCUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ if(CMAKE_COMPILER_IS_GNUCXX)
list(APPEND RAFT_CXX_FLAGS -Wall -Werror -Wno-unknown-pragmas -Wno-error=deprecated-declarations)
endif()

list(APPEND RAFT_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr)
list(APPEND RAFT_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr --default-stream per-thread)

# set warnings as errors
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.2.0)
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/comms/test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,8 @@ bool test_commsplit(const handle_t &h, int n_colors) {
// first we need to assign to a color, then assign the rank within the color
int color = rank % n_colors;
int key = rank / n_colors;

handle_t new_handle(1);
auto stream_pool = std::make_shared<rmm::cuda_stream_pool>(1);
handle_t new_handle(rmm::cuda_stream_default, stream_pool);
auto shared_comm =
std::make_shared<comms_t>(communicator.comm_split(color, key));
new_handle.set_comms(shared_comm);
Expand Down
189 changes: 115 additions & 74 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,69 +47,42 @@ namespace raft {
* necessary cuda kernels and/or libraries
*/
class handle_t {
private:
static constexpr int kNumDefaultWorkerStreams = 0;

public:
// delete copy/move constructors and assignment operators as
// copying and moving underlying resources is unsafe
handle_t(const handle_t&) = delete;
handle_t& operator=(const handle_t&) = delete;
handle_t(handle_t&&) = delete;
handle_t& operator=(handle_t&&) = delete;

/**
* @brief Construct a handle with the specified number of worker streams
* @brief Construct a handle with a stream view and stream pool
*
* @param[in] n_streams number worker streams to be created
* @param[in] stream the default stream (which has the default per-thread stream if unspecified)
* @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified)
*/
explicit handle_t(int n_streams = kNumDefaultWorkerStreams)
handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread,
std::shared_ptr<rmm::cuda_stream_pool> stream_pool = {nullptr})
: dev_id_([]() -> int {
int cur_dev = -1;
CUDA_CHECK(cudaGetDevice(&cur_dev));
return cur_dev;
}()) {
if (n_streams != 0) {
streams_ = std::make_unique<rmm::cuda_stream_pool>(n_streams);
}
create_resources();
thrust_policy_ = std::make_unique<rmm::exec_policy>(user_stream_);
}

/**
* @brief Construct a light handle copy from another
* user stream, cuda handles, comms and worker pool are not copied
* The user_stream of the returned handle is set to the specified stream
* of the other handle worker pool
* @param[in] other other handle for which to use streams
* @param[in] stream_id stream id in `other` worker streams
* to be set as user stream in the constructed handle
* @param[in] n_streams number worker streams to be created
*/
handle_t(const handle_t& other, int stream_id,
int n_streams = kNumDefaultWorkerStreams)
: dev_id_(other.get_device()) {
RAFT_EXPECTS(
other.get_num_internal_streams() > 0,
"ERROR: the main handle must have at least one worker stream\n");
if (n_streams != 0) {
streams_ = std::make_unique<rmm::cuda_stream_pool>(n_streams);
}
prop_ = other.get_device_properties();
device_prop_initialized_ = true;
}()),
stream_view_{stream_view},
stream_pool_{stream_pool} {
create_resources();
set_stream(other.get_internal_stream(stream_id));
thrust_policy_ = std::make_unique<rmm::exec_policy>(user_stream_);
}

/** Destroys all held-up resources */
virtual ~handle_t() { destroy_resources(); }

int get_device() const { return dev_id_; }

void set_stream(cudaStream_t stream) { user_stream_ = stream; }
cudaStream_t get_stream() const { return user_stream_; }
rmm::cuda_stream_view get_stream_view() const {
return rmm::cuda_stream_view(user_stream_);
}

cublasHandle_t get_cublas_handle() const {
std::lock_guard<std::mutex> _(mutex_);
if (!cublas_initialized_) {
CUBLAS_CHECK(cublasCreate(&cublas_handle_));
CUBLAS_CHECK(cublasSetStream(cublas_handle_, stream_view_));
cublas_initialized_ = true;
}
return cublas_handle_;
Expand All @@ -119,6 +92,7 @@ class handle_t {
std::lock_guard<std::mutex> _(mutex_);
if (!cusolver_dn_initialized_) {
CUSOLVER_CHECK(cusolverDnCreate(&cusolver_dn_handle_));
CUSOLVER_CHECK(cusolverDnSetStream(cusolver_dn_handle_, stream_view_));
cusolver_dn_initialized_ = true;
}
return cusolver_dn_handle_;
Expand All @@ -128,6 +102,7 @@ class handle_t {
std::lock_guard<std::mutex> _(mutex_);
if (!cusolver_sp_initialized_) {
CUSOLVER_CHECK(cusolverSpCreate(&cusolver_sp_handle_));
CUSOLVER_CHECK(cusolverSpSetStream(cusolver_sp_handle_, stream_view_));
cusolver_sp_initialized_ = true;
}
return cusolver_sp_handle_;
Expand All @@ -137,51 +112,112 @@ class handle_t {
std::lock_guard<std::mutex> _(mutex_);
if (!cusparse_initialized_) {
CUSPARSE_CHECK(cusparseCreate(&cusparse_handle_));
CUSPARSE_CHECK(cusparseSetStream(cusparse_handle_, stream_view_));
cusparse_initialized_ = true;
}
return cusparse_handle_;
}

rmm::exec_policy& get_thrust_policy() const { return *thrust_policy_; }

// legacy compatibility for cuML
cudaStream_t get_internal_stream(int sid) const {
RAFT_EXPECTS(
streams_.get() != nullptr,
"ERROR: rmm::cuda_stream_pool was not initialized with a non-zero value");
return streams_->get_stream(sid).value();
/**
* @brief synchronize main stream on the handle
*/
void sync_stream() const { stream_view_.synchronize(); }

/**
* @brief returns main stream on the handle
*/
rmm::cuda_stream_view get_stream() const { return stream_view_; }

/**
* @brief returns whether stream pool was initialized on the handle
*/

bool is_stream_pool_initialized() const {
return stream_pool_.get() != nullptr;
}

/**
* @brief returns stream pool on the handle
*/
const rmm::cuda_stream_pool& get_stream_pool() const {
RAFT_EXPECTS(stream_pool_,
"ERROR: rmm::cuda_stream_pool was not initialized");
return *stream_pool_;
}

std::size_t get_stream_pool_size() const {
return is_stream_pool_initialized() ? stream_pool_->get_pool_size() : 0;
}

/**
* @brief return stream from pool
*/
rmm::cuda_stream_view get_stream_from_stream_pool() const {
RAFT_EXPECTS(stream_pool_,
"ERROR: rmm::cuda_stream_pool was not initialized");
return stream_pool_->get_stream();
}

/**
* @brief return stream from pool at index
*/
rmm::cuda_stream_view get_stream_from_stream_pool(
std::size_t stream_idx) const {
RAFT_EXPECTS(stream_pool_,
"ERROR: rmm::cuda_stream_pool was not initialized");
return stream_pool_->get_stream(stream_idx);
}
// new accessor return rmm::cuda_stream_view
rmm::cuda_stream_view get_internal_stream_view(int sid) const {
RAFT_EXPECTS(
streams_.get() != nullptr,
"ERROR: rmm::cuda_stream_pool was not initialized with a non-zero value");
return streams_->get_stream(sid);

/**
* @brief return stream from pool if size > 0, else main stream on handle
*/
rmm::cuda_stream_view get_next_usable_stream() const {
return is_stream_pool_initialized() ? get_stream_from_stream_pool()
: stream_view_;
}

int get_num_internal_streams() const {
return streams_.get() != nullptr ? streams_->get_pool_size() : 0;
/**
* @brief return stream from pool at index if size > 0, else main stream on handle
*
* @param[in] stream_index the required index of the stream in the stream pool if available
*/
rmm::cuda_stream_view get_next_usable_stream(std::size_t stream_idx) const {
return is_stream_pool_initialized()
? get_stream_from_stream_pool(stream_idx)
: stream_view_;
}

std::vector<cudaStream_t> get_internal_streams() const {
std::vector<cudaStream_t> int_streams_vec;
for (int i = 0; i < get_num_internal_streams(); i++) {
int_streams_vec.push_back(get_internal_stream(i));
/**
* @brief synchronize the stream pool on the handle
*/
void sync_stream_pool() const {
divyegala marked this conversation as resolved.
Show resolved Hide resolved
for (std::size_t i = 0; i < get_stream_pool_size(); i++) {
stream_pool_->get_stream(i).synchronize();
}
return int_streams_vec;
}

void wait_on_user_stream() const {
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
CUDA_CHECK(cudaEventRecord(event_, user_stream_));
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaStreamWaitEvent(get_internal_stream(i), event_, 0));
/**
* @brief synchronize subset of stream pool
*
* @param[in] stream_indices the indices of the streams in the stream pool to synchronize
*/
void sync_stream_pool(const std::vector<std::size_t> stream_indices) const {
RAFT_EXPECTS(stream_pool_,
"ERROR: rmm::cuda_stream_pool was not initialized");
for (const auto& stream_index : stream_indices) {
stream_pool_->get_stream(stream_index).synchronize();
}
}

void wait_on_internal_streams() const {
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaEventRecord(event_, get_internal_stream(i)));
CUDA_CHECK(cudaStreamWaitEvent(user_stream_, event_, 0));
/**
* @brief ask stream pool to wait on last event in main stream
*/
void wait_stream_pool_on_stream() const {
CUDA_CHECK(cudaEventRecord(event_, stream_view_));
for (std::size_t i = 0; i < get_stream_pool_size(); i++) {
CUDA_CHECK(cudaStreamWaitEvent(stream_pool_->get_stream(i), event_, 0));
}
}

Expand Down Expand Up @@ -227,7 +263,6 @@ class handle_t {
std::unordered_map<std::string, std::shared_ptr<comms::comms_t>> subcomms_;

const int dev_id_;
std::unique_ptr<rmm::cuda_stream_pool> streams_{nullptr};
mutable cublasHandle_t cublas_handle_;
mutable bool cublas_initialized_{false};
mutable cusolverDnHandle_t cusolver_dn_handle_;
Expand All @@ -237,13 +272,16 @@ class handle_t {
mutable cusparseHandle_t cusparse_handle_;
mutable bool cusparse_initialized_{false};
std::unique_ptr<rmm::exec_policy> thrust_policy_{nullptr};
cudaStream_t user_stream_{nullptr};
rmm::cuda_stream_view stream_view_{rmm::cuda_stream_per_thread};
std::shared_ptr<rmm::cuda_stream_pool> stream_pool_{nullptr};
cudaEvent_t event_;
mutable cudaDeviceProp prop_;
mutable bool device_prop_initialized_{false};
mutable std::mutex mutex_;

void create_resources() {
thrust_policy_ = std::make_unique<rmm::exec_policy>(stream_view_);

CUDA_CHECK(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}

Expand Down Expand Up @@ -276,9 +314,12 @@ class handle_t {
class stream_syncer {
public:
explicit stream_syncer(const handle_t& handle) : handle_(handle) {
handle_.wait_on_user_stream();
handle_.sync_stream();
}
~stream_syncer() {
handle_.wait_stream_pool_on_stream();
handle_.sync_stream_pool();
}
~stream_syncer() { handle_.wait_on_internal_streams(); }

stream_syncer(const stream_syncer& other) = delete;
stream_syncer& operator=(const stream_syncer& other) = delete;
Expand Down
9 changes: 5 additions & 4 deletions cpp/include/raft/label/classlabels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,18 @@ int getUniquelabels(rmm::device_uvector<value_t> &unique, value_t *y, size_t n,

// Query how much temporary storage we will need for cub operations
// and allocate it
cub::DeviceRadixSort::SortKeys(NULL, bytes, y, workspace.data(), n);
cub::DeviceRadixSort::SortKeys(NULL, bytes, y, workspace.data(), n, 0,
sizeof(value_t) * 8, stream);
cub::DeviceSelect::Unique(NULL, bytes2, workspace.data(), workspace.data(),
d_num_selected.data(), n);
d_num_selected.data(), n, stream);
bytes = max(bytes, bytes2);
rmm::device_uvector<char> cub_storage(bytes, stream);

// Select Unique classes
cub::DeviceRadixSort::SortKeys(cub_storage.data(), bytes, y, workspace.data(),
n);
n, 0, sizeof(value_t) * 8, stream);
cub::DeviceSelect::Unique(cub_storage.data(), bytes, workspace.data(),
workspace.data(), d_num_selected.data(), n);
workspace.data(), d_num_selected.data(), n, stream);

int n_unique = d_num_selected.value(stream);
// Copy unique classes to output
Expand Down
5 changes: 2 additions & 3 deletions cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,8 @@ void k_closest_landmarks(const raft::handle_t &handle,
std::vector<std::uint32_t> sizes = {index.n_landmarks};

brute_force_knn_impl<std::uint32_t, std::int64_t>(
input, sizes, index.n, const_cast<value_t *>(query_pts), n_query_pts,
R_knn_inds, R_knn_dists, k, handle.get_stream(), nullptr, 0, true, true,
nullptr, index.metric);
handle, input, sizes, index.n, const_cast<value_t *>(query_pts),
n_query_pts, R_knn_inds, R_knn_dists, k, true, true, nullptr, index.metric);
}

/**
Expand Down
Loading