From 7e7832d4022c093d58b1ac36e42cb2184c74ae72 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Tue, 8 Feb 2022 22:01:25 +0100 Subject: [PATCH] Interruptible execution (#4463) ### Cooperative-style interruptible C++ threads. This proposal attempts to make cuml experience more responsive by allowing easier way to interrupt/cancel long running cuml tasks. It replaces calls `cudaStreamSynchronize` with `raft::interruptible::synchronize`, which serve as a cancellation points in the algorithms. With a small extra hook on the python side, Ctrl+C requests now can interrupt the execution (almost) immediately. At this moment, I adapted just a few models as a proof-of-concept. Example: ```python import sklearn.datasets import cuml.svm X, y = sklearn.datasets.fetch_olivetti_faces(return_X_y=True) model = cuml.svm.SVC() print("Data loaded; fitting... (try Ctrl+C now)") try: model.fit(X, y) print("Done! Score:", model.score(X, y)) except Exception as e: print("Canceled!") print(e) ``` #### Implementation details https://github.com/rapidsai/raft/pull/433 #### Adoption costs From the changeset in this PR you can see that I introduce two types of changes: 1. Change `cudaStreamSynchronize` to either `handle.sync_thread` or `raft::interruptible::synchronize` 2. Wrap the cython calls with [`cuda_interruptible`](https://github.com/rapidsai/raft/blob/36e8de5f73e9ec7e604b38a4290ac82bc35be4b7/python/raft/common/interruptible.pyx#L28) and `nogil` Change (1) is straightforward and can mostly be automated. Change (2) is a bit more involved. You definitely have to wrap a C++ call with `interruptibleCpp` to make `Ctrl+C` work, but that is also rather simple. The tricky part is adding `nogil`, because you have to make sure there is no python objects within `with nogil` block. However, `nogil` does not seem to be strictly required for the signal handler to successfully interrupt the C++ thread. It worked in my tests without `nogil` as well. Yet, I chose to add `nogil` in the code where possible, because in theory it should reduce the interrupt latency and enable more multithreading. #### Motivation In general, this proposal makes executing threads (and thus algos/models) more controllable. The main use cases I see: 1. Being able to Ctrl+C the running model using signal handlers. 2. Stopping the thread programmatically, e.g. we can create the tests of sort "if running for more than n seconds, stop and fail". Resolves https://github.com/rapidsai/cuml/issues/4384 Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/4463 --- cpp/bench/sg/arima_loglikelihood.cu | 2 +- cpp/bench/sg/fil.cu | 2 +- cpp/bench/sg/rf_classifier.cu | 2 +- cpp/bench/sg/rf_regressor.cu | 2 +- cpp/bench/sg/svc.cu | 2 +- cpp/bench/sg/svr.cu | 4 +- cpp/src/dbscan/adjgraph/naive.cuh | 4 +- cpp/src/dbscan/runner.cuh | 2 +- .../batched-levelalgo/builder.cuh | 2 +- cpp/src/fil/fil.cu | 2 +- cpp/src/fil/treelite_import.cu | 2 +- cpp/src/genetic/genetic.cu | 4 +- cpp/src/glm/ols_mg.cu | 4 +- cpp/src/glm/qn/glm_base.cuh | 2 +- cpp/src/glm/qn/glm_regularizer.cuh | 2 +- cpp/src/glm/qn/glm_softmax.cuh | 2 +- cpp/src/glm/qn/simple_mat/base.hpp | 3 +- cpp/src/glm/qn/simple_mat/dense.hpp | 11 +++--- cpp/src/glm/qn/simple_mat/sparse.hpp | 6 +-- cpp/src/glm/ridge_mg.cu | 4 +- cpp/src/hdbscan/detail/condense.cuh | 2 +- cpp/src/hdbscan/detail/extract.cuh | 4 +- cpp/src/hdbscan/detail/select.cuh | 4 +- cpp/src/kmeans/common.cuh | 6 +-- cpp/src/kmeans/kmeans_mg_impl.cuh | 4 +- cpp/src/kmeans/sg_impl.cuh | 6 +-- cpp/src/knn/knn_opg_common.cuh | 10 ++--- cpp/src/pca/pca_mg.cu | 16 ++++---- cpp/src/pca/sign_flip_mg.cu | 4 +- cpp/src/random_projection/rproj_utils.cuh | 2 +- cpp/src/randomforest/randomforest.cuh | 6 +-- cpp/src/solver/cd.cuh | 2 +- cpp/src/solver/cd_mg.cu | 10 ++--- cpp/src/solver/lars_impl.cuh | 4 +- cpp/src/solver/sgd.cuh | 2 +- cpp/src/svm/kernelcache.cuh | 2 +- cpp/src/svm/results.cuh | 10 ++--- cpp/src/svm/smosolver.cuh | 4 +- cpp/src/svm/svc_impl.cuh | 4 +- cpp/src/svm/workingset.cuh | 4 +- cpp/src/tsa/auto_arima.cuh | 4 +- cpp/src/tsne/exact_kernels.cuh | 2 +- cpp/src/tsne/fft_tsne.cuh | 2 +- cpp/src/tsvd/tsvd.cuh | 2 +- cpp/src/tsvd/tsvd_mg.cu | 14 +++---- cpp/src/umap/fuzzy_simpl_set/naive.cuh | 2 +- cpp/src/umap/optimize.cuh | 4 +- cpp/src_prims/cache/cache.cuh | 3 +- cpp/src_prims/label/classlabels.cuh | 2 +- cpp/src_prims/label/merge_labels.cuh | 5 ++- cpp/src_prims/linalg/lstsq.cuh | 2 +- cpp/src_prims/metrics/dispersion.cuh | 3 +- cpp/src_prims/metrics/entropy.cuh | 2 +- cpp/src_prims/metrics/kl_divergence.cuh | 2 +- cpp/src_prims/metrics/mutual_info_score.cuh | 2 +- cpp/src_prims/metrics/rand_index.cuh | 2 +- cpp/src_prims/metrics/scores.cuh | 4 +- cpp/src_prims/metrics/silhouette_score.cuh | 2 +- cpp/src_prims/random/mvg.cuh | 5 ++- cpp/src_prims/selection/knn.cuh | 2 +- cpp/test/mg/knn.cu | 4 +- cpp/test/mg/knn_test_helper.cuh | 4 +- cpp/test/prims/batched/csr.cu | 2 +- cpp/test/prims/batched/matrix.cu | 10 ++--- cpp/test/prims/cache.cu | 5 ++- cpp/test/prims/contingencyMatrix.cu | 2 +- cpp/test/prims/decoupled_lookback.cu | 3 +- cpp/test/prims/device_utils.cu | 3 +- cpp/test/prims/dispersion.cu | 3 +- cpp/test/prims/entropy.cu | 3 +- cpp/test/prims/fillna.cu | 4 +- cpp/test/prims/gather.cu | 3 +- cpp/test/prims/gram.cu | 2 +- cpp/test/prims/histogram.cu | 3 +- cpp/test/prims/knn_classify.cu | 2 +- cpp/test/prims/knn_regression.cu | 2 +- cpp/test/prims/label.cu | 3 +- cpp/test/prims/linalg_block.cu | 14 +++---- cpp/test/prims/make_arima.cu | 5 ++- cpp/test/prims/permute.cu | 7 ++-- cpp/test/prims/reduce_cols_by_key.cu | 7 ++-- cpp/test/prims/reduce_rows_by_key.cu | 2 +- cpp/test/prims/score.cu | 9 +++-- cpp/test/prims/silhouette_score.cu | 2 +- cpp/test/prims/test_utils.h | 13 ++++--- cpp/test/sg/dbscan_test.cu | 8 ++-- cpp/test/sg/fil_test.cu | 2 +- cpp/test/sg/genetic/program_test.cu | 6 +-- cpp/test/sg/hdbscan_test.cu | 8 ++-- cpp/test/sg/holtwinters_test.cu | 2 +- cpp/test/sg/kmeans_test.cu | 4 +- cpp/test/sg/knn_test.cu | 2 +- cpp/test/sg/linkage_test.cu | 4 +- cpp/test/sg/quasi_newton.cu | 20 +++++----- cpp/test/sg/rproj_test.cu | 4 +- cpp/test/sg/shap_kernel.cu | 2 +- cpp/test/sg/svc_test.cu | 14 +++---- cpp/test/sg/tsne_test.cu | 8 ++-- cpp/test/sg/umap_parametrizable_test.cu | 14 +++---- python/cuml/common/logger.pyx | 36 ++++------------- python/cuml/svm/linear.pyx | 39 ++++++++++++------- python/cuml/svm/svc.pyx | 23 +++++++---- python/cuml/svm/svm_base.pyx | 9 +---- python/cuml/svm/svr.pyx | 4 +- wiki/cpp/DEVELOPER_GUIDE.md | 36 ++++++++--------- 105 files changed, 304 insertions(+), 294 deletions(-) diff --git a/cpp/bench/sg/arima_loglikelihood.cu b/cpp/bench/sg/arima_loglikelihood.cu index 9cf24576af..119eb5027a 100644 --- a/cpp/bench/sg/arima_loglikelihood.cu +++ b/cpp/bench/sg/arima_loglikelihood.cu @@ -68,7 +68,7 @@ class ArimaLoglikelihood : public TsFixtureRandom { counting + this->params.batch_size, [=] __device__(int bid) { x[(bid + 1) * N - 1] = 1.0; }); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); // Benchmark loop this->loopOnState(state, [this]() { diff --git a/cpp/bench/sg/fil.cu b/cpp/bench/sg/fil.cu index 707d9d84d8..8128276d3e 100644 --- a/cpp/bench/sg/fil.cu +++ b/cpp/bench/sg/fil.cu @@ -79,7 +79,7 @@ class FIL : public RegressionFixture { auto* mPtr = &rf_model; size_t train_nrows = std::min(params.nrows, 1000); fit(*handle, mPtr, data.X.data(), train_nrows, params.ncols, data.y.data(), p_rest.rf); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle->sync_stream(stream); ML::build_treelite_forest(&model, &rf_model, params.ncols); ML::fil::treelite_params_t tl_params = { diff --git a/cpp/bench/sg/rf_classifier.cu b/cpp/bench/sg/rf_classifier.cu index 22e1441ec3..0141bb2798 100644 --- a/cpp/bench/sg/rf_classifier.cu +++ b/cpp/bench/sg/rf_classifier.cu @@ -68,7 +68,7 @@ class RFClassifier : public BlobsFixture { this->data.y.data(), this->params.nclasses, rfParams); - RAFT_CUDA_TRY(cudaStreamSynchronize(this->stream)); + this->handle->sync_stream(this->stream); }); } diff --git a/cpp/bench/sg/rf_regressor.cu b/cpp/bench/sg/rf_regressor.cu index 9e0fca8c3a..2985d7fcf6 100644 --- a/cpp/bench/sg/rf_regressor.cu +++ b/cpp/bench/sg/rf_regressor.cu @@ -67,7 +67,7 @@ class RFRegressor : public RegressionFixture { this->params.ncols, this->data.y, rfParams); - RAFT_CUDA_TRY(cudaStreamSynchronize(this->stream)); + handle->sync_stream(this->stream); }); } diff --git a/cpp/bench/sg/svc.cu b/cpp/bench/sg/svc.cu index c280acf04c..ea4037a822 100644 --- a/cpp/bench/sg/svc.cu +++ b/cpp/bench/sg/svc.cu @@ -68,7 +68,7 @@ class SVC : public BlobsFixture { this->kernel, this->model, static_cast(nullptr)); - RAFT_CUDA_TRY(cudaStreamSynchronize(this->stream)); + this->handle->sync_stream(this->stream); ML::SVM::svmFreeBuffers(*this->handle, this->model); }); } diff --git a/cpp/bench/sg/svr.cu b/cpp/bench/sg/svr.cu index 7698c27ad3..71b0b123b8 100644 --- a/cpp/bench/sg/svr.cu +++ b/cpp/bench/sg/svr.cu @@ -67,7 +67,7 @@ class SVR : public RegressionFixture { this->svm_param, this->kernel, *(this->model)); - RAFT_CUDA_TRY(cudaStreamSynchronize(this->stream)); + this->handle->sync_stream(this->stream); ML::SVM::svmFreeBuffers(*this->handle, *(this->model)); }); } @@ -130,4 +130,4 @@ ML_BENCH_REGISTER(SvrParams, SVR, "regression", getInputs host_ex_scan(batch_size); raft::update_host((bool*)host_adj.data(), data.adj, batch_size * N, stream); raft::update_host(host_vd.data(), data.vd, batch_size + 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); size_t adjgraph_size = size_t(host_vd[batch_size]); ML::pinned_host_vector host_adj_graph(adjgraph_size); for (Index_ i = 0; i < batch_size; i++) { @@ -62,4 +62,4 @@ void launcher(const raft::handle_t& handle, } // namespace Naive } // namespace AdjGraph } // namespace Dbscan -} // namespace ML \ No newline at end of file +} // namespace ML diff --git a/cpp/src/dbscan/runner.cuh b/cpp/src/dbscan/runner.cuh index 3e0cbd722f..e3e8dcd8aa 100644 --- a/cpp/src/dbscan/runner.cuh +++ b/cpp/src/dbscan/runner.cuh @@ -223,7 +223,7 @@ std::size_t run(const raft::handle_t& handle, raft::common::nvtx::pop_range(); } raft::update_host(&curradjlen, vd + n_points, 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); CUML_LOG_DEBUG("--> Computing adjacency graph with %ld nnz.", (unsigned long)curradjlen); raft::common::nvtx::push_range("Trace::Dbscan::AdjGraph"); diff --git a/cpp/src/decisiontree/batched-levelalgo/builder.cuh b/cpp/src/decisiontree/batched-levelalgo/builder.cuh index edf3a7bf7a..f51ad51d54 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder.cuh @@ -410,7 +410,7 @@ struct Builder { RAFT_CUDA_TRY(cudaGetLastError()); raft::common::nvtx::pop_range(); raft::update_host(h_splits, splits, work_items.size(), builder_stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(builder_stream)); + handle.sync_stream(builder_stream); return std::make_tuple(h_splits, work_items.size()); } diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index f00e310a6b..6617ec30eb 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -411,7 +411,7 @@ struct dense_forest : forest { dispatch_on_fil_template_params(opt_into_arch_dependent_shmem(max_shm_), static_cast(class_ssp_)); // copy must be finished before freeing the host data - RAFT_CUDA_TRY(cudaStreamSynchronize(h.get_stream())); + h.sync_stream(); h_nodes_.clear(); h_nodes_.shrink_to_fit(); } diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 8a3d2121cd..54b0c80750 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -636,7 +636,7 @@ struct tl2fil_t { handle, pforest, cat_sets_.accessor(), vector_leaf_, roots_.data(), nodes_.data(), ¶ms_); // sync is necessary as nodes_ are used in init(), // but destructed at the end of this function - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); if (tl_params_.pforest_shape_str) { *tl_params_.pforest_shape_str = sprintf_shape(model_, nodes_, roots_, cat_sets_); } diff --git a/cpp/src/genetic/genetic.cu b/cpp/src/genetic/genetic.cu index ece6c6d81c..fb46efa745 100644 --- a/cpp/src/genetic/genetic.cu +++ b/cpp/src/genetic/genetic.cu @@ -191,7 +191,7 @@ void parallel_evolve(const raft::handle_t& h, RAFT_CUDA_TRY(cudaPeekAtLastError()); // Make sure tournaments have finished running before copying win indices - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + h.sync_stream(stream); // Perform host mutations @@ -242,7 +242,7 @@ void parallel_evolve(const raft::handle_t& h, } // Make sure all copying is done - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + h.sync_stream(stream); // Update raw fitness for all programs set_batched_fitness( diff --git a/cpp/src/glm/ols_mg.cu b/cpp/src/glm/ols_mg.cu index 325566908e..3b16f8e4c6 100644 --- a/cpp/src/glm/ols_mg.cu +++ b/cpp/src/glm/ols_mg.cu @@ -153,7 +153,7 @@ void fit_impl(raft::handle_t& handle, verbose); for (int i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (int i = 0; i < n_streams; i++) { @@ -227,7 +227,7 @@ void predict_impl(raft::handle_t& handle, handle, input_data, input_desc, coef, intercept, preds_data, streams, n_streams, verbose); for (int i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (int i = 0; i < n_streams; i++) { diff --git a/cpp/src/glm/qn/glm_base.cuh b/cpp/src/glm/qn/glm_base.cuh index 126855637e..cfe975b48c 100644 --- a/cpp/src/glm/qn/glm_base.cuh +++ b/cpp/src/glm/qn/glm_base.cuh @@ -219,7 +219,7 @@ struct GLMWithData : GLMDims { objective->loss_grad(dev_scalar, G, W, *X, *y, *Z, stream); T loss_host; raft::update_host(&loss_host, dev_scalar, 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); return loss_host; } }; diff --git a/cpp/src/glm/qn/glm_regularizer.cuh b/cpp/src/glm/qn/glm_regularizer.cuh index 9e4aa7067b..4650205fc2 100644 --- a/cpp/src/glm/qn/glm_regularizer.cuh +++ b/cpp/src/glm/qn/glm_regularizer.cuh @@ -81,7 +81,7 @@ struct RegularizedGLM : GLMDims { loss->loss_grad(lossVal.data, G, W, Xb, yb, Zb, stream, false); raft::update_host(&loss_host, lossVal.data, 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); lossVal.fill(loss_host + reg_host, stream); } diff --git a/cpp/src/glm/qn/glm_softmax.cuh b/cpp/src/glm/qn/glm_softmax.cuh index 91a18f15b5..f0f3835403 100644 --- a/cpp/src/glm/qn/glm_softmax.cuh +++ b/cpp/src/glm/qn/glm_softmax.cuh @@ -152,7 +152,7 @@ void launchLogsoftmax( T* loss_val, T* dldZ, const T* Z, const T* labels, int C, int N, cudaStream_t stream) { RAFT_CUDA_TRY(cudaMemsetAsync(loss_val, 0, sizeof(T), stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); if (C <= 4) { dim3 bs(4, 64); dim3 gs(ceildiv(N, 64)); diff --git a/cpp/src/glm/qn/simple_mat/base.hpp b/cpp/src/glm/qn/simple_mat/base.hpp index 8bbe0b7ac8..e72119a310 100644 --- a/cpp/src/glm/qn/simple_mat/base.hpp +++ b/cpp/src/glm/qn/simple_mat/base.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #include #include +#include namespace ML { diff --git a/cpp/src/glm/qn/simple_mat/dense.hpp b/cpp/src/glm/qn/simple_mat/dense.hpp index efd6de68a5..df37f573ad 100644 --- a/cpp/src/glm/qn/simple_mat/dense.hpp +++ b/cpp/src/glm/qn/simple_mat/dense.hpp @@ -289,7 +289,8 @@ inline T dot(const SimpleVec& u, const SimpleVec& v, T* tmp_dev, cudaStrea raft::linalg::mapThenSumReduce(tmp_dev, u.len, f, stream, u.data, v.data); T tmp_host; raft::update_host(&tmp_host, tmp_dev, 1, stream); - cudaStreamSynchronize(stream); + + raft::interruptible::synchronize(stream); return tmp_host; } @@ -307,7 +308,7 @@ inline T nrmMax(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data); T tmp_host; raft::update_host(&tmp_host, tmp_dev, 1, stream); - cudaStreamSynchronize(stream); + raft::interruptible::synchronize(stream); return tmp_host; } @@ -324,7 +325,7 @@ inline T nrm1(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) tmp_dev, u.data, u.len, 1, raft::linalg::L1Norm, true, stream, raft::Nop()); T tmp_host; raft::update_host(&tmp_host, tmp_dev, 1, stream); - cudaStreamSynchronize(stream); + raft::interruptible::synchronize(stream); return tmp_host; } @@ -333,7 +334,7 @@ std::ostream& operator<<(std::ostream& os, const SimpleVec& v) { std::vector out(v.len); raft::update_host(&out[0], v.data, v.len, 0); - RAFT_CUDA_TRY(cudaStreamSynchronize(0)); + raft::interruptible::synchronize(rmm::cuda_stream_view()); int it = 0; for (; it < v.len - 1;) { os << out[it] << " "; @@ -349,7 +350,7 @@ std::ostream& operator<<(std::ostream& os, const SimpleDenseMat& mat) os << "ord=" << (mat.ord == COL_MAJOR ? "CM" : "RM") << "\n"; std::vector out(mat.len); raft::update_host(&out[0], mat.data, mat.len, rmm::cuda_stream_default); - RAFT_CUDA_TRY(cudaStreamSynchronize(0)); + raft::interruptible::synchronize(rmm::cuda_stream_view()); if (mat.ord == COL_MAJOR) { for (int r = 0; r < mat.m; r++) { int idx = r; diff --git a/cpp/src/glm/qn/simple_mat/sparse.hpp b/cpp/src/glm/qn/simple_mat/sparse.hpp index 0cfa750338..29ad5c8f5c 100644 --- a/cpp/src/glm/qn/simple_mat/sparse.hpp +++ b/cpp/src/glm/qn/simple_mat/sparse.hpp @@ -144,7 +144,7 @@ struct SimpleSparseMat : SimpleMat { &bufferSize, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); rmm::device_uvector tmp(bufferSize, stream); RAFT_CUSPARSE_TRY(raft::sparse::cusparsespmm(handle.get_cusparse_handle(), @@ -170,7 +170,7 @@ inline void check_csr(const SimpleSparseMat& mat, cudaStream_t stream) { int row_ids_nnz; raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); ASSERT(row_ids_nnz == mat.nnz, "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and " "the last element must be equal nnz."); @@ -188,7 +188,7 @@ std::ostream& operator<<(std::ostream& os, const SimpleSparseMat& mat) raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default); raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default); raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default); - RAFT_CUDA_TRY(cudaStreamSynchronize(0)); + raft::interruptible::synchronize(rmm::cuda_stream_view()); int i, row_end = 0; for (int row = 0; row < mat.m; row++) { diff --git a/cpp/src/glm/ridge_mg.cu b/cpp/src/glm/ridge_mg.cu index 3710eef28b..84b062c7b0 100644 --- a/cpp/src/glm/ridge_mg.cu +++ b/cpp/src/glm/ridge_mg.cu @@ -267,7 +267,7 @@ void fit_impl(raft::handle_t& handle, verbose); for (int i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (int i = 0; i < n_streams; i++) { @@ -341,7 +341,7 @@ void predict_impl(raft::handle_t& handle, handle, input_data, input_desc, coef, intercept, preds_data, streams, n_streams, verbose); for (int i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (int i = 0; i < n_streams; i++) { diff --git a/cpp/src/hdbscan/detail/condense.cuh b/cpp/src/hdbscan/detail/condense.cuh index d7b913a7f2..ccedd9c1b2 100644 --- a/cpp/src/hdbscan/detail/condense.cuh +++ b/cpp/src/hdbscan/detail/condense.cuh @@ -147,7 +147,7 @@ void build_condensed_hierarchy(const raft::handle_t& handle, n_elements_to_traverse = thrust::reduce(exec_policy, frontier.data(), frontier.data() + root + 1, 0); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } condensed_tree.condense(out_parent.data(), out_child.data(), out_lambda.data(), out_size.data()); diff --git a/cpp/src/hdbscan/detail/extract.cuh b/cpp/src/hdbscan/detail/extract.cuh index 4a9429e98b..7bb0578a0b 100644 --- a/cpp/src/hdbscan/detail/extract.cuh +++ b/cpp/src/hdbscan/detail/extract.cuh @@ -119,7 +119,7 @@ void do_labelling_on_host(const raft::handle_t& handle, raft::update_host( lambda_h.data(), condensed_tree.get_lambdas(), condensed_tree.get_n_edges(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); auto parents = thrust::device_pointer_cast(condensed_tree.get_parents()); auto thrust_policy = handle.get_thrust_policy(); @@ -230,7 +230,7 @@ value_idx extract_clusters(const raft::handle_t& handle, std::vector is_cluster_h(is_cluster.size()); raft::update_host(is_cluster_h.data(), is_cluster.data(), is_cluster_h.size(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); std::set clusters; for (std::size_t i = 0; i < is_cluster_h.size(); i++) { diff --git a/cpp/src/hdbscan/detail/select.cuh b/cpp/src/hdbscan/detail/select.cuh index 4aeb3a5640..ef5b097bf7 100644 --- a/cpp/src/hdbscan/detail/select.cuh +++ b/cpp/src/hdbscan/detail/select.cuh @@ -93,7 +93,7 @@ void perform_bfs(const raft::handle_t& handle, thrust::fill(thrust_policy, next_frontier.begin(), next_frontier.end(), 0); n_elements_to_traverse = thrust::reduce(thrust_policy, frontier, frontier + n_clusters, 0); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } } @@ -200,7 +200,7 @@ void excess_of_mass(const raft::handle_t& handle, std::vector indptr_h(indptr.size(), 0); if (cluster_tree_edges > 0) raft::update_host(indptr_h.data(), indptr.data(), indptr.size(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); // Loop through stabilities in "reverse topological order" (e.g. reverse sorted order) value_idx tree_top = allow_single_cluster ? 0 : 1; diff --git a/cpp/src/kmeans/common.cuh b/cpp/src/kmeans/common.cuh index 811e155439..5ff7c750ab 100644 --- a/cpp/src/kmeans/common.cuh +++ b/cpp/src/kmeans/common.cuh @@ -233,7 +233,7 @@ Tensor sampleCentroids(const raft::handle_t& handle, int nPtsSampledInRank = 0; raft::copy(&nPtsSampledInRank, nSelected.data(), nSelected.numElements(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); int* rawPtr_isSampleCentroid = isSampleCentroid.data(); thrust::for_each_n(handle.get_thrust_policy(), @@ -769,7 +769,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, // Choose 'n_trials' centroid candidates from X with probability proportional to the squared // distance to the nearest existing cluster raft::copy(h_wt.data(), minClusterDistance.data(), minClusterDistance.numElements(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); // Note - n_trials is relative small here, we don't need MLCommon::gather call std::discrete_distribution<> d(h_wt.begin(), h_wt.end()); @@ -880,7 +880,7 @@ void checkWeights(const raft::handle_t& handle, DataT wt_sum = 0; raft::copy(&wt_sum, wt_aggr.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); if (wt_sum != n_samples) { LOG(handle, diff --git a/cpp/src/kmeans/kmeans_mg_impl.cuh b/cpp/src/kmeans/kmeans_mg_impl.cuh index df4867bdf3..d6f739018e 100644 --- a/cpp/src/kmeans/kmeans_mg_impl.cuh +++ b/cpp/src/kmeans/kmeans_mg_impl.cuh @@ -426,7 +426,7 @@ void checkWeights(const raft::handle_t& handle, raft::comms::op_t::SUM, stream); DataT wt_sum = wt_aggr.value(stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); if (wt_sum != n_samples) { LOG(handle, @@ -662,7 +662,7 @@ void fit(const raft::handle_t& handle, priorClusteringCost = curClusteringCost; } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); if (sqrdNormError < params.tol) done = true; if (done) { diff --git a/cpp/src/kmeans/sg_impl.cuh b/cpp/src/kmeans/sg_impl.cuh index be5dde69dd..80d9f51d35 100644 --- a/cpp/src/kmeans/sg_impl.cuh +++ b/cpp/src/kmeans/sg_impl.cuh @@ -232,7 +232,7 @@ void fit(const raft::handle_t& handle, DataT curClusteringCost = 0; raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); ASSERT(curClusteringCost != (DataT)0.0, "Too few points and centriods being found is getting 0 cost from " "centers"); @@ -244,7 +244,7 @@ void fit(const raft::handle_t& handle, priorClusteringCost = curClusteringCost; } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); if (sqrdNormError < params.tol) done = true; if (done) { @@ -425,7 +425,7 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, // <<< End of Step-2 >>> // Scalable kmeans++ paper claims 8 rounds is sufficient - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); int niter = std::min(8, (int)ceil(log(psi))); LOG(handle, "KMeans||: psi = %g, log(psi) = %g, niter = %d ", psi, log(psi), niter); diff --git a/cpp/src/knn/knn_opg_common.cuh b/cpp/src/knn/knn_opg_common.cuh index 0ea333993d..bc231a5cea 100644 --- a/cpp/src/knn/knn_opg_common.cuh +++ b/cpp/src/knn/knn_opg_common.cuh @@ -458,7 +458,7 @@ void perform_local_knn(opg_knn_param& params, params.rowMajorQuery, &start_indices_long, raft::distance::DistanceType::L2SqrtExpanded); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -538,7 +538,7 @@ void copy_label_outputs_from_index_parts(opg_knn_param& params, handle.get_stream()); } } - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); break; } i++; @@ -733,7 +733,7 @@ void reduce(opg_knn_param& params, params.k, handle.get_stream(), trans.data()); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); RAFT_CUDA_TRY(cudaPeekAtLastError()); if (params.knn_op != knn_operation::knn) { @@ -767,7 +767,7 @@ void reduce(opg_knn_param& params, perform_local_operation( params, work, handle, outputs, probas_with_offsets, merged_outputs_b.data(), batch_size); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } } diff --git a/cpp/src/pca/pca_mg.cu b/cpp/src/pca/pca_mg.cu index 87b7fee68d..9b88016d00 100644 --- a/cpp/src/pca/pca_mg.cu +++ b/cpp/src/pca/pca_mg.cu @@ -129,7 +129,7 @@ void fit_impl(raft::handle_t& handle, n_streams, verbose); for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } } else if (prms.algorithm == mg_solver::QR) { const raft::handle_t& h = handle; @@ -141,7 +141,7 @@ void fit_impl(raft::handle_t& handle, Stats::opg::mean(handle, mu_data, input_data, input_desc, streams, n_streams); Stats::opg::mean_center(input_data, input_desc, mu_data, comm, streams, n_streams); for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } // Allocate Q, S and V and call QR @@ -205,7 +205,7 @@ void fit_impl(raft::handle_t& handle, } for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (std::uint32_t i = 0; i < n_streams; i++) { @@ -283,7 +283,7 @@ void transform_impl(raft::handle_t& handle, } for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } } @@ -342,7 +342,7 @@ void transform_impl(raft::handle_t& handle, verbose); for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (std::uint32_t i = 0; i < n_streams; i++) { @@ -411,7 +411,7 @@ void inverse_transform_impl(raft::handle_t& handle, } for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } } @@ -468,7 +468,7 @@ void inverse_transform_impl(raft::handle_t& handle, verbose); for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (std::uint32_t i = 0; i < n_streams; i++) { @@ -550,7 +550,7 @@ void fit_transform_impl(raft::handle_t& handle, sign_flip(handle, trans_data, input_desc, components, prms.n_components, streams, n_streams); for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (std::uint32_t i = 0; i < n_streams; i++) { diff --git a/cpp/src/pca/sign_flip_mg.cu b/cpp/src/pca/sign_flip_mg.cu index 5653c0a933..b7d5a0498d 100644 --- a/cpp/src/pca/sign_flip_mg.cu +++ b/cpp/src/pca/sign_flip_mg.cu @@ -148,7 +148,7 @@ void sign_flip_imp(raft::handle_t& handle, } for (std::uint32_t i = 0; i < n_stream; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } findMaxAbsOfColumns( @@ -166,7 +166,7 @@ void sign_flip_imp(raft::handle_t& handle, } for (std::uint32_t i = 0; i < n_stream; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } flip(components, input_desc.N, n_components, max_vals.data(), streams[0]); diff --git a/cpp/src/random_projection/rproj_utils.cuh b/cpp/src/random_projection/rproj_utils.cuh index aaf8173580..b4b1fc3fa4 100644 --- a/cpp/src/random_projection/rproj_utils.cuh +++ b/cpp/src/random_projection/rproj_utils.cuh @@ -86,7 +86,7 @@ inline size_t binomial(const raft::handle_t& h, size_t n, double p, int random_s int ret = 0; raft::update_host(&ret, successes.data(), 1, h.get_stream()); - cudaStreamSynchronize(h.get_stream()); + h.sync_stream(); RAFT_CUDA_TRY(cudaPeekAtLastError()); return n - ret; diff --git a/cpp/src/randomforest/randomforest.cuh b/cpp/src/randomforest/randomforest.cuh index 39782dcb31..ab45bcb1f0 100644 --- a/cpp/src/randomforest/randomforest.cuh +++ b/cpp/src/randomforest/randomforest.cuh @@ -190,7 +190,7 @@ class RandomForest { } // Cleanup handle.sync_stream_pool(); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(); } /** @@ -217,7 +217,7 @@ class RandomForest { std::vector h_input(n_rows * n_cols); raft::update_host(h_input.data(), input, n_rows * n_cols, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + user_handle.sync_stream(stream); int row_size = n_cols; @@ -254,7 +254,7 @@ class RandomForest { } raft::update_device(predictions, h_predictions.data(), n_rows, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + user_handle.sync_stream(stream); } /** diff --git a/cpp/src/solver/cd.cuh b/cpp/src/solver/cd.cuh index b534358273..543688dbce 100644 --- a/cpp/src/solver/cd.cuh +++ b/cpp/src/solver/cd.cuh @@ -182,7 +182,7 @@ void cdFit(const raft::handle_t& handle, coef_prev = h_coef[ci]; raft::update_host(&(h_coef[ci]), coef_loc, 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); math_t diff = abs(coef_prev - h_coef[ci]); diff --git a/cpp/src/solver/cd_mg.cu b/cpp/src/solver/cd_mg.cu index 837d0d83ba..d7b60af948 100644 --- a/cpp/src/solver/cd_mg.cu +++ b/cpp/src/solver/cd_mg.cu @@ -192,7 +192,7 @@ void fit_impl(raft::handle_t& handle, } for (int k = 0; k < n_streams; k++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[k])); + handle.sync_stream(streams[k]); } coef_loc_data.ptr = coef_loc; @@ -206,7 +206,7 @@ void fit_impl(raft::handle_t& handle, coef_prev = h_coef[ci]; raft::update_host(&(h_coef[ci]), coef_loc, 1, streams[0]); - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[0])); + handle.sync_stream(streams[0]); T diff = abs(coef_prev - h_coef[ci]); @@ -231,7 +231,7 @@ void fit_impl(raft::handle_t& handle, } for (int k = 0; k < n_streams; k++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[k])); + handle.sync_stream(streams[k]); } } @@ -330,7 +330,7 @@ void fit_impl(raft::handle_t& handle, verbose); for (int i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (int i = 0; i < n_streams; i++) { @@ -405,7 +405,7 @@ void predict_impl(raft::handle_t& handle, handle, input_data, input_desc, coef, intercept, preds_data, streams, n_streams, verbose); for (int i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (int i = 0; i < n_streams; i++) { diff --git a/cpp/src/solver/lars_impl.cuh b/cpp/src/solver/lars_impl.cuh index 09646c64a8..49dfd04152 100644 --- a/cpp/src/solver/lars_impl.cuh +++ b/cpp/src/solver/lars_impl.cuh @@ -90,7 +90,7 @@ LarsFitStatus selectMostCorrelated(idx_t n_active, thrust::device_ptr ptr(workspace.data() + n_active - start); auto max_ptr = thrust::max_element(thrust::cuda::par.on(stream), ptr, ptr + n - n_active); raft::update_host(cj, max_ptr.get(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); *max_idx = n_active + (max_ptr - ptr); // the index of the maximum element @@ -505,7 +505,7 @@ LarsFitStatus calcEquiangularVec(const raft::handle_t& handle, raft::update_host(&ws_host, ws, 1, stream); math_t diag_host; // U[n_active-1, n_active-1] raft::update_host(&diag_host, U + ld_U * (n_active - 1) + n_active - 1, 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); if (diag_host < 1e-7) { CUML_LOG_WARN( "Vanising diagonal in Cholesky factorization (%e). This indicates " diff --git a/cpp/src/solver/sgd.cuh b/cpp/src/solver/sgd.cuh index a3b77fbdd4..df25f66c50 100644 --- a/cpp/src/solver/sgd.cuh +++ b/cpp/src/solver/sgd.cuh @@ -273,7 +273,7 @@ void sgdFit(const raft::handle_t& handle, } raft::update_host(&curr_loss_value, loss_value.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); if (i > 0) { if (curr_loss_value > (prev_loss_value - tol)) { diff --git a/cpp/src/svm/kernelcache.cuh b/cpp/src/svm/kernelcache.cuh index d1a011df49..b2900f4674 100644 --- a/cpp/src/svm/kernelcache.cuh +++ b/cpp/src/svm/kernelcache.cuh @@ -378,7 +378,7 @@ class KernelCache { n_ws, stream); raft::update_host(n_unique, d_num_selected_out.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } }; diff --git a/cpp/src/svm/results.cuh b/cpp/src/svm/results.cuh index 85123d224c..67a6f87f19 100644 --- a/cpp/src/svm/results.cuh +++ b/cpp/src/svm/results.cuh @@ -126,7 +126,7 @@ class Results { *x_support = nullptr; } // Make sure that all pending GPU calculations finished before we return - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } /** @@ -189,7 +189,7 @@ class Results { *n_support = SelectByCoef(val_tmp, n_rows, val_tmp, select_op, val_selected.data()); *dual_coefs = (math_t*)rmm_alloc->allocate(*n_support * sizeof(math_t), stream); raft::copy(*dual_coefs, val_selected.data(), *n_support, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } /** @@ -272,7 +272,7 @@ class Results { cub_storage.data(), cub_bytes, val, flag.data(), out, d_num_selected.data(), n, stream); int n_selected; raft::update_host(&n_selected, d_num_selected.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); return n_selected; } @@ -350,7 +350,7 @@ class Results { cub_storage.data(), cub_bytes, val, flag.data(), out, d_num_selected.data(), n, stream); int n_selected; raft::update_host(&n_selected, d_num_selected.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); return n_selected; } @@ -377,7 +377,7 @@ class Results { stream); int n_selected; raft::update_host(&n_selected, d_num_selected.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); math_t res = 0; ASSERT(n_selected > 0, "Incorrect training: cannot calculate the constant in the decision " diff --git a/cpp/src/svm/smosolver.cuh b/cpp/src/svm/smosolver.cuh index f8b52c73a2..a6bc8a6968 100644 --- a/cpp/src/svm/smosolver.cuh +++ b/cpp/src/svm/smosolver.cuh @@ -178,7 +178,7 @@ class SmoSolver { UpdateF(f.data(), n_rows, delta_alpha.data(), cache.GetUniqueSize(), cacheTile); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); math_t diff = host_return_buff[0]; keep_going = CheckStoppingCondition(diff); @@ -509,4 +509,4 @@ class SmoSolver { }; }; // end namespace SVM -}; // end namespace ML \ No newline at end of file +}; // end namespace ML diff --git a/cpp/src/svm/svc_impl.cuh b/cpp/src/svm/svc_impl.cuh index 548f860727..89fcc44aa8 100644 --- a/cpp/src/svm/svc_impl.cuh +++ b/cpp/src/svm/svc_impl.cuh @@ -70,7 +70,7 @@ void svcFit(const raft::handle_t& handle, rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource(); model.unique_labels = (math_t*)rmm_alloc->allocate(model.n_classes * sizeof(math_t), stream); raft::copy(model.unique_labels, unique_labels.data(), model.n_classes, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle_impl.sync_stream(stream); } ASSERT(model.n_classes == 2, "Only binary classification is implemented at the moment"); @@ -208,7 +208,7 @@ void svcPredict(const raft::handle_t& handle, raft::linalg::unaryOp( preds, y.data(), n_rows, [b] __device__(math_t y) { return y + b; }, stream); } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle_impl.sync_stream(stream); delete kernel; } diff --git a/cpp/src/svm/workingset.cuh b/cpp/src/svm/workingset.cuh index ba71677952..6b112ffd86 100644 --- a/cpp/src/svm/workingset.cuh +++ b/cpp/src/svm/workingset.cuh @@ -440,7 +440,7 @@ class WorkingSet { d_num_selected.data(), n_train); int n_selected = d_num_selected.value(stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); // Copy to output int n_copy = n_selected > n_needed ? n_needed : n_selected; @@ -487,7 +487,7 @@ class WorkingSet { n_ws, op); int n_selected = d_num_selected.value(stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); int n_copy = n_selected < n_needed ? n_selected : n_needed; raft::copy(idx.data() + n_already_selected, ws_idx_selected.data(), n_copy, stream); return n_copy; diff --git a/cpp/src/tsa/auto_arima.cuh b/cpp/src/tsa/auto_arima.cuh index 71da2d991b..6608a0678d 100644 --- a/cpp/src/tsa/auto_arima.cuh +++ b/cpp/src/tsa/auto_arima.cuh @@ -105,7 +105,7 @@ inline int divide_by_mask_build_index(const bool* d_mask, // Compute and return the number of true elements in the mask int true_elements; raft::update_host(&true_elements, index1.data() + batch_size - 1, 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); return true_elements; } @@ -246,7 +246,7 @@ inline void divide_by_min_build_index(const DataT* d_matrix, d_size[j] = d_cumul[(j + 1) * batch_size - 1]; }); raft::update_host(h_size, d_size, n_sub, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } /** diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index d744c4a060..6ad3da2120 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -159,7 +159,7 @@ void perplexity_search(const value_t* restrict distances, sigmas_kernel<<>>( distances, P, perplexity, desired_entropy, epochs, tol, n, dim); RAFT_CUDA_TRY(cudaPeekAtLastError()); - cudaStreamSynchronize(stream); + handle.sync_stream(stream); } /****************************************/ diff --git a/cpp/src/tsne/fft_tsne.cuh b/cpp/src/tsne/fft_tsne.cuh index c9ed35d78a..2f5e80d10f 100644 --- a/cpp/src/tsne/fft_tsne.cuh +++ b/cpp/src/tsne/fft_tsne.cuh @@ -135,7 +135,7 @@ std::pair min_max(const value_t* Y, const value_idx n, cudaStr min_h = min_d.value(stream); max_h = max_d.value(stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); return std::make_pair(std::move(min_h), std::move(max_h)); } diff --git a/cpp/src/tsvd/tsvd.cuh b/cpp/src/tsvd/tsvd.cuh index f452fd613f..c571aebc84 100644 --- a/cpp/src/tsvd/tsvd.cuh +++ b/cpp/src/tsvd/tsvd.cuh @@ -289,7 +289,7 @@ void tsvdFitTransform(const raft::handle_t& handle, math_t total_vars_h; raft::update_host(&total_vars_h, total_vars.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); math_t scalar = math_t(1) / total_vars_h; raft::linalg::scalarMultiply( diff --git a/cpp/src/tsvd/tsvd_mg.cu b/cpp/src/tsvd/tsvd_mg.cu index 56c02d4d27..43035389e8 100644 --- a/cpp/src/tsvd/tsvd_mg.cu +++ b/cpp/src/tsvd/tsvd_mg.cu @@ -114,7 +114,7 @@ void fit_impl(raft::handle_t& handle, handle, input_data, input_desc, components, singular_vals, prms, streams, n_streams, verbose); for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (std::uint32_t i = 0; i < n_streams; i++) { @@ -158,7 +158,7 @@ void transform_impl(raft::handle_t& handle, } for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } } @@ -201,7 +201,7 @@ void transform_impl(raft::handle_t& handle, handle, input_data, input_desc, components, trans_data, prms, streams, n_streams, verbose); for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (std::uint32_t i = 0; i < n_streams; i++) { @@ -243,7 +243,7 @@ void inverse_transform_impl(raft::handle_t& handle, } for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } } @@ -287,7 +287,7 @@ void inverse_transform_impl(raft::handle_t& handle, handle, trans_data, trans_desc, components, input_data, prms, streams, n_streams, verbose); for (std::uint32_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (std::uint32_t i = 0; i < n_streams; i++) { RAFT_CUDA_TRY(cudaStreamDestroy(streams[i])); @@ -366,14 +366,14 @@ void fit_transform_impl(raft::handle_t& handle, T total_vars_h; raft::update_host(&total_vars_h, total_vars.data(), std::size_t(1), streams[0]); - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[0])); + handle.sync_stream(streams[0]); T scalar = T(1) / total_vars_h; raft::linalg::scalarMultiply( explained_var_ratio, explained_var, scalar, prms.n_components, streams[0]); for (std::size_t i = 0; i < n_streams; i++) { - RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i])); + handle.sync_stream(streams[i]); } for (std::size_t i = 0; i < n_streams; i++) { RAFT_CUDA_TRY(cudaStreamDestroy(streams[i])); diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index 51f7526fc9..3b7d7080ed 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -257,7 +257,7 @@ void smooth_knn_dist(int n, value_t mean_dist = 0.0; raft::update_host(&mean_dist, dist_means_dev.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); /** * Smooth kNN distances to be continuous diff --git a/cpp/src/umap/optimize.cuh b/cpp/src/umap/optimize.cuh index 685d95d32b..502b7a5821 100644 --- a/cpp/src/umap/optimize.cuh +++ b/cpp/src/umap/optimize.cuh @@ -156,7 +156,7 @@ void optimize_params(T* input, T* grads_h = (T*)malloc(2 * sizeof(T)); raft::update_host(grads_h, grads.data(), 2, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (int i = 0; i < 2; i++) { if (abs(grads_h[i]) - tolerance <= 0) tol_grads += 1; @@ -210,7 +210,7 @@ void find_params_ab(UMAPParams* params, cudaStream_t stream) raft::update_host(&(params->a), coeffs.data(), 1, stream); raft::update_host(&(params->b), coeffs.data() + 1, 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); CUML_LOG_DEBUG("a=%f, b=%f", params->a, params->b); } diff --git a/cpp/src_prims/cache/cache.cuh b/cpp/src_prims/cache/cache.cuh index 828da58d4f..af442a5268 100644 --- a/cpp/src_prims/cache/cache.cuh +++ b/cpp/src_prims/cache/cache.cuh @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -313,7 +314,7 @@ class Cache { n, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } /** diff --git a/cpp/src_prims/label/classlabels.cuh b/cpp/src_prims/label/classlabels.cuh index 17b9aa030a..fe4573da33 100644 --- a/cpp/src_prims/label/classlabels.cuh +++ b/cpp/src_prims/label/classlabels.cuh @@ -49,7 +49,7 @@ int getUniqueLabels(math_t* y, size_t n, math_t* unique, cudaStream_t stream) rmm::device_uvector unique_v(0, stream); auto n_unique = raft::label::getUniquelabels(unique_v, y, n, stream); raft::copy(unique, unique_v.data(), n_unique, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); return n_unique; } diff --git a/cpp/src_prims/label/merge_labels.cuh b/cpp/src_prims/label/merge_labels.cuh index 163aaac9cb..4cff001a2c 100644 --- a/cpp/src_prims/label/merge_labels.cuh +++ b/cpp/src_prims/label/merge_labels.cuh @@ -22,6 +22,7 @@ #include #include #include +#include namespace MLCommon { namespace Label { @@ -142,7 +143,7 @@ void merge_labels(Index_* labels_a, RAFT_CUDA_TRY(cudaPeekAtLastError()); raft::update_host(&host_m, m, 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } while (host_m); // Step 2: re-assign minimum equivalent label @@ -152,4 +153,4 @@ void merge_labels(Index_* labels_a, } }; // namespace Label -}; // namespace MLCommon \ No newline at end of file +}; // namespace MLCommon diff --git a/cpp/src_prims/linalg/lstsq.cuh b/cpp/src_prims/linalg/lstsq.cuh index 890775b274..562ec4e1cf 100644 --- a/cpp/src_prims/linalg/lstsq.cuh +++ b/cpp/src_prims/linalg/lstsq.cuh @@ -72,7 +72,7 @@ struct DeviceEvent { } void wait() { - if (e != nullptr) RAFT_CUDA_TRY(cudaEventSynchronize(e)); + if (e != nullptr) raft::interruptible::synchronize(e); } DeviceEvent& operator=(const DeviceEvent& other) = delete; }; diff --git a/cpp/src_prims/metrics/dispersion.cuh b/cpp/src_prims/metrics/dispersion.cuh index 0af5b93dca..b2d3c007fb 100644 --- a/cpp/src_prims/metrics/dispersion.cuh +++ b/cpp/src_prims/metrics/dispersion.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -127,7 +128,7 @@ DataT dispersion(const DataT* centroids, RAFT_CUDA_TRY(cudaGetLastError()); DataT h_result; raft::update_host(&h_result, result.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); return sqrt(h_result); } diff --git a/cpp/src_prims/metrics/entropy.cuh b/cpp/src_prims/metrics/entropy.cuh index e6cd6a21e2..55650a3345 100644 --- a/cpp/src_prims/metrics/entropy.cuh +++ b/cpp/src_prims/metrics/entropy.cuh @@ -143,7 +143,7 @@ double entropy(const T* clusterArray, double h_entropy; raft::update_host(&h_entropy, d_entropy.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); return h_entropy; } diff --git a/cpp/src_prims/metrics/kl_divergence.cuh b/cpp/src_prims/metrics/kl_divergence.cuh index cb6e69d951..bce3bf7283 100644 --- a/cpp/src_prims/metrics/kl_divergence.cuh +++ b/cpp/src_prims/metrics/kl_divergence.cuh @@ -74,7 +74,7 @@ DataT kl_divergence(const DataT* modelPDF, const DataT* candidatePDF, int size, raft::update_host(&h_KLDVal, d_KLDVal.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); return h_KLDVal; } diff --git a/cpp/src_prims/metrics/mutual_info_score.cuh b/cpp/src_prims/metrics/mutual_info_score.cuh index b08b5a309d..f20de778e4 100644 --- a/cpp/src_prims/metrics/mutual_info_score.cuh +++ b/cpp/src_prims/metrics/mutual_info_score.cuh @@ -167,7 +167,7 @@ double mutual_info_score(const T* firstClusterArray, // updating in the host memory h_MI = d_MI.value(stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); return h_MI / size; } diff --git a/cpp/src_prims/metrics/rand_index.cuh b/cpp/src_prims/metrics/rand_index.cuh index e13ce93240..f1acf30ac5 100644 --- a/cpp/src_prims/metrics/rand_index.cuh +++ b/cpp/src_prims/metrics/rand_index.cuh @@ -148,7 +148,7 @@ double compute_rand_index(T* firstClusterArray, // synchronizing and updating the calculated values of a and b from device to host uint64_t ab_host[2] = {0}; raft::update_host(ab_host, arr_buf.data(), 2, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); // error handling RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/src_prims/metrics/scores.cuh b/cpp/src_prims/metrics/scores.cuh index 5fde1df62f..ec13ca7a2c 100644 --- a/cpp/src_prims/metrics/scores.cuh +++ b/cpp/src_prims/metrics/scores.cuh @@ -174,7 +174,7 @@ void regression_metrics(const T* predictions, predictions, ref_predictions, n, abs_diffs_array.data(), tmp_sums.data()); RAFT_CUDA_TRY(cudaGetLastError()); raft::update_host(&mean_errors[0], tmp_sums.data(), 2, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); mean_abs_error = mean_errors[0] / n; mean_squared_error = mean_errors[1] / n; @@ -202,7 +202,7 @@ void regression_metrics(const T* predictions, stream)); raft::update_host(h_sorted_abs_diffs.data(), sorted_abs_diffs.data(), n, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); int middle = n / 2; if (n % 2 == 1) { diff --git a/cpp/src_prims/metrics/silhouette_score.cuh b/cpp/src_prims/metrics/silhouette_score.cuh index bd212924a2..7686d80d8d 100644 --- a/cpp/src_prims/metrics/silhouette_score.cuh +++ b/cpp/src_prims/metrics/silhouette_score.cuh @@ -320,7 +320,7 @@ DataT silhouette_score( DataT avgSilhouetteScore = d_avgSilhouetteScore.value(stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); avgSilhouetteScore /= nRows; diff --git a/cpp/src_prims/random/mvg.cuh b/cpp/src_prims/random/mvg.cuh index 6c8465d856..420d72bf4a 100644 --- a/cpp/src_prims/random/mvg.cuh +++ b/cpp/src_prims/random/mvg.cuh @@ -19,6 +19,7 @@ #include #include #include +#include // #TODO: Replace with public header when ready #include // #TODO: Replace with public header when ready @@ -218,7 +219,7 @@ class MultiVarGaussian { cusolverHandle, jobz, uplo, dim, P, dim, eig, workspace_decomp, Lwork, info, cudaStream)); } raft::update_host(&info_h, info, 1, cudaStream); - RAFT_CUDA_TRY(cudaStreamSynchronize(cudaStream)); + raft::interruptible::synchronize(cudaStream); ASSERT(info_h == 0, "mvg: error in syevj/syevd/potrf, info=%d | expected=0", info_h); T mean = 0.0, stddv = 1.0; // generate nxN gaussian nums in X @@ -259,7 +260,7 @@ class MultiVarGaussian { // checking if any eigen vals were negative raft::update_host(&info_h, info, 1, cudaStream); - RAFT_CUDA_TRY(cudaStreamSynchronize(cudaStream)); + raft::interruptible::synchronize(cudaStream); ASSERT(info_h == 0, "mvg: Cov matrix has %dth Eigenval negative", info_h); // Got Q = eigvect*eigvals.sqrt in P, Q*X in X below diff --git a/cpp/src_prims/selection/knn.cuh b/cpp/src_prims/selection/knn.cuh index 31d5284c35..bdba95d082 100644 --- a/cpp/src_prims/selection/knn.cuh +++ b/cpp/src_prims/selection/knn.cuh @@ -339,7 +339,7 @@ void knn_regress(const raft::handle_t& handle, <<(TPB_X)), TPB_X, 0, stream>>>( out, knn_indices, y[i], n_query_rows, k, y.size(), i); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); } } diff --git a/cpp/test/mg/knn.cu b/cpp/test/mg/knn.cu index c3fcf3e361..3368fb5921 100644 --- a/cpp/test/mg/knn.cu +++ b/cpp/test/mg/knn.cu @@ -148,7 +148,7 @@ class BruteForceKNNTest : public ::testing::TestWithParam { Matrix::PartDescriptor query_desc( params.min_rows * params.n_query_parts, params.n_cols, queryPartsToRanks, comm.get_rank()); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); /** * Execute brute_force_knn() @@ -164,7 +164,7 @@ class BruteForceKNNTest : public ::testing::TestWithParam { params.batch_size, true); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); std::cout << raft::arr2Str(out_i_parts[0]->ptr, 10, "final_out_I", stream) << std::endl; std::cout << raft::arr2Str(out_d_parts[0]->ptr, 10, "final_out_D", stream) << std::endl; diff --git a/cpp/test/mg/knn_test_helper.cuh b/cpp/test/mg/knn_test_helper.cuh index 5c1cdc90bd..0cf6fdaac5 100644 --- a/cpp/test/mg/knn_test_helper.cuh +++ b/cpp/test/mg/knn_test_helper.cuh @@ -163,12 +163,12 @@ class KNNTestHelper { this->out_i_parts.push_back(out_i); } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void display_results() { - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); std::cout << "Finished!" << std::endl; diff --git a/cpp/test/prims/batched/csr.cu b/cpp/test/prims/batched/csr.cu index e2ab47da5d..076bed10d0 100644 --- a/cpp/test/prims/batched/csr.cu +++ b/cpp/test/prims/batched/csr.cu @@ -162,7 +162,7 @@ class CSRTest : public ::testing::TestWithParam> { break; } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } void TearDown() override diff --git a/cpp/test/prims/batched/matrix.cu b/cpp/test/prims/batched/matrix.cu index cc949e8ba4..7505688e85 100644 --- a/cpp/test/prims/batched/matrix.cu +++ b/cpp/test/prims/batched/matrix.cu @@ -203,7 +203,7 @@ class MatrixTest : public ::testing::TestWithParam> { // Check that H is in Hessenberg form std::vector H = std::vector(n * n * params.batch_size); raft::update_host(H.data(), HbM.raw_data(), H.size(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (int ib = 0; ib < params.batch_size; ib++) { for (int j = 0; j < n - 2; j++) { for (int i = j + 2; i < n; i++) { @@ -215,7 +215,7 @@ class MatrixTest : public ::testing::TestWithParam> { // Check that U is unitary (UU'=I) std::vector UUt = std::vector(n * n * params.batch_size); raft::update_host(UUt.data(), b_gemm(UbM, UbM, false, true).raw_data(), UUt.size(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (int ib = 0; ib < params.batch_size; ib++) { for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { @@ -240,7 +240,7 @@ class MatrixTest : public ::testing::TestWithParam> { // Check that S is in Schur form std::vector S = std::vector(n * n * params.batch_size); raft::update_host(S.data(), SbM.raw_data(), S.size(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (int ib = 0; ib < params.batch_size; ib++) { for (int j = 0; j < n - 2; j++) { for (int i = j + 2; i < n; i++) { @@ -259,7 +259,7 @@ class MatrixTest : public ::testing::TestWithParam> { // Check that U is unitary (UU'=I) std::vector UUt = std::vector(n * n * params.batch_size); raft::update_host(UUt.data(), b_gemm(UbM, UbM, false, true).raw_data(), UUt.size(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (int ib = 0; ib < params.batch_size; ib++) { for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { @@ -369,7 +369,7 @@ class MatrixTest : public ::testing::TestWithParam> { break; } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } void TearDown() override diff --git a/cpp/test/prims/cache.cu b/cpp/test/prims/cache.cu index 36d3f707d5..6c41469fc5 100644 --- a/cpp/test/prims/cache.cu +++ b/cpp/test/prims/cache.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace MLCommon { @@ -262,7 +263,7 @@ TEST_F(CacheTest, TestStoreCollect) raft::update_host(cache_idx_host, cache_idx_dev.data(), n_cached, stream); int keys_host[10]; raft::update_host(keys_host, keys_dev.data(), n_cached, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (int i = 0; i < n_cached; i++) { EXPECT_TRUE(devArrMatch(x_dev.data() + keys_host[i] * n_cols, tile_dev.data() + i * n_cols, @@ -292,7 +293,7 @@ TEST_F(CacheTest, TestStoreCollect) raft::update_host(cache_idx_host, cache_idx_dev.data(), 10, stream); raft::update_host(keys_host, keys_dev.data(), 10, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (int i = 0; i < 10; i++) { if (cache_idx_host[i] >= 0) { EXPECT_TRUE(devArrMatch(x_dev.data() + keys_host[i] * n_cols, diff --git a/cpp/test/prims/contingencyMatrix.cu b/cpp/test/prims/contingencyMatrix.cu index 1aee21c98f..bbdb309682 100644 --- a/cpp/test/prims/contingencyMatrix.cu +++ b/cpp/test/prims/contingencyMatrix.cu @@ -115,7 +115,7 @@ class ContingencyMatrixTest : public ::testing::TestWithParam #include #include +#include #include namespace MLCommon { @@ -76,7 +77,7 @@ template { std::vector act_h(size); raft::update_host(&(act_h[0]), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (size_t i(0); i < size; ++i) { auto act = act_h[i]; auto expected = (T)i; diff --git a/cpp/test/prims/device_utils.cu b/cpp/test/prims/device_utils.cu index c32d150e8a..abecf4c36f 100644 --- a/cpp/test/prims/device_utils.cu +++ b/cpp/test/prims/device_utils.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include namespace MLCommon { @@ -92,7 +93,7 @@ class BatchedBlockReduceTest : public ::testing::TestWithParam #include #include +#include #include #include #include @@ -85,7 +86,7 @@ class DispersionTest : public ::testing::TestWithParam> { } } expectedVal = sqrt(expectedVal); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } diff --git a/cpp/test/prims/entropy.cu b/cpp/test/prims/entropy.cu index cd33fcfc35..8afcfa60ef 100644 --- a/cpp/test/prims/entropy.cu +++ b/cpp/test/prims/entropy.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -77,7 +78,7 @@ class entropyTest : public ::testing::TestWithParam { rmm::device_uvector clusterArray(nElements, stream); raft::update_device(clusterArray.data(), &arr1[0], (int)nElements, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); // calling the entropy CUDA implementation computedEntropy = MLCommon::Metrics::entropy( clusterArray.data(), nElements, lowerLabelRange, upperLabelRange, stream); diff --git a/cpp/test/prims/fillna.cu b/cpp/test/prims/fillna.cu index a14fb4a9f4..9a6b07c07b 100644 --- a/cpp/test/prims/fillna.cu +++ b/cpp/test/prims/fillna.cu @@ -87,7 +87,7 @@ class FillnaTest : public ::testing::TestWithParam> { /* Copy to device */ raft::update_device( y.data(), h_y.data(), params.n_obs * params.batch_size, handle.get_stream()); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); /* Compute using tested prims */ fillna(y.data(), params.batch_size, params.n_obs, handle.get_stream()); @@ -173,4 +173,4 @@ INSTANTIATE_TEST_CASE_P(FillnaTests, FillnaTestF, ::testing::ValuesIn(inputsf)); INSTANTIATE_TEST_CASE_P(FillnaTests, FillnaTestD, ::testing::ValuesIn(inputsd)); } // namespace TimeSeries -} // namespace MLCommon \ No newline at end of file +} // namespace MLCommon diff --git a/cpp/test/prims/gather.cu b/cpp/test/prims/gather.cu index cefcc6250c..45884cc8c7 100644 --- a/cpp/test/prims/gather.cu +++ b/cpp/test/prims/gather.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -107,7 +108,7 @@ class GatherTest : public ::testing::TestWithParam { // launch device version of the kernel gatherLaunch(d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } diff --git a/cpp/test/prims/gram.cu b/cpp/test/prims/gram.cu index d6605243c0..b0c5ae2086 100644 --- a/cpp/test/prims/gram.cu +++ b/cpp/test/prims/gram.cu @@ -116,7 +116,7 @@ class GramMatrixTest : public ::testing::TestWithParam { raft::update_host(x1_host.data(), x1.data(), x1.size(), stream); std::vector x2_host(x2.size()); raft::update_host(x2_host.data(), x2.data(), x2.size(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); for (int i = 0; i < params.n1; i++) { for (int j = 0; j < params.n2; j++) { diff --git a/cpp/test/prims/histogram.cu b/cpp/test/prims/histogram.cu index e446c5a250..8c6640b2c7 100644 --- a/cpp/test/prims/histogram.cu +++ b/cpp/test/prims/histogram.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -82,7 +83,7 @@ class HistTest : public ::testing::TestWithParam { naiveHist(ref_bins.data(), params.nbins, in.data(), params.nrows, params.ncols, stream); histogram( params.type, bins.data(), params.nbins, in.data(), params.nrows, params.ncols, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } diff --git a/cpp/test/prims/knn_classify.cu b/cpp/test/prims/knn_classify.cu index 366ec7b25d..14954c265f 100644 --- a/cpp/test/prims/knn_classify.cu +++ b/cpp/test/prims/knn_classify.cu @@ -103,7 +103,7 @@ class KNNClassifyTest : public ::testing::TestWithParam { uniq_labels, n_unique); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/prims/knn_regression.cu b/cpp/test/prims/knn_regression.cu index 5e2fddb006..282a2ba5e0 100644 --- a/cpp/test/prims/knn_regression.cu +++ b/cpp/test/prims/knn_regression.cu @@ -115,7 +115,7 @@ class KNNRegressionTest : public ::testing::TestWithParam { knn_regress( handle, pred_labels.data(), knn_indices.data(), y, params.rows, params.rows, params.k); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void SetUp() override { basicTest(); } diff --git a/cpp/test/prims/label.cu b/cpp/test/prims/label.cu index aee27b668e..0146406542 100644 --- a/cpp/test/prims/label.cu +++ b/cpp/test/prims/label.cu @@ -21,6 +21,7 @@ #include "test_utils.h" #include #include +#include #include #include @@ -59,7 +60,7 @@ TEST_F(MakeMonotonicTest, Result) make_monotonic(actual.data(), data.data(), m, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); ASSERT_TRUE(devArrMatch(actual.data(), expected.data(), m, raft::Compare(), stream)); diff --git a/cpp/test/prims/linalg_block.cu b/cpp/test/prims/linalg_block.cu index d26a9220ab..75839c3b2c 100644 --- a/cpp/test/prims/linalg_block.cu +++ b/cpp/test/prims/linalg_block.cu @@ -103,7 +103,7 @@ class BlockGemmTest : public ::testing::TestWithParam> { h_a.data(), a.data(), params.m * params.k * params.batch_size, handle.get_stream()); raft::update_host( h_b.data(), b.data(), params.k * params.n * params.batch_size, handle.get_stream()); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); /* Compute using tested prims */ block_gemm_test_kernel @@ -328,7 +328,7 @@ class BlockGemvTest : public ::testing::TestWithParam> { raft::update_host( h_a.data(), a.data(), params.m * params.n * params.batch_size, handle.get_stream()); raft::update_host(h_x.data(), x.data(), params.n * params.batch_size, handle.get_stream()); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); /* Compute using tested prims */ int shared_mem_size = params.n * sizeof(T); @@ -459,7 +459,7 @@ class BlockDotTest : public ::testing::TestWithParam> { /* Copy to host */ raft::update_host(h_x.data(), x.data(), params.n * params.batch_size, handle.get_stream()); raft::update_host(h_y.data(), y.data(), params.n * params.batch_size, handle.get_stream()); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); /* Compute using tested prims */ constexpr int BlockSize = 64; @@ -588,7 +588,7 @@ class BlockXaxtTest : public ::testing::TestWithParam> { raft::update_host(h_x.data(), x.data(), params.n * params.batch_size, handle.get_stream()); raft::update_host( h_A.data(), A.data(), params.n * params.n * params.batch_size, handle.get_stream()); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); /* Compute using tested prims */ constexpr int BlockSize = 64; @@ -701,7 +701,7 @@ class BlockAxTest : public ::testing::TestWithParam> { /* Copy to host */ raft::update_host(h_x.data(), x.data(), params.n * params.batch_size, handle.get_stream()); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); /* Compute using tested prims */ constexpr int BlockSize = 64; @@ -796,7 +796,7 @@ class BlockCovStabilityTest : public ::testing::TestWithParam @@ -878,4 +878,4 @@ INSTANTIATE_TEST_CASE_P(BlockCovStabilityTests, ::testing::ValuesIn(cs_inputsd)); } // namespace LinAlg -} // namespace MLCommon \ No newline at end of file +} // namespace MLCommon diff --git a/cpp/test/prims/make_arima.cu b/cpp/test/prims/make_arima.cu index 4acfd5cc60..3eae72be2a 100644 --- a/cpp/test/prims/make_arima.cu +++ b/cpp/test/prims/make_arima.cu @@ -21,6 +21,7 @@ #include "test_utils.h" #include #include +#include #include namespace MLCommon { @@ -83,11 +84,11 @@ const std::vector make_arima_inputs = { {10000, 150, 2, 1, 2, 0, 1, 2, 4, 0, raft::random::GenPhilox, 1234ULL}}; typedef MakeArimaTest MakeArimaTestF; -TEST_P(MakeArimaTestF, Result) { RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } +TEST_P(MakeArimaTestF, Result) { raft::interruptible::synchronize(stream); } INSTANTIATE_TEST_CASE_P(MakeArimaTests, MakeArimaTestF, ::testing::ValuesIn(make_arima_inputs)); typedef MakeArimaTest MakeArimaTestD; -TEST_P(MakeArimaTestD, Result) { RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } +TEST_P(MakeArimaTestD, Result) { raft::interruptible::synchronize(stream); } INSTANTIATE_TEST_CASE_P(MakeArimaTests, MakeArimaTestD, ::testing::ValuesIn(make_arima_inputs)); } // end namespace Random diff --git a/cpp/test/prims/permute.cu b/cpp/test/prims/permute.cu index aeaeedbe95..ca528cdd77 100644 --- a/cpp/test/prims/permute.cu +++ b/cpp/test/prims/permute.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -67,7 +68,7 @@ class PermTest : public ::testing::TestWithParam> { r.uniform(in_ptr, len, T(-1.0), T(1.0), stream); } permute(outPerms_ptr, out_ptr, in_ptr, D, N, params.rowMajor, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } @@ -88,7 +89,7 @@ template { std::vector act_h(size); raft::update_host(&(act_h[0]), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); if (doSort) std::sort(act_h.begin(), act_h.end()); for (size_t i(0); i < size; ++i) { auto act = act_h[i]; @@ -116,7 +117,7 @@ template std::vector h_out(N * D), h_in(N * D); raft::update_host(&(h_out[0]), out, N * D, stream); raft::update_host(&(h_in[0]), in, N * D, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (int i = 0; i < N; ++i) { for (int j = 0; j < D; ++j) { int outPos = rowMajor ? i * D + j : j * N + i; diff --git a/cpp/test/prims/reduce_cols_by_key.cu b/cpp/test/prims/reduce_cols_by_key.cu index 3afbec0f85..a372a8d31d 100644 --- a/cpp/test/prims/reduce_cols_by_key.cu +++ b/cpp/test/prims/reduce_cols_by_key.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include namespace MLCommon { @@ -36,7 +37,7 @@ void naiveReduceColsByKey(const T* in, raft::copy(&(h_keys[0]), keys, ncols, stream); std::vector h_in(nrows * ncols); raft::copy(&(h_in[0]), in, nrows * ncols, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); std::vector out(nrows * nkeys, T(0)); for (uint32_t i = 0; i < nrows; ++i) { for (uint32_t j = 0; j < ncols; ++j) { @@ -44,7 +45,7 @@ void naiveReduceColsByKey(const T* in, } } raft::copy(out_ref, &(out[0]), nrows * nkeys, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } template @@ -83,7 +84,7 @@ class ReduceColsTest : public ::testing::TestWithParam> { r.uniformInt(keys.data(), ncols, 0u, params.nkeys, stream); naiveReduceColsByKey(in.data(), keys.data(), out_ref.data(), nrows, ncols, nkeys, stream); reduce_cols_by_key(in.data(), keys.data(), out.data(), nrows, ncols, nkeys, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } diff --git a/cpp/test/prims/reduce_rows_by_key.cu b/cpp/test/prims/reduce_rows_by_key.cu index a2e79fd751..d7b40e1ad6 100644 --- a/cpp/test/prims/reduce_rows_by_key.cu +++ b/cpp/test/prims/reduce_rows_by_key.cu @@ -141,7 +141,7 @@ class ReduceRowTest : public ::testing::TestWithParam> { reduce_rows_by_key( in.data(), cols, keys.data(), scratch_buf.data(), nobs, cols, nkeys, out.data(), stream); } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/prims/score.cu b/cpp/test/prims/score.cu index 164179756f..e0a0102c39 100644 --- a/cpp/test/prims/score.cu +++ b/cpp/test/prims/score.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -125,7 +126,7 @@ class AccuracyTest : public ::testing::TestWithParam { rmm::device_uvector ref_predictions(params.n, stream); r.normal(ref_predictions.data(), params.n, (T)0.0, (T)1.0, stream); raft::copy_async(predictions.data(), ref_predictions.data(), params.n, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); // Modify params.changed_n unique predictions to a different value. New value is irrelevant. if (params.changed_n > 0) { @@ -136,7 +137,7 @@ class AccuracyTest : public ::testing::TestWithParam { change_vals<<>>( predictions.data(), ref_predictions.data(), params.changed_n); RAFT_CUDA_TRY(cudaGetLastError()); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); } computed_accuracy = MLCommon::Score::accuracy_score( @@ -281,7 +282,7 @@ class RegressionMetricsTest : public ::testing::TestWithParam #include #include +#include namespace raft { @@ -84,7 +85,7 @@ testing::AssertionResult devArrMatch( std::unique_ptr act_h(new T[size]); raft::update_host(exp_h.get(), expected, size, stream); raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (size_t i(0); i < size; ++i) { auto exp = exp_h.get()[i]; auto act = act_h.get()[i]; @@ -101,7 +102,7 @@ testing::AssertionResult devArrMatch( { std::unique_ptr act_h(new T[size]); raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (size_t i(0); i < size; ++i) { auto act = act_h.get()[i]; if (!eq_compare(expected, act)) { @@ -125,7 +126,7 @@ testing::AssertionResult devArrMatch(const T* expected, std::unique_ptr act_h(new T[size]); raft::update_host(exp_h.get(), expected, size, stream); raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (size_t i(0); i < rows; ++i) { for (size_t j(0); j < cols; ++j) { auto idx = i * cols + j; // row major assumption! @@ -147,7 +148,7 @@ testing::AssertionResult devArrMatch( size_t size = rows * cols; std::unique_ptr act_h(new T[size]); raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (size_t i(0); i < rows; ++i) { for (size_t j(0); j < cols; ++j) { auto idx = i * cols + j; // row major assumption! @@ -178,7 +179,7 @@ testing::AssertionResult devArrMatchHost( { std::unique_ptr act_h(new T[size]); raft::update_host(act_h.get(), actual_d, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); bool ok = true; auto fail = testing::AssertionFailure(); for (size_t i(0); i < size; ++i) { @@ -210,7 +211,7 @@ testing::AssertionResult diagonalMatch( size_t size = rows * cols; std::unique_ptr act_h(new T[size]); raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); for (size_t i(0); i < rows; ++i) { for (size_t j(0); j < cols; ++j) { if (i != j) continue; diff --git a/cpp/test/sg/dbscan_test.cu b/cpp/test/sg/dbscan_test.cu index 9afd801f43..958ab959c0 100644 --- a/cpp/test/sg/dbscan_test.cu +++ b/cpp/test/sg/dbscan_test.cu @@ -108,7 +108,7 @@ class DbscanTest : public ::testing::TestWithParam> { raft::copy(labels_ref.data(), l.data(), params.n_row, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); Dbscan::fit(handle, params.metric == raft::distance::Precomputed ? dist.data() : out.data(), @@ -121,7 +121,7 @@ class DbscanTest : public ::testing::TestWithParam> { nullptr, params.max_bytes_per_batch); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); score = adjusted_rand_index(handle, labels_ref.data(), labels.data(), params.n_row); @@ -225,7 +225,7 @@ class Dbscan2DSimple : public ::testing::TestWithParam> { raft::copy(inputs.data(), params.points, params.n_row * 2, stream); raft::copy(labels_ref.data(), params.out, params.n_out, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); Dbscan::fit(handle, inputs.data(), @@ -237,7 +237,7 @@ class Dbscan2DSimple : public ::testing::TestWithParam> { labels.data(), core_sample_indices_d.data()); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); score = adjusted_rand_index(handle, labels_ref.data(), labels.data(), (int)params.n_out); diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 7cf02ad5f2..9627e79e6a 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -867,7 +867,7 @@ class TreeliteFilTest : public BaseFilTest { params.n_items = ps.n_items; params.pforest_shape_str = ps.print_forest_shape ? &forest_shape_str : nullptr; fil::from_treelite(handle, pforest, (ModelHandle)model.get(), ¶ms); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); if (ps.print_forest_shape) { std::string str(forest_shape_str); for (const char* substr : {"model size", diff --git a/cpp/test/sg/genetic/program_test.cu b/cpp/test/sg/genetic/program_test.cu index d5f57e310a..5ce753f205 100644 --- a/cpp/test/sg/genetic/program_test.cu +++ b/cpp/test/sg/genetic/program_test.cu @@ -650,7 +650,7 @@ TEST_F(GeneticProgramTest, ProgramExecution) n_progs * n_samples * sizeof(float), cudaMemcpyDeviceToHost, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); // Check results @@ -690,7 +690,7 @@ TEST_F(GeneticProgramTest, ProgramFitnessScore) dx2.data(), dy2.data(), dw2.data()); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } RAFT_CUDA_TRY(cudaMemcpyAsync(hactualscores.data(), @@ -708,4 +708,4 @@ TEST_F(GeneticProgramTest, ProgramFitnessScore) } } // namespace genetic -} // namespace cuml \ No newline at end of file +} // namespace cuml diff --git a/cpp/test/sg/hdbscan_test.cu b/cpp/test/sg/hdbscan_test.cu index b0e6ecb703..b8df86f0f9 100644 --- a/cpp/test/sg/hdbscan_test.cu +++ b/cpp/test/sg/hdbscan_test.cu @@ -105,7 +105,7 @@ class HDBSCANTest : public ::testing::TestWithParam> { hdbscan_params, out); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); score = MLCommon::Metrics::compute_adjusted_rand_index( out.get_labels(), labels_ref.data(), params.n_row, handle.get_stream()); @@ -184,7 +184,7 @@ class ClusterCondensingTest : public ::testing::TestWithParam labels(params.n_row, handle.get_stream()); rmm::device_uvector stabilities(condensed_tree.get_n_clusters(), handle.get_stream()); @@ -294,7 +294,7 @@ class ClusterSelectionTest : public ::testing::TestWithParam> { season_ptr.data(), forecast_ptr.data()); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void SetUp() override { basicTest(); } diff --git a/cpp/test/sg/kmeans_test.cu b/cpp/test/sg/kmeans_test.cu index 26d637b82d..4daefb3aa5 100644 --- a/cpp/test/sg/kmeans_test.cu +++ b/cpp/test/sg/kmeans_test.cu @@ -100,7 +100,7 @@ class KmeansTest : public ::testing::TestWithParam> { raft::copy(d_labels_ref.data(), labels.data(), n_samples, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); T inertia = 0; int n_iter = 0; @@ -116,7 +116,7 @@ class KmeansTest : public ::testing::TestWithParam> { inertia, n_iter); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); score = adjusted_rand_index(handle, d_labels_ref.data(), d_labels.data(), n_samples); diff --git a/cpp/test/sg/knn_test.cu b/cpp/test/sg/knn_test.cu index 3aef62f950..9c56d47fdd 100644 --- a/cpp/test/sg/knn_test.cu +++ b/cpp/test/sg/knn_test.cu @@ -257,7 +257,7 @@ class KNNTest : public ::testing::TestWithParam { index_labels_float.data(), index_labels.data(), index_labels_float.size()); to_float<<>>( query_labels_float.data(), search_labels.data(), params.n_query_row); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); rmm::device_uvector actual_labels_float(params.n_query_row, stream); diff --git a/cpp/test/sg/linkage_test.cu b/cpp/test/sg/linkage_test.cu index 1a9bd7f7f9..fa931da991 100644 --- a/cpp/test/sg/linkage_test.cu +++ b/cpp/test/sg/linkage_test.cu @@ -80,7 +80,7 @@ class LinkageTest : public ::testing::TestWithParam> { raft::copy(data.data(), params.data.data(), data.size(), handle.get_stream()); raft::copy(labels_ref.data(), params.expected_labels.data(), params.n_row, handle.get_stream()); - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); raft::hierarchy::linkage_output out_arrs; out_arrs.labels = labels.data(); @@ -107,7 +107,7 @@ class LinkageTest : public ::testing::TestWithParam> { params.n_clusters); } - RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); + handle.sync_stream(handle.get_stream()); } void SetUp() override { basicTest(); } diff --git a/cpp/test/sg/quasi_newton.cu b/cpp/test/sg/quasi_newton.cu index 8a39a3c81a..31b2e6c03a 100644 --- a/cpp/test/sg/quasi_newton.cu +++ b/cpp/test/sg/quasi_newton.cu @@ -52,7 +52,7 @@ struct QuasiNewtonTest : ::testing::Test { raft::update_device(Xdev->data, &X[0][0], Xdev->len, stream); ydev.reset(new SimpleVecOwning(N, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void TearDown() {} }; @@ -93,7 +93,7 @@ template SimpleVecOwning w_ref(dims.n_param, stream); raft::update_device(w_ref.data, &w_ref_cm[0], C * D, stream); if (fit_intercept) { raft::update_device(&w_ref.data[C * D], host_bias, C, stream); } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); return raft::devArrMatch(w_ref.data, w, w_ref.len, comp); } @@ -224,7 +224,7 @@ TEST_F(QuasiNewtonTest, binary_logistic_vs_sklearn) // Test case generated in python and solved with sklearn double y[N] = {1, 1, 1, 0, 1, 0, 1, 0, 1, 0}; raft::update_device(ydev->data, &y[0], ydev->len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); double alpha = 0.01 * N; @@ -305,7 +305,7 @@ TEST_F(QuasiNewtonTest, multiclass_logistic_vs_sklearn) raft::CompareApprox compApprox(tol); double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; raft::update_device(ydev->data, &y[0], ydev->len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); double fx, l1, l2; int C = 4; @@ -378,7 +378,7 @@ TEST_F(QuasiNewtonTest, linear_regression_vs_sklearn) 0.7966520536712608, -1.0767450516284769}; raft::update_device(ydev->data, &y[0], ydev->len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); double fx, l1, l2; double alpha = 0.01 * N; @@ -458,7 +458,7 @@ TEST_F(QuasiNewtonTest, predict) qnPredict( handle, Xdev->data, false, N, D, 2, false, w.data, QN_LOSS_LOGISTIC, preds.data, stream); raft::update_host(&preds_host[0], preds.data, preds.len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); for (int it = 0; it < N; it++) { ASSERT_TRUE(X[it][0] > 0 ? compApprox(preds_host[it], 1) : compApprox(preds_host[it], 0)); @@ -466,7 +466,7 @@ TEST_F(QuasiNewtonTest, predict) qnPredict(handle, Xdev->data, false, N, D, 1, false, w.data, QN_LOSS_SQUARED, preds.data, stream); raft::update_host(&preds_host[0], preds.data, preds.len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); for (int it = 0; it < N; it++) { ASSERT_TRUE(compApprox(X[it][0], preds_host[it])); @@ -489,7 +489,7 @@ TEST_F(QuasiNewtonTest, predict_softmax) qnPredict(handle, Xdev->data, false, N, D, C, false, w.data, QN_LOSS_SOFTMAX, preds.data, stream); raft::update_host(&preds_host[0], preds.data, preds.len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); for (int it = 0; it < N; it++) { if (X[it][0] < 0 && X[it][1] < 0) { @@ -530,7 +530,7 @@ TEST_F(QuasiNewtonTest, dense_vs_sparse_logistic) raft::CompareApprox compApprox(tol); double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; raft::update_device(ydev->data, &y[0], ydev->len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); int C = 4; QN_LOSS_TYPE loss_type = QN_LOSS_SOFTMAX; // Softmax (loss_b, loss_no_b) @@ -581,7 +581,7 @@ TEST_F(QuasiNewtonTest, dense_vs_sparse_logistic) raft::update_host(&preds_dense_host[0], preds_dense.data, preds_dense.len, stream); raft::update_host(&preds_sparse_host[0], preds_sparse.data, preds_sparse.len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); for (int i = 0; i < N; i++) { ASSERT_TRUE(compApprox(preds_dense_host[i], preds_sparse_host[i])); } diff --git a/cpp/test/sg/rproj_test.cu b/cpp/test/sg/rproj_test.cu index 967223f765..34d5526fd3 100644 --- a/cpp/test/sg/rproj_test.cu +++ b/cpp/test/sg/rproj_test.cu @@ -84,7 +84,7 @@ class RPROJTest : public ::testing::Test { params1.n_components, stream); // From column major to row major - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void sparseTest() @@ -113,7 +113,7 @@ class RPROJTest : public ::testing::Test { params2.n_components, stream); // From column major to row major - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void SetUp() override diff --git a/cpp/test/sg/shap_kernel.cu b/cpp/test/sg/shap_kernel.cu index fe7b19efda..6a27021cef 100644 --- a/cpp/test/sg/shap_kernel.cu +++ b/cpp/test/sg/shap_kernel.cu @@ -113,7 +113,7 @@ class MakeKSHAPDatasetTest : public ::testing::TestWithParam kidx_h(n_ws); raft::update_host(kidx_h.data(), kColIdx, n_ws, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); // Note: kernel cache can permute the working set, so we have to look // up which rows we compare for (int i = 0; i < n_ws; i++) { @@ -618,7 +618,7 @@ void checkResults(SvmModel model, } math_t* dual_coefs_host = new math_t[model.n_support]; raft::update_host(dual_coefs_host, model.dual_coefs, model.n_support, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); math_t ay = 0; for (int i = 0; i < model.n_support; i++) { ay += dual_coefs_host[i]; @@ -641,7 +641,7 @@ void checkResults(SvmModel model, math_t* x_support_host = new math_t[model.n_support * model.n_cols]; raft::update_host(x_support_host, model.x_support, model.n_support * model.n_cols, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); if (w_exp) { std::vector w(model.n_cols, 0); @@ -725,7 +725,7 @@ class SmoSolverTest : public ::testing::Test { math_t return_buff[2]; raft::update_host(return_buff, return_buff_dev.data(), 2, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); EXPECT_FLOAT_EQ(return_buff[0], 2.0f) << return_buff[0]; EXPECT_LT(return_buff[1], 100) << return_buff[1]; @@ -802,7 +802,7 @@ class SmoSolverTest : public ::testing::Test { math_t return_buff[2]; raft::update_host(return_buff, return_buff_dev.data(), 2, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); EXPECT_LT(return_buff[1], 10) << return_buff[1]; math_t alpha_exp[] = {0, 0.8, 0.8, 0}; @@ -1188,7 +1188,7 @@ TYPED_TEST(SmoSolverTest, MemoryLeak) } else { svc.fit(x.data(), p.n_rows, p.n_cols, y.data()); rmm::device_uvector y_pred(p.n_rows, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); RAFT_CUDA_TRY(cudaMemGetInfo(&free2, &total)); float delta = (free1 - free2); // Just to make sure that we measure any mem consumption at all: @@ -1197,7 +1197,7 @@ TYPED_TEST(SmoSolverTest, MemoryLeak) // it (one could additionally control the exec time by the max_iter arg to // SVC). EXPECT_GT(delta, p.n_rows * p.n_cols * 4); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::interruptible::synchronize(stream); svc.predict(x.data(), p.n_rows, p.n_cols, y_pred.data()); } } diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index 35595b3f52..c7c79d8535 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -134,7 +134,7 @@ class TSNETest : public ::testing::TestWithParam { k_graph.knn_dists = input_dists.data(); TSNE::get_distances(handle, input, k_graph, stream); } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); TSNE_runner, knn_indices_dense_t, float> runner( handle, input, k_graph, model_params); results.kl_div = runner.run(); @@ -149,7 +149,7 @@ class TSNETest : public ::testing::TestWithParam { model_params.dim, raft::distance::DistanceType::L2Expanded, false); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); // Compute theorical KL div results.kl_div_ref = @@ -159,7 +159,7 @@ class TSNETest : public ::testing::TestWithParam { float* embeddings_h = (float*)malloc(sizeof(float) * n * model_params.dim); assert(embeddings_h != NULL); raft::update_host(embeddings_h, Y_d.data(), n * model_params.dim, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); // Move embeddings to host. // This can be used for printing if needed. int k = 0; @@ -170,7 +170,7 @@ class TSNETest : public ::testing::TestWithParam { } // Move transposed embeddings back to device, as trustworthiness requires C contiguous format raft::update_device(Y_d.data(), C_contiguous_embedding, n * model_params.dim, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); free(embeddings_h); // Produce trustworthiness score diff --git a/cpp/test/sg/umap_parametrizable_test.cu b/cpp/test/sg/umap_parametrizable_test.cu index 95a3e4d679..23e8fa501b 100644 --- a/cpp/test/sg/umap_parametrizable_test.cu +++ b/cpp/test/sg/umap_parametrizable_test.cu @@ -152,7 +152,7 @@ class UMAPParametrizableTest : public ::testing::Test { knn_dists, umap_params.n_neighbors); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } float* model_embedding = nullptr; @@ -168,7 +168,7 @@ class UMAPParametrizableTest : public ::testing::Test { RAFT_CUDA_TRY(cudaMemsetAsync( model_embedding, 0, n_samples * umap_params.n_components * sizeof(float), stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); if (test_params.supervised) { ML::UMAP::fit( @@ -198,13 +198,13 @@ class UMAPParametrizableTest : public ::testing::Test { handle, X, n_samples, n_features, cgraph_coo.get(), &umap_params, model_embedding); } } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); if (!test_params.fit_transform) { RAFT_CUDA_TRY(cudaMemsetAsync( embedding_ptr, 0, n_samples * umap_params.n_components * sizeof(float), stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); ML::UMAP::transform(handle, X, @@ -219,7 +219,7 @@ class UMAPParametrizableTest : public ::testing::Test { &umap_params, embedding_ptr); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); delete model_embedding_b; } @@ -294,11 +294,11 @@ class UMAPParametrizableTest : public ::testing::Test { 10.f, 1234ULL); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); MLCommon::LinAlg::convert_array((float*)y_d.data(), y_d.data(), n_samples, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); rmm::device_uvector embeddings1(n_samples * umap_params.n_components, stream); diff --git a/python/cuml/common/logger.pyx b/python/cuml/common/logger.pyx index 44620a8d59..73f0bdfff4 100644 --- a/python/cuml/common/logger.pyx +++ b/python/cuml/common/logger.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -75,7 +75,7 @@ level_critical = CUML_LEVEL_CRITICAL """Disables all log messages""" level_off = CUML_LEVEL_OFF -cdef void _log_callback(int lvl, const char * msg) nogil: +cdef void _log_callback(int lvl, const char * msg) with gil: """ Default spdlogs callback function to redirect logs correctly to sys.stdout @@ -86,35 +86,15 @@ cdef void _log_callback(int lvl, const char * msg) nogil: msg : char * Message to be logged """ - with gil: - print(msg.decode('utf-8'), end='') + print(msg.decode('utf-8'), end='') -cdef void _nogil_log_callback(int lvl, const char * msg) nogil: - """ - Wrapper for _log_callback to explicitly disable Cython's automatic GIL - acquire - """ - with nogil: - _log_callback(lvl, msg) - - -cdef void _log_flush() nogil: +cdef void _log_flush() with gil: """ Default spdlogs callback function to flush logs """ - with gil: - if sys.stdout is not None: - sys.stdout.flush() - - -cdef void _nogil_log_flush() nogil: - """ - Wrapper for _log_flush to explicitly disable Cython's automatic GIL - acquire - """ - with nogil: - _log_flush() + if sys.stdout is not None: + sys.stdout.flush() class LogLevelSetter: @@ -366,5 +346,5 @@ def flush(): # Set callback functions to handle redirected sys.stdout in Python -Logger.get().setCallback(_nogil_log_callback) -Logger.get().setFlush(_nogil_log_flush) +Logger.get().setCallback(_log_callback) +Logger.get().setFlush(_log_flush) diff --git a/python/cuml/svm/linear.pyx b/python/cuml/svm/linear.pyx index dc00dae0ea..ec598545e2 100644 --- a/python/cuml/svm/linear.pyx +++ b/python/cuml/svm/linear.pyx @@ -30,6 +30,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.array import CumlArray from cuml.common.base import Base from cuml.raft.common.handle cimport handle_t +from cuml.raft.common.interruptible import cuda_interruptible from cuml.common import input_to_cuml_array from libc.stdint cimport uintptr_t from libcpp cimport bool as cppbool @@ -43,7 +44,7 @@ from cuda.ccudart cimport( __all__ = ['LinearSVM', 'LinearSVM_defaults'] -cdef extern from "cuml/svm/linear.hpp" namespace "ML::SVM": +cdef extern from "cuml/svm/linear.hpp" namespace "ML::SVM" nogil: cdef enum Penalty "ML::SVM::LinearSVMParams::Penalty": L1 "ML::SVM::LinearSVMParams::L1" @@ -288,29 +289,39 @@ cdef class LinearSVMWrapper: self.dtype = X.dtype if do_training else coefs.dtype cdef cuda_stream_view stream = self.handle.get_stream() nClasses = 0 - nCols = 0 if self.dtype != np.float32 and self.dtype != np.float64: raise TypeError('Input data type must be float32 or float64') + cdef uintptr_t Xptr = X.ptr + cdef uintptr_t yptr = y.ptr + cdef uintptr_t swptr = sampleWeight.ptr \ + if sampleWeight is not None else 0 + cdef size_t nCols = 0 + cdef size_t nRows = 0 if do_training: nCols = X.shape[1] + nRows = X.shape[0] sw_ptr = sampleWeight.ptr if sampleWeight is not None else 0 if self.dtype == np.float32: - self.model.float32 = LinearSVMModel[float].fit( - deref(self.handle), self.params, - X.ptr, - X.shape[0], nCols, - y.ptr, - sw_ptr) + with cuda_interruptible(): + with nogil: + self.model.float32 = LinearSVMModel[float].fit( + deref(self.handle), self.params, + Xptr, + nRows, nCols, + yptr, + swptr) nClasses = self.model.float32.nClasses elif self.dtype == np.float64: - self.model.float64 = LinearSVMModel[double].fit( - deref(self.handle), self.params, - X.ptr, - X.shape[0], nCols, - y.ptr, - sw_ptr) + with cuda_interruptible(): + with nogil: + self.model.float64 = LinearSVMModel[double].fit( + deref(self.handle), self.params, + Xptr, + nRows, nCols, + yptr, + swptr) nClasses = self.model.float64.nClasses else: nCols = coefs.shape[1] diff --git a/python/cuml/svm/svc.pyx b/python/cuml/svm/svc.pyx index 9c65be4e60..a19a406a2d 100644 --- a/python/cuml/svm/svc.pyx +++ b/python/cuml/svm/svc.pyx @@ -35,6 +35,7 @@ from cuml.common.doc_utils import generate_docstring from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.logger import warn from cuml.raft.common.handle cimport handle_t +from cuml.raft.common.interruptible import cuda_interruptible from cuml.common import input_to_cuml_array, input_to_host_array, with_cupy_rmm from cuml.common.input_utils import input_to_cupy_array from cuml.preprocessing import LabelEncoder @@ -90,7 +91,7 @@ cdef extern from "cuml/svm/svm_model.h" namespace "ML::SVM": int n_classes math_t *unique_labels -cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM": +cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM" nogil: cdef void svcFit[math_t](const handle_t &handle, math_t *input, int n_rows, int n_cols, math_t *labels, @@ -482,17 +483,25 @@ class SVC(SVMBase, cdef SvmModel[double] *model_d cdef handle_t* handle_ = self.handle.getHandle() + cdef int n_rows = self.n_rows + cdef int n_cols = self.n_cols if self.dtype == np.float32: model_f = new SvmModel[float]() - svcFit(handle_[0], X_ptr, self.n_rows, - self.n_cols, y_ptr, param, _kernel_params, - model_f[0], sample_weight_ptr) + with cuda_interruptible(): + with nogil: + svcFit( + deref(handle_), X_ptr, n_rows, + n_cols, y_ptr, param, _kernel_params, + deref(model_f), sample_weight_ptr) self._model = model_f elif self.dtype == np.float64: model_d = new SvmModel[double]() - svcFit(handle_[0], X_ptr, self.n_rows, - self.n_cols, y_ptr, param, _kernel_params, - model_d[0], sample_weight_ptr) + with cuda_interruptible(): + with nogil: + svcFit( + deref(handle_), X_ptr, n_rows, + n_cols, y_ptr, param, _kernel_params, + deref(model_d), sample_weight_ptr) self._model = model_d else: raise TypeError('Input data type should be float32 or float64') diff --git a/python/cuml/svm/svm_base.pyx b/python/cuml/svm/svm_base.pyx index 9d68aae2c1..46c65cfc20 100644 --- a/python/cuml/svm/svm_base.pyx +++ b/python/cuml/svm/svm_base.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -83,13 +83,6 @@ cdef extern from "cuml/svm/svm_model.h" namespace "ML::SVM": cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM": - cdef void svcFit[math_t](const handle_t &handle, math_t *input, - int n_rows, int n_cols, math_t *labels, - const SvmParameter ¶m, - KernelParams &kernel_params, - SvmModel[math_t] &model, - const math_t *sample_weight) except+ - cdef void svcPredict[math_t]( const handle_t &handle, math_t *input, int n_rows, int n_cols, KernelParams &kernel_params, const SvmModel[math_t] &model, diff --git a/python/cuml/svm/svr.pyx b/python/cuml/svm/svr.pyx index bd3073c169..e853cc679a 100644 --- a/python/cuml/svm/svr.pyx +++ b/python/cuml/svm/svr.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -89,7 +89,7 @@ cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM": cdef void svmFreeBuffers[math_t](const handle_t &handle, SvmModel[math_t] &m) except + -cdef extern from "cuml/svm/svr.hpp" namespace "ML::SVM": +cdef extern from "cuml/svm/svr.hpp" namespace "ML::SVM" nogil: cdef void svrFit[math_t](const handle_t &handle, math_t *X, int n_rows, int n_cols, math_t *y, diff --git a/wiki/cpp/DEVELOPER_GUIDE.md b/wiki/cpp/DEVELOPER_GUIDE.md index 7479d8dac7..0792f7e2b4 100644 --- a/wiki/cpp/DEVELOPER_GUIDE.md +++ b/wiki/cpp/DEVELOPER_GUIDE.md @@ -10,50 +10,48 @@ Please start by reading [CONTRIBUTING.md](../../CONTRIBUTING.md). ## Threading Model -With the exception of the raft::handle_t, cuML algorithms should maintain thread-safety and are, in general, -assumed to be single threaded. This means they should be able to be called from multiple host threads so +With the exception of the raft::handle_t, cuML algorithms should maintain thread-safety and are, in general, +assumed to be single threaded. This means they should be able to be called from multiple host threads so long as different instances of `raft::handle_t` are used. Exceptions are made for algorithms that can take advantage of multiple CUDA streams within multiple host threads -in order to oversubscribe or increase occupancy on a single GPU. In these cases, the use of multiple host -threads within cuML algorithms should be used only to maintain concurrency of the underlying CUDA streams. -Multiple host threads should be used sparingly, be bounded, and should steer clear of performing CPU-intensive +in order to oversubscribe or increase occupancy on a single GPU. In these cases, the use of multiple host +threads within cuML algorithms should be used only to maintain concurrency of the underlying CUDA streams. +Multiple host threads should be used sparingly, be bounded, and should steer clear of performing CPU-intensive computations. A good example of an acceptable use of host threads within a cuML algorithm might look like the following ``` -cudaStreamSynchronize(handle.get_stream()); +handle.sync_stream(); int n_streams = handle.get_num_internal_streams(); #pragma omp parallel for num_threads(n_threads) for(int i = 0; i < n; i++) { int thread_num = omp_get_thread_num() % n_threads; - cudaStream_t s = handle.getInternalStream(thread_num); + cudaStream_t s = handle.get_stream_from_stream_pool(thread_num); ... possible light cpu pre-processing ... my_kernel1<<>>(...); ... ... some possible async d2h / h2d copies ... my_kernel2<<>>(...); ... - cudaStreamSynchronize(s); + handle.sync_stream(s); ... possible light cpu post-processing ... } ``` In the example above, if there is no CPU pre-processing at the beginning of the for-loop, an event can be registered in -each of the streams within the for-loop to make them wait on the stream from the handle. - -This can be done easily by replacing `cudaStreamSynchronize(handle.get_stream())` with `handle.wait_on_user_stream()` -for a lighter-weight synchronization. If there is no CPU post-processing at the end of each for-loop iteration, -`cudaStreamSynchronize(s)` can be replaced with a single `handle.wait_on_internal_streams()` after the for-loop. +each of the streams within the for-loop to make them wait on the stream from the handle. If there is no CPU post-processing +at the end of each for-loop iteration, `handle.sync_stream(s)` can be replaced with a single `handle.sync_stream_pool()` +after the for-loop. To avoid compatibility issues between different threading models, the only threading programming allowed in cuML is OpenMP. Though cuML's build enables OpenMP by default, cuML algorithms should still function properly even when OpenMP has been disabled. If the CPU pre- and post-processing were not needed in the example above, OpenMP would not be needed. -The use of threads in third-party libraries is allowed, though they should still avoid depending on a specific OpenMP runtime. +The use of threads in third-party libraries is allowed, though they should still avoid depending on a specific OpenMP runtime. ## Public cuML interface ### Terminology @@ -302,7 +300,7 @@ void foo(const raft::handle_t& h, ..., cudaStream_t stream ) { ... MLCommon::device_buffer temp( h.get_device_allocator(), stream, 0 ) - + temp.resize(n, stream); kernelA<<>>(..., temp.data(), ...); kernelB<<>>(..., temp.data(), ...); @@ -349,7 +347,7 @@ void foo(const raft::handle_t& h, ...) cudaStream_t stream = h.get_stream(); } ``` -When multiple streams are needed, e.g. to manage a pipeline, use the internal streams available in `raft::handle_t` (see [CUDA Resources](#cuda-resources)). If multiple streams are used all operations still must be ordered according to `raft::handle_t::get_stream()`. Before any operation in any of the internal CUDA streams is started, all previous work in `raft::handle_t::get_stream()` must have completed. Any work enqueued in `raft::handle_t::get_stream()` after a cuML function returns should not start before all work enqueued in the internal streams has completed. E.g. if a cuML algorithm is called like this: +When multiple streams are needed, e.g. to manage a pipeline, use the internal streams available in `raft::handle_t` (see [CUDA Resources](#cuda-resources)). If multiple streams are used all operations still must be ordered according to `raft::handle_t::get_stream()`. Before any operation in any of the internal CUDA streams is started, all previous work in `raft::handle_t::get_stream()` must have completed. Any work enqueued in `raft::handle_t::get_stream()` after a cuML function returns should not start before all work enqueued in the internal streams has completed. E.g. if a cuML algorithm is called like this: ```cpp void foo(const double* const srcdata, double* const result) { @@ -370,7 +368,7 @@ void foo(const double* const srcdata, double* const result) ``` No work in any stream should start in `ML::algo` before the `cudaMemcpyAsync` in `stream` launched before the call to `ML::algo` is done. And all work in all streams used in `ML::algo` should be done before the `cudaMemcpyAsync` in `stream` launched after the call to `ML::algo` starts. -This can be ensured by introducing interstream dependencies with CUDA events and `cudaStreamWaitEvent`. For convenience, the header `raft/handle.hpp` provides the class `raft::stream_syncer` which lets all `raft::handle_t` internal CUDA streams wait on `raft::handle_t:get_stream()` in its constructor and in its destructor and lets `raft::handle_t::get_stream()` wait on all work enqueued in the `raft::handle_t` internal CUDA streams. The intended use would be to create a `raft::stream_syncer` object as the first thing in a entry function of the public cuML API: +This can be ensured by introducing interstream dependencies with CUDA events and `cudaStreamWaitEvent`. For convenience, the header `raft/handle.hpp` provides the class `raft::stream_syncer` which lets all `raft::handle_t` internal CUDA streams wait on `raft::handle_t:get_stream()` in its constructor and in its destructor and lets `raft::handle_t::get_stream()` wait on all work enqueued in the `raft::handle_t` internal CUDA streams. The intended use would be to create a `raft::stream_syncer` object as the first thing in a entry function of the public cuML API: ```cpp void cumlAlgo(const raft::handle_t& handle, ...) @@ -444,9 +442,9 @@ int main(int argc, char * argv[]) { raft::handle_t raftHandle; initialize_mpi_comms(raftHandle, raft_mpi_comms); - + ... - + ML::mlalgo(raftHandle, ... ); }