diff --git a/cpp/cmake/modules/ConfigureCUDA.cmake b/cpp/cmake/modules/ConfigureCUDA.cmake index 3786910be0..02bd15c407 100644 --- a/cpp/cmake/modules/ConfigureCUDA.cmake +++ b/cpp/cmake/modules/ConfigureCUDA.cmake @@ -23,7 +23,7 @@ if(CMAKE_COMPILER_IS_GNUCXX) list(APPEND RAFT_CXX_FLAGS -Wall -Werror -Wno-unknown-pragmas -Wno-error=deprecated-declarations) endif() -list(APPEND RAFT_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr) +list(APPEND RAFT_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr --default-stream per-thread) # set warnings as errors if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.2.0) diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index 93b57b13a0..01ad6369f8 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -528,10 +528,10 @@ bool test_commsplit(const handle_t& h, int n_colors) if (n_colors > size) n_colors = size; // first we need to assign to a color, then assign the rank within the color - int color = rank % n_colors; - int key = rank / n_colors; - - handle_t new_handle(1); + int color = rank % n_colors; + int key = rank / n_colors; + auto stream_pool = std::make_shared(1); + handle_t new_handle(rmm::cuda_stream_default, stream_pool); auto shared_comm = std::make_shared(communicator.comm_split(color, key)); new_handle.set_comms(shared_comm); diff --git a/cpp/include/raft/handle.hpp b/cpp/include/raft/handle.hpp index 97b442afe3..bba7fabc54 100644 --- a/cpp/include/raft/handle.hpp +++ b/cpp/include/raft/handle.hpp @@ -47,48 +47,31 @@ namespace raft { * necessary cuda kernels and/or libraries */ class handle_t { - private: - static constexpr int kNumDefaultWorkerStreams = 0; - public: + // delete copy/move constructors and assignment operators as + // copying and moving underlying resources is unsafe + handle_t(const handle_t&) = delete; + handle_t& operator=(const handle_t&) = delete; + handle_t(handle_t&&) = delete; + handle_t& operator=(handle_t&&) = delete; + /** - * @brief Construct a handle with the specified number of worker streams + * @brief Construct a handle with a stream view and stream pool * - * @param[in] n_streams number worker streams to be created + * @param[in] stream the default stream (which has the default per-thread stream if unspecified) + * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) */ - explicit handle_t(int n_streams = kNumDefaultWorkerStreams) + handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, + std::shared_ptr stream_pool = {nullptr}) : dev_id_([]() -> int { int cur_dev = -1; RAFT_CUDA_TRY(cudaGetDevice(&cur_dev)); return cur_dev; - }()) + }()), + stream_view_{stream_view}, + stream_pool_{stream_pool} { - if (n_streams != 0) { streams_ = std::make_unique(n_streams); } create_resources(); - thrust_policy_ = std::make_unique(user_stream_); - } - - /** - * @brief Construct a light handle copy from another - * user stream, cuda handles, comms and worker pool are not copied - * The user_stream of the returned handle is set to the specified stream - * of the other handle worker pool - * @param[in] other other handle for which to use streams - * @param[in] stream_id stream id in `other` worker streams - * to be set as user stream in the constructed handle - * @param[in] n_streams number worker streams to be created - */ - handle_t(const handle_t& other, int stream_id, int n_streams = kNumDefaultWorkerStreams) - : dev_id_(other.get_device()) - { - RAFT_EXPECTS(other.get_num_internal_streams() > 0, - "ERROR: the main handle must have at least one worker stream\n"); - if (n_streams != 0) { streams_ = std::make_unique(n_streams); } - prop_ = other.get_device_properties(); - device_prop_initialized_ = true; - create_resources(); - set_stream(other.get_internal_stream(stream_id)); - thrust_policy_ = std::make_unique(user_stream_); } /** Destroys all held-up resources */ @@ -96,15 +79,12 @@ class handle_t { int get_device() const { return dev_id_; } - void set_stream(cudaStream_t stream) { user_stream_ = stream; } - cudaStream_t get_stream() const { return user_stream_; } - rmm::cuda_stream_view get_stream_view() const { return rmm::cuda_stream_view(user_stream_); } - cublasHandle_t get_cublas_handle() const { std::lock_guard _(mutex_); if (!cublas_initialized_) { RAFT_CUBLAS_TRY_NO_THROW(cublasCreate(&cublas_handle_)); + RAFT_CUBLAS_TRY_NO_THROW(cublasSetStream(cublas_handle_, stream_view_)); cublas_initialized_ = true; } return cublas_handle_; @@ -115,6 +95,7 @@ class handle_t { std::lock_guard _(mutex_); if (!cusolver_dn_initialized_) { RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnCreate(&cusolver_dn_handle_)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnSetStream(cusolver_dn_handle_, stream_view_)); cusolver_dn_initialized_ = true; } return cusolver_dn_handle_; @@ -125,6 +106,7 @@ class handle_t { std::lock_guard _(mutex_); if (!cusolver_sp_initialized_) { RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpCreate(&cusolver_sp_handle_)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpSetStream(cusolver_sp_handle_, stream_view_)); cusolver_sp_initialized_ = true; } return cusolver_sp_handle_; @@ -135,6 +117,7 @@ class handle_t { std::lock_guard _(mutex_); if (!cusparse_initialized_) { RAFT_CUSPARSE_TRY_NO_THROW(cusparseCreate(&cusparse_handle_)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseSetStream(cusparse_handle_, stream_view_)); cusparse_initialized_ = true; } return cusparse_handle_; @@ -142,48 +125,103 @@ class handle_t { rmm::exec_policy& get_thrust_policy() const { return *thrust_policy_; } - // legacy compatibility for cuML - cudaStream_t get_internal_stream(int sid) const + /** + * @brief synchronize main stream on the handle + */ + void sync_stream() const { stream_view_.synchronize(); } + + /** + * @brief returns main stream on the handle + */ + rmm::cuda_stream_view get_stream() const { return stream_view_; } + + /** + * @brief returns whether stream pool was initialized on the handle + */ + + bool is_stream_pool_initialized() const { return stream_pool_.get() != nullptr; } + + /** + * @brief returns stream pool on the handle + */ + const rmm::cuda_stream_pool& get_stream_pool() const { - RAFT_EXPECTS(streams_.get() != nullptr, - "ERROR: rmm::cuda_stream_pool was not initialized with a non-zero value"); - return streams_->get_stream(sid).value(); + RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); + return *stream_pool_; } - // new accessor return rmm::cuda_stream_view - rmm::cuda_stream_view get_internal_stream_view(int sid) const + + std::size_t get_stream_pool_size() const { - RAFT_EXPECTS(streams_.get() != nullptr, - "ERROR: rmm::cuda_stream_pool was not initialized with a non-zero value"); - return streams_->get_stream(sid); + return is_stream_pool_initialized() ? stream_pool_->get_pool_size() : 0; } - int get_num_internal_streams() const + /** + * @brief return stream from pool + */ + rmm::cuda_stream_view get_stream_from_stream_pool() const { - return streams_.get() != nullptr ? streams_->get_pool_size() : 0; + RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); + return stream_pool_->get_stream(); } - std::vector get_internal_streams() const + /** + * @brief return stream from pool at index + */ + rmm::cuda_stream_view get_stream_from_stream_pool(std::size_t stream_idx) const + { + RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); + return stream_pool_->get_stream(stream_idx); + } + + /** + * @brief return stream from pool if size > 0, else main stream on handle + */ + rmm::cuda_stream_view get_next_usable_stream() const + { + return is_stream_pool_initialized() ? get_stream_from_stream_pool() : stream_view_; + } + + /** + * @brief return stream from pool at index if size > 0, else main stream on handle + * + * @param[in] stream_index the required index of the stream in the stream pool if available + */ + rmm::cuda_stream_view get_next_usable_stream(std::size_t stream_idx) const + { + return is_stream_pool_initialized() ? get_stream_from_stream_pool(stream_idx) : stream_view_; + } + + /** + * @brief synchronize the stream pool on the handle + */ + void sync_stream_pool() const { - std::vector int_streams_vec; - for (int i = 0; i < get_num_internal_streams(); i++) { - int_streams_vec.push_back(get_internal_stream(i)); + for (std::size_t i = 0; i < get_stream_pool_size(); i++) { + stream_pool_->get_stream(i).synchronize(); } - return int_streams_vec; } - void wait_on_user_stream() const + /** + * @brief synchronize subset of stream pool + * + * @param[in] stream_indices the indices of the streams in the stream pool to synchronize + */ + void sync_stream_pool(const std::vector stream_indices) const { - RAFT_CUDA_TRY(cudaEventRecord(event_, user_stream_)); - for (int i = 0; i < get_num_internal_streams(); i++) { - RAFT_CUDA_TRY(cudaStreamWaitEvent(get_internal_stream(i), event_, 0)); + RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); + for (const auto& stream_index : stream_indices) { + stream_pool_->get_stream(stream_index).synchronize(); } } - void wait_on_internal_streams() const + /** + * @brief ask stream pool to wait on last event in main stream + */ + void wait_stream_pool_on_stream() const { - for (int i = 0; i < get_num_internal_streams(); i++) { - RAFT_CUDA_TRY(cudaEventRecord(event_, get_internal_stream(i))); - RAFT_CUDA_TRY(cudaStreamWaitEvent(user_stream_, event_, 0)); + RAFT_CUDA_TRY(cudaEventRecord(event_, stream_view_)); + for (std::size_t i = 0; i < get_stream_pool_size(); i++) { + RAFT_CUDA_TRY(cudaStreamWaitEvent(stream_pool_->get_stream(i), event_, 0)); } } @@ -229,7 +267,6 @@ class handle_t { std::unordered_map> subcomms_; const int dev_id_; - std::unique_ptr streams_{nullptr}; mutable cublasHandle_t cublas_handle_; mutable bool cublas_initialized_{false}; mutable cusolverDnHandle_t cusolver_dn_handle_; @@ -239,7 +276,8 @@ class handle_t { mutable cusparseHandle_t cusparse_handle_; mutable bool cusparse_initialized_{false}; std::unique_ptr thrust_policy_{nullptr}; - cudaStream_t user_stream_{nullptr}; + rmm::cuda_stream_view stream_view_{rmm::cuda_stream_per_thread}; + std::shared_ptr stream_pool_{nullptr}; cudaEvent_t event_; mutable cudaDeviceProp prop_; mutable bool device_prop_initialized_{false}; @@ -247,12 +285,13 @@ class handle_t { void create_resources() { + thrust_policy_ = std::make_unique(stream_view_); + RAFT_CUDA_TRY(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); } void destroy_resources() { - ///@todo: enable *_NO_THROW variants once we have enabled logging if (cusparse_initialized_) { RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroy(cusparse_handle_)); } if (cusolver_dn_initialized_) { RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnDestroy(cusolver_dn_handle_)); @@ -270,11 +309,12 @@ class handle_t { */ class stream_syncer { public: - explicit stream_syncer(const handle_t& handle) : handle_(handle) + explicit stream_syncer(const handle_t& handle) : handle_(handle) { handle_.sync_stream(); } + ~stream_syncer() { - handle_.wait_on_user_stream(); + handle_.wait_stream_pool_on_stream(); + handle_.sync_stream_pool(); } - ~stream_syncer() { handle_.wait_on_internal_streams(); } stream_syncer(const stream_syncer& other) = delete; stream_syncer& operator=(const stream_syncer& other) = delete; diff --git a/cpp/include/raft/label/classlabels.cuh b/cpp/include/raft/label/classlabels.cuh index 4e9e993b78..6cc23576f1 100644 --- a/cpp/include/raft/label/classlabels.cuh +++ b/cpp/include/raft/label/classlabels.cuh @@ -51,16 +51,23 @@ int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, // Query how much temporary storage we will need for cub operations // and allocate it - cub::DeviceRadixSort::SortKeys(NULL, bytes, y, workspace.data(), n); + cub::DeviceRadixSort::SortKeys( + NULL, bytes, y, workspace.data(), n, 0, sizeof(value_t) * 8, stream); cub::DeviceSelect::Unique( - NULL, bytes2, workspace.data(), workspace.data(), d_num_selected.data(), n); + NULL, bytes2, workspace.data(), workspace.data(), d_num_selected.data(), n, stream); bytes = max(bytes, bytes2); rmm::device_uvector cub_storage(bytes, stream); // Select Unique classes - cub::DeviceRadixSort::SortKeys(cub_storage.data(), bytes, y, workspace.data(), n); - cub::DeviceSelect::Unique( - cub_storage.data(), bytes, workspace.data(), workspace.data(), d_num_selected.data(), n); + cub::DeviceRadixSort::SortKeys( + cub_storage.data(), bytes, y, workspace.data(), n, 0, sizeof(value_t) * 8, stream); + cub::DeviceSelect::Unique(cub_storage.data(), + bytes, + workspace.data(), + workspace.data(), + d_num_selected.data(), + n, + stream); int n_unique = d_num_selected.value(stream); // Copy unique classes to output diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 59fce73188..3e787811bd 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -164,7 +164,8 @@ void k_closest_landmarks(const raft::handle_t& handle, std::vector input = {index.get_R()}; std::vector sizes = {index.n_landmarks}; - brute_force_knn_impl(input, + brute_force_knn_impl(handle, + input, sizes, index.n, const_cast(query_pts), @@ -172,9 +173,6 @@ void k_closest_landmarks(const raft::handle_t& handle, R_knn_inds, R_knn_dists, k, - handle.get_stream(), - nullptr, - 0, true, true, nullptr, diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 414c1dc1ce..12b7124773 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -18,6 +18,7 @@ #include #include +#include #include @@ -217,6 +218,7 @@ inline void knn_merge_parts(value_t* inK, */ template void brute_force_knn_impl( + const raft::handle_t& handle, std::vector& input, std::vector& sizes, IntType D, @@ -225,15 +227,14 @@ void brute_force_knn_impl( IdxType* res_I, float* res_D, IntType k, - cudaStream_t userStream, - cudaStream_t* internalStreams = nullptr, - int n_int_streams = 0, bool rowMajorIndex = true, bool rowMajorQuery = true, std::vector* translations = nullptr, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, float metricArg = 0) { + auto userStream = handle.get_stream(); + ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); std::vector* id_ranges; @@ -284,14 +285,14 @@ void brute_force_knn_impl( out_I = all_I.data(); } - // Sync user stream only if using other streams to parallelize query - if (n_int_streams > 0) RAFT_CUDA_TRY(cudaStreamSynchronize(userStream)); + // Make other streams from pool wait on main stream + handle.wait_stream_pool_on_stream(); for (size_t i = 0; i < input.size(); i++) { float* out_d_ptr = out_D + (i * k * n); IdxType* out_i_ptr = out_I + (i * k * n); - cudaStream_t stream = raft::select_stream(userStream, internalStreams, n_int_streams, i); + auto stream = handle.get_next_usable_stream(i); // // TODO: Enable this once we figure out why it's causing pytest failures in cuml. // if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && @@ -358,9 +359,7 @@ void brute_force_knn_impl( // Sync internal streams if used. We don't need to // sync the user stream because we'll already have // fully serial execution. - for (int i = 0; i < n_int_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(internalStreams[i])); - } + handle.sync_stream_pool(); if (input.size() > 1 || translations != nullptr) { // This is necessary for proper index translations. If there are @@ -378,11 +377,7 @@ void brute_force_knn_impl( float p = 0.5; // standard l2 if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; raft::linalg::unaryOp( - res_D, - res_D, - n * k, - [p] __device__(float input) { return powf(fabsf(input), p); }, - userStream); + res_D, res_D, n * k, [p] __device__(float input) { return powf(input, p); }, userStream); } query_metric_processor->revert(search_items); diff --git a/cpp/include/raft/spatial/knn/knn.hpp b/cpp/include/raft/spatial/knn/knn.hpp index eb9a8f1436..e1e1eac248 100644 --- a/cpp/include/raft/spatial/knn/knn.hpp +++ b/cpp/include/raft/spatial/knn/knn.hpp @@ -140,9 +140,8 @@ inline void brute_force_knn(raft::handle_t const& handle, { ASSERT(input.size() == sizes.size(), "input and sizes vectors must be the same size"); - std::vector int_streams = handle.get_internal_streams(); - - detail::brute_force_knn_impl(input, + detail::brute_force_knn_impl(handle, + input, sizes, D, search_items, @@ -150,9 +149,6 @@ inline void brute_force_knn(raft::handle_t const& handle, res_I, res_D, k, - handle.get_stream(), - int_streams.data(), - handle.get_num_internal_streams(), rowMajorIndex, rowMajorQuery, translations, diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index a748b0ef0e..c0598804a8 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -57,11 +57,12 @@ void naiveDistanceAdj(bool* dist, int n, int k, DataType eps, - bool isRowMajor) + bool isRowMajor, + cudaStream_t stream) { static const dim3 TPB(16, 32, 1); dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); - naiveDistanceAdjKernel<<>>(dist, x, y, m, n, k, eps, isRowMajor); + naiveDistanceAdjKernel<<>>(dist, x, y, m, n, k, eps, isRowMajor); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -106,7 +107,7 @@ class DistanceAdjTest : public ::testing::TestWithParam( x.data(), y.data(), m, n, k); @@ -155,7 +156,7 @@ TEST_P(DistanceAdjTestF, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare())); + ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); } INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestF, ::testing::ValuesIn(inputsf)); @@ -174,7 +175,7 @@ TEST_P(DistanceAdjTestD, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare())); + ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); } INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_canberra.cu b/cpp/test/distance/dist_canberra.cu index db318605b4..ca90907779 100644 --- a/cpp/test/distance/dist_canberra.cu +++ b/cpp/test/distance/dist_canberra.cu @@ -40,7 +40,7 @@ TEST_P(DistanceCanberraF, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraF, ::testing::ValuesIn(inputsf)); @@ -60,7 +60,7 @@ TEST_P(DistanceCanberraD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_chebyshev.cu b/cpp/test/distance/dist_chebyshev.cu index c7dccfe712..641b958d72 100644 --- a/cpp/test/distance/dist_chebyshev.cu +++ b/cpp/test/distance/dist_chebyshev.cu @@ -40,7 +40,7 @@ TEST_P(DistanceLinfF, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfF, ::testing::ValuesIn(inputsf)); @@ -60,7 +60,7 @@ TEST_P(DistanceLinfD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_correlation.cu b/cpp/test/distance/dist_correlation.cu index 0648ed96ca..72df5b10f4 100644 --- a/cpp/test/distance/dist_correlation.cu +++ b/cpp/test/distance/dist_correlation.cu @@ -41,7 +41,7 @@ TEST_P(DistanceCorrelationF, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF, ::testing::ValuesIn(inputsf)); @@ -61,7 +61,7 @@ TEST_P(DistanceCorrelationD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu index b3e6a4c97f..a085e82705 100644 --- a/cpp/test/distance/dist_cos.cu +++ b/cpp/test/distance/dist_cos.cu @@ -39,8 +39,8 @@ TEST_P(DistanceExpCosF, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE( - devArrMatch(dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosF, ::testing::ValuesIn(inputsf)); @@ -59,8 +59,8 @@ TEST_P(DistanceExpCosD, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE( - devArrMatch(dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_euc_exp.cu b/cpp/test/distance/dist_euc_exp.cu index 75ff7e682a..f840a91bec 100644 --- a/cpp/test/distance/dist_euc_exp.cu +++ b/cpp/test/distance/dist_euc_exp.cu @@ -39,8 +39,8 @@ TEST_P(DistanceEucExpTestF, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE( - devArrMatch(dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestF, ::testing::ValuesIn(inputsf)); @@ -59,8 +59,8 @@ TEST_P(DistanceEucExpTestD, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE( - devArrMatch(dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_euc_unexp.cu b/cpp/test/distance/dist_euc_unexp.cu index 88affa16d5..6d374f3332 100644 --- a/cpp/test/distance/dist_euc_unexp.cu +++ b/cpp/test/distance/dist_euc_unexp.cu @@ -40,8 +40,8 @@ TEST_P(DistanceEucUnexpTestF, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE( - devArrMatch(dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestF, ::testing::ValuesIn(inputsf)); @@ -60,8 +60,8 @@ TEST_P(DistanceEucUnexpTestD, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE( - devArrMatch(dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_hamming.cu b/cpp/test/distance/dist_hamming.cu index 631adc751c..e0f1efc3f7 100644 --- a/cpp/test/distance/dist_hamming.cu +++ b/cpp/test/distance/dist_hamming.cu @@ -41,7 +41,7 @@ TEST_P(DistanceHammingF, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingF, ::testing::ValuesIn(inputsf)); @@ -61,7 +61,7 @@ TEST_P(DistanceHammingD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_hellinger.cu b/cpp/test/distance/dist_hellinger.cu index 8a07c8836f..caa96f189d 100644 --- a/cpp/test/distance/dist_hellinger.cu +++ b/cpp/test/distance/dist_hellinger.cu @@ -41,7 +41,7 @@ TEST_P(DistanceHellingerExpF, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpF, ::testing::ValuesIn(inputsf)); @@ -61,7 +61,7 @@ TEST_P(DistanceHellingerExpD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_jensen_shannon.cu b/cpp/test/distance/dist_jensen_shannon.cu index 3cda31a852..74b02ef18d 100644 --- a/cpp/test/distance/dist_jensen_shannon.cu +++ b/cpp/test/distance/dist_jensen_shannon.cu @@ -41,7 +41,7 @@ TEST_P(DistanceJensenShannonF, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonF, ::testing::ValuesIn(inputsf)); @@ -61,7 +61,7 @@ TEST_P(DistanceJensenShannonD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_kl_divergence.cu b/cpp/test/distance/dist_kl_divergence.cu index 4303b8cc8f..e551eda0ab 100644 --- a/cpp/test/distance/dist_kl_divergence.cu +++ b/cpp/test/distance/dist_kl_divergence.cu @@ -41,7 +41,7 @@ TEST_P(DistanceKLDivergenceF, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceF, ::testing::ValuesIn(inputsf)); @@ -61,7 +61,7 @@ TEST_P(DistanceKLDivergenceD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_l1.cu b/cpp/test/distance/dist_l1.cu index dad160ca41..ac2ee024f6 100644 --- a/cpp/test/distance/dist_l1.cu +++ b/cpp/test/distance/dist_l1.cu @@ -40,7 +40,7 @@ TEST_P(DistanceUnexpL1F, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1F, ::testing::ValuesIn(inputsf)); @@ -60,7 +60,7 @@ TEST_P(DistanceUnexpL1D, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1D, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_minkowski.cu b/cpp/test/distance/dist_minkowski.cu index 34f6d2825e..f0a6833f2b 100644 --- a/cpp/test/distance/dist_minkowski.cu +++ b/cpp/test/distance/dist_minkowski.cu @@ -40,7 +40,7 @@ TEST_P(DistanceLpUnexpF, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpF, ::testing::ValuesIn(inputsf)); @@ -60,7 +60,7 @@ TEST_P(DistanceLpUnexpD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_russell_rao.cu b/cpp/test/distance/dist_russell_rao.cu index e0bfcd7eb3..42234d4f0b 100644 --- a/cpp/test/distance/dist_russell_rao.cu +++ b/cpp/test/distance/dist_russell_rao.cu @@ -41,7 +41,7 @@ TEST_P(DistanceRussellRaoF, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoF, ::testing::ValuesIn(inputsf)); @@ -61,7 +61,7 @@ TEST_P(DistanceRussellRaoD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance))); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index ec9d35bb09..8d150a4a87 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -307,7 +307,8 @@ void naiveDistance(DataType* dist, int k, raft::distance::DistanceType type, bool isRowMajor, - DataType metric_arg = 2.0f) + DataType metric_arg = 2.0f, + cudaStream_t stream = 0) { static const dim3 TPB(16, 32, 1); dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); @@ -317,38 +318,46 @@ void naiveDistance(DataType* dist, case raft::distance::DistanceType::Linf: case raft::distance::DistanceType::L1: naiveL1_Linf_CanberraDistanceKernel - <<>>(dist, x, y, m, n, k, type, isRowMajor); + <<>>(dist, x, y, m, n, k, type, isRowMajor); break; case raft::distance::DistanceType::L2SqrtUnexpanded: case raft::distance::DistanceType::L2Unexpanded: case raft::distance::DistanceType::L2SqrtExpanded: case raft::distance::DistanceType::L2Expanded: - naiveDistanceKernel<<>>(dist, x, y, m, n, k, type, isRowMajor); + naiveDistanceKernel + <<>>(dist, x, y, m, n, k, type, isRowMajor); break; case raft::distance::DistanceType::CosineExpanded: - naiveCosineDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); + naiveCosineDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); break; case raft::distance::DistanceType::HellingerExpanded: - naiveHellingerDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); + naiveHellingerDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); break; case raft::distance::DistanceType::LpUnexpanded: naiveLpUnexpDistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor, metric_arg); + <<>>(dist, x, y, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::HammingUnexpanded: - naiveHammingDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); + naiveHammingDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); break; case raft::distance::DistanceType::JensenShannon: - naiveJensenShannonDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); + naiveJensenShannonDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); break; case raft::distance::DistanceType::RusselRaoExpanded: - naiveRussellRaoDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); + naiveRussellRaoDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); break; case raft::distance::DistanceType::KLDivergence: - naiveKLDivergenceDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); + naiveKLDivergenceDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); break; case raft::distance::DistanceType::CorrelationExpanded: - naiveCorrelationDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); + naiveCorrelationDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); break; default: FAIL() << "should be here\n"; } @@ -433,7 +442,7 @@ class DistanceTest : public ::testing::TestWithParam> { r.uniform(y.data(), n * k, DataType(-1.0), DataType(1.0), stream); } naiveDistance( - dist_ref.data(), x.data(), y.data(), m, n, k, distanceType, isRowMajor, metric_arg); + dist_ref.data(), x.data(), y.data(), m, n, k, distanceType, isRowMajor, metric_arg, stream); size_t worksize = raft::distance::getWorkspaceSize( x.data(), y.data(), m, n, k); rmm::device_uvector workspace(worksize, stream); diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 68ad220734..072176e503 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -250,7 +250,7 @@ TEST_P(FusedL2NNTestF_Sq, Result) { runTest(min.data()); ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance))); + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestF_Sq, ::testing::ValuesIn(inputsf)); typedef FusedL2NNTest FusedL2NNTestF_Sqrt; @@ -258,7 +258,7 @@ TEST_P(FusedL2NNTestF_Sqrt, Result) { runTest(min.data()); ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance))); + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestF_Sqrt, ::testing::ValuesIn(inputsf)); @@ -285,7 +285,7 @@ TEST_P(FusedL2NNTestD_Sq, Result) { runTest(min.data()); ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance))); + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestD_Sq, ::testing::ValuesIn(inputsd)); typedef FusedL2NNTest FusedL2NNTestD_Sqrt; @@ -293,7 +293,7 @@ TEST_P(FusedL2NNTestD_Sqrt, Result) { runTest(min.data()); ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance))); + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestD_Sqrt, ::testing::ValuesIn(inputsd)); @@ -330,7 +330,7 @@ TEST_P(FusedL2NNDetTestF_Sq, Result) runTest(min.data()); // assumed to be golden for (int i = 0; i < NumRepeats; ++i) { runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP())); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); } } INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestF_Sq, ::testing::ValuesIn(inputsf)); @@ -340,7 +340,7 @@ TEST_P(FusedL2NNDetTestF_Sqrt, Result) runTest(min.data()); // assumed to be golden for (int i = 0; i < NumRepeats; ++i) { runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP())); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); } } INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestF_Sqrt, ::testing::ValuesIn(inputsf)); @@ -351,7 +351,7 @@ TEST_P(FusedL2NNDetTestD_Sq, Result) runTest(min.data()); // assumed to be golden for (int i = 0; i < NumRepeats; ++i) { runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP())); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); } } INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestD_Sq, ::testing::ValuesIn(inputsd)); @@ -361,7 +361,7 @@ TEST_P(FusedL2NNDetTestD_Sqrt, Result) runTest(min.data()); // assumed to be golden for (int i = 0; i < NumRepeats; ++i) { runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP())); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); } } INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestD_Sqrt, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/handle.cpp b/cpp/test/handle.cpp index 81b8bb6c6c..ddc0806a65 100644 --- a/cpp/test/handle.cpp +++ b/cpp/test/handle.cpp @@ -26,7 +26,7 @@ TEST(Raft, HandleDefault) { handle_t h; ASSERT_EQ(0, h.get_device()); - ASSERT_EQ(nullptr, h.get_stream()); + ASSERT_EQ(rmm::cuda_stream_per_thread, h.get_stream()); ASSERT_NE(nullptr, h.get_cublas_handle()); ASSERT_NE(nullptr, h.get_cusolver_dn_handle()); ASSERT_NE(nullptr, h.get_cusolver_sp_handle()); @@ -35,44 +35,33 @@ TEST(Raft, HandleDefault) TEST(Raft, Handle) { - handle_t h(4); - ASSERT_EQ(4, h.get_num_internal_streams()); + // test stream pool creation + constexpr std::size_t n_streams = 4; + auto stream_pool = std::make_shared(n_streams); + handle_t h(rmm::cuda_stream_default, stream_pool); + ASSERT_EQ(n_streams, h.get_stream_pool_size()); + + // test non default stream handle cudaStream_t stream; RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - h.set_stream(stream); - ASSERT_EQ(stream, h.get_stream()); + rmm::cuda_stream_view stream_view(stream); + handle_t handle(stream_view); + ASSERT_EQ(stream_view, handle.get_stream()); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } -TEST(Raft, GetInternalStreams) -{ - handle_t h(4); - auto streams = h.get_internal_streams(); - ASSERT_EQ(4U, streams.size()); -} - TEST(Raft, GetHandleFromPool) { - handle_t parent(4); - - handle_t child(parent, 2); - ASSERT_EQ(parent.get_internal_stream(2), child.get_stream()); + constexpr std::size_t n_streams = 4; + auto stream_pool = std::make_shared(n_streams); + handle_t parent(rmm::cuda_stream_default, stream_pool); - child.set_stream(parent.get_internal_stream(3)); - ASSERT_EQ(parent.get_internal_stream(3), child.get_stream()); - ASSERT_NE(parent.get_internal_stream(2), child.get_stream()); - - ASSERT_EQ(parent.get_device(), child.get_device()); + for (std::size_t i = 0; i < n_streams; i++) { + auto worker_stream = parent.get_stream_from_stream_pool(i); + handle_t child(worker_stream); + ASSERT_EQ(parent.get_stream_from_stream_pool(i), child.get_stream()); + } } -TEST(Raft, GetHandleStreamViews) -{ - handle_t parent(4); - - handle_t child(parent, 2); - ASSERT_EQ(parent.get_internal_stream_view(2), child.get_stream_view()); - ASSERT_EQ(parent.get_internal_stream_view(2).value(), child.get_stream_view().value()); - EXPECT_FALSE(child.get_stream_view().is_default()); -} } // namespace raft diff --git a/cpp/test/label/merge_labels.cu b/cpp/test/label/merge_labels.cu index dd67f0fd89..726c5c427b 100644 --- a/cpp/test/label/merge_labels.cu +++ b/cpp/test/label/merge_labels.cu @@ -65,7 +65,7 @@ class MergeLabelsTest : public ::testing::TestWithParam( - expected.data(), labels_a.data(), params.N, raft::Compare())); + expected.data(), labels_a.data(), params.N, raft::Compare(), stream)); } protected: diff --git a/cpp/test/linalg/add.cu b/cpp/test/linalg/add.cu index 2b51f4640a..b65a8665bc 100644 --- a/cpp/test/linalg/add.cu +++ b/cpp/test/linalg/add.cu @@ -45,7 +45,7 @@ class AddTest : public ::testing::TestWithParam> { int len = params.len; r.uniform(in1.data(), len, InT(-1.0), InT(1.0), stream); r.uniform(in2.data(), len, InT(-1.0), InT(1.0), stream); - naiveAddElem(out_ref.data(), in1.data(), in2.data(), len); + naiveAddElem(out_ref.data(), in1.data(), in2.data(), len, stream); add(out.data(), in1.data(), in2.data(), len, stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } @@ -53,7 +53,7 @@ class AddTest : public ::testing::TestWithParam> { void compare() { ASSERT_TRUE(raft::devArrMatch( - out_ref.data(), out.data(), params.len, raft::CompareApprox(params.tolerance))); + out_ref.data(), out.data(), params.len, raft::CompareApprox(params.tolerance), stream)); } protected: diff --git a/cpp/test/linalg/add.cuh b/cpp/test/linalg/add.cuh index 5e887e0040..70e4866407 100644 --- a/cpp/test/linalg/add.cuh +++ b/cpp/test/linalg/add.cuh @@ -30,11 +30,11 @@ __global__ void naiveAddElemKernel(OutT* out, const InT* in1, const InT* in2, in } template -void naiveAddElem(OutT* out, const InT* in1, const InT* in2, int len) +void naiveAddElem(OutT* out, const InT* in1, const InT* in2, int len, cudaStream_t stream) { static const int TPB = 64; int nblks = raft::ceildiv(len, TPB); - naiveAddElemKernel<<>>(out, in1, in2, len); + naiveAddElemKernel<<>>(out, in1, in2, len); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index bb62ddced3..3de29c6ee8 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -76,8 +76,8 @@ const std::vector> inputsf_i32 = {{0.000001f, 1024 * typedef BinaryOpTest BinaryOpTestF_i32; TEST_P(BinaryOpTestF_i32, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(BinaryOpTests, BinaryOpTestF_i32, ::testing::ValuesIn(inputsf_i32)); @@ -85,8 +85,8 @@ const std::vector> inputsf_i64 = {{0.000001f, 1024 typedef BinaryOpTest BinaryOpTestF_i64; TEST_P(BinaryOpTestF_i64, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(BinaryOpTests, BinaryOpTestF_i64, ::testing::ValuesIn(inputsf_i64)); @@ -95,8 +95,8 @@ const std::vector> inputsf_i32_d = { typedef BinaryOpTest BinaryOpTestF_i32_D; TEST_P(BinaryOpTestF_i32_D, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(BinaryOpTests, BinaryOpTestF_i32_D, ::testing::ValuesIn(inputsf_i32_d)); @@ -104,8 +104,8 @@ const std::vector> inputsd_i32 = {{0.00000001, 1024 typedef BinaryOpTest BinaryOpTestD_i32; TEST_P(BinaryOpTestD_i32, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(BinaryOpTests, BinaryOpTestD_i32, ::testing::ValuesIn(inputsd_i32)); @@ -114,24 +114,18 @@ const std::vector> inputsd_i64 = { typedef BinaryOpTest BinaryOpTestD_i64; TEST_P(BinaryOpTestD_i64, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(BinaryOpTests, BinaryOpTestD_i64, ::testing::ValuesIn(inputsd_i64)); template class BinaryOpAlignment : public ::testing::Test { protected: - BinaryOpAlignment() - { - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - handle.set_stream(stream); - } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } - public: void Misaligned() { + auto stream = handle.get_stream(); // Test to trigger cudaErrorMisalignedAddress if veclen is incorrectly // chosen. int n = 1024; @@ -146,11 +140,10 @@ class BinaryOpAlignment : public ::testing::Test { y.data() + 19, 256, [] __device__(math_t x, math_t y) { return x + y; }, - stream); + handle.get_stream()); } raft::handle_t handle; - cudaStream_t stream; }; typedef ::testing::Types FloatTypes; TYPED_TEST_CASE(BinaryOpAlignment, FloatTypes); diff --git a/cpp/test/linalg/cholesky_r1.cu b/cpp/test/linalg/cholesky_r1.cu index 1c3d99a883..0326cf5a47 100644 --- a/cpp/test/linalg/cholesky_r1.cu +++ b/cpp/test/linalg/cholesky_r1.cu @@ -38,9 +38,7 @@ class CholeskyR1Test : public ::testing::Test { devInfo(handle.get_stream()), workspace(0, handle.get_stream()) { - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - handle.set_stream(stream); - raft::update_device(G.data(), G_host, n_rows * n_rows, stream); + raft::update_device(G.data(), G_host, n_rows * n_rows, handle.get_stream()); // Allocate workspace solver_handle = handle.get_cusolver_dn_handle(); @@ -49,27 +47,31 @@ class CholeskyR1Test : public ::testing::Test { int n_bytes = 0; // Initializing in CUBLAS_FILL_MODE_LOWER, because that has larger workspace // requirements. - raft::linalg::choleskyRank1Update( - handle, L.data(), n_rows, n_rows, nullptr, &n_bytes, CUBLAS_FILL_MODE_LOWER, stream); + raft::linalg::choleskyRank1Update(handle, + L.data(), + n_rows, + n_rows, + nullptr, + &n_bytes, + CUBLAS_FILL_MODE_LOWER, + handle.get_stream()); Lwork = std::max(Lwork * sizeof(math_t), (size_t)n_bytes); - workspace.resize(Lwork, stream); + workspace.resize(Lwork, handle.get_stream()); } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } - void testR1Update() { int n = n_rows * n_rows; std::vector fillmode{CUBLAS_FILL_MODE_LOWER, CUBLAS_FILL_MODE_UPPER}; for (auto uplo : fillmode) { - raft::copy(L.data(), G.data(), n, stream); + raft::copy(L.data(), G.data(), n, handle.get_stream()); for (int rank = 1; rank <= n_rows; rank++) { std::stringstream ss; ss << "Rank " << rank << ((uplo == CUBLAS_FILL_MODE_LOWER) ? ", lower" : ", upper"); SCOPED_TRACE(ss.str()); // Expected solution using Cholesky factorization from scratch - raft::copy(L_exp.data(), G.data(), n, stream); + raft::copy(L_exp.data(), G.data(), n, handle.get_stream()); RAFT_CUSOLVER_TRY(raft::linalg::cusolverDnpotrf(solver_handle, uplo, rank, @@ -78,33 +80,36 @@ class CholeskyR1Test : public ::testing::Test { (math_t*)workspace.data(), Lwork, devInfo.data(), - stream)); + handle.get_stream())); // Incremental Cholesky factorization using rank one updates. raft::linalg::choleskyRank1Update( - handle, L.data(), rank, n_rows, workspace.data(), &Lwork, uplo, stream); + handle, L.data(), rank, n_rows, workspace.data(), &Lwork, uplo, handle.get_stream()); - ASSERT_TRUE(raft::devArrMatch( - L_exp.data(), L.data(), n_rows * rank, raft::CompareApprox(3e-3))); + ASSERT_TRUE(raft::devArrMatch(L_exp.data(), + L.data(), + n_rows * rank, + raft::CompareApprox(3e-3), + handle.get_stream())); } } } void testR1Error() { - raft::update_device(G.data(), G2_host, 4, stream); + raft::update_device(G.data(), G2_host, 4, handle.get_stream()); std::vector fillmode{CUBLAS_FILL_MODE_LOWER, CUBLAS_FILL_MODE_UPPER}; for (auto uplo : fillmode) { - raft::copy(L.data(), G.data(), 4, stream); + raft::copy(L.data(), G.data(), 4, handle.get_stream()); ASSERT_NO_THROW(raft::linalg::choleskyRank1Update( - handle, L.data(), 1, 2, workspace.data(), &Lwork, uplo, stream)); + handle, L.data(), 1, 2, workspace.data(), &Lwork, uplo, handle.get_stream())); ASSERT_THROW(raft::linalg::choleskyRank1Update( - handle, L.data(), 2, 2, workspace.data(), &Lwork, uplo, stream), + handle, L.data(), 2, 2, workspace.data(), &Lwork, uplo, handle.get_stream()), raft::exception); math_t eps = std::numeric_limits::epsilon(); ASSERT_NO_THROW(raft::linalg::choleskyRank1Update( - handle, L.data(), 2, 2, workspace.data(), &Lwork, uplo, stream, eps)); + handle, L.data(), 2, 2, workspace.data(), &Lwork, uplo, handle.get_stream(), eps)); } } diff --git a/cpp/test/linalg/coalesced_reduction.cu b/cpp/test/linalg/coalesced_reduction.cu index 910e6a2365..4773ecf50f 100644 --- a/cpp/test/linalg/coalesced_reduction.cu +++ b/cpp/test/linalg/coalesced_reduction.cu @@ -101,15 +101,21 @@ const std::vector> inputsd = {{0.000000001, 102 typedef coalescedReductionTest coalescedReductionTestF; TEST_P(coalescedReductionTestF, Result) { - ASSERT_TRUE(raft::devArrMatch( - dots_exp.data(), dots_act.data(), params.rows, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(dots_exp.data(), + dots_act.data(), + params.rows, + raft::CompareApprox(params.tolerance), + stream)); } typedef coalescedReductionTest coalescedReductionTestD; TEST_P(coalescedReductionTestD, Result) { - ASSERT_TRUE(raft::devArrMatch( - dots_exp.data(), dots_act.data(), params.rows, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(dots_exp.data(), + dots_act.data(), + params.rows, + raft::CompareApprox(params.tolerance), + stream)); } INSTANTIATE_TEST_CASE_P(coalescedReductionTests, diff --git a/cpp/test/linalg/divide.cu b/cpp/test/linalg/divide.cu index 7f57c79a7e..d2d2f24397 100644 --- a/cpp/test/linalg/divide.cu +++ b/cpp/test/linalg/divide.cu @@ -79,7 +79,7 @@ typedef DivideTest DivideTestF; TEST_P(DivideTestF, Result) { ASSERT_TRUE(devArrMatch( - out_ref.data(), out.data(), params.len, raft::CompareApprox(params.tolerance))); + out_ref.data(), out.data(), params.len, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(DivideTests, DivideTestF, ::testing::ValuesIn(inputsf)); @@ -88,7 +88,7 @@ const std::vector> inputsd = {{0.000001f, 1024 * 1024, 2.f TEST_P(DivideTestD, Result) { ASSERT_TRUE(devArrMatch( - out_ref.data(), out.data(), params.len, raft::CompareApprox(params.tolerance))); + out_ref.data(), out.data(), params.len, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(DivideTests, DivideTestD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/linalg/eig.cu b/cpp/test/linalg/eig.cu index c9d95d2058..6bdd880118 100644 --- a/cpp/test/linalg/eig.cu +++ b/cpp/test/linalg/eig.cu @@ -160,7 +160,8 @@ TEST_P(EigTestValF, Result) ASSERT_TRUE(raft::devArrMatch(eig_vals_ref.data(), eig_vals.data(), params.n_col, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigTest EigTestValD; @@ -169,7 +170,8 @@ TEST_P(EigTestValD, Result) ASSERT_TRUE(raft::devArrMatch(eig_vals_ref.data(), eig_vals.data(), params.n_col, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigTest EigTestVecF; @@ -178,7 +180,8 @@ TEST_P(EigTestVecF, Result) ASSERT_TRUE(raft::devArrMatch(eig_vectors_ref.data(), eig_vectors.data(), params.len, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigTest EigTestVecD; @@ -187,7 +190,8 @@ TEST_P(EigTestVecD, Result) ASSERT_TRUE(raft::devArrMatch(eig_vectors_ref.data(), eig_vectors.data(), params.len, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigTest EigTestValJacobiF; @@ -196,7 +200,8 @@ TEST_P(EigTestValJacobiF, Result) ASSERT_TRUE(raft::devArrMatch(eig_vals_ref.data(), eig_vals_jacobi.data(), params.n_col, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigTest EigTestValJacobiD; @@ -205,7 +210,8 @@ TEST_P(EigTestValJacobiD, Result) ASSERT_TRUE(raft::devArrMatch(eig_vals_ref.data(), eig_vals_jacobi.data(), params.n_col, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigTest EigTestVecJacobiF; @@ -214,7 +220,8 @@ TEST_P(EigTestVecJacobiF, Result) ASSERT_TRUE(raft::devArrMatch(eig_vectors_ref.data(), eig_vectors_jacobi.data(), params.len, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigTest EigTestVecJacobiD; @@ -223,7 +230,8 @@ TEST_P(EigTestVecJacobiD, Result) ASSERT_TRUE(raft::devArrMatch(eig_vectors_ref.data(), eig_vectors_jacobi.data(), params.len, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigTest EigTestVecCompareF; @@ -232,7 +240,8 @@ TEST_P(EigTestVecCompareF, Result) ASSERT_TRUE(raft::devArrMatch(eig_vectors_large.data(), eig_vectors_jacobi_large.data(), (params.n * params.n), - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigTest EigTestVecCompareD; @@ -241,7 +250,8 @@ TEST_P(EigTestVecCompareD, Result) ASSERT_TRUE(raft::devArrMatch(eig_vectors_large.data(), eig_vectors_jacobi_large.data(), (params.n * params.n), - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } INSTANTIATE_TEST_SUITE_P(EigTests, EigTestValF, ::testing::ValuesIn(inputsf2)); diff --git a/cpp/test/linalg/eig_sel.cu b/cpp/test/linalg/eig_sel.cu index 518dce4048..e41651ef61 100644 --- a/cpp/test/linalg/eig_sel.cu +++ b/cpp/test/linalg/eig_sel.cu @@ -117,7 +117,8 @@ TEST_P(EigSelTestValF, Result) ASSERT_TRUE(raft::devArrMatch(eig_vals_ref.data(), eig_vals.data(), params.n_col, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigSelTest EigSelTestValD; @@ -126,7 +127,8 @@ TEST_P(EigSelTestValD, Result) ASSERT_TRUE(raft::devArrMatch(eig_vals_ref.data(), eig_vals.data(), params.n_col, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigSelTest EigSelTestVecF; @@ -135,7 +137,8 @@ TEST_P(EigSelTestVecF, Result) ASSERT_TRUE(raft::devArrMatch(eig_vectors_ref.data(), eig_vectors.data(), 12, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } typedef EigSelTest EigSelTestVecD; @@ -144,7 +147,8 @@ TEST_P(EigSelTestVecD, Result) ASSERT_TRUE(raft::devArrMatch(eig_vectors_ref.data(), eig_vectors.data(), 12, - raft::CompareApproxAbs(params.tolerance))); + raft::CompareApproxAbs(params.tolerance), + stream)); } INSTANTIATE_TEST_SUITE_P(EigSelTest, EigSelTestValF, ::testing::ValuesIn(inputsf2)); diff --git a/cpp/test/linalg/eltwise.cu b/cpp/test/linalg/eltwise.cu index 023b04f8ed..1f6c411b79 100644 --- a/cpp/test/linalg/eltwise.cu +++ b/cpp/test/linalg/eltwise.cu @@ -95,15 +95,15 @@ const std::vector> inputsd1 = { typedef ScalarMultiplyTest ScalarMultiplyTestF; TEST_P(ScalarMultiplyTestF, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } typedef ScalarMultiplyTest ScalarMultiplyTestD; TEST_P(ScalarMultiplyTestD, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(ScalarMultiplyTests, ScalarMultiplyTestF, ::testing::ValuesIn(inputsf1)); @@ -182,15 +182,15 @@ const std::vector> inputsd2 = {{0.00000001, 1024 * 1024 typedef EltwiseAddTest EltwiseAddTestF; TEST_P(EltwiseAddTestF, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } typedef EltwiseAddTest EltwiseAddTestD; TEST_P(EltwiseAddTestD, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(EltwiseAddTests, EltwiseAddTestF, ::testing::ValuesIn(inputsf2)); diff --git a/cpp/test/linalg/gemm_layout.cu b/cpp/test/linalg/gemm_layout.cu index da07ed797e..6f512aec71 100644 --- a/cpp/test/linalg/gemm_layout.cu +++ b/cpp/test/linalg/gemm_layout.cu @@ -105,6 +105,7 @@ class GemmLayoutTest : public ::testing::TestWithParam> { params.xLayout, params.yLayout, stream); + handle.sync_stream(); } void TearDown() override diff --git a/cpp/test/linalg/gemv.cu b/cpp/test/linalg/gemv.cu index 4d5472f38c..962b17fa24 100644 --- a/cpp/test/linalg/gemv.cu +++ b/cpp/test/linalg/gemv.cu @@ -106,7 +106,7 @@ class GemvTest : public ::testing::TestWithParam> { dim3 blocks(raft::ceildiv(yElems, 256), 1, 1); dim3 threads(256, 1, 1); - naiveGemv<<>>( + naiveGemv<<>>( refy.data(), A.data(), x.data(), params.n_rows, params.n_cols, params.lda, params.trans_a); gemv(handle, @@ -118,6 +118,7 @@ class GemvTest : public ::testing::TestWithParam> { y.data(), params.trans_a, stream); + handle.sync_stream(); } void TearDown() override {} diff --git a/cpp/test/linalg/map_then_reduce.cu b/cpp/test/linalg/map_then_reduce.cu index c35e1ea9ef..0baeba5807 100644 --- a/cpp/test/linalg/map_then_reduce.cu +++ b/cpp/test/linalg/map_then_reduce.cu @@ -103,16 +103,16 @@ const std::vector> inputsf = {{0.001f, 1024 * 1024, 1234U typedef MapReduceTest MapReduceTestFF; TEST_P(MapReduceTestFF, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(MapReduceTests, MapReduceTestFF, ::testing::ValuesIn(inputsf)); typedef MapReduceTest MapReduceTestFD; TEST_P(MapReduceTestFD, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(MapReduceTests, MapReduceTestFD, ::testing::ValuesIn(inputsf)); @@ -120,8 +120,8 @@ const std::vector> inputsd = {{0.000001, 1024 * 1024, 12 typedef MapReduceTest MapReduceTestDD; TEST_P(MapReduceTestDD, Result) { - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_SUITE_P(MapReduceTests, MapReduceTestDD, ::testing::ValuesIn(inputsd)); @@ -133,37 +133,37 @@ class MapGenericReduceTest : public ::testing::Test { protected: MapGenericReduceTest() : input(n, handle.get_stream()), output(handle.get_stream()) { - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - handle.set_stream(stream); - initInput(input.data(), input.size(), stream); + initInput(input.data(), input.size(), handle.get_stream()); } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } - public: void initInput(InType* input, int n, cudaStream_t stream) { raft::random::Rng r(137); - r.uniform(input, n, InType(2), InType(3), stream); + r.uniform(input, n, InType(2), InType(3), handle.get_stream()); InType val = 1; - raft::update_device(input + 42, &val, 1, stream); + raft::update_device(input + 42, &val, 1, handle.get_stream()); val = 5; - raft::update_device(input + 337, &val, 1, stream); + raft::update_device(input + 337, &val, 1, handle.get_stream()); } void testMin() { auto op = [] __device__(InType in) { return in; }; const OutType neutral = std::numeric_limits::max(); - mapThenReduce(output.data(), input.size(), neutral, op, cub::Min(), stream, input.data()); - EXPECT_TRUE(raft::devArrMatch(OutType(1), output.data(), 1, raft::Compare())); + mapThenReduce( + output.data(), input.size(), neutral, op, cub::Min(), handle.get_stream(), input.data()); + EXPECT_TRUE(raft::devArrMatch( + OutType(1), output.data(), 1, raft::Compare(), handle.get_stream())); } void testMax() { auto op = [] __device__(InType in) { return in; }; const OutType neutral = std::numeric_limits::min(); - mapThenReduce(output.data(), input.size(), neutral, op, cub::Max(), stream, input.data()); - EXPECT_TRUE(raft::devArrMatch(OutType(5), output.data(), 1, raft::Compare())); + mapThenReduce( + output.data(), input.size(), neutral, op, cub::Max(), handle.get_stream(), input.data()); + EXPECT_TRUE(raft::devArrMatch( + OutType(5), output.data(), 1, raft::Compare(), handle.get_stream())); } protected: diff --git a/cpp/test/linalg/matrix_vector_op.cu b/cpp/test/linalg/matrix_vector_op.cu index 9f2a1ac78f..b471972304 100644 --- a/cpp/test/linalg/matrix_vector_op.cu +++ b/cpp/test/linalg/matrix_vector_op.cu @@ -111,7 +111,8 @@ class MatVecOpTest : public ::testing::TestWithParam> N, params.rowMajor, params.bcastAlongRows, - (T)1.0); + (T)1.0, + stream); } else { naiveMatVec(out_ref.data(), in.data(), @@ -120,7 +121,8 @@ class MatVecOpTest : public ::testing::TestWithParam> N, params.rowMajor, params.bcastAlongRows, - (T)1.0); + (T)1.0, + stream); } matrixVectorOpLaunch(out.data(), in.data(), diff --git a/cpp/test/linalg/matrix_vector_op.cuh b/cpp/test/linalg/matrix_vector_op.cuh index 70a68fb542..e51802c135 100644 --- a/cpp/test/linalg/matrix_vector_op.cuh +++ b/cpp/test/linalg/matrix_vector_op.cuh @@ -54,12 +54,14 @@ void naiveMatVec(Type* out, IdxType N, bool rowMajor, bool bcastAlongRows, - Type scalar) + Type scalar, + cudaStream_t stream) { static const IdxType TPB = 64; IdxType len = N * D; IdxType nblks = raft::ceildiv(len, TPB); - naiveMatVecKernel<<>>(out, mat, vec, D, N, rowMajor, bcastAlongRows, scalar); + naiveMatVecKernel + <<>>(out, mat, vec, D, N, rowMajor, bcastAlongRows, scalar); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -98,13 +100,14 @@ void naiveMatVec(Type* out, IdxType N, bool rowMajor, bool bcastAlongRows, - Type scalar) + Type scalar, + cudaStream_t stream) { static const IdxType TPB = 64; IdxType len = N * D; IdxType nblks = raft::ceildiv(len, TPB); naiveMatVecKernel - <<>>(out, mat, vec1, vec2, D, N, rowMajor, bcastAlongRows, scalar); + <<>>(out, mat, vec1, vec2, D, N, rowMajor, bcastAlongRows, scalar); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index a1001f3816..1e11062a87 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -47,11 +47,11 @@ __global__ void nativeSqrtKernel(Type* in, Type* out, int len) } template -void naiveSqrt(Type* in, Type* out, int len) +void naiveSqrt(Type* in, Type* out, int len, cudaStream_t stream) { static const int TPB = 64; int nblks = raft::ceildiv(len, TPB); - nativeSqrtKernel<<>>(in, out, len); + nativeSqrtKernel<<>>(in, out, len); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -86,9 +86,9 @@ __global__ void naiveSignFlipKernel(Type* in, Type* out, int rowCount, int colCo } template -void naiveSignFlip(Type* in, Type* out, int rowCount, int colCount) +void naiveSignFlip(Type* in, Type* out, int rowCount, int colCount, cudaStream_t stream) { - naiveSignFlipKernel<<>>(in, out, rowCount, colCount); + naiveSignFlipKernel<<>>(in, out, rowCount, colCount); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -149,12 +149,13 @@ class MathTest : public ::testing::TestWithParam> { naivePower(in_power.data(), out_power_ref.data(), len, stream); power(in_power.data(), len, stream); - naiveSqrt(in_sqrt.data(), out_sqrt_ref.data(), len); + naiveSqrt(in_sqrt.data(), out_sqrt_ref.data(), len, stream); seqRoot(in_sqrt.data(), len, stream); ratio(handle, in_ratio.data(), in_ratio.data(), 4, stream); - naiveSignFlip(in_sign_flip.data(), out_sign_flip_ref.data(), params.n_row, params.n_col); + naiveSignFlip( + in_sign_flip.data(), out_sign_flip_ref.data(), params.n_row, params.n_col, stream); signFlip(in_sign_flip.data(), params.n_row, params.n_col, stream); // default threshold is 1e-15 @@ -196,43 +197,55 @@ const std::vector> inputsd = {{0.00001, 1024, 1024, 1024 * 10 typedef MathTest MathPowerTestF; TEST_P(MathPowerTestF, Result) { - ASSERT_TRUE(devArrMatch( - in_power.data(), out_power_ref.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch(in_power.data(), + out_power_ref.data(), + params.len, + CompareApprox(params.tolerance), + stream)); } typedef MathTest MathPowerTestD; TEST_P(MathPowerTestD, Result) { - ASSERT_TRUE(devArrMatch( - in_power.data(), out_power_ref.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch(in_power.data(), + out_power_ref.data(), + params.len, + CompareApprox(params.tolerance), + stream)); } typedef MathTest MathSqrtTestF; TEST_P(MathSqrtTestF, Result) { - ASSERT_TRUE(devArrMatch( - in_sqrt.data(), out_sqrt_ref.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch(in_sqrt.data(), + out_sqrt_ref.data(), + params.len, + CompareApprox(params.tolerance), + stream)); } typedef MathTest MathSqrtTestD; TEST_P(MathSqrtTestD, Result) { - ASSERT_TRUE(devArrMatch( - in_sqrt.data(), out_sqrt_ref.data(), params.len, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch(in_sqrt.data(), + out_sqrt_ref.data(), + params.len, + CompareApprox(params.tolerance), + stream)); } typedef MathTest MathRatioTestF; TEST_P(MathRatioTestF, Result) { - ASSERT_TRUE( - devArrMatch(in_ratio.data(), out_ratio_ref.data(), 4, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + in_ratio.data(), out_ratio_ref.data(), 4, CompareApprox(params.tolerance), stream)); } typedef MathTest MathRatioTestD; TEST_P(MathRatioTestD, Result) { - ASSERT_TRUE( - devArrMatch(in_ratio.data(), out_ratio_ref.data(), 4, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + in_ratio.data(), out_ratio_ref.data(), 4, CompareApprox(params.tolerance), stream)); } typedef MathTest MathSignFlipTestF; @@ -241,7 +254,8 @@ TEST_P(MathSignFlipTestF, Result) ASSERT_TRUE(devArrMatch(in_sign_flip.data(), out_sign_flip_ref.data(), params.len, - CompareApprox(params.tolerance))); + CompareApprox(params.tolerance), + stream)); } typedef MathTest MathSignFlipTestD; @@ -250,49 +264,62 @@ TEST_P(MathSignFlipTestD, Result) ASSERT_TRUE(devArrMatch(in_sign_flip.data(), out_sign_flip_ref.data(), params.len, - CompareApprox(params.tolerance))); + CompareApprox(params.tolerance), + stream)); } typedef MathTest MathReciprocalTestF; TEST_P(MathReciprocalTestF, Result) { - ASSERT_TRUE( - devArrMatch(in_recip.data(), in_recip_ref.data(), 4, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + in_recip.data(), in_recip_ref.data(), 4, CompareApprox(params.tolerance), stream)); // 4-th term tests `setzero=true` functionality, not present in this version of `reciprocal`. - ASSERT_TRUE( - devArrMatch(out_recip.data(), in_recip_ref.data(), 3, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_recip.data(), in_recip_ref.data(), 3, CompareApprox(params.tolerance), stream)); } typedef MathTest MathReciprocalTestD; TEST_P(MathReciprocalTestD, Result) { - ASSERT_TRUE( - devArrMatch(in_recip.data(), in_recip_ref.data(), 4, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + in_recip.data(), in_recip_ref.data(), 4, CompareApprox(params.tolerance), stream)); // 4-th term tests `setzero=true` functionality, not present in this version of `reciprocal`. - ASSERT_TRUE( - devArrMatch(out_recip.data(), in_recip_ref.data(), 3, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + out_recip.data(), in_recip_ref.data(), 3, CompareApprox(params.tolerance), stream)); } typedef MathTest MathSetSmallZeroTestF; TEST_P(MathSetSmallZeroTestF, Result) { - ASSERT_TRUE(devArrMatch( - in_smallzero.data(), out_smallzero_ref.data(), 4, CompareApprox(params.tolerance))); - - ASSERT_TRUE(devArrMatch( - out_smallzero.data(), out_smallzero_ref.data(), 4, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch(in_smallzero.data(), + out_smallzero_ref.data(), + 4, + CompareApprox(params.tolerance), + stream)); + + ASSERT_TRUE(devArrMatch(out_smallzero.data(), + out_smallzero_ref.data(), + 4, + CompareApprox(params.tolerance), + stream)); } typedef MathTest MathSetSmallZeroTestD; TEST_P(MathSetSmallZeroTestD, Result) { - ASSERT_TRUE(devArrMatch( - in_smallzero.data(), out_smallzero_ref.data(), 4, CompareApprox(params.tolerance))); - - ASSERT_TRUE(devArrMatch( - out_smallzero.data(), out_smallzero_ref.data(), 4, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch(in_smallzero.data(), + out_smallzero_ref.data(), + 4, + CompareApprox(params.tolerance), + stream)); + + ASSERT_TRUE(devArrMatch(out_smallzero.data(), + out_smallzero_ref.data(), + 4, + CompareApprox(params.tolerance), + stream)); } INSTANTIATE_TEST_SUITE_P(MathTests, MathPowerTestF, ::testing::ValuesIn(inputsf)); diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 696ef2dd08..85bf780112 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -84,7 +84,8 @@ TEST_P(MatrixTestF, Result) ASSERT_TRUE(raft::devArrMatch(in1.data(), in2.data(), params.n_row * params.n_col, - raft::CompareApprox(params.tolerance))); + raft::CompareApprox(params.tolerance), + stream)); } typedef MatrixTest MatrixTestD; @@ -93,7 +94,8 @@ TEST_P(MatrixTestD, Result) ASSERT_TRUE(raft::devArrMatch(in1.data(), in2.data(), params.n_row * params.n_col, - raft::CompareApprox(params.tolerance))); + raft::CompareApprox(params.tolerance), + stream)); } INSTANTIATE_TEST_SUITE_P(MatrixTests, MatrixTestF, ::testing::ValuesIn(inputsf2)); @@ -108,12 +110,11 @@ class MatrixCopyRowsTest : public ::testing::Test { protected: MatrixCopyRowsTest() - : input(n_cols * n_rows, handle.get_stream()), + : stream(handle.get_stream()), + input(n_cols * n_rows, handle.get_stream()), indices(n_selected, handle.get_stream()), output(n_cols * n_selected, handle.get_stream()) { - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - handle.set_stream(stream); raft::update_device(indices.data(), indices_host, n_selected, stream); // Init input array thrust::counting_iterator first(0); @@ -121,17 +122,28 @@ class MatrixCopyRowsTest : public ::testing::Test { thrust::copy(handle.get_thrust_policy(), first, first + n_cols * n_rows, ptr); } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } - void testCopyRows() { - copyRows( - input.data(), n_rows, n_cols, output.data(), indices.data(), n_selected, stream, false); + copyRows(input.data(), + n_rows, + n_cols, + output.data(), + indices.data(), + n_selected, + handle.get_stream(), + false); EXPECT_TRUE(raft::devArrMatchHost( - output_exp_colmajor, output.data(), n_selected * n_cols, raft::Compare())); - copyRows(input.data(), n_rows, n_cols, output.data(), indices.data(), n_selected, stream, true); + output_exp_colmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); + copyRows(input.data(), + n_rows, + n_cols, + output.data(), + indices.data(), + n_selected, + handle.get_stream(), + true); EXPECT_TRUE(raft::devArrMatchHost( - output_exp_rowmajor, output.data(), n_selected * n_cols, raft::Compare())); + output_exp_rowmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); } protected: diff --git a/cpp/test/sparse/add.cu b/cpp/test/sparse/add.cu index e1223b90a3..74f419be23 100644 --- a/cpp/test/sparse/add.cu +++ b/cpp/test/sparse/add.cu @@ -103,7 +103,7 @@ class CSRAddTest : public ::testing::TestWithParam> ASSERT_TRUE(nnz == nnz_result); ASSERT_TRUE(raft::devArrMatch( - ind_verify.data(), ind_result.data(), n_rows, raft::Compare())); + ind_verify.data(), ind_result.data(), n_rows, raft::Compare(), stream)); linalg::csr_add_finalize(ind_a.data(), ind_ptr_a.data(), @@ -120,9 +120,9 @@ class CSRAddTest : public ::testing::TestWithParam> stream); ASSERT_TRUE(raft::devArrMatch( - ind_ptr_verify.data(), ind_ptr_result.data(), nnz, raft::Compare())); + ind_ptr_verify.data(), ind_ptr_result.data(), nnz, raft::Compare(), stream)); ASSERT_TRUE(raft::devArrMatch( - values_verify.data(), values_result.data(), nnz, raft::Compare())); + values_verify.data(), values_result.data(), nnz, raft::Compare(), stream)); } protected: diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 3b69c9240c..d78cc2d026 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -77,7 +77,7 @@ TEST_P(SortedCOOToCSR, Result) convert::sorted_coo_to_csr(in.data(), nnz, out.data(), 4, stream); - ASSERT_TRUE(raft::devArrMatch(out.data(), exp.data(), 4, raft::Compare())); + ASSERT_TRUE(raft::devArrMatch(out.data(), exp.data(), 4, raft::Compare(), stream)); cudaStreamDestroy(stream); diff --git a/cpp/test/sparse/filter.cu b/cpp/test/sparse/filter.cu index 77c66e2133..dc9b2d63ad 100644 --- a/cpp/test/sparse/filter.cu +++ b/cpp/test/sparse/filter.cu @@ -96,6 +96,7 @@ TEST_P(COORemoveZeros, Result) raft::update_device(out_ref.vals(), out_vals_ref_h, 2, stream); op::coo_remove_zeros(&in, &out, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); ASSERT_TRUE(raft::devArrMatch(out_ref.rows(), out.rows(), 2, raft::Compare())); ASSERT_TRUE(raft::devArrMatch(out_ref.cols(), out.cols(), 2, raft::Compare())); diff --git a/cpp/test/sparse/norm.cu b/cpp/test/sparse/norm.cu index be26b6f24b..59c0961699 100644 --- a/cpp/test/sparse/norm.cu +++ b/cpp/test/sparse/norm.cu @@ -73,6 +73,7 @@ class CSRRowNormalizeTest : public ::testing::TestWithParam(verify.data(), result.data(), nnz, raft::Compare())); diff --git a/cpp/test/sparse/reduce.cu b/cpp/test/sparse/reduce.cu index f66cd873d5..41328b5f78 100644 --- a/cpp/test/sparse/reduce.cu +++ b/cpp/test/sparse/reduce.cu @@ -78,7 +78,7 @@ class SparseReduceTest : public ::testing::TestWithParam( out_rows.data(), out.rows(), out.nnz, raft::Compare())); ASSERT_TRUE(raft::devArrMatch( diff --git a/cpp/test/sparse/row_op.cu b/cpp/test/sparse/row_op.cu index e650661c0d..be523bc97f 100644 --- a/cpp/test/sparse/row_op.cu +++ b/cpp/test/sparse/row_op.cu @@ -78,8 +78,8 @@ class CSRRowOpTest : public ::testing::TestWithParam(ex_scan.data(), n_rows, nnz, result.data(), stream); - ASSERT_TRUE( - raft::devArrMatch(verify.data(), result.data(), nnz, raft::Compare())); + ASSERT_TRUE(raft::devArrMatch( + verify.data(), result.data(), nnz, raft::Compare(), stream)); } protected: diff --git a/cpp/test/sparse/sort.cu b/cpp/test/sparse/sort.cu index 0a0864ce15..85ee0fe79b 100644 --- a/cpp/test/sparse/sort.cu +++ b/cpp/test/sparse/sort.cu @@ -78,8 +78,8 @@ TEST_P(COOSort, Result) op::coo_sort( params.m, params.n, params.nnz, in_rows.data(), in_cols.data(), in_vals.data(), stream); - ASSERT_TRUE( - raft::devArrMatch(verify.data(), in_rows.data(), params.nnz, raft::Compare())); + ASSERT_TRUE(raft::devArrMatch( + verify.data(), in_rows.data(), params.nnz, raft::Compare(), stream)); delete[] in_rows_h; delete[] in_cols_h; diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 00f83254c3..73c0f87fdd 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -105,10 +105,10 @@ void compute_bfknn(const raft::handle_t& handle, std::vector input_vec = {const_cast(X1)}; std::vector sizes_vec = {n}; - cudaStream_t* int_streams = nullptr; std::vector* translations = nullptr; - raft::spatial::knn::detail::brute_force_knn_impl(input_vec, + raft::spatial::knn::detail::brute_force_knn_impl(handle, + input_vec, sizes_vec, d, const_cast(X2), @@ -116,9 +116,6 @@ void compute_bfknn(const raft::handle_t& handle, inds, dists, k, - handle.get_stream(), - int_streams, - 0, true, true, translations, @@ -251,13 +248,13 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { ToRadians()); } - cudaStream_t* int_streams = nullptr; std::vector* translations = nullptr; std::vector input_vec = {d_train_inputs.data()}; std::vector sizes_vec = {n}; - raft::spatial::knn::detail::brute_force_knn_impl(input_vec, + raft::spatial::knn::detail::brute_force_knn_impl(handle, + input_vec, sizes_vec, d, d_train_inputs.data(), @@ -265,9 +262,6 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { d_ref_I.data(), d_ref_D.data(), k, - handle.get_stream(), - int_streams, - 0, true, true, translations, diff --git a/cpp/test/spatial/haversine.cu b/cpp/test/spatial/haversine.cu index e268dc0c55..171b698265 100644 --- a/cpp/test/spatial/haversine.cu +++ b/cpp/test/spatial/haversine.cu @@ -121,9 +121,10 @@ typedef HaversineKNNTest HaversineKNNTestF; TEST_F(HaversineKNNTestF, Fit) { + ASSERT_TRUE(raft::devArrMatch( + d_ref_D.data(), d_pred_D.data(), n * n, raft::CompareApprox(1e-3), stream)); ASSERT_TRUE( - raft::devArrMatch(d_ref_D.data(), d_pred_D.data(), n * n, raft::CompareApprox(1e-3))); - ASSERT_TRUE(raft::devArrMatch(d_ref_I.data(), d_pred_I.data(), n * n, raft::Compare())); + raft::devArrMatch(d_ref_I.data(), d_pred_I.data(), n * n, raft::Compare(), stream)); } } // namespace knn diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index 8ab33745f3..2fb9bd2ca5 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -103,7 +103,7 @@ class KNNTest : public ::testing::TestWithParam { expected_labels_.data(), rows_, k_, search_labels_.data()); ASSERT_TRUE(devArrMatch( - expected_labels_.data(), actual_labels_.data(), rows_ * k_, raft::Compare())); + expected_labels_.data(), actual_labels_.data(), rows_ * k_, raft::Compare(), stream)); } void SetUp() override diff --git a/cpp/test/stats/mean_center.cu b/cpp/test/stats/mean_center.cu index 8f2e2ecef1..e14a9062d3 100644 --- a/cpp/test/stats/mean_center.cu +++ b/cpp/test/stats/mean_center.cu @@ -77,7 +77,8 @@ class MeanCenterTest : public ::testing::TestWithParam StdDevTestF; TEST_P(StdDevTestF, Result) { ASSERT_TRUE(devArrMatch( - params.stddev, stddev_act.data(), params.cols, CompareApprox(params.tolerance))); + params.stddev, stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); - ASSERT_TRUE(devArrMatch( - stddev_act.data(), vars_act.data(), params.cols, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch(stddev_act.data(), + vars_act.data(), + params.cols, + CompareApprox(params.tolerance), + stream)); } typedef StdDevTest StdDevTestD; TEST_P(StdDevTestD, Result) { - ASSERT_TRUE(devArrMatch( - params.stddev, stddev_act.data(), params.cols, CompareApprox(params.tolerance))); - - ASSERT_TRUE(devArrMatch( - stddev_act.data(), vars_act.data(), params.cols, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch(params.stddev, + stddev_act.data(), + params.cols, + CompareApprox(params.tolerance), + stream)); + + ASSERT_TRUE(devArrMatch(stddev_act.data(), + vars_act.data(), + params.cols, + CompareApprox(params.tolerance), + stream)); } INSTANTIATE_TEST_SUITE_P(StdDevTests, StdDevTestF, ::testing::ValuesIn(inputsf)); diff --git a/python/raft/common/handle.pxd b/python/raft/common/handle.pxd index 884d81bed1..d2ae0a401d 100644 --- a/python/raft/common/handle.pxd +++ b/python/raft/common/handle.pxd @@ -22,7 +22,10 @@ from libcpp.memory cimport shared_ptr from .cuda cimport _Stream - +from rmm._lib.cuda_stream_view cimport cuda_stream_view +from rmm._lib.cuda_stream_pool cimport cuda_stream_pool +from libcpp.memory cimport shared_ptr +from libcpp.memory cimport unique_ptr cdef extern from "raft/mr/device/allocator.hpp" \ namespace "raft::mr::device" nogil: @@ -32,7 +35,15 @@ cdef extern from "raft/mr/device/allocator.hpp" \ cdef extern from "raft/handle.hpp" namespace "raft" nogil: cdef cppclass handle_t: handle_t() except + - handle_t(int ns) except + - void set_stream(_Stream s) except + - _Stream get_stream() except + - int get_num_internal_streams() except + + handle_t(cuda_stream_view stream_view) except + + handle_t(cuda_stream_view stream_view, + shared_ptr[cuda_stream_pool] stream_pool) except + + void set_device_allocator(shared_ptr[allocator] a) except + + shared_ptr[allocator] get_device_allocator() except + + cuda_stream_view get_stream() except + + void sync_stream() except + + +cdef class Handle: + cdef unique_ptr[handle_t] c_obj + cdef shared_ptr[cuda_stream_pool] stream_pool + cdef int n_streams diff --git a/python/raft/common/handle.pyx b/python/raft/common/handle.pyx index 7198695cb4..1accf9e679 100644 --- a/python/raft/common/handle.pyx +++ b/python/raft/common/handle.pyx @@ -21,6 +21,8 @@ # import raft from libcpp.memory cimport shared_ptr +from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread +from rmm._lib.cuda_stream_view cimport cuda_stream_view from .cuda cimport _Stream, _Error, cudaStreamSynchronize from .cuda import CudaRuntimeError @@ -38,8 +40,7 @@ cdef class Handle: from raft.common import Stream, Handle stream = Stream() - handle = Handle() - handle.setStream(stream) + handle = Handle(stream) # call algos here @@ -50,51 +51,39 @@ cdef class Handle: del handle # optional! """ - # handle_t doesn't have copy operator. So, use pointer for the object - # python world cannot access to this raw object directly, hence use - # 'size_t'! - cdef size_t h - - # not using __dict__ unless we need it to keep this Extension as lean as - # possible - cdef int n_streams - - def __cinit__(self, n_streams=0): + def __cinit__(self, stream=None, n_streams=0): self.n_streams = n_streams - self.h = (new handle_t(n_streams)) - - def __dealloc__(self): - h_ = self.h - del h_ - - def setStream(self, stream): - cdef size_t s = stream.getStream() - cdef handle_t* h_ = self.h - h_.set_stream(<_Stream>s) + if n_streams > 0: + self.stream_pool.reset(new cuda_stream_pool(n_streams)) + + cdef cuda_stream_view c_stream + if stream is None: + # this constructor will construct a "main" handle on + # per-thread default stream, which is non-blocking + self.c_obj.reset(new handle_t(cuda_stream_per_thread, + self.stream_pool)) + else: + # this constructor constructs a handle on user stream + c_stream = cuda_stream_view(<_Stream> stream.getStream()) + self.c_obj.reset(new handle_t(c_stream, + self.stream_pool)) def sync(self): """ Issues a sync on the stream set for this handle. - - Once we make `raft.common.cuda.Stream` as a mandatory option - for creating `raft.common.Handle`, this should go away """ - cdef handle_t* h_ = self.h - cdef _Stream stream = h_.get_stream() - cdef _Error e = cudaStreamSynchronize(stream) - if e != 0: - raise CudaRuntimeError("Stream sync") + self.c_obj.get()[0].sync_stream() def getHandle(self): - return self.h - - def getNumInternalStreams(self): - cdef handle_t* h_ = self.h - return h_.get_num_internal_streams() + return self.c_obj.get() def __getstate__(self): return self.n_streams def __setstate__(self, state): self.n_streams = state - self.h = (new handle_t(self.n_streams)) + if self.n_streams > 0: + self.stream_pool.reset(new cuda_stream_pool(self.n_streams)) + + self.c_obj.reset(new handle_t(cuda_stream_per_thread, + self.stream_pool)) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 27533dfb9a..ee768b41ff 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -509,7 +509,7 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): sessionId=sessionId, state_object=worker ) - handle = Handle(streams_per_handle) + handle = Handle(n_streams=streams_per_handle) nccl_comm = raft_comm_state["nccl"] eps = raft_comm_state["ucp_eps"] nWorkers = raft_comm_state["nworkers"] @@ -546,7 +546,7 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): topic="info", msg="Finished injecting comms on handle." ) - handle = Handle(streams_per_handle) + handle = Handle(n_streams=streams_per_handle) raft_comm_state = get_raft_comm_state( sessionId=sessionId, state_object=worker diff --git a/python/raft/dask/common/comms_utils.pyx b/python/raft/dask/common/comms_utils.pyx index 20f004b1d6..7370085805 100644 --- a/python/raft/dask/common/comms_utils.pyx +++ b/python/raft/dask/common/comms_utils.pyx @@ -225,7 +225,7 @@ def perform_test_comms_device_multicast_sendrecv(handle, n_trials): n_trilas : int Number of test trials """ - cdef const handle_t *h = handle.getHandle() + cdef const handle_t *h = handle.getHandle() return test_pointToPoint_device_multicast_sendrecv(deref(h), n_trials)