From 269b7b17b492301c1ce59b9a5790a23dc9db820f Mon Sep 17 00:00:00 2001 From: Victor Lafargue Date: Mon, 30 Aug 2021 18:33:19 +0200 Subject: [PATCH] Apply modifications to account for RAFT changes (#4077) This PR apply modifications to the cuML codebase to account for changes in RAFT and RMM : - https://github.com/rapidsai/raft/pull/283 - https://github.com/rapidsai/raft/pull/285 - https://github.com/rapidsai/raft/pull/286 - https://github.com/rapidsai/rmm/pull/816 Authors: - Victor Lafargue (https://github.com/viclafargue) - Dante Gama Dessavre (https://github.com/dantegd) Approvers: - William Hicks (https://github.com/wphicks) - Micka (https://github.com/lowener) - Dante Gama Dessavre (https://github.com/dantegd) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuml/pull/4077 --- cpp/bench/common/ml_benchmark.hpp | 17 +- cpp/bench/prims/add.cu | 9 +- cpp/bench/prims/distance_common.cuh | 41 +- cpp/bench/prims/fused_l2_nn.cu | 9 +- cpp/bench/prims/gram_matrix.cu | 36 +- cpp/bench/prims/make_blobs.cu | 25 +- cpp/bench/prims/map_then_reduce.cu | 9 +- cpp/bench/prims/matrix_vector_op.cu | 9 +- cpp/bench/prims/permute.cu | 8 +- cpp/bench/prims/reduce.cu | 9 +- cpp/bench/prims/rng.cu | 9 +- cpp/bench/sg/arima_loglikelihood.cu | 61 ++- cpp/bench/sg/benchmark.cuh | 12 +- cpp/bench/sg/dataset.cuh | 48 ++- cpp/bench/sg/dataset_ts.cuh | 18 +- cpp/bench/sg/dbscan.cu | 4 +- cpp/bench/sg/fil.cu | 12 +- cpp/bench/sg/kmeans.cu | 4 +- cpp/bench/sg/linkage.cu | 2 +- cpp/bench/sg/rf_classifier.cu | 4 +- cpp/bench/sg/svc.cu | 18 +- cpp/bench/sg/svr.cu | 24 +- cpp/bench/sg/umap.cu | 12 +- cpp/examples/dbscan/dbscan_example.cpp | 5 - cpp/examples/kmeans/kmeans_example.cpp | 6 - cpp/include/cuml/common/device_buffer.hpp | 43 -- cpp/include/cuml/common/host_buffer.hpp | 47 --- cpp/include/cuml/random_projection/rproj_c.h | 30 +- cpp/include/cuml/svm/svc.hpp | 14 +- cpp/include/cuml/svm/svm_model.h | 2 +- cpp/include/cuml/svm/svm_parameter.h | 4 +- cpp/include/cuml/svm/svr.hpp | 8 +- cpp/include/cuml/tsa/arima_common.h | 42 +- cpp/src/arima/batched_arima.cu | 38 +- cpp/src/arima/batched_kalman.cu | 132 ++----- cpp/src/common/cuml_api.cpp | 52 --- cpp/src/common/tensor.hpp | 25 +- cpp/src/datasets/make_arima.cu | 5 +- cpp/src/datasets/make_blobs.cu | 4 - cpp/src/datasets/make_regression.cu | 1 - cpp/src/dbscan/adjgraph/algo.cuh | 11 +- cpp/src/dbscan/adjgraph/naive.cuh | 12 +- cpp/src/dbscan/corepoints/compute.cuh | 5 +- cpp/src/dbscan/dbscan.cuh | 3 +- cpp/src/dbscan/runner.cuh | 31 +- .../batched-levelalgo/builder.cuh | 4 +- .../batched-levelalgo/builder_base.cuh | 20 +- cpp/src/decisiontree/decisiontree.cu | 1 - cpp/src/decisiontree/decisiontree.cuh | 17 +- cpp/src/decisiontree/quantile/quantile.cuh | 31 +- cpp/src/decisiontree/quantile/quantile.h | 9 +- cpp/src/fil/fil.cu | 65 ++- cpp/src/glm/ols.cuh | 4 +- cpp/src/glm/preprocess.cuh | 7 +- cpp/src/glm/qn/simple_mat/dense.hpp | 14 +- cpp/src/glm/qn/simple_mat/sparse.hpp | 6 +- cpp/src/glm/ridge.cuh | 9 +- cpp/src/hdbscan/condensed_hierarchy.cu | 8 +- cpp/src/hdbscan/detail/condense.cuh | 2 +- cpp/src/hdbscan/detail/extract.cuh | 4 +- cpp/src/hdbscan/detail/membership.cuh | 2 +- cpp/src/hdbscan/detail/reachability.cuh | 8 +- cpp/src/hdbscan/detail/select.cuh | 20 +- cpp/src/hdbscan/detail/stabilities.cuh | 4 +- cpp/src/hdbscan/detail/utils.h | 7 +- cpp/src/hdbscan/runner.h | 8 +- cpp/src/hierarchy/pw_dist_graph.cuh | 4 +- cpp/src/holtwinters/internal/hw_decompose.cuh | 24 +- cpp/src/holtwinters/internal/hw_eval.cuh | 3 +- cpp/src/holtwinters/internal/hw_optim.cuh | 3 +- cpp/src/holtwinters/runner.cuh | 48 +-- cpp/src/kmeans/common.cuh | 112 +++--- cpp/src/kmeans/kmeans_mg_impl.cuh | 120 +++--- cpp/src/kmeans/sg_impl.cuh | 121 +++--- cpp/src/knn/knn.cu | 32 +- cpp/src/knn/knn_opg_common.cuh | 45 +-- cpp/src/metrics/accuracy_score.cu | 3 +- cpp/src/metrics/adjusted_rand_index.cu | 4 +- cpp/src/metrics/completeness_score.cu | 9 +- cpp/src/metrics/entropy.cu | 2 +- cpp/src/metrics/homogeneity_score.cu | 9 +- cpp/src/metrics/kl_divergence.cu | 6 +- cpp/src/metrics/mutual_info_score.cu | 9 +- cpp/src/metrics/pairwise_distance_canberra.cu | 5 +- .../metrics/pairwise_distance_chebyshev.cu | 5 +- cpp/src/metrics/pairwise_distance_cosine.cu | 5 +- .../metrics/pairwise_distance_euclidean.cu | 5 +- .../metrics/pairwise_distance_hellinger.cu | 5 +- cpp/src/metrics/pairwise_distance_l1.cu | 5 +- .../metrics/pairwise_distance_minkowski.cu | 5 +- cpp/src/metrics/rand_index.cu | 3 +- cpp/src/metrics/silhouette_score.cu | 12 +- cpp/src/metrics/v_measure.cu | 9 +- cpp/src/ml_mg_utils.cuh | 19 +- cpp/src/pca/pca.cuh | 20 +- cpp/src/pca/pca_mg.cu | 18 +- cpp/src/pca/sign_flip_mg.cu | 59 +-- cpp/src/random_projection/rproj.cu | 2 - cpp/src/random_projection/rproj.cuh | 22 +- cpp/src/random_projection/rproj_utils.cuh | 19 +- cpp/src/randomforest/randomforest.cuh | 35 +- cpp/src/solver/cd.cuh | 14 +- cpp/src/solver/cd_mg.cu | 18 +- cpp/src/solver/lars_impl.cuh | 52 ++- cpp/src/solver/sgd.cuh | 19 +- cpp/src/svm/kernelcache.cuh | 32 +- cpp/src/svm/results.cuh | 48 ++- cpp/src/svm/smosolver.cuh | 59 +-- cpp/src/svm/svc.cu | 20 +- cpp/src/svm/svc_impl.cuh | 50 +-- cpp/src/svm/svm_api.cpp | 14 +- cpp/src/svm/svr.cu | 8 +- cpp/src/svm/svr_impl.cuh | 6 +- cpp/src/svm/workingset.cuh | 91 +++-- cpp/src/tsa/auto_arima.cu | 15 +- cpp/src/tsa/auto_arima.cuh | 44 +-- cpp/src/tsa/stationarity.cu | 4 +- cpp/src/tsne/barnes_hut_tsne.cuh | 47 ++- cpp/src/tsne/distances.cuh | 3 +- cpp/src/tsne/exact_kernels.cuh | 1 - cpp/src/tsne/exact_tsne.cuh | 18 +- cpp/src/tsne/fft_tsne.cuh | 25 +- cpp/src/tsne/tsne_runner.cuh | 4 +- cpp/src/tsvd/tsvd.cuh | 45 +-- cpp/src/tsvd/tsvd_mg.cu | 20 +- cpp/src/umap/fuzzy_simpl_set/naive.cuh | 16 +- cpp/src/umap/fuzzy_simpl_set/runner.cuh | 5 +- cpp/src/umap/init_embed/spectral_algo.cuh | 4 +- cpp/src/umap/knn_graph/algo.cuh | 9 - cpp/src/umap/knn_graph/runner.cuh | 5 +- cpp/src/umap/optimize.cuh | 36 +- cpp/src/umap/runner.cuh | 97 ++--- cpp/src/umap/simpl_set_embed/algo.cuh | 24 +- cpp/src/umap/simpl_set_embed/runner.cuh | 3 +- cpp/src/umap/supervised.cuh | 50 +-- cpp/src/umap/umap.cu | 3 +- cpp/src_prims/cache/cache.cuh | 51 ++- cpp/src_prims/functions/penalty.cuh | 3 +- cpp/src_prims/label/classlabels.cuh | 79 +--- cpp/src_prims/linalg/batched/matrix.cuh | 107 ++--- cpp/src_prims/linalg/lstsq.cuh | 10 +- cpp/src_prims/linalg/rsvd.cuh | 51 ++- cpp/src_prims/metrics/adjusted_rand_index.cuh | 40 +- .../metrics/batched/silhouette_score.cuh | 11 +- cpp/src_prims/metrics/completeness_score.cuh | 10 +- cpp/src_prims/metrics/dispersion.cuh | 9 +- cpp/src_prims/metrics/entropy.cuh | 26 +- cpp/src_prims/metrics/homogeneity_score.cuh | 9 +- cpp/src_prims/metrics/kl_divergence.cuh | 13 +- cpp/src_prims/metrics/mutual_info_score.cuh | 20 +- cpp/src_prims/metrics/rand_index.cuh | 8 +- cpp/src_prims/metrics/scores.cuh | 87 ++-- cpp/src_prims/metrics/silhouette_score.cuh | 36 +- .../metrics/trustworthiness_score.cuh | 13 +- cpp/src_prims/metrics/v_measure.cuh | 8 +- cpp/src_prims/random/make_arima.cuh | 17 +- cpp/src_prims/random/make_blobs.cuh | 7 +- cpp/src_prims/random/make_regression.cuh | 55 +-- cpp/src_prims/selection/knn.cuh | 24 +- cpp/src_prims/selection/processing.cuh | 39 +- cpp/src_prims/sparse/batched/csr.cuh | 45 +-- cpp/src_prims/timeSeries/arima_helpers.cuh | 14 +- cpp/src_prims/timeSeries/jones_transform.cuh | 4 - cpp/src_prims/timeSeries/stationarity.cuh | 23 +- cpp/test/CMakeLists.txt | 1 - cpp/test/mg/knn.cu | 14 +- cpp/test/mg/knn_regress.cu | 3 - cpp/test/mg/knn_test_helper.cuh | 7 +- cpp/test/mg/pca.cu | 16 +- cpp/test/prims/add_sub_dev_scalar.cu | 47 +-- cpp/test/prims/adjusted_rand_index.cu | 35 +- cpp/test/prims/batched/csr.cu | 13 +- cpp/test/prims/batched/gemv.cu | 53 +-- .../prims/batched/information_criterion.cu | 2 +- cpp/test/prims/batched/make_symm.cu | 35 +- cpp/test/prims/batched/matrix.cu | 21 +- cpp/test/prims/cache.cu | 206 +++++----- cpp/test/prims/columnSort.cu | 73 ++-- cpp/test/prims/completeness_score.cu | 33 +- cpp/test/prims/contingencyMatrix.cu | 73 ++-- cpp/test/prims/cov.cu | 77 ++-- cpp/test/prims/decoupled_lookback.cu | 30 +- cpp/test/prims/device_utils.cu | 30 +- cpp/test/prims/dispersion.cu | 47 +-- cpp/test/prims/dist_adj.cu | 64 +-- cpp/test/prims/distance_base.cuh | 69 ++-- cpp/test/prims/eltwise2d.cu | 41 +- cpp/test/prims/entropy.cu | 21 +- cpp/test/prims/epsilon_neighborhood.cu | 53 ++- cpp/test/prims/fast_int_div.cu | 20 +- cpp/test/prims/gather.cu | 60 ++- cpp/test/prims/gram.cu | 24 +- cpp/test/prims/grid_sync.cu | 36 +- cpp/test/prims/hinge.cu | 42 +- cpp/test/prims/histogram.cu | 36 +- cpp/test/prims/homogeneity_score.cu | 33 +- cpp/test/prims/host_buffer.cu | 113 ------ cpp/test/prims/jones_transform.cu | 50 +-- cpp/test/prims/kl_divergence.cu | 26 +- cpp/test/prims/knn_classify.cu | 28 +- cpp/test/prims/knn_regression.cu | 11 +- cpp/test/prims/kselection.cu | 47 +-- cpp/test/prims/label.cu | 53 ++- cpp/test/prims/linalg_block.cu | 42 +- cpp/test/prims/linearReg.cu | 42 +- cpp/test/prims/log.cu | 38 +- cpp/test/prims/logisticReg.cu | 44 +-- cpp/test/prims/make_arima.cu | 20 +- cpp/test/prims/make_blobs.cu | 22 +- cpp/test/prims/make_regression.cu | 12 +- cpp/test/prims/merge_labels.cu | 2 +- cpp/test/prims/minmax.cu | 57 +-- cpp/test/prims/mutual_info_score.cu | 31 +- cpp/test/prims/mvg.cu | 88 +++-- cpp/test/prims/penalty.cu | 28 +- cpp/test/prims/permute.cu | 57 ++- cpp/test/prims/power.cu | 59 ++- cpp/test/prims/rand_index.cu | 29 +- cpp/test/prims/reduce_cols_by_key.cu | 45 ++- cpp/test/prims/reduce_rows_by_key.cu | 14 +- cpp/test/prims/reverse.cu | 46 ++- cpp/test/prims/rsvd.cu | 133 +++---- cpp/test/prims/score.cu | 50 +-- cpp/test/prims/sigmoid.cu | 34 +- cpp/test/prims/silhouette_score.cu | 56 +-- cpp/test/prims/sqrt.cu | 43 +- cpp/test/prims/ternary_op.cu | 18 +- cpp/test/prims/trustworthiness.cu | 16 +- cpp/test/prims/v_measure.cu | 37 +- cpp/test/prims/weighted_mean.cu | 4 +- cpp/test/sg/cd_test.cu | 22 +- cpp/test/sg/dbscan_test.cu | 28 +- cpp/test/sg/decisiontree_batchedlevel_algo.cu | 214 ++++++++++ .../sg/decisiontree_batchedlevel_unittest.cu | 374 ++++++++++++++++++ cpp/test/sg/fil_test.cu | 24 +- cpp/test/sg/hdbscan_test.cu | 23 +- cpp/test/sg/holtwinters_test.cu | 4 +- cpp/test/sg/kmeans_test.cu | 64 +-- cpp/test/sg/knn_test.cu | 43 +- cpp/test/sg/lars_test.cu | 106 +++-- cpp/test/sg/linkage_test.cu | 31 +- cpp/test/sg/ols.cu | 40 +- cpp/test/sg/pca_test.cu | 44 +-- cpp/test/sg/quasi_newton.cu | 44 +-- cpp/test/sg/rf_test.cu | 18 +- cpp/test/sg/rf_treelite_test.cu | 56 ++- cpp/test/sg/ridge.cu | 40 +- cpp/test/sg/rproj_test.cu | 46 +-- cpp/test/sg/sgd.cu | 38 +- cpp/test/sg/shap_kernel.cu | 17 +- cpp/test/sg/svc_test.cu | 210 +++++----- cpp/test/sg/trustworthiness_test.cu | 18 +- cpp/test/sg/tsne_test.cu | 11 +- cpp/test/sg/tsvd_test.cu | 24 +- cpp/test/sg/umap_parametrizable_test.cu | 61 ++- .../random_projection/random_projection.pyx | 8 +- python/cuml/svm/svc.pyx | 23 +- python/cuml/svm/svm_base.pyx | 50 +-- python/cuml/svm/svr.pyx | 26 +- python/cuml/test/test_naive_bayes.py | 1 + 260 files changed, 3714 insertions(+), 4409 deletions(-) delete mode 100644 cpp/include/cuml/common/device_buffer.hpp delete mode 100644 cpp/include/cuml/common/host_buffer.hpp delete mode 100644 cpp/test/prims/host_buffer.cu create mode 100644 cpp/test/sg/decisiontree_batchedlevel_algo.cu create mode 100644 cpp/test/sg/decisiontree_batchedlevel_unittest.cu diff --git a/cpp/bench/common/ml_benchmark.hpp b/cpp/bench/common/ml_benchmark.hpp index 15a606b502..ee9f0289b1 100644 --- a/cpp/bench/common/ml_benchmark.hpp +++ b/cpp/bench/common/ml_benchmark.hpp @@ -80,7 +80,7 @@ struct CudaEventTimer { private: ::benchmark::State* state; - cudaStream_t stream; + cudaStream_t stream = 0; cudaEvent_t start; cudaEvent_t stop; }; // end struct CudaEventTimer @@ -88,11 +88,7 @@ struct CudaEventTimer { /** Main fixture to be inherited and used by all other c++ benchmarks in cuml */ class Fixture : public ::benchmark::Fixture { public: - Fixture(const std::string& name, std::shared_ptr _alloc) - : ::benchmark::Fixture(), d_alloc(_alloc) - { - SetName(name.c_str()); - } + Fixture(const std::string& name) : ::benchmark::Fixture() { SetName(name.c_str()); } Fixture() = delete; void SetUp(const ::benchmark::State& state) override @@ -163,19 +159,20 @@ class Fixture : public ::benchmark::Fixture { template void alloc(T*& ptr, size_t len, bool init = false) { - auto nBytes = len * sizeof(T); - ptr = (T*)d_alloc->allocate(nBytes, stream); + auto nBytes = len * sizeof(T); + auto d_alloc = rmm::mr::get_current_device_resource(); + ptr = (T*)d_alloc->allocate(nBytes, stream); if (init) { CUDA_CHECK(cudaMemsetAsync(ptr, 0, nBytes, stream)); } } template void dealloc(T* ptr, size_t len) { + auto d_alloc = rmm::mr::get_current_device_resource(); d_alloc->deallocate(ptr, len * sizeof(T), stream); } - std::shared_ptr d_alloc; - cudaStream_t stream; + cudaStream_t stream = 0; int l2CacheSize; char* scratchBuffer; }; // class Fixture diff --git a/cpp/bench/prims/add.cu b/cpp/bench/prims/add.cu index 25a6a0acb0..1665ad7656 100644 --- a/cpp/bench/prims/add.cu +++ b/cpp/bench/prims/add.cu @@ -16,7 +16,6 @@ #include #include -#include namespace MLCommon { namespace Bench { @@ -28,13 +27,7 @@ struct AddParams { template struct AddBench : public Fixture { - AddBench(const std::string& name, const AddParams& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) - { - } + AddBench(const std::string& name, const AddParams& p) : Fixture(name), params(p) {} protected: void allocateBuffers(const ::benchmark::State& state) override diff --git a/cpp/bench/prims/distance_common.cuh b/cpp/bench/prims/distance_common.cuh index 465d45be15..cc4eff27db 100644 --- a/cpp/bench/prims/distance_common.cuh +++ b/cpp/bench/prims/distance_common.cuh @@ -17,7 +17,6 @@ #include #include #include -#include namespace MLCommon { namespace Bench { @@ -31,42 +30,34 @@ struct Params { template struct Distance : public Fixture { Distance(const std::string& name, const Params& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) + : Fixture(name), params(p), x(0, stream), y(0, stream), out(0, stream), workspace(0, stream) { } protected: void allocateBuffers(const ::benchmark::State& state) override { - alloc(x, params.m * params.k, true); - alloc(y, params.n * params.k, true); - alloc(out, params.m * params.n, true); - workspace = nullptr; - worksize = raft::distance::getWorkspaceSize(x, y, params.m, params.n, params.k); - if (worksize != 0) { alloc(workspace, worksize, false); } - } - - void deallocateBuffers(const ::benchmark::State& state) override - { - dealloc(x, params.m * params.k); - dealloc(y, params.n * params.k); - dealloc(out, params.m * params.n); - dealloc(workspace, worksize); + x.resize(params.m * params.k, stream); + y.resize(params.n * params.k, stream); + out.resize(params.m * params.n, stream); + CUDA_CHECK(cudaMemsetAsync(x.data(), 0, x.size() * sizeof(T), stream)); + CUDA_CHECK(cudaMemsetAsync(y.data(), 0, y.size() * sizeof(T), stream)); + CUDA_CHECK(cudaMemsetAsync(out.data(), 0, out.size() * sizeof(T), stream)); + worksize = raft::distance::getWorkspaceSize( + x.data(), y.data(), params.m, params.n, params.k); + workspace.resize(worksize, stream); } void runBenchmark(::benchmark::State& state) override { loopOnState(state, [this]() { - raft::distance::distance(x, - y, - out, + raft::distance::distance(x.data(), + y.data(), + out.data(), params.m, params.n, params.k, - (void*)workspace, + (void*)workspace.data(), worksize, stream, params.isRowMajor); @@ -75,8 +66,8 @@ struct Distance : public Fixture { private: Params params; - T *x, *y, *out; - char* workspace; + rmm::device_uvector x, y, out; + rmm::device_uvector workspace; size_t worksize; }; // struct Distance diff --git a/cpp/bench/prims/fused_l2_nn.cu b/cpp/bench/prims/fused_l2_nn.cu index d3a35f3e7e..ef21a03881 100644 --- a/cpp/bench/prims/fused_l2_nn.cu +++ b/cpp/bench/prims/fused_l2_nn.cu @@ -19,7 +19,6 @@ #include #include #include -#include #include namespace MLCommon { @@ -32,13 +31,7 @@ struct FLNParams { template struct FusedL2NN : public Fixture { - FusedL2NN(const std::string& name, const FLNParams& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) - { - } + FusedL2NN(const std::string& name, const FLNParams& p) : Fixture(name), params(p) {} protected: void allocateBuffers(const ::benchmark::State& state) override diff --git a/cpp/bench/prims/gram_matrix.cu b/cpp/bench/prims/gram_matrix.cu index c561a875c2..8efb858b30 100644 --- a/cpp/bench/prims/gram_matrix.cu +++ b/cpp/bench/prims/gram_matrix.cu @@ -15,11 +15,11 @@ */ #include +#include #include #include #include #include -#include #include #include #include @@ -42,10 +42,7 @@ struct GramTestParams { template struct GramMatrix : public Fixture { GramMatrix(const std::string& name, const GramTestParams& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) + : Fixture(name), params(p), A(0, stream), B(0, stream), C(0, stream) { std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; std::ostringstream oss; @@ -63,31 +60,24 @@ struct GramMatrix : public Fixture { protected: void allocateBuffers(const ::benchmark::State& state) override { - alloc(A, params.m * params.k); - alloc(B, params.k * params.n); - alloc(C, params.m * params.n); + A.resize(params.m * params.k, stream); + B.resize(params.k * params.n, stream); + C.resize(params.m * params.n, stream); raft::random::Rng r(123456ULL); - r.uniform(A, params.m * params.k, T(-1.0), T(1.0), stream); - r.uniform(B, params.k * params.n, T(-1.0), T(1.0), stream); - } - - void deallocateBuffers(const ::benchmark::State& state) override - { - dealloc(A, params.m * params.k); - dealloc(B, params.k * params.n); - dealloc(C, params.m * params.n); + r.uniform(A.data(), params.m * params.k, T(-1.0), T(1.0), stream); + r.uniform(B.data(), params.k * params.n, T(-1.0), T(1.0), stream); } void runBenchmark(::benchmark::State& state) override { if (!this->kernel) { state.SkipWithError("Kernel matrix is not initialized"); } loopOnState(state, [this]() { - (*this->kernel)(this->A, + (*this->kernel)(A.data(), this->params.m, this->params.k, - this->B, + B.data(), this->params.n, - this->C, + C.data(), this->params.is_row_major, this->stream); }); @@ -98,9 +88,9 @@ struct GramMatrix : public Fixture { std::unique_ptr> kernel; GramTestParams params; - T* A; // input matrix A, size [m * k] - T* B; // input matrix B, size [n * k] - T* C; // output matrix C, size [m*n] + rmm::device_uvector A; // input matrix A, size [m * k] + rmm::device_uvector B; // input matrix B, size [n * k] + rmm::device_uvector C; // output matrix C, size [m*n] }; static std::vector getInputs() diff --git a/cpp/bench/prims/make_blobs.cu b/cpp/bench/prims/make_blobs.cu index dacc6d0688..68d8109f25 100644 --- a/cpp/bench/prims/make_blobs.cu +++ b/cpp/bench/prims/make_blobs.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace MLCommon { @@ -30,35 +29,25 @@ struct Params { template struct MakeBlobs : public Fixture { MakeBlobs(const std::string& name, const Params& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) + : Fixture(name), params(p), data(0, stream), labels(0, stream) { } protected: void allocateBuffers(const ::benchmark::State& state) override { - alloc(data, params.rows * params.cols); - alloc(labels, params.rows); - } - - void deallocateBuffers(const ::benchmark::State& state) override - { - dealloc(data, params.rows * params.cols); - dealloc(labels, params.rows); + data.resize(params.rows * params.cols, stream); + labels.resize(params.rows, stream); } void runBenchmark(::benchmark::State& state) override { loopOnState(state, [this]() { - MLCommon::Random::make_blobs(data, - labels, + MLCommon::Random::make_blobs(data.data(), + labels.data(), params.rows, params.cols, params.clusters, - this->d_alloc, this->stream, params.row_major); }); @@ -66,8 +55,8 @@ struct MakeBlobs : public Fixture { private: Params params; - T* data; - int* labels; + rmm::device_uvector data; + rmm::device_uvector labels; }; // struct MakeBlobs static std::vector getInputs() diff --git a/cpp/bench/prims/map_then_reduce.cu b/cpp/bench/prims/map_then_reduce.cu index 87c565e71a..6f451672ba 100644 --- a/cpp/bench/prims/map_then_reduce.cu +++ b/cpp/bench/prims/map_then_reduce.cu @@ -16,7 +16,6 @@ #include #include -#include namespace MLCommon { namespace Bench { @@ -33,13 +32,7 @@ struct Identity { template struct MapThenReduce : public Fixture { - MapThenReduce(const std::string& name, const Params& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) - { - } + MapThenReduce(const std::string& name, const Params& p) : Fixture(name), params(p) {} protected: void allocateBuffers(const ::benchmark::State& state) override diff --git a/cpp/bench/prims/matrix_vector_op.cu b/cpp/bench/prims/matrix_vector_op.cu index a67680fb74..35cc0122d5 100644 --- a/cpp/bench/prims/matrix_vector_op.cu +++ b/cpp/bench/prims/matrix_vector_op.cu @@ -16,7 +16,6 @@ #include #include -#include namespace MLCommon { namespace Bench { @@ -29,13 +28,7 @@ struct Params { template struct MatVecOp : public Fixture { - MatVecOp(const std::string& name, const Params& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) - { - } + MatVecOp(const std::string& name, const Params& p) : Fixture(name), params(p) {} protected: void allocateBuffers(const ::benchmark::State& state) override diff --git a/cpp/bench/prims/permute.cu b/cpp/bench/prims/permute.cu index 0404a79679..34475d18ca 100644 --- a/cpp/bench/prims/permute.cu +++ b/cpp/bench/prims/permute.cu @@ -31,13 +31,7 @@ struct Params { template struct Permute : public Fixture { - Permute(const std::string& name, const Params& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) - { - } + Permute(const std::string& name, const Params& p) : Fixture(name), params(p) {} protected: void allocateBuffers(const ::benchmark::State& state) override diff --git a/cpp/bench/prims/reduce.cu b/cpp/bench/prims/reduce.cu index d97b4120d3..cb593c2a3d 100644 --- a/cpp/bench/prims/reduce.cu +++ b/cpp/bench/prims/reduce.cu @@ -16,7 +16,6 @@ #include #include -#include namespace MLCommon { namespace Bench { @@ -29,13 +28,7 @@ struct Params { template struct Reduce : public Fixture { - Reduce(const std::string& name, const Params& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) - { - } + Reduce(const std::string& name, const Params& p) : Fixture(name), params(p) {} protected: void allocateBuffers(const ::benchmark::State& state) override diff --git a/cpp/bench/prims/rng.cu b/cpp/bench/prims/rng.cu index b7a32ee7b9..5eb6caa31a 100644 --- a/cpp/bench/prims/rng.cu +++ b/cpp/bench/prims/rng.cu @@ -16,7 +16,6 @@ #include #include -#include #include namespace MLCommon { @@ -45,13 +44,7 @@ struct Params { template struct RngBench : public Fixture { - RngBench(const std::string& name, const Params& p) - : Fixture( - name, - std::shared_ptr(new raft::mr::device::default_allocator)), - params(p) - { - } + RngBench(const std::string& name, const Params& p) : Fixture(name), params(p) {} protected: void allocateBuffers(const ::benchmark::State& state) override { alloc(ptr, params.len); } diff --git a/cpp/bench/sg/arima_loglikelihood.cu b/cpp/bench/sg/arima_loglikelihood.cu index 4cffe92bfb..2f7cce35eb 100644 --- a/cpp/bench/sg/arima_loglikelihood.cu +++ b/cpp/bench/sg/arima_loglikelihood.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "benchmark.cuh" @@ -39,7 +40,12 @@ template class ArimaLoglikelihood : public TsFixtureRandom { public: ArimaLoglikelihood(const std::string& name, const ArimaParams& p) - : TsFixtureRandom(name, p.data), order(p.order) + : TsFixtureRandom(name, p.data), + order(p.order), + param(0, rmm::cuda_stream_default), + loglike(0, rmm::cuda_stream_default), + residual(0, rmm::cuda_stream_default), + temp_mem(0, rmm::cuda_stream_default) { } @@ -55,9 +61,9 @@ class ArimaLoglikelihood : public TsFixtureRandom { // Generate random parameters int N = order.complexity(); raft::random::Rng gpu_gen(this->params.seed, raft::random::GenPhilox); - gpu_gen.uniform(param, N * this->params.batch_size, -1.0, 1.0, stream); + gpu_gen.uniform(param.data(), N * this->params.batch_size, -1.0, 1.0, stream); // Set sigma2 parameters to 1.0 - DataT* x = param; // copy the object attribute for thrust + DataT* x = param.data(); // copy the object attribute for thrust thrust::for_each(thrust::cuda::par.on(stream), counting, counting + this->params.batch_size, @@ -67,18 +73,19 @@ class ArimaLoglikelihood : public TsFixtureRandom { // Benchmark loop this->loopOnState(state, [this]() { - ARIMAMemory arima_mem(order, this->params.batch_size, this->params.n_obs, temp_mem); + ARIMAMemory arima_mem( + order, this->params.batch_size, this->params.n_obs, temp_mem.data()); // Evaluate log-likelihood batched_loglike(*this->handle, arima_mem, - this->data.X, + this->data.X.data(), this->params.batch_size, this->params.n_obs, order, - param, - loglike, - residual, + param.data(), + loglike.data(), + residual.data(), true, false); }); @@ -88,46 +95,30 @@ class ArimaLoglikelihood : public TsFixtureRandom { { Fixture::allocateBuffers(state); - auto& handle = *this->handle; - auto stream = handle.get_stream(); - auto allocator = handle.get_device_allocator(); + auto& handle = *this->handle; + auto stream = handle.get_stream(); // Buffer for the model parameters - param = (DataT*)allocator->allocate( - order.complexity() * this->params.batch_size * sizeof(DataT), stream); + param.resize(order.complexity() * this->params.batch_size, stream); // Buffers for the log-likelihood and residuals - loglike = (DataT*)allocator->allocate(this->params.batch_size * sizeof(DataT), stream); - residual = (DataT*)allocator->allocate( - this->params.batch_size * this->params.n_obs * sizeof(DataT), stream); + loglike.resize(this->params.batch_size, stream); + residual.resize(this->params.batch_size * this->params.n_obs, stream); // Temporary memory size_t temp_buf_size = ARIMAMemory::compute_size(order, this->params.batch_size, this->params.n_obs); - temp_mem = (char*)allocator->allocate(temp_buf_size, stream); + temp_mem.resize(temp_buf_size, stream); } - void deallocateBuffers(const ::benchmark::State& state) - { - Fixture::deallocateBuffers(state); - - auto& handle = *this->handle; - auto stream = handle.get_stream(); - auto allocator = handle.get_device_allocator(); - - allocator->deallocate( - param, order.complexity() * this->params.batch_size * sizeof(DataT), stream); - allocator->deallocate(loglike, this->params.batch_size * sizeof(DataT), stream); - allocator->deallocate( - residual, this->params.batch_size * this->params.n_obs * sizeof(DataT), stream); - } + void deallocateBuffers(const ::benchmark::State& state) { Fixture::deallocateBuffers(state); } protected: ARIMAOrder order; - DataT* param; - DataT* loglike; - DataT* residual; - char* temp_mem; + rmm::device_uvector param; + rmm::device_uvector loglike; + rmm::device_uvector residual; + rmm::device_uvector temp_mem; }; std::vector getInputs() diff --git a/cpp/bench/sg/benchmark.cuh b/cpp/bench/sg/benchmark.cuh index 2537ea3723..c2cd8a9ce6 100644 --- a/cpp/bench/sg/benchmark.cuh +++ b/cpp/bench/sg/benchmark.cuh @@ -32,17 +32,12 @@ namespace Bench { /** Main fixture to be inherited and used by all algos in cuML benchmark */ class Fixture : public MLCommon::Bench::Fixture { public: - Fixture(const std::string& name) - : MLCommon::Bench::Fixture( - name, std::shared_ptr(new raft::mr::device::default_allocator)) - { - } + Fixture(const std::string& name) : MLCommon::Bench::Fixture(name) {} Fixture() = delete; void SetUp(const ::benchmark::State& state) override { handle.reset(new raft::handle_t(NumStreams)); - d_alloc = handle->get_device_allocator(); MLCommon::Bench::Fixture::SetUp(state); handle->set_stream(stream); } @@ -176,11 +171,6 @@ class TsFixtureRandom : public Fixture { data.random(*handle, params); } - void deallocateData(const ::benchmark::State& state) override - { - data.deallocate(*handle, params); - } - TimeSeriesParams params; TimeSeriesDataset data; }; // end class TsFixtureRandom diff --git a/cpp/bench/sg/dataset.cuh b/cpp/bench/sg/dataset.cuh index 8af8f8c764..b6312f5c9e 100644 --- a/cpp/bench/sg/dataset.cuh +++ b/cpp/bench/sg/dataset.cuh @@ -74,27 +74,26 @@ struct RegressionParams { */ template struct Dataset { + Dataset() : X(0, rmm::cuda_stream_default), y(0, rmm::cuda_stream_default) {} /** input data */ - D* X; + rmm::device_uvector X; /** labels or output associated with each row of input data */ - L* y; + rmm::device_uvector y; /** allocate space needed for the dataset */ void allocate(const raft::handle_t& handle, const DatasetParams& p) { - auto allocator = handle.get_device_allocator(); - auto stream = handle.get_stream(); - X = (D*)allocator->allocate(p.nrows * p.ncols * sizeof(D), stream); - y = (L*)allocator->allocate(p.nrows * sizeof(L), stream); + auto stream = handle.get_stream(); + X.resize(p.nrows * p.ncols, stream); + y.resize(p.nrows, stream); } /** free-up the buffers */ void deallocate(const raft::handle_t& handle, const DatasetParams& p) { - auto allocator = handle.get_device_allocator(); - auto stream = handle.get_stream(); - allocator->deallocate(X, p.nrows * p.ncols * sizeof(D), stream); - allocator->deallocate(y, p.nrows * sizeof(L), stream); + auto stream = handle.get_stream(); + X.release(); + y.release(); } /** whether the current dataset is for classification or regression */ @@ -109,19 +108,20 @@ struct Dataset { const auto& handle_impl = handle; auto stream = handle_impl.get_stream(); auto cublas_handle = handle_impl.get_cublas_handle(); - auto allocator = handle_impl.get_device_allocator(); // Make blobs will generate labels of type IdxT which has to be an integer // type. We cast it to a different output type if needed. IdxT* tmpY; + rmm::device_uvector tmpY_vec(0, stream); if (std::is_same::value) { - tmpY = (IdxT*)y; + tmpY = (IdxT*)y.data(); } else { - tmpY = (IdxT*)allocator->allocate(p.nrows * sizeof(IdxT), stream); + tmpY_vec.resize(p.nrows, stream); + tmpY = tmpY_vec.data(); } ML::Datasets::make_blobs(handle, - X, + X.data(), tmpY, p.nrows, p.ncols, @@ -136,8 +136,7 @@ struct Dataset { b.seed); if (!std::is_same::value) { raft::linalg::unaryOp( - y, tmpY, p.nrows, [] __device__(IdxT z) { return (L)z; }, stream); - allocator->deallocate(tmpY, p.nrows * sizeof(IdxT), stream); + y.data(), tmpY, p.nrows, [] __device__(IdxT z) { return (L)z; }, stream); } } @@ -152,14 +151,16 @@ struct Dataset { auto stream = handle_impl.get_stream(); auto cublas_handle = handle_impl.get_cublas_handle(); auto cusolver_handle = handle_impl.get_cusolver_dn_handle(); - auto allocator = handle_impl.get_device_allocator(); - D* tmpX = X; - - if (!p.rowMajor) { tmpX = (D*)allocator->allocate(p.nrows * p.ncols * sizeof(D), stream); } + D* tmpX = X.data(); + rmm::device_uvector tmpX_vec(0, stream); + if (!p.rowMajor) { + tmpX_vec.resize(p.nrows * p.ncols, stream); + tmpX = tmpX_vec.data(); + } MLCommon::Random::make_regression(handle, tmpX, - y, + y.data(), p.nrows, p.ncols, r.n_informative, @@ -172,10 +173,7 @@ struct Dataset { D(r.noise), r.shuffle, r.seed); - if (!p.rowMajor) { - raft::linalg::transpose(handle, tmpX, X, p.nrows, p.ncols, stream); - allocator->deallocate(tmpX, p.nrows * p.ncols * sizeof(D), stream); - } + if (!p.rowMajor) { raft::linalg::transpose(handle, tmpX, X.data(), p.nrows, p.ncols, stream); } } /** diff --git a/cpp/bench/sg/dataset_ts.cuh b/cpp/bench/sg/dataset_ts.cuh index dcc940aa2d..a8b9ac0790 100644 --- a/cpp/bench/sg/dataset_ts.cuh +++ b/cpp/bench/sg/dataset_ts.cuh @@ -37,23 +37,15 @@ struct TimeSeriesParams { */ template struct TimeSeriesDataset { + TimeSeriesDataset() : X(0, rmm::cuda_stream_default) {} + /** input data */ - DataT* X; + rmm::device_uvector X; /** allocate space needed for the dataset */ void allocate(const raft::handle_t& handle, const TimeSeriesParams& p) { - auto allocator = handle.get_device_allocator(); - auto stream = handle.get_stream(); - X = (DataT*)allocator->allocate(p.batch_size * p.n_obs * sizeof(DataT), stream); - } - - /** free-up the buffers */ - void deallocate(const raft::handle_t& handle, const TimeSeriesParams& p) - { - auto allocator = handle.get_device_allocator(); - auto stream = handle.get_stream(); - allocator->deallocate(X, p.batch_size * p.n_obs * sizeof(DataT), stream); + X.resize(p.batch_size * p.n_obs, handle.get_stream()); } /** generate random time series (normal distribution) */ @@ -63,7 +55,7 @@ struct TimeSeriesDataset { DataT sigma = 1) { raft::random::Rng gpu_gen(p.seed, raft::random::GenPhilox); - gpu_gen.normal(X, p.batch_size * p.n_obs, mu, sigma, handle.get_stream()); + gpu_gen.normal(X.data(), p.batch_size * p.n_obs, mu, sigma, handle.get_stream()); } }; diff --git a/cpp/bench/sg/dbscan.cu b/cpp/bench/sg/dbscan.cu index 544e0a45c7..799bf972eb 100644 --- a/cpp/bench/sg/dbscan.cu +++ b/cpp/bench/sg/dbscan.cu @@ -51,13 +51,13 @@ class Dbscan : public BlobsFixture { if (!this->params.rowMajor) { state.SkipWithError("Dbscan only supports row-major inputs"); } this->loopOnState(state, [this, &state]() { ML::Dbscan::fit(*this->handle, - this->data.X, + this->data.X.data(), this->params.nrows, this->params.ncols, D(dParams.eps), dParams.min_pts, raft::distance::L2SqrtUnexpanded, - this->data.y, + this->data.y.data(), this->core_sample_indices, dParams.max_bytes_per_batch); state.SetItemsProcessed(this->params.nrows * this->params.ncols); diff --git a/cpp/bench/sg/fil.cu b/cpp/bench/sg/fil.cu index 2108c1c2a1..a8f89481f1 100644 --- a/cpp/bench/sg/fil.cu +++ b/cpp/bench/sg/fil.cu @@ -72,14 +72,14 @@ class FIL : public RegressionFixture { if (!params.rowMajor) { state.SkipWithError("FIL only supports row-major inputs"); } if (params.nclasses > 1) { // convert regression ranges into [0..nclasses-1] - regression_to_classification(data.y, params.nrows, params.nclasses, stream); + regression_to_classification(data.y.data(), params.nrows, params.nclasses, stream); } // create model ML::RandomForestRegressorF rf_model; auto* mPtr = &rf_model; mPtr->trees = nullptr; size_t train_nrows = std::min(params.nrows, 1000); - fit(*handle, mPtr, data.X, train_nrows, params.ncols, data.y, p_rest.rf); + fit(*handle, mPtr, data.X.data(), train_nrows, params.ncols, data.y.data(), p_rest.rf); CUDA_CHECK(cudaStreamSynchronize(stream)); ML::build_treelite_forest(&model, &rf_model, params.ncols, params.nclasses > 1 ? 2 : 1); @@ -99,8 +99,12 @@ class FIL : public RegressionFixture { // Dataset allocates y assuming one output value per input row, // so not supporting predict_proba yet for (int i = 0; i < p_rest.predict_repetitions; i++) { - ML::fil::predict( - *this->handle, this->forest, this->data.y, this->data.X, this->params.nrows, false); + ML::fil::predict(*this->handle, + this->forest, + this->data.y.data(), + this->data.X.data(), + this->params.nrows, + false); } }); } diff --git a/cpp/bench/sg/kmeans.cu b/cpp/bench/sg/kmeans.cu index a74b9f091d..267baea305 100644 --- a/cpp/bench/sg/kmeans.cu +++ b/cpp/bench/sg/kmeans.cu @@ -45,12 +45,12 @@ class KMeans : public BlobsFixture { this->loopOnState(state, [this]() { ML::kmeans::fit_predict(*this->handle, kParams, - this->data.X, + this->data.X.data(), this->params.nrows, this->params.ncols, nullptr, centroids, - this->data.y, + this->data.y.data(), inertia, nIter); }); diff --git a/cpp/bench/sg/linkage.cu b/cpp/bench/sg/linkage.cu index cf0e5954c9..d9bd7b9fe5 100644 --- a/cpp/bench/sg/linkage.cu +++ b/cpp/bench/sg/linkage.cu @@ -48,7 +48,7 @@ class Linkage : public BlobsFixture { out_arrs.children = out_children; ML::single_linkage_neighbors(*this->handle, - this->data.X, + this->data.X.data(), this->params.nrows, this->params.ncols, &out_arrs, diff --git a/cpp/bench/sg/rf_classifier.cu b/cpp/bench/sg/rf_classifier.cu index b451d79075..9aa540454f 100644 --- a/cpp/bench/sg/rf_classifier.cu +++ b/cpp/bench/sg/rf_classifier.cu @@ -63,10 +63,10 @@ class RFClassifier : public BlobsFixture { mPtr->trees = nullptr; fit(*this->handle, mPtr, - this->data.X, + this->data.X.data(), this->params.nrows, this->params.ncols, - this->data.y, + this->data.y.data(), this->params.nclasses, rfParams); CUDA_CHECK(cudaStreamSynchronize(this->stream)); diff --git a/cpp/bench/sg/svc.cu b/cpp/bench/sg/svc.cu index 8d22775b5f..4a281658bc 100644 --- a/cpp/bench/sg/svc.cu +++ b/cpp/bench/sg/svc.cu @@ -32,8 +32,8 @@ struct SvcParams { DatasetParams data; BlobsParams blobs; MLCommon::Matrix::KernelParams kernel; - ML::SVM::svmParameter svm_param; - ML::SVM::svmModel model; + ML::SVM::SvmParameter svm_param; + ML::SVM::SvmModel model; }; template @@ -60,10 +60,10 @@ class SVC : public BlobsFixture { } this->loopOnState(state, [this]() { ML::SVM::svcFit(*this->handle, - this->data.X, + this->data.X.data(), this->params.nrows, this->params.ncols, - this->data.y, + this->data.y.data(), this->svm_param, this->kernel, this->model); @@ -74,8 +74,8 @@ class SVC : public BlobsFixture { private: MLCommon::Matrix::KernelParams kernel; - ML::SVM::svmParameter svm_param; - ML::SVM::svmModel model; + ML::SVM::SvmParameter svm_param; + ML::SVM::SvmModel model; }; template @@ -95,9 +95,9 @@ std::vector> getInputs() p.blobs.center_box_max = 2.0; p.blobs.seed = 12345ULL; - // svmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity}) - p.svm_param = ML::SVM::svmParameter{1, 200, 100, 100, 1e-3, CUML_LEVEL_INFO, 0, ML::SVM::C_SVC}; - p.model = ML::SVM::svmModel{0, 0, 0, nullptr, nullptr, nullptr, 0, nullptr}; + // SvmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity}) + p.svm_param = ML::SVM::SvmParameter{1, 200, 100, 100, 1e-3, CUML_LEVEL_INFO, 0, ML::SVM::C_SVC}; + p.model = ML::SVM::SvmModel{0, 0, 0, nullptr, nullptr, nullptr, 0, nullptr}; std::vector rowcols = {{50000, 2, 2}, {2048, 100000, 2}, {50000, 1000, 2}}; diff --git a/cpp/bench/sg/svr.cu b/cpp/bench/sg/svr.cu index 31d6dc2ba5..31be755472 100644 --- a/cpp/bench/sg/svr.cu +++ b/cpp/bench/sg/svr.cu @@ -32,8 +32,8 @@ struct SvrParams { DatasetParams data; RegressionParams regression; MLCommon::Matrix::KernelParams kernel; - ML::SVM::svmParameter svm_param; - ML::SVM::svmModel model; + ML::SVM::SvmParameter svm_param; + ML::SVM::SvmModel* model; }; template @@ -60,22 +60,22 @@ class SVR : public RegressionFixture { } this->loopOnState(state, [this]() { ML::SVM::svrFit(*this->handle, - this->data.X, + this->data.X.data(), this->params.nrows, this->params.ncols, - this->data.y, + this->data.y.data(), this->svm_param, this->kernel, - this->model); + *(this->model)); CUDA_CHECK(cudaStreamSynchronize(this->stream)); - ML::SVM::svmFreeBuffers(*this->handle, this->model); + ML::SVM::svmFreeBuffers(*this->handle, *(this->model)); }); } private: MLCommon::Matrix::KernelParams kernel; - ML::SVM::svmParameter svm_param; - ML::SVM::svmModel model; + ML::SVM::SvmParameter svm_param; + ML::SVM::SvmModel* model; }; template @@ -96,11 +96,11 @@ std::vector> getInputs() p.regression.tail_strength = 0.5; // unused when effective_rank = -1 p.regression.noise = 1; - // svmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity, + // SvmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity, // epsilon, svmType}) p.svm_param = - ML::SVM::svmParameter{1, 200, 200, 100, 1e-3, CUML_LEVEL_INFO, 0.1, ML::SVM::EPSILON_SVR}; - p.model = ML::SVM::svmModel{0, 0, 0, nullptr, nullptr, nullptr, 0, nullptr}; + ML::SVM::SvmParameter{1, 200, 200, 100, 1e-3, CUML_LEVEL_INFO, 0.1, ML::SVM::EPSILON_SVR}; + p.model = new ML::SVM::SvmModel{0, 0, 0, 0}; std::vector rowcols = {{50000, 2, 2}, {1024, 10000, 10}, {3000, 200, 200}}; @@ -130,4 +130,4 @@ ML_BENCH_REGISTER(SvrParams, SVR, "regression", getInputs { { alloc(yFloat, this->params.nrows); alloc(embeddings, this->params.nrows * uParams.n_components); - cast(yFloat, this->data.y, this->params.nrows, this->stream); + cast(yFloat, this->data.y.data(), this->params.nrows, this->stream); } void deallocateTempBuffers(const ::benchmark::State& state) override @@ -116,7 +116,7 @@ class UmapSupervised : public UmapBase { void coreBenchmarkMethod() { UMAP::fit(*this->handle, - this->data.X, + this->data.X.data(), yFloat, this->params.nrows, this->params.ncols, @@ -136,7 +136,7 @@ class UmapUnsupervised : public UmapBase { void coreBenchmarkMethod() { UMAP::fit(*this->handle, - this->data.X, + this->data.X.data(), nullptr, this->params.nrows, this->params.ncols, @@ -156,12 +156,12 @@ class UmapTransform : public UmapBase { void coreBenchmarkMethod() { UMAP::transform(*this->handle, - this->data.X, + this->data.X.data(), this->params.nrows, this->params.ncols, nullptr, nullptr, - this->data.X, + this->data.X.data(), this->params.nrows, embeddings, this->params.nrows, @@ -174,7 +174,7 @@ class UmapTransform : public UmapBase { auto& handle = *this->handle; alloc(transformed, this->params.nrows * uParams.n_components); UMAP::fit(handle, - this->data.X, + this->data.X.data(), yFloat, this->params.nrows, this->params.ncols, diff --git a/cpp/examples/dbscan/dbscan_example.cpp b/cpp/examples/dbscan/dbscan_example.cpp index 7bb882fed2..af7fd5e6bf 100644 --- a/cpp/examples/dbscan/dbscan_example.cpp +++ b/cpp/examples/dbscan/dbscan_example.cpp @@ -24,7 +24,6 @@ #include #include -#include #include @@ -139,10 +138,6 @@ int main(int argc, char* argv[]) raft::handle_t handle; - std::shared_ptr allocator(new raft::mr::device::default_allocator()); - - handle.set_device_allocator(allocator); - std::vector h_inputData; if (input == "") { diff --git a/cpp/examples/kmeans/kmeans_example.cpp b/cpp/examples/kmeans/kmeans_example.cpp index 69bd8db8ff..3aa9c20a4c 100644 --- a/cpp/examples/kmeans/kmeans_example.cpp +++ b/cpp/examples/kmeans/kmeans_example.cpp @@ -24,7 +24,6 @@ #include #include -#include #include @@ -130,11 +129,6 @@ int main(int argc, char* argv[]) raft::handle_t handle; - std::shared_ptr allocator( - new raft::mr::device::default_allocator()); - - handle.set_device_allocator(allocator); - cudaStream_t stream; CUDA_RT_CALL(cudaStreamCreate(&stream)); handle.set_stream(stream); diff --git a/cpp/include/cuml/common/device_buffer.hpp b/cpp/include/cuml/common/device_buffer.hpp deleted file mode 100644 index 2c42960ea9..0000000000 --- a/cpp/include/cuml/common/device_buffer.hpp +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace MLCommon { - -/** - * RAII object owning a contigous typed device buffer. The passed in allocator supports asynchronus - * allocation and deallocation so this can be used for temporary memory - * @code{.cpp} - * template - * void foo( const raft::handle_t& h, ..., cudaStream_t stream ) - * { - * ... - * device_buffer temp( h.get_device_allocator(), stream, 0 ) - * - * temp.resize(n, stream); - * kernelA<<>>(...,temp.data(),...); - * kernelB<<>>(...,temp.data(),...); - * temp.release(stream); - * } - * @endcode - */ -template -using device_buffer = raft::mr::device::buffer; - -} // namespace MLCommon diff --git a/cpp/include/cuml/common/host_buffer.hpp b/cpp/include/cuml/common/host_buffer.hpp deleted file mode 100644 index 423899f603..0000000000 --- a/cpp/include/cuml/common/host_buffer.hpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace MLCommon { - -/** - * RAII object owning a contigous typed host buffer. The passed in allocator supports asynchronus - * allocation and deallocation so this can be used for temporary memory - * @code{.cpp} - * template - * void foo( const raft::handle_t& h, const T* in_d , T* out_d, ..., cudaStream_t stream ) - * { - * ... - * host_buffer temp( handle->get_host_allocator(), stream, 0 ) - * - * temp.resize(n, stream); - * cudaMemcpyAsync( temp.data(), in_d, temp.size()*sizeof(T), cudaMemcpyDeviceToHost ); - * ... - * cudaMemcpyAsync( out_d, temp.data(), temp.size()*sizeof(T), cudaMemcpyHostToDevice ); - * temp.release(stream); - * } - * @endcode - * @todo: Add missing doxygen documentation - */ - -template -using host_buffer = raft::mr::host::buffer; - -} // namespace MLCommon diff --git a/cpp/include/cuml/random_projection/rproj_c.h b/cpp/include/cuml/random_projection/rproj_c.h index d4f1702b54..7e14e14e0d 100644 --- a/cpp/include/cuml/random_projection/rproj_c.h +++ b/cpp/include/cuml/random_projection/rproj_c.h @@ -16,10 +16,8 @@ #pragma once -#include - #include -#include +#include namespace ML { @@ -50,11 +48,11 @@ enum random_matrix_type { unset, dense, sparse }; template struct rand_mat { - rand_mat(std::shared_ptr allocator, cudaStream_t stream) - : dense_data(allocator, stream), - indices(allocator, stream), - indptr(allocator, stream), - sparse_data(allocator, stream), + rand_mat(cudaStream_t stream) + : dense_data(0, stream), + indices(0, stream), + indptr(0, stream), + sparse_data(0, stream), stream(stream), type(unset) { @@ -63,12 +61,12 @@ struct rand_mat { ~rand_mat() { this->reset(); } // For dense matrices - MLCommon::device_buffer dense_data; + rmm::device_uvector dense_data; // For sparse CSC matrices - MLCommon::device_buffer indices; - MLCommon::device_buffer indptr; - MLCommon::device_buffer sparse_data; + rmm::device_uvector indices; + rmm::device_uvector indptr; + rmm::device_uvector sparse_data; cudaStream_t stream; @@ -76,10 +74,10 @@ struct rand_mat { void reset() { - this->dense_data.release(this->stream); - this->indices.release(this->stream); - this->indptr.release(this->stream); - this->sparse_data.release(this->stream); + this->dense_data.release(); + this->indices.release(); + this->indptr.release(); + this->sparse_data.release(); this->type = unset; }; }; diff --git a/cpp/include/cuml/svm/svc.hpp b/cpp/include/cuml/svm/svc.hpp index f9770a665c..e56bdb26f3 100644 --- a/cpp/include/cuml/svm/svc.hpp +++ b/cpp/include/cuml/svm/svc.hpp @@ -55,9 +55,9 @@ void svcFit(const raft::handle_t& handle, int n_rows, int n_cols, math_t* labels, - const svmParameter& param, + const SvmParameter& param, MLCommon::Matrix::KernelParams& kernel_params, - svmModel& model, + SvmModel& model, const math_t* sample_weight = nullptr); /** @@ -95,19 +95,19 @@ void svcPredict(const raft::handle_t& handle, int n_rows, int n_cols, MLCommon::Matrix::KernelParams& kernel_params, - const svmModel& model, + const SvmModel& model, math_t* preds, math_t buffer_size, bool predict_class = true); /** - * Deallocate device buffers in the svmModel struct. + * Deallocate device buffers in the SvmModel struct. * * @param [in] handle cuML handle * @param [inout] m SVM model parameters */ template -void svmFreeBuffers(const raft::handle_t& handle, svmModel& m); +void svmFreeBuffers(const raft::handle_t& handle, SvmModel& m); /** * @brief C-Support Vector Classification @@ -134,8 +134,8 @@ class SVC { // Public members for easier access during testing from Python. MLCommon::Matrix::KernelParams kernel_params; - svmParameter param; - svmModel model; + SvmParameter param; + SvmModel model; /** * @brief Constructs a support vector classifier * @param handle cuML handle diff --git a/cpp/include/cuml/svm/svm_model.h b/cpp/include/cuml/svm/svm_model.h index 8b981f3316..edc4bd2fa5 100644 --- a/cpp/include/cuml/svm/svm_model.h +++ b/cpp/include/cuml/svm/svm_model.h @@ -23,7 +23,7 @@ namespace SVM { * All pointers are device pointers. */ template -struct svmModel { +struct SvmModel { int n_support; //!< Number of support vectors int n_cols; //!< Number of features math_t b; //!< Constant used in the decision function diff --git a/cpp/include/cuml/svm/svm_parameter.h b/cpp/include/cuml/svm/svm_parameter.h index f6be63060e..c5fc4ef2d0 100644 --- a/cpp/include/cuml/svm/svm_parameter.h +++ b/cpp/include/cuml/svm/svm_parameter.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ enum SvmType { C_SVC, NU_SVC, EPSILON_SVR, NU_SVR }; * - the diff is changing less then 0.001*tol in nochange_steps consecutive * outer iterations. */ -struct svmParameter { +struct SvmParameter { double C; //!< Penalty term C double cache_size; //!< kernel cache size in MiB //! maximum number of outer SMO iterations. Use -1 to let the SMO solver set diff --git a/cpp/include/cuml/svm/svr.hpp b/cpp/include/cuml/svm/svr.hpp index 6c8573f248..e03fb93a84 100644 --- a/cpp/include/cuml/svm/svr.hpp +++ b/cpp/include/cuml/svm/svr.hpp @@ -23,8 +23,8 @@ namespace ML { namespace SVM { template -struct svmModel; -struct svmParameter; +struct SvmModel; +struct SvmParameter; // Forward declarations of the stateless API /** @@ -52,9 +52,9 @@ void svrFit(const raft::handle_t& handle, int n_rows, int n_cols, math_t* y, - const svmParameter& param, + const SvmParameter& param, MLCommon::Matrix::KernelParams& kernel_params, - svmModel& model, + SvmModel& model, const math_t* sample_weight = nullptr); // For prediction we use svcPredict diff --git a/cpp/include/cuml/tsa/arima_common.h b/cpp/include/cuml/tsa/arima_common.h index 17dc2ec3b6..67c4874328 100644 --- a/cpp/include/cuml/tsa/arima_common.h +++ b/cpp/include/cuml/tsa/arima_common.h @@ -70,23 +70,18 @@ struct ARIMAParams { * @tparam AllocatorT Type of allocator used * @param[in] order ARIMA order * @param[in] batch_size Batch size - * @param[in] alloc Allocator * @param[in] stream CUDA stream * @param[in] tr Whether these are the transformed parameters */ - template - void allocate(const ARIMAOrder& order, - int batch_size, - AllocatorT& alloc, - cudaStream_t stream, - bool tr = false) + void allocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false) { - if (order.k && !tr) mu = (DataT*)alloc->allocate(batch_size * sizeof(DataT), stream); - if (order.p) ar = (DataT*)alloc->allocate(order.p * batch_size * sizeof(DataT), stream); - if (order.q) ma = (DataT*)alloc->allocate(order.q * batch_size * sizeof(DataT), stream); - if (order.P) sar = (DataT*)alloc->allocate(order.P * batch_size * sizeof(DataT), stream); - if (order.Q) sma = (DataT*)alloc->allocate(order.Q * batch_size * sizeof(DataT), stream); - sigma2 = (DataT*)alloc->allocate(batch_size * sizeof(DataT), stream); + rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource(); + if (order.k && !tr) mu = (DataT*)rmm_alloc->allocate(batch_size * sizeof(DataT), stream); + if (order.p) ar = (DataT*)rmm_alloc->allocate(order.p * batch_size * sizeof(DataT), stream); + if (order.q) ma = (DataT*)rmm_alloc->allocate(order.q * batch_size * sizeof(DataT), stream); + if (order.P) sar = (DataT*)rmm_alloc->allocate(order.P * batch_size * sizeof(DataT), stream); + if (order.Q) sma = (DataT*)rmm_alloc->allocate(order.Q * batch_size * sizeof(DataT), stream); + sigma2 = (DataT*)rmm_alloc->allocate(batch_size * sizeof(DataT), stream); } /** @@ -95,23 +90,18 @@ struct ARIMAParams { * @tparam AllocatorT Type of allocator used * @param[in] order ARIMA order * @param[in] batch_size Batch size - * @param[in] alloc Allocator * @param[in] stream CUDA stream * @param[in] tr Whether these are the transformed parameters */ - template - void deallocate(const ARIMAOrder& order, - int batch_size, - AllocatorT& alloc, - cudaStream_t stream, - bool tr = false) + void deallocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false) { - if (order.k && !tr) alloc->deallocate(mu, batch_size * sizeof(DataT), stream); - if (order.p) alloc->deallocate(ar, order.p * batch_size * sizeof(DataT), stream); - if (order.q) alloc->deallocate(ma, order.q * batch_size * sizeof(DataT), stream); - if (order.P) alloc->deallocate(sar, order.P * batch_size * sizeof(DataT), stream); - if (order.Q) alloc->deallocate(sma, order.Q * batch_size * sizeof(DataT), stream); - alloc->deallocate(sigma2, batch_size * sizeof(DataT), stream); + rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource(); + if (order.k && !tr) rmm_alloc->deallocate(mu, batch_size * sizeof(DataT), stream); + if (order.p) rmm_alloc->deallocate(ar, order.p * batch_size * sizeof(DataT), stream); + if (order.q) rmm_alloc->deallocate(ma, order.q * batch_size * sizeof(DataT), stream); + if (order.P) rmm_alloc->deallocate(sar, order.P * batch_size * sizeof(DataT), stream); + if (order.Q) rmm_alloc->deallocate(sma, order.Q * batch_size * sizeof(DataT), stream); + rmm_alloc->deallocate(sigma2, batch_size * sizeof(DataT), stream); } /** diff --git a/cpp/src/arima/batched_arima.cu b/cpp/src/arima/batched_arima.cu index acf8fbc4f9..9ebdd577c4 100644 --- a/cpp/src/arima/batched_arima.cu +++ b/cpp/src/arima/batched_arima.cu @@ -29,12 +29,12 @@ #include #include -#include #include #include #include #include #include +#include #include namespace ML { @@ -87,7 +87,6 @@ void predict(raft::handle_t& handle, double* d_upper) { ML::PUSH_RANGE(__func__); - auto allocator = handle.get_device_allocator(); const auto stream = handle.get_stream(); bool diff = order.need_diff() && pre_diff && level == 0; @@ -113,7 +112,7 @@ void predict(raft::handle_t& handle, // Create temporary array for the forecasts int num_steps = std::max(end - n_obs, 0); - MLCommon::device_buffer fc_buffer(allocator, stream, num_steps * batch_size); + rmm::device_uvector fc_buffer(num_steps * batch_size, stream); double* d_y_fc = fc_buffer.data(); // Compute the residual and forecast @@ -357,8 +356,7 @@ void batched_loglike(raft::handle_t& handle, { ML::PUSH_RANGE(__func__); - auto allocator = handle.get_device_allocator(); - auto stream = handle.get_stream(); + auto stream = handle.get_stream(); ARIMAParams Tparams = {arima_mem.Tparams_mu, arima_mem.Tparams_ar, @@ -374,12 +372,17 @@ void batched_loglike(raft::handle_t& handle, if (trans) { MLCommon::TimeSeries::batched_jones_transform( - order, batch_size, false, params, Tparams, allocator, stream); + order, batch_size, false, params, Tparams, stream); Tparams.mu = params.mu; } else { // non-transformed case: just use original parameters - Tparams = params; + Tparams.mu = params.mu; + Tparams.ar = params.ar; + Tparams.ma = params.ma; + Tparams.sar = params.sar; + Tparams.sma = params.sma; + Tparams.sigma2 = params.sigma2; } if (method == CSS) { @@ -430,8 +433,7 @@ void batched_loglike(raft::handle_t& handle, ML::PUSH_RANGE(__func__); // unpack parameters - auto allocator = handle.get_device_allocator(); - auto stream = handle.get_stream(); + auto stream = handle.get_stream(); ARIMAParams params = {arima_mem.params_mu, arima_mem.params_ar, @@ -478,10 +480,9 @@ void batched_loglike_grad(raft::handle_t& handle, int truncate) { ML::PUSH_RANGE(__func__); - auto allocator = handle.get_device_allocator(); - auto stream = handle.get_stream(); - auto counting = thrust::make_counting_iterator(0); - int N = order.complexity(); + auto stream = handle.get_stream(); + auto counting = thrust::make_counting_iterator(0); + int N = order.complexity(); // Initialize the perturbed x vector double* d_x_pert = arima_mem.x_pert; @@ -555,8 +556,7 @@ void information_criterion(raft::handle_t& handle, int ic_type) { ML::PUSH_RANGE(__func__); - auto allocator = handle.get_device_allocator(); - auto stream = handle.get_stream(); + auto stream = handle.get_stream(); double* d_vs = arima_mem.vs; @@ -636,7 +636,6 @@ void _arma_least_squares(raft::handle_t& handle, const auto& handle_impl = handle; auto stream = handle_impl.get_stream(); auto cublas_handle = handle_impl.get_cublas_handle(); - auto allocator = handle_impl.get_device_allocator(); auto counting = thrust::make_counting_iterator(0); int batch_size = bm_y.batches(); @@ -662,7 +661,7 @@ void _arma_least_squares(raft::handle_t& handle, * side by side. The left side will be used to estimate AR, the right * side to estimate MA */ MLCommon::LinAlg::Batched::Matrix bm_ls_ar_res( - n_obs - r, p + q + k, batch_size, cublas_handle, allocator, stream, false); + n_obs - r, p + q + k, batch_size, cublas_handle, stream, false); int ar_offset = r - ps; int res_offset = r - p_ar - qs; @@ -717,7 +716,7 @@ void _arma_least_squares(raft::handle_t& handle, // The residuals will be computed only if sigma2 is requested MLCommon::LinAlg::Batched::Matrix bm_final_residual( - n_obs - r, 1, batch_size, cublas_handle, allocator, stream, false); + n_obs - r, 1, batch_size, cublas_handle, stream, false); if (estimate_sigma2) { raft::copy( bm_final_residual.raw_data(), bm_arma_fit.raw_data(), (n_obs - r) * batch_size, stream); @@ -842,11 +841,10 @@ void estimate_x0(raft::handle_t& handle, const auto& handle_impl = handle; auto stream = handle_impl.get_stream(); auto cublas_handle = handle_impl.get_cublas_handle(); - auto allocator = handle_impl.get_device_allocator(); // Difference if necessary, copy otherwise MLCommon::LinAlg::Batched::Matrix bm_yd( - n_obs - order.d - order.s * order.D, 1, batch_size, cublas_handle, allocator, stream, false); + n_obs - order.d - order.s * order.D, 1, batch_size, cublas_handle, stream, false); MLCommon::TimeSeries::prepare_data( bm_yd.raw_data(), d_y, batch_size, n_obs, order.d, order.D, order.s, stream); diff --git a/cpp/src/arima/batched_kalman.cu b/cpp/src/arima/batched_kalman.cu index be64430401..604312faf1 100644 --- a/cpp/src/arima/batched_kalman.cu +++ b/cpp/src/arima/batched_kalman.cu @@ -25,12 +25,14 @@ #include #include -#include -#include -#include #include #include #include +#include + +#include +#include +#include #include namespace ML { @@ -584,7 +586,6 @@ void _batched_kalman_device_loop_large(const ARIMAMemory& arima_mem, "Gemm and gemv policies: block size mismatch"); auto stream = T.stream(); - auto allocator = T.allocator(); auto cublasHandle = T.cublasHandle(); int batch_size = T.batches(); @@ -595,18 +596,10 @@ void _batched_kalman_device_loop_large(const ARIMAMemory& arima_mem, cublasHandle, arima_mem.m_tmp_batches, arima_mem.m_tmp_dense, - allocator, stream, false); - MLCommon::LinAlg::Batched::Matrix TP(rd, - rd, - batch_size, - cublasHandle, - arima_mem.TP_batches, - arima_mem.TP_dense, - allocator, - stream, - false); + MLCommon::LinAlg::Batched::Matrix TP( + rd, rd, batch_size, cublasHandle, arima_mem.TP_batches, arima_mem.TP_dense, stream, false); int grid_size = std::min(batch_size, 65536); size_t shared_mem_size = 4 * rd * sizeof(double); @@ -1051,7 +1044,6 @@ void _lyapunov_wrapper(raft::handle_t& handle, if (r <= 5) { auto stream = handle.get_stream(); auto cublasHandle = handle.get_cublas_handle(); - auto allocator = handle.get_device_allocator(); int batch_size = A.batches(); int r2 = r * r; @@ -1065,7 +1057,6 @@ void _lyapunov_wrapper(raft::handle_t& handle, cublasHandle, arima_mem.I_m_AxA_batches, arima_mem.I_m_AxA_dense, - allocator, stream, false); MLCommon::LinAlg::Batched::Matrix I_m_AxA_inv(r2, @@ -1074,7 +1065,6 @@ void _lyapunov_wrapper(raft::handle_t& handle, cublasHandle, arima_mem.I_m_AxA_inv_batches, arima_mem.I_m_AxA_inv_dense, - allocator, stream, false); @@ -1111,7 +1101,6 @@ void _batched_kalman_filter(raft::handle_t& handle, const size_t batch_size = Zb.batches(); auto stream = handle.get_stream(); auto cublasHandle = handle.get_cublas_handle(); - auto allocator = handle.get_device_allocator(); auto counting = thrust::make_counting_iterator(0); @@ -1119,15 +1108,8 @@ void _batched_kalman_filter(raft::handle_t& handle, int rd = order.rd(); int r = order.r(); - MLCommon::LinAlg::Batched::Matrix RQb(rd, - 1, - batch_size, - cublasHandle, - arima_mem.RQ_batches, - arima_mem.RQ_dense, - allocator, - stream, - true); + MLCommon::LinAlg::Batched::Matrix RQb( + rd, 1, batch_size, cublasHandle, arima_mem.RQ_batches, arima_mem.RQ_dense, stream, true); double* d_RQ = RQb.raw_data(); const double* d_R = Rb.raw_data(); thrust::for_each( @@ -1137,28 +1119,14 @@ void _batched_kalman_filter(raft::handle_t& handle, d_RQ[bid * rd + i] = d_R[bid * rd + i] * sigma2; } }); - MLCommon::LinAlg::Batched::Matrix RQR(rd, - rd, - batch_size, - cublasHandle, - arima_mem.RQR_batches, - arima_mem.RQR_dense, - allocator, - stream, - false); + MLCommon::LinAlg::Batched::Matrix RQR( + rd, rd, batch_size, cublasHandle, arima_mem.RQR_batches, arima_mem.RQR_dense, stream, false); MLCommon::LinAlg::Batched::b_gemm(false, true, rd, rd, 1, 1.0, RQb, Rb, 0.0, RQR); // Durbin Koopman "Time Series Analysis" pg 138 ML::PUSH_RANGE("Init P"); - MLCommon::LinAlg::Batched::Matrix P(rd, - rd, - batch_size, - cublasHandle, - arima_mem.P_batches, - arima_mem.P_dense, - allocator, - stream, - true); + MLCommon::LinAlg::Batched::Matrix P( + rd, rd, batch_size, cublasHandle, arima_mem.P_batches, arima_mem.P_dense, stream, true); { double* d_P = P.raw_data(); @@ -1175,33 +1143,18 @@ void _batched_kalman_filter(raft::handle_t& handle, }); // Initialize the stationary part by solving a Lyapunov equation - MLCommon::LinAlg::Batched::Matrix Ts(r, - r, - batch_size, - cublasHandle, - arima_mem.Ts_batches, - arima_mem.Ts_dense, - allocator, - stream, - false); + MLCommon::LinAlg::Batched::Matrix Ts( + r, r, batch_size, cublasHandle, arima_mem.Ts_batches, arima_mem.Ts_dense, stream, false); MLCommon::LinAlg::Batched::Matrix RQRs(r, r, batch_size, cublasHandle, arima_mem.RQRs_batches, arima_mem.RQRs_dense, - allocator, stream, false); - MLCommon::LinAlg::Batched::Matrix Ps(r, - r, - batch_size, - cublasHandle, - arima_mem.Ps_batches, - arima_mem.Ps_dense, - allocator, - stream, - false); + MLCommon::LinAlg::Batched::Matrix Ps( + r, r, batch_size, cublasHandle, arima_mem.Ps_batches, arima_mem.Ps_dense, stream, false); MLCommon::LinAlg::Batched::b_2dcopy(Tb, Ts, n_diff, n_diff, r, r); MLCommon::LinAlg::Batched::b_2dcopy(RQR, RQRs, n_diff, n_diff, r, r); @@ -1229,20 +1182,12 @@ void _batched_kalman_filter(raft::handle_t& handle, handle.get_cublas_handle(), arima_mem.alpha_batches, arima_mem.alpha_dense, - handle.get_device_allocator(), stream, false); if (intercept) { // Compute I-T* - MLCommon::LinAlg::Batched::Matrix ImT(r, - r, - batch_size, - cublasHandle, - arima_mem.ImT_batches, - arima_mem.ImT_dense, - allocator, - stream, - false); + MLCommon::LinAlg::Batched::Matrix ImT( + r, r, batch_size, cublasHandle, arima_mem.ImT_batches, arima_mem.ImT_dense, stream, false); const double* d_T = Tb.raw_data(); double* d_ImT = ImT.raw_data(); thrust::for_each( @@ -1271,7 +1216,6 @@ void _batched_kalman_filter(raft::handle_t& handle, cublasHandle, arima_mem.ImT_inv_batches, arima_mem.ImT_inv_dense, - allocator, stream, false); MLCommon::LinAlg::Batched::Matrix::inv( @@ -1465,38 +1409,16 @@ void batched_kalman_filter(raft::handle_t& handle, auto cublasHandle = handle.get_cublas_handle(); auto stream = handle.get_stream(); - auto allocator = handle.get_device_allocator(); // see (3.18) in TSA by D&K int rd = order.rd(); - MLCommon::LinAlg::Batched::Matrix Zb(1, - rd, - batch_size, - cublasHandle, - arima_mem.Z_batches, - arima_mem.Z_dense, - allocator, - stream, - false); - MLCommon::LinAlg::Batched::Matrix Tb(rd, - rd, - batch_size, - cublasHandle, - arima_mem.T_batches, - arima_mem.T_dense, - allocator, - stream, - false); - MLCommon::LinAlg::Batched::Matrix Rb(rd, - 1, - batch_size, - cublasHandle, - arima_mem.R_batches, - arima_mem.R_dense, - allocator, - stream, - false); + MLCommon::LinAlg::Batched::Matrix Zb( + 1, rd, batch_size, cublasHandle, arima_mem.Z_batches, arima_mem.Z_dense, stream, false); + MLCommon::LinAlg::Batched::Matrix Tb( + rd, rd, batch_size, cublasHandle, arima_mem.T_batches, arima_mem.T_dense, stream, false); + MLCommon::LinAlg::Batched::Matrix Rb( + rd, 1, batch_size, cublasHandle, arima_mem.R_batches, arima_mem.R_dense, stream, false); init_batched_kalman_matrices(handle, params.ar, @@ -1545,7 +1467,6 @@ void batched_jones_transform(raft::handle_t& handle, double* h_Tparams) { int N = order.complexity(); - auto allocator = handle.get_device_allocator(); auto stream = handle.get_stream(); double* d_params = arima_mem.d_params; double* d_Tparams = arima_mem.d_Tparams; @@ -1566,8 +1487,7 @@ void batched_jones_transform(raft::handle_t& handle, params.unpack(order, batch_size, d_params, stream); - MLCommon::TimeSeries::batched_jones_transform( - order, batch_size, isInv, params, Tparams, allocator, stream); + MLCommon::TimeSeries::batched_jones_transform(order, batch_size, isInv, params, Tparams, stream); Tparams.mu = params.mu; Tparams.pack(order, batch_size, d_Tparams, stream); diff --git a/cpp/src/common/cuml_api.cpp b/cpp/src/common/cuml_api.cpp index cca2793bca..6284a8aa6f 100644 --- a/cpp/src/common/cuml_api.cpp +++ b/cpp/src/common/cuml_api.cpp @@ -140,58 +140,6 @@ extern "C" cumlError_t cumlGetStream(cumlHandle_t handle, cudaStream_t* stream) return status; } -extern "C" cumlError_t cumlSetDeviceAllocator(cumlHandle_t handle, - cuml_allocate allocate_fn, - cuml_deallocate deallocate_fn) -{ - cumlError_t status; - raft::handle_t* handle_ptr; - std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); - if (status == CUML_SUCCESS) { - try { - std::shared_ptr allocator( - new ML::detail::deviceAllocatorFunctionWrapper(allocate_fn, deallocate_fn)); - handle_ptr->set_device_allocator(allocator); - } - // TODO: Implement this - // catch (const MLCommon::Exception& e) - //{ - // //log e.what()? - // status = e.getErrorCode(); - //} - catch (...) { - status = CUML_ERROR_UNKNOWN; - } - } - return status; -} - -extern "C" cumlError_t cumlSetHostAllocator(cumlHandle_t handle, - cuml_allocate allocate_fn, - cuml_deallocate deallocate_fn) -{ - cumlError_t status; - raft::handle_t* handle_ptr; - std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); - if (status == CUML_SUCCESS) { - try { - std::shared_ptr allocator( - new ML::detail::hostAllocatorFunctionWrapper(allocate_fn, deallocate_fn)); - handle_ptr->set_host_allocator(allocator); - } - // TODO: Implement this - // catch (const MLCommon::Exception& e) - //{ - // //log e.what()? - // status = e.getErrorCode(); - //} - catch (...) { - status = CUML_ERROR_UNKNOWN; - } - } - return status; -} - extern "C" cumlError_t cumlDestroy(cumlHandle_t handle) { return ML::handleMap.removeAndDestroyHandle(handle); diff --git a/cpp/src/common/tensor.hpp b/cpp/src/common/tensor.hpp index 8578556199..b76428b6c3 100644 --- a/cpp/src/common/tensor.hpp +++ b/cpp/src/common/tensor.hpp @@ -17,8 +17,7 @@ #pragma once #include -#include -#include +#include #include @@ -33,10 +32,12 @@ class Tensor { __host__ ~Tensor() { if (_state == AllocState::Owner) { + if (memory_type(_data) == cudaMemoryTypeHost) { delete _data; } + if (memory_type(_data) == cudaMemoryTypeDevice) { - _dAllocator->deallocate(_data, this->getSizeInBytes(), _stream); + rmm_alloc->deallocate(_data, this->getSizeInBytes(), _stream); } else if (memory_type(_data) == cudaMemoryTypeHost) { - _hAllocator->deallocate(_data, this->getSizeInBytes(), _stream); + delete _data; } } } @@ -62,10 +63,8 @@ class Tensor { // allocate the data using the allocator and release when the object goes out of scope // allocating tensor is the owner of the data - __host__ Tensor(const std::vector& sizes, - std::shared_ptr allocator, - cudaStream_t stream) - : _stream(stream), _dAllocator(allocator), _state(AllocState::Owner) + __host__ Tensor(const std::vector& sizes, cudaStream_t stream) + : _stream(stream), _state(AllocState::Owner) { static_assert(Dim > 0, "must have > 0 dimensions"); @@ -80,9 +79,8 @@ class Tensor { _stride[j] = _stride[j + 1] * _size[j + 1]; } - _data = static_cast(_dAllocator->allocate(this->getSizeInBytes(), _stream)); - - CUDA_CHECK(cudaStreamSynchronize(_stream)); + rmm_alloc = rmm::mr::get_current_device_resource(); + _data = (DataT*)rmm_alloc->allocate(this->getSizeInBytes(), _stream); ASSERT(this->data() || (this->getSizeInBytes() == 0), "device allocation failed"); } @@ -168,9 +166,6 @@ class Tensor { }; protected: - std::shared_ptr _dAllocator; - std::shared_ptr _hAllocator; - /// Raw pointer to where the tensor data begins DataPtrT _data{}; @@ -183,6 +178,8 @@ class Tensor { AllocState _state{}; cudaStream_t _stream{}; + + rmm::mr::device_memory_resource* rmm_alloc; }; }; // end namespace ML diff --git a/cpp/src/datasets/make_arima.cu b/cpp/src/datasets/make_arima.cu index fb91a8366c..f28bdd8e02 100644 --- a/cpp/src/datasets/make_arima.cu +++ b/cpp/src/datasets/make_arima.cu @@ -31,11 +31,10 @@ inline void make_arima_helper(const raft::handle_t& handle, DataT intercept_scale, uint64_t seed) { - auto stream = handle.get_stream(); - auto allocator = handle.get_device_allocator(); + auto stream = handle.get_stream(); MLCommon::Random::make_arima( - out, batch_size, n_obs, order, allocator, stream, scale, noise_scale, intercept_scale, seed); + out, batch_size, n_obs, order, stream, scale, noise_scale, intercept_scale, seed); } void make_arima(const raft::handle_t& handle, diff --git a/cpp/src/datasets/make_blobs.cu b/cpp/src/datasets/make_blobs.cu index 38b611fe4d..88ca7b70e5 100644 --- a/cpp/src/datasets/make_blobs.cu +++ b/cpp/src/datasets/make_blobs.cu @@ -40,7 +40,6 @@ void make_blobs(const raft::handle_t& handle, n_rows, n_cols, n_clusters, - handle.get_device_allocator(), handle.get_stream(), row_major, centers, @@ -72,7 +71,6 @@ void make_blobs(const raft::handle_t& handle, n_rows, n_cols, n_clusters, - handle.get_device_allocator(), handle.get_stream(), row_major, centers, @@ -104,7 +102,6 @@ void make_blobs(const raft::handle_t& handle, n_rows, n_cols, n_clusters, - handle.get_device_allocator(), handle.get_stream(), row_major, centers, @@ -136,7 +133,6 @@ void make_blobs(const raft::handle_t& handle, n_rows, n_cols, n_clusters, - handle.get_device_allocator(), handle.get_stream(), row_major, centers, diff --git a/cpp/src/datasets/make_regression.cu b/cpp/src/datasets/make_regression.cu index 8fc6f4b00c..8b95e02c6c 100644 --- a/cpp/src/datasets/make_regression.cu +++ b/cpp/src/datasets/make_regression.cu @@ -40,7 +40,6 @@ void make_regression_helper(const raft::handle_t& handle, cudaStream_t stream = handle_impl.get_stream(); cublasHandle_t cublas_handle = handle_impl.get_cublas_handle(); cusolverDnHandle_t cusolver_handle = handle_impl.get_cusolver_dn_handle(); - auto allocator = handle_impl.get_device_allocator(); MLCommon::Random::make_regression(handle, out, diff --git a/cpp/src/dbscan/adjgraph/algo.cuh b/cpp/src/dbscan/adjgraph/algo.cuh index 13cbf3eae6..c987bae89c 100644 --- a/cpp/src/dbscan/adjgraph/algo.cuh +++ b/cpp/src/dbscan/adjgraph/algo.cuh @@ -16,17 +16,15 @@ #pragma once +#include +#include + #include "../common.cuh" #include "pack.h" -#include - #include #include -#include -#include - using namespace thrust; namespace ML { @@ -49,8 +47,7 @@ void launcher(const raft::handle_t& handle, device_ptr dev_vd = device_pointer_cast(data.vd); device_ptr dev_ex_scan = device_pointer_cast(data.ex_scan); - ML::thrustAllocatorAdapter alloc(handle.get_device_allocator(), stream); - exclusive_scan(thrust::cuda::par(alloc).on(stream), dev_vd, dev_vd + batch_size, dev_ex_scan); + exclusive_scan(handle.get_thrust_policy(), dev_vd, dev_vd + batch_size, dev_ex_scan); raft::sparse::convert::csr_adj_graph_batched( data.ex_scan, data.N, data.adjnnz, batch_size, data.adj, data.adj_graph, stream); diff --git a/cpp/src/dbscan/adjgraph/naive.cuh b/cpp/src/dbscan/adjgraph/naive.cuh index afb1e6befe..6ef2830c7d 100644 --- a/cpp/src/dbscan/adjgraph/naive.cuh +++ b/cpp/src/dbscan/adjgraph/naive.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include #include "../common.cuh" #include "pack.h" @@ -35,14 +35,14 @@ void launcher(const raft::handle_t& handle, { Index_ k = 0; Index_ N = data.N; - MLCommon::host_buffer host_vd(handle.get_host_allocator(), stream, batch_size + 1); - MLCommon::host_buffer host_adj(handle.get_host_allocator(), stream, batch_size * N); - MLCommon::host_buffer host_ex_scan(handle.get_host_allocator(), stream, batch_size); - raft::update_host(host_adj.data(), data.adj, batch_size * N, stream); + std::vector host_vd(batch_size + 1); + std::vector host_adj(((batch_size * N) / 8) + 1); + std::vector host_ex_scan(batch_size); + raft::update_host((bool*)host_adj.data(), data.adj, batch_size * N, stream); raft::update_host(host_vd.data(), data.vd, batch_size + 1, stream); CUDA_CHECK(cudaStreamSynchronize(stream)); size_t adjgraph_size = size_t(host_vd[batch_size]); - MLCommon::host_buffer host_adj_graph(handle.get_host_allocator(), stream, adjgraph_size); + std::vector host_adj_graph(adjgraph_size); for (Index_ i = 0; i < batch_size; i++) { for (Index_ j = 0; j < N; j++) { /// TODO: change layout or remove; cf #3414 diff --git a/cpp/src/dbscan/corepoints/compute.cuh b/cpp/src/dbscan/corepoints/compute.cuh index 486ff23f79..5945f00280 100644 --- a/cpp/src/dbscan/corepoints/compute.cuh +++ b/cpp/src/dbscan/corepoints/compute.cuh @@ -41,10 +41,9 @@ void compute(const raft::handle_t& handle, Index_ batch_size, cudaStream_t stream) { - auto execution_policy = ML::thrust_exec_policy(handle.get_device_allocator(), stream); - auto counting = thrust::make_counting_iterator(0); + auto counting = thrust::make_counting_iterator(0); thrust::for_each( - execution_policy->on(stream), counting, counting + batch_size, [=] __device__(Index_ idx) { + handle.get_thrust_policy(), counting, counting + batch_size, [=] __device__(Index_ idx) { mask[idx + start_vertex_id] = vd[idx] >= min_pts; }); } diff --git a/cpp/src/dbscan/dbscan.cuh b/cpp/src/dbscan/dbscan.cuh index 467ecb0839..647420db1b 100644 --- a/cpp/src/dbscan/dbscan.cuh +++ b/cpp/src/dbscan/dbscan.cuh @@ -21,7 +21,6 @@ #include #include -#include #include #include @@ -182,7 +181,7 @@ void dbscanFitImpl(const raft::handle_t& handle, CUML_LOG_DEBUG("Workspace size: %lf MB", (double)workspaceSize * 1e-6); - MLCommon::device_buffer workspace(handle.get_device_allocator(), stream, workspaceSize); + rmm::device_uvector workspace(workspaceSize, stream); Dbscan::run(handle, input, n_rows, diff --git a/cpp/src/dbscan/runner.cuh b/cpp/src/dbscan/runner.cuh index 178a29887d..e6c68eb5c9 100644 --- a/cpp/src/dbscan/runner.cuh +++ b/cpp/src/dbscan/runner.cuh @@ -16,6 +16,10 @@ #pragma once +#include +#include +#include