From f95e93b80052802dcf165871d768fb563da80a78 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Tue, 8 Feb 2022 04:19:53 -0500 Subject: [PATCH] Updating RAFT linalg headers (#4515) Depends on https://github.com/rapidsai/raft/pull/383 Authors: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/4515 --- cpp/bench/prims/add.cu | 4 +- cpp/bench/prims/fused_l2_nn.cu | 2 +- cpp/bench/prims/gram_matrix.cu | 3 +- cpp/bench/prims/map_then_reduce.cu | 4 +- cpp/bench/prims/matrix_vector_op.cu | 4 +- cpp/bench/prims/reduce.cu | 4 +- cpp/bench/sg/dataset.cuh | 4 +- cpp/bench/sg/linkage.cu | 2 +- cpp/include/cuml/cluster/dbscan.hpp | 2 +- cpp/include/cuml/cluster/hdbscan.hpp | 2 +- cpp/include/cuml/cluster/linkage.hpp | 2 +- cpp/include/cuml/metrics/metrics.hpp | 4 +- cpp/include/cuml/neighbors/knn.hpp | 4 +- cpp/include/cuml/neighbors/knn_sparse.hpp | 2 +- cpp/src/arima/batched_arima.cu | 2 +- cpp/src/arima/batched_kalman.cu | 83 +++--- cpp/src/common/cumlHandle.cpp | 6 +- cpp/src/dbscan/vertexdeg/precomputed.cuh | 4 +- .../decisiontree/batched-levelalgo/split.cuh | 2 +- cpp/src/genetic/fitness.cuh | 10 +- cpp/src/genetic/genetic.cu | 4 +- cpp/src/genetic/program.cu | 2 +- cpp/src/glm/ols.cuh | 8 +- cpp/src/glm/ols_mg.cu | 4 +- cpp/src/glm/preprocess.cuh | 6 +- cpp/src/glm/preprocess_mg.cu | 4 +- cpp/src/glm/qn/glm_base.cuh | 10 +- cpp/src/glm/qn/glm_linear.cuh | 2 +- cpp/src/glm/qn/glm_logistic.cuh | 2 +- cpp/src/glm/qn/glm_regularizer.cuh | 4 +- cpp/src/glm/qn/glm_softmax.cuh | 2 +- cpp/src/glm/qn/glm_svm.cuh | 2 +- cpp/src/glm/qn/simple_mat/dense.hpp | 42 +-- cpp/src/glm/qn/simple_mat/sparse.hpp | 9 +- cpp/src/glm/ridge.cuh | 10 +- cpp/src/glm/ridge_mg.cu | 4 +- cpp/src/hdbscan/detail/reachability.cuh | 2 +- cpp/src/hierarchy/pw_dist_graph.cuh | 2 +- cpp/src/holtwinters/internal/hw_decompose.cuh | 114 +++++---- cpp/src/holtwinters/internal/hw_utils.cuh | 4 +- cpp/src/holtwinters/runner.cuh | 9 +- cpp/src/kmeans/common.cuh | 8 +- cpp/src/metrics/silhouette_score.cu | 2 +- cpp/src/pca/pca.cuh | 7 +- cpp/src/pca/pca_mg.cu | 2 +- cpp/src/random_projection/rproj.cuh | 34 +-- cpp/src/solver/cd.cuh | 15 +- cpp/src/solver/cd_mg.cu | 10 +- cpp/src/solver/lars_impl.cuh | 242 +++++++++--------- cpp/src/solver/sgd.cuh | 13 +- cpp/src/svm/kernelcache.cuh | 2 +- cpp/src/svm/linear.cu | 13 +- cpp/src/svm/results.cuh | 7 +- cpp/src/svm/smosolver.cuh | 66 ++--- cpp/src/svm/svc.cu | 3 +- cpp/src/svm/svc_impl.cuh | 32 +-- cpp/src/svm/svr.cu | 3 +- cpp/src/svm/svr_impl.cuh | 3 +- cpp/src/svm/workingset.cuh | 4 +- cpp/src/tsne/barnes_hut_tsne.cuh | 2 +- cpp/src/tsne/distances.cuh | 4 +- cpp/src/tsne/exact_kernels.cuh | 2 +- cpp/src/tsne/fft_tsne.cuh | 2 +- cpp/src/tsne/utils.cuh | 4 +- cpp/src/tsvd/tsvd.cuh | 11 +- cpp/src/tsvd/tsvd_mg.cu | 2 +- cpp/src/umap/init_embed/spectral_algo.cuh | 4 +- cpp/src/umap/knn_graph/algo.cuh | 4 +- cpp/src/umap/optimize.cuh | 9 +- cpp/src/umap/simpl_set_embed/algo.cuh | 2 +- .../distance/epsilon_neighborhood.cuh | 2 +- cpp/src_prims/functions/hinge.cuh | 15 +- cpp/src_prims/functions/linearReg.cuh | 11 +- cpp/src_prims/functions/log.cuh | 4 +- cpp/src_prims/functions/logisticReg.cuh | 12 +- cpp/src_prims/functions/penalty.cuh | 6 +- cpp/src_prims/functions/sigmoid.cuh | 4 +- cpp/src_prims/functions/sign.cuh | 4 +- cpp/src_prims/functions/softThres.cuh | 4 +- cpp/src_prims/label/classlabels.cuh | 2 +- cpp/src_prims/linalg/batched/matrix.cuh | 79 +++--- cpp/src_prims/linalg/lstsq.cuh | 214 ++++++++-------- cpp/src_prims/linalg/power.cuh | 6 +- cpp/src_prims/linalg/rsvd.cuh | 12 +- cpp/src_prims/linalg/sqrt.cuh | 4 +- cpp/src_prims/matrix/grammatrix.cuh | 67 ++--- cpp/src_prims/matrix/kernelmatrices.cuh | 2 +- cpp/src_prims/metrics/adjusted_rand_index.cuh | 4 +- .../metrics/batched/information_criterion.cuh | 4 +- cpp/src_prims/metrics/dispersion.cuh | 2 +- cpp/src_prims/metrics/entropy.cuh | 4 +- cpp/src_prims/metrics/kl_divergence.cuh | 2 +- cpp/src_prims/metrics/mutual_info_score.cuh | 2 +- cpp/src_prims/metrics/scores.cuh | 4 +- cpp/src_prims/metrics/silhouette_score.cuh | 12 +- cpp/src_prims/random/make_blobs.cuh | 2 +- cpp/src_prims/random/make_regression.cuh | 102 ++++---- cpp/src_prims/random/mvg.cuh | 116 +++++---- cpp/src_prims/selection/knn.cuh | 2 +- cpp/src_prims/selection/processing.cuh | 8 +- cpp/src_prims/sparse/batched/csr.cuh | 1 - cpp/src_prims/stats/cov.cuh | 34 +-- cpp/src_prims/stats/weighted_mean.cuh | 6 +- cpp/src_prims/timeSeries/arima_helpers.cuh | 4 +- cpp/src_prims/timeSeries/fillna.cuh | 4 +- cpp/src_prims/timeSeries/jones_transform.cuh | 2 +- cpp/src_prims/timeSeries/stationarity.cuh | 5 +- cpp/test/mg/pca.cu | 1 - cpp/test/prims/add_sub_dev_scalar.cu | 6 +- cpp/test/prims/batched/matrix.cu | 2 +- cpp/test/prims/knn_regression.cu | 3 +- cpp/test/prims/make_regression.cu | 38 +-- cpp/test/prims/mvg.cu | 33 +-- cpp/test/prims/silhouette_score.cu | 2 +- cpp/test/sg/cd_test.cu | 1 - cpp/test/sg/dbscan_test.cu | 5 +- cpp/test/sg/hdbscan_test.cu | 4 +- cpp/test/sg/lars_test.cu | 25 +- cpp/test/sg/linear_svm_test.cu | 8 +- cpp/test/sg/linkage_test.cu | 4 +- cpp/test/sg/pca_test.cu | 1 - cpp/test/sg/quasi_newton.cu | 2 +- cpp/test/sg/rf_test.cu | 2 +- cpp/test/sg/rproj_test.cu | 2 +- cpp/test/sg/sgd.cu | 1 - cpp/test/sg/svc_test.cu | 6 +- cpp/test/sg/tsne_test.cu | 2 +- python/cuml/metrics/distance_type.pxd | 4 +- python/cuml/metrics/trustworthiness.pyx | 4 +- 129 files changed, 944 insertions(+), 888 deletions(-) diff --git a/cpp/bench/prims/add.cu b/cpp/bench/prims/add.cu index 1665ad7656..5a9340cd2f 100644 --- a/cpp/bench/prims/add.cu +++ b/cpp/bench/prims/add.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ #include -#include +#include namespace MLCommon { namespace Bench { diff --git a/cpp/bench/prims/fused_l2_nn.cu b/cpp/bench/prims/fused_l2_nn.cu index 174618c857..c949e119d3 100644 --- a/cpp/bench/prims/fused_l2_nn.cu +++ b/cpp/bench/prims/fused_l2_nn.cu @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/bench/prims/gram_matrix.cu b/cpp/bench/prims/gram_matrix.cu index 1e3ad2b5ca..d4a83a8e31 100644 --- a/cpp/bench/prims/gram_matrix.cu +++ b/cpp/bench/prims/gram_matrix.cu @@ -19,7 +19,8 @@ #include #include #include -#include +// #TODO: Replace with public header when ready +#include #include #include #include diff --git a/cpp/bench/prims/map_then_reduce.cu b/cpp/bench/prims/map_then_reduce.cu index 6f451672ba..0520562f7b 100644 --- a/cpp/bench/prims/map_then_reduce.cu +++ b/cpp/bench/prims/map_then_reduce.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ #include -#include +#include namespace MLCommon { namespace Bench { diff --git a/cpp/bench/prims/matrix_vector_op.cu b/cpp/bench/prims/matrix_vector_op.cu index 35cc0122d5..e117d96bb2 100644 --- a/cpp/bench/prims/matrix_vector_op.cu +++ b/cpp/bench/prims/matrix_vector_op.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ #include -#include +#include namespace MLCommon { namespace Bench { diff --git a/cpp/bench/prims/reduce.cu b/cpp/bench/prims/reduce.cu index cb593c2a3d..bdfe17c62d 100644 --- a/cpp/bench/prims/reduce.cu +++ b/cpp/bench/prims/reduce.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ #include -#include +#include namespace MLCommon { namespace Bench { diff --git a/cpp/bench/sg/dataset.cuh b/cpp/bench/sg/dataset.cuh index 133529c19e..de5bd470fa 100644 --- a/cpp/bench/sg/dataset.cuh +++ b/cpp/bench/sg/dataset.cuh @@ -22,8 +22,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include diff --git a/cpp/bench/sg/linkage.cu b/cpp/bench/sg/linkage.cu index a6dc8305e9..1003b6cfdb 100644 --- a/cpp/bench/sg/linkage.cu +++ b/cpp/bench/sg/linkage.cu @@ -17,7 +17,7 @@ #include "benchmark.cuh" #include #include -#include +#include #include #include diff --git a/cpp/include/cuml/cluster/dbscan.hpp b/cpp/include/cuml/cluster/dbscan.hpp index d5fe70e992..c71d8539f6 100644 --- a/cpp/include/cuml/cluster/dbscan.hpp +++ b/cpp/include/cuml/cluster/dbscan.hpp @@ -16,7 +16,7 @@ #pragma once -#include +#include #include diff --git a/cpp/include/cuml/cluster/hdbscan.hpp b/cpp/include/cuml/cluster/hdbscan.hpp index 3fb6708312..6162799e7d 100644 --- a/cpp/include/cuml/cluster/hdbscan.hpp +++ b/cpp/include/cuml/cluster/hdbscan.hpp @@ -16,7 +16,7 @@ #pragma once -#include +#include #include diff --git a/cpp/include/cuml/cluster/linkage.hpp b/cpp/include/cuml/cluster/linkage.hpp index bac0b9218b..eb6e88ff81 100644 --- a/cpp/include/cuml/cluster/linkage.hpp +++ b/cpp/include/cuml/cluster/linkage.hpp @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include diff --git a/cpp/include/cuml/metrics/metrics.hpp b/cpp/include/cuml/metrics/metrics.hpp index 66d2459aaa..f1f9a3d218 100644 --- a/cpp/include/cuml/metrics/metrics.hpp +++ b/cpp/include/cuml/metrics/metrics.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include #include diff --git a/cpp/include/cuml/neighbors/knn.hpp b/cpp/include/cuml/neighbors/knn.hpp index 08f726c6af..47af2ffaa0 100644 --- a/cpp/include/cuml/neighbors/knn.hpp +++ b/cpp/include/cuml/neighbors/knn.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include diff --git a/cpp/include/cuml/neighbors/knn_sparse.hpp b/cpp/include/cuml/neighbors/knn_sparse.hpp index 916d89567e..0d0e359eb0 100644 --- a/cpp/include/cuml/neighbors/knn_sparse.hpp +++ b/cpp/include/cuml/neighbors/knn_sparse.hpp @@ -19,7 +19,7 @@ #include #include -#include +#include namespace raft { class handle_t; diff --git a/cpp/src/arima/batched_arima.cu b/cpp/src/arima/batched_arima.cu index 86dffaabce..9bf5cf3225 100644 --- a/cpp/src/arima/batched_arima.cu +++ b/cpp/src/arima/batched_arima.cu @@ -35,7 +35,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/src/arima/batched_kalman.cu b/cpp/src/arima/batched_kalman.cu index 2dc76a7b8a..f0c042f9c4 100644 --- a/cpp/src/arima/batched_kalman.cu +++ b/cpp/src/arima/batched_kalman.cu @@ -26,8 +26,9 @@ #include #include #include -#include -#include +#include +// #TODO: Replace with public header when ready +#include #include #include @@ -1222,48 +1223,50 @@ void _batched_kalman_filter(raft::handle_t& handle, double alpha = 1.0; double beta = 0.0; - RAFT_CUBLAS_TRY(raft::linalg::cublasgemmStridedBatched(cublasHandle, - CUBLAS_OP_N, - CUBLAS_OP_N, - nobs, - 1, - order.n_exog, - &alpha, - d_exog, - nobs, - nobs * order.n_exog, - d_beta, - order.n_exog, - order.n_exog, - &beta, - obs_intercept.data(), - nobs, - nobs, - batch_size, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemmStridedBatched(cublasHandle, + CUBLAS_OP_N, + CUBLAS_OP_N, + nobs, + 1, + order.n_exog, + &alpha, + d_exog, + nobs, + nobs * order.n_exog, + d_beta, + order.n_exog, + order.n_exog, + &beta, + obs_intercept.data(), + nobs, + nobs, + batch_size, + stream)); if (fc_steps > 0) { obs_intercept_fut.resize(fc_steps * batch_size, stream); - RAFT_CUBLAS_TRY(raft::linalg::cublasgemmStridedBatched(cublasHandle, - CUBLAS_OP_N, - CUBLAS_OP_N, - fc_steps, - 1, - order.n_exog, - &alpha, - d_exog_fut, - fc_steps, - fc_steps * order.n_exog, - d_beta, - order.n_exog, - order.n_exog, - &beta, - obs_intercept_fut.data(), - fc_steps, - fc_steps, - batch_size, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemmStridedBatched(cublasHandle, + CUBLAS_OP_N, + CUBLAS_OP_N, + fc_steps, + 1, + order.n_exog, + &alpha, + d_exog_fut, + fc_steps, + fc_steps * order.n_exog, + d_beta, + order.n_exog, + order.n_exog, + &beta, + obs_intercept_fut.data(), + fc_steps, + fc_steps, + batch_size, + stream)); } } diff --git a/cpp/src/common/cumlHandle.cpp b/cpp/src/common/cumlHandle.cpp index 6192b24a22..5133876a14 100644 --- a/cpp/src/common/cumlHandle.cpp +++ b/cpp/src/common/cumlHandle.cpp @@ -18,8 +18,10 @@ #include #include -#include -#include +// #TODO: Replace with public header when ready +#include +// #TODO: Replace with public header when ready +#include #include #include #include diff --git a/cpp/src/dbscan/vertexdeg/precomputed.cuh b/cpp/src/dbscan/vertexdeg/precomputed.cuh index 75e0886642..3cead4bac8 100644 --- a/cpp/src/dbscan/vertexdeg/precomputed.cuh +++ b/cpp/src/dbscan/vertexdeg/precomputed.cuh @@ -22,8 +22,8 @@ #include #include #include -#include -#include +#include +#include #include "pack.h" diff --git a/cpp/src/decisiontree/batched-levelalgo/split.cuh b/cpp/src/decisiontree/batched-levelalgo/split.cuh index ea80412cfc..e69fb81489 100644 --- a/cpp/src/decisiontree/batched-levelalgo/split.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/split.cuh @@ -17,7 +17,7 @@ #pragma once #include -#include +#include namespace ML { namespace DT { diff --git a/cpp/src/genetic/fitness.cuh b/cpp/src/genetic/fitness.cuh index fa32e198c1..a15fb96b54 100644 --- a/cpp/src/genetic/fitness.cuh +++ b/cpp/src/genetic/fitness.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,10 +15,10 @@ */ #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include diff --git a/cpp/src/genetic/genetic.cu b/cpp/src/genetic/genetic.cu index c4aa018f7f..ece6c6d81c 100644 --- a/cpp/src/genetic/genetic.cu +++ b/cpp/src/genetic/genetic.cu @@ -23,8 +23,8 @@ #include #include -#include -#include +#include +#include #include #include diff --git a/cpp/src/genetic/program.cu b/cpp/src/genetic/program.cu index 0f62cedb96..cd69056ca9 100644 --- a/cpp/src/genetic/program.cu +++ b/cpp/src/genetic/program.cu @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index 7bb3a10594..0cb009bd3b 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -17,10 +17,10 @@ #pragma once #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include diff --git a/cpp/src/glm/ols_mg.cu b/cpp/src/glm/ols_mg.cu index 7e0a1404b1..325566908e 100644 --- a/cpp/src/glm/ols_mg.cu +++ b/cpp/src/glm/ols_mg.cu @@ -22,8 +22,8 @@ #include #include -#include -#include +#include +#include #include #include #include diff --git a/cpp/src/glm/preprocess.cuh b/cpp/src/glm/preprocess.cuh index 8ee77966c2..07e1a8cee5 100644 --- a/cpp/src/glm/preprocess.cuh +++ b/cpp/src/glm/preprocess.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,8 @@ #pragma once #include -#include -#include +#include +#include #include #include #include diff --git a/cpp/src/glm/preprocess_mg.cu b/cpp/src/glm/preprocess_mg.cu index 3769bedd78..655af3b9be 100644 --- a/cpp/src/glm/preprocess_mg.cu +++ b/cpp/src/glm/preprocess_mg.cu @@ -23,8 +23,8 @@ #include #include #include -#include -#include +#include +#include #include #include diff --git a/cpp/src/glm/qn/glm_base.cuh b/cpp/src/glm/qn/glm_base.cuh index 597d4ad2b0..126855637e 100644 --- a/cpp/src/glm/qn/glm_base.cuh +++ b/cpp/src/glm/qn/glm_base.cuh @@ -19,12 +19,10 @@ #include "simple_mat.cuh" #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include #include #include diff --git a/cpp/src/glm/qn/glm_linear.cuh b/cpp/src/glm/qn/glm_linear.cuh index 664f25b2b2..11df1d5833 100644 --- a/cpp/src/glm/qn/glm_linear.cuh +++ b/cpp/src/glm/qn/glm_linear.cuh @@ -19,7 +19,7 @@ #include "glm_base.cuh" #include "simple_mat.cuh" #include -#include +#include namespace ML { namespace GLM { diff --git a/cpp/src/glm/qn/glm_logistic.cuh b/cpp/src/glm/qn/glm_logistic.cuh index 01f732df05..5e76da4843 100644 --- a/cpp/src/glm/qn/glm_logistic.cuh +++ b/cpp/src/glm/qn/glm_logistic.cuh @@ -19,7 +19,7 @@ #include "glm_base.cuh" #include "simple_mat.cuh" #include -#include +#include namespace ML { namespace GLM { diff --git a/cpp/src/glm/qn/glm_regularizer.cuh b/cpp/src/glm/qn/glm_regularizer.cuh index 60958d2a9f..9e4aa7067b 100644 --- a/cpp/src/glm/qn/glm_regularizer.cuh +++ b/cpp/src/glm/qn/glm_regularizer.cuh @@ -19,8 +19,8 @@ #include "simple_mat.cuh" #include #include -#include -#include +#include +#include #include namespace ML { diff --git a/cpp/src/glm/qn/glm_softmax.cuh b/cpp/src/glm/qn/glm_softmax.cuh index 7b80ae61af..91a18f15b5 100644 --- a/cpp/src/glm/qn/glm_softmax.cuh +++ b/cpp/src/glm/qn/glm_softmax.cuh @@ -19,7 +19,7 @@ #include "glm_base.cuh" #include "simple_mat.cuh" #include -#include +#include namespace ML { namespace GLM { diff --git a/cpp/src/glm/qn/glm_svm.cuh b/cpp/src/glm/qn/glm_svm.cuh index 04d41b8c3e..fa71377760 100644 --- a/cpp/src/glm/qn/glm_svm.cuh +++ b/cpp/src/glm/qn/glm_svm.cuh @@ -19,7 +19,7 @@ #include "glm_base.cuh" #include "simple_mat.cuh" #include -#include +#include namespace ML { namespace GLM { diff --git a/cpp/src/glm/qn/simple_mat/dense.hpp b/cpp/src/glm/qn/simple_mat/dense.hpp index d87a8765cf..efd6de68a5 100644 --- a/cpp/src/glm/qn/simple_mat/dense.hpp +++ b/cpp/src/glm/qn/simple_mat/dense.hpp @@ -23,11 +23,12 @@ #include #include #include -#include -#include -#include -#include -#include +#include +// #TODO: Replace with public header when ready +#include +#include +#include +#include #include namespace ML { @@ -89,21 +90,22 @@ struct SimpleDenseMat : SimpleMat { ASSERT(kA == kB, "GEMM invalid dims: k"); if (A.ord == COL_MAJOR && B.ord == COL_MAJOR && C.ord == COL_MAJOR) { - raft::linalg::cublasgemm(handle.get_cublas_handle(), // handle - transA ? CUBLAS_OP_T : CUBLAS_OP_N, // transA - transB ? CUBLAS_OP_T : CUBLAS_OP_N, // transB - C.m, - C.n, - kA, // dimensions m,n,k - &alpha, - A.data, - A.m, // lda - B.data, - B.m, // ldb - &beta, - C.data, - C.m, // ldc, - stream); + // #TODO: Call from public API when ready + raft::linalg::detail::cublasgemm(handle.get_cublas_handle(), // handle + transA ? CUBLAS_OP_T : CUBLAS_OP_N, // transA + transB ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + C.m, + C.n, + kA, // dimensions m,n,k + &alpha, + A.data, + A.m, // lda + B.data, + B.m, // ldb + &beta, + C.data, + C.m, // ldc, + stream); return; } if (A.ord == ROW_MAJOR) { diff --git a/cpp/src/glm/qn/simple_mat/sparse.hpp b/cpp/src/glm/qn/simple_mat/sparse.hpp index ccc46b3e6d..0cfa750338 100644 --- a/cpp/src/glm/qn/simple_mat/sparse.hpp +++ b/cpp/src/glm/qn/simple_mat/sparse.hpp @@ -23,11 +23,10 @@ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include #include #include diff --git a/cpp/src/glm/ridge.cuh b/cpp/src/glm/ridge.cuh index b6d855f571..1ebeabfe10 100644 --- a/cpp/src/glm/ridge.cuh +++ b/cpp/src/glm/ridge.cuh @@ -17,11 +17,11 @@ #pragma once #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include #include diff --git a/cpp/src/glm/ridge_mg.cu b/cpp/src/glm/ridge_mg.cu index b73137e4cc..3710eef28b 100644 --- a/cpp/src/glm/ridge_mg.cu +++ b/cpp/src/glm/ridge_mg.cu @@ -24,8 +24,8 @@ #include #include #include -#include -#include +#include +#include #include #include diff --git a/cpp/src/hdbscan/detail/reachability.cuh b/cpp/src/hdbscan/detail/reachability.cuh index 26080dee8f..dff60121ba 100644 --- a/cpp/src/hdbscan/detail/reachability.cuh +++ b/cpp/src/hdbscan/detail/reachability.cuh @@ -23,7 +23,7 @@ #include -#include +#include #include #include diff --git a/cpp/src/hierarchy/pw_dist_graph.cuh b/cpp/src/hierarchy/pw_dist_graph.cuh index dca5f100ba..1c45a66af8 100644 --- a/cpp/src/hierarchy/pw_dist_graph.cuh +++ b/cpp/src/hierarchy/pw_dist_graph.cuh @@ -24,7 +24,7 @@ #include #include -#include +#include #include // TODO: Not a good strategy for pluggability but will be diff --git a/cpp/src/holtwinters/internal/hw_decompose.cuh b/cpp/src/holtwinters/internal/hw_decompose.cuh index 5ffbeba9a5..d94cf30ec3 100644 --- a/cpp/src/holtwinters/internal/hw_decompose.cuh +++ b/cpp/src/holtwinters/internal/hw_decompose.cuh @@ -18,6 +18,10 @@ #include #include +// #TODO: Replace with public header when ready +#include +// #TODO: Replace with public header when ready +#include #include #include @@ -180,58 +184,63 @@ void batched_ls(const raft::handle_t& handle, } raft::update_device(A_d.data(), A_h.data(), 2 * trend_len, stream); - RAFT_CUSOLVER_TRY(raft::linalg::cusolverDngeqrf_bufferSize( + // #TODO: Call from public API when ready + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngeqrf_bufferSize( cusolver_h, trend_len, 2, A_d.data(), 2, &geqrf_buffer)); - RAFT_CUSOLVER_TRY(raft::linalg::cusolverDnorgqr_bufferSize( + // #TODO: Call from public API when ready + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnorgqr_bufferSize( cusolver_h, trend_len, 2, 2, A_d.data(), 2, tau_d.data(), &orgqr_buffer)); lwork_size = geqrf_buffer > orgqr_buffer ? geqrf_buffer : orgqr_buffer; rmm::device_uvector lwork_d(lwork_size, stream); // QR decomposition of A - RAFT_CUSOLVER_TRY(raft::linalg::cusolverDngeqrf(cusolver_h, - trend_len, - 2, - A_d.data(), - trend_len, - tau_d.data(), - lwork_d.data(), - lwork_size, - dev_info_d.data(), - stream)); + // #TODO: Call from public API when ready + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngeqrf(cusolver_h, + trend_len, + 2, + A_d.data(), + trend_len, + tau_d.data(), + lwork_d.data(), + lwork_size, + dev_info_d.data(), + stream)); // Single thread kenrel to inverse R RinvKernel<<<1, 1, 0, stream>>>(A_d.data(), Rinv_d.data(), trend_len); // R1QT = inv(R)*transpose(Q) - RAFT_CUSOLVER_TRY(raft::linalg::cusolverDnorgqr(cusolver_h, - trend_len, - 2, - 2, - A_d.data(), - trend_len, - tau_d.data(), - lwork_d.data(), - lwork_size, - dev_info_d.data(), - stream)); + // #TODO: Call from public API when ready + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnorgqr(cusolver_h, + trend_len, + 2, + 2, + A_d.data(), + trend_len, + tau_d.data(), + lwork_d.data(), + lwork_size, + dev_info_d.data(), + stream)); - RAFT_CUBLAS_TRY(raft::linalg::cublasgemm(cublas_h, - CUBLAS_OP_N, - CUBLAS_OP_T, - 2, - trend_len, - 2, - &one, - Rinv_d.data(), - 2, - A_d.data(), - trend_len, - &zero, - R1Qt_d.data(), - 2, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_h, + CUBLAS_OP_N, + CUBLAS_OP_T, + 2, + trend_len, + 2, + &one, + Rinv_d.data(), + 2, + A_d.data(), + trend_len, + &zero, + R1Qt_d.data(), + 2, + stream)); batched_ls_solver_kernel <<>>( @@ -277,20 +286,21 @@ void stl_decomposition_gpu(const raft::handle_t& handle, if (seasonal == ML::SeasonalType::ADDITIVE) { const Dtype one = 1.; const Dtype minus_one = -1.; - RAFT_CUBLAS_TRY(raft::linalg::cublasgeam(cublas_h, - CUBLAS_OP_N, - CUBLAS_OP_N, - trend_len, - batch_size, - &one, - ts + ts_offset, - trend_len, - &minus_one, - trend_d.data(), - trend_len, - season_d.data(), - trend_len, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgeam(cublas_h, + CUBLAS_OP_N, + CUBLAS_OP_N, + trend_len, + batch_size, + &one, + ts + ts_offset, + trend_len, + &minus_one, + trend_d.data(), + trend_len, + season_d.data(), + trend_len, + stream)); } else { rmm::device_uvector aligned_ts(batch_size * trend_len, stream); raft::copy(aligned_ts.data(), ts + ts_offset, batch_size * trend_len, stream); diff --git a/cpp/src/holtwinters/internal/hw_utils.cuh b/cpp/src/holtwinters/internal/hw_utils.cuh index a32f3d7591..ca2c578ad1 100644 --- a/cpp/src/holtwinters/internal/hw_utils.cuh +++ b/cpp/src/holtwinters/internal/hw_utils.cuh @@ -19,9 +19,7 @@ #include #include #include -#include -#include -#include +#include #include #include diff --git a/cpp/src/holtwinters/runner.cuh b/cpp/src/holtwinters/runner.cuh index 59f6059e16..e06bd50543 100644 --- a/cpp/src/holtwinters/runner.cuh +++ b/cpp/src/holtwinters/runner.cuh @@ -22,7 +22,9 @@ #include "internal/hw_optim.cuh" #include #include -#include +// #TODO: Replace with public header when ready +#include +#include #include namespace ML { @@ -89,8 +91,9 @@ void HoltWintersDecompose(const raft::handle_t& handle, raft::copy(start_level, ts + batch_size, batch_size, stream); raft::copy(start_trend, ts + batch_size, batch_size, stream); const Dtype alpha = -1.; - RAFT_CUBLAS_TRY( - raft::linalg::cublasaxpy(cublas_h, batch_size, &alpha, ts, 1, start_trend, 1, stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasaxpy( + cublas_h, batch_size, &alpha, ts, 1, start_trend, 1, stream)); // cublas::axpy(batch_size, (Dtype)-1., ts, start_trend); } else if (start_level != nullptr && start_trend != nullptr && start_season != nullptr) { stl_decomposition_gpu(handle_impl, diff --git a/cpp/src/kmeans/common.cuh b/cpp/src/kmeans/common.cuh index 125a1a82b5..811e155439 100644 --- a/cpp/src/kmeans/common.cuh +++ b/cpp/src/kmeans/common.cuh @@ -32,10 +32,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include diff --git a/cpp/src/metrics/silhouette_score.cu b/cpp/src/metrics/silhouette_score.cu index 4c3e7ccc87..c80fe099f1 100644 --- a/cpp/src/metrics/silhouette_score.cu +++ b/cpp/src/metrics/silhouette_score.cu @@ -18,7 +18,7 @@ #include #include #include -#include +#include namespace ML { diff --git a/cpp/src/pca/pca.cuh b/cpp/src/pca/pca.cuh index 0e886a1ef6..9261ea127c 100644 --- a/cpp/src/pca/pca.cuh +++ b/cpp/src/pca/pca.cuh @@ -19,10 +19,9 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include #include #include #include diff --git a/cpp/src/pca/pca_mg.cu b/cpp/src/pca/pca_mg.cu index 29fd0a1722..87b7fee68d 100644 --- a/cpp/src/pca/pca_mg.cu +++ b/cpp/src/pca/pca_mg.cu @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/src/random_projection/rproj.cuh b/cpp/src/random_projection/rproj.cuh index f266e24664..83f96105c6 100644 --- a/cpp/src/random_projection/rproj.cuh +++ b/cpp/src/random_projection/rproj.cuh @@ -22,7 +22,8 @@ #include #include -#include +// #TODO: Replace with public header when ready +#include #include #include @@ -162,21 +163,22 @@ void RPROJtransform(const raft::handle_t& handle, auto& ldb = k; auto& ldc = m; - RAFT_CUBLAS_TRY(raft::linalg::cublasgemm(cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - params->n_samples, - n, - k, - &alfa, - input, - lda, - random_matrix->dense_data.data(), - ldb, - &beta, - output, - ldc, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + params->n_samples, + n, + k, + &alfa, + input, + lda, + random_matrix->dense_data.data(), + ldb, + &beta, + output, + ldc, + stream)); } else if (random_matrix->type == sparse) { cusparseHandle_t cusparse_handle = handle.get_cusparse_handle(); diff --git a/cpp/src/solver/cd.cuh b/cpp/src/solver/cd.cuh index 142a1fbc9f..b534358273 100644 --- a/cpp/src/solver/cd.cuh +++ b/cpp/src/solver/cd.cuh @@ -24,13 +24,14 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +// #TODO: Replace with public header when ready +#include +#include +#include +#include +#include +#include #include #include diff --git a/cpp/src/solver/cd_mg.cu b/cpp/src/solver/cd_mg.cu index cc9e409c67..837d0d83ba 100644 --- a/cpp/src/solver/cd_mg.cu +++ b/cpp/src/solver/cd_mg.cu @@ -28,11 +28,11 @@ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include diff --git a/cpp/src/solver/lars_impl.cuh b/cpp/src/solver/lars_impl.cuh index 9e5ccf91e3..09646c64a8 100644 --- a/cpp/src/solver/lars_impl.cuh +++ b/cpp/src/solver/lars_impl.cuh @@ -26,12 +26,13 @@ #include #include #include -#include -#include -#include -#include -#include -#include +#include +#include +// #TODO: Replace with public header when ready +#include +#include +#include +#include #include #include #include @@ -150,17 +151,22 @@ void swapFeatures(cublasHandle_t handle, { std::swap(indices[j], indices[k]); if (G) { + // #TODO: Call from public API when ready RAFT_CUBLAS_TRY( - raft::linalg::cublasSwap(handle, n_cols, G + ld_G * j, 1, G + ld_G * k, 1, stream)); - RAFT_CUBLAS_TRY(raft::linalg::cublasSwap(handle, n_cols, G + j, ld_G, G + k, ld_G, stream)); + raft::linalg::detail::cublasSwap(handle, n_cols, G + ld_G * j, 1, G + ld_G * k, 1, stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY( + raft::linalg::detail::cublasSwap(handle, n_cols, G + j, ld_G, G + k, ld_G, stream)); } else { // Only swap X if G is nullptr. Only in that case will we use the feature // columns, otherwise all the necessary information is already there in G. + // #TODO: Call from public API when ready RAFT_CUBLAS_TRY( - raft::linalg::cublasSwap(handle, n_rows, X + ld_X * j, 1, X + ld_X * k, 1, stream)); + raft::linalg::detail::cublasSwap(handle, n_rows, X + ld_X * j, 1, X + ld_X * k, 1, stream)); } // swap (c[j], c[k]) - RAFT_CUBLAS_TRY(raft::linalg::cublasSwap(handle, 1, cor + j, 1, cor + k, 1, stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasSwap(handle, 1, cor + j, 1, cor + k, 1, stream)); } /** @@ -280,19 +286,20 @@ void updateCholesky(const raft::handle_t& handle, const math_t* X_row = X + (n_active - 1) * ld_X; math_t one = 1; math_t zero = 0; - RAFT_CUBLAS_TRY(raft::linalg::cublasgemv(handle.get_cublas_handle(), - CUBLAS_OP_T, - n_rows, - n_cols, - &one, - X, - n_rows, - X_row, - 1, - &zero, - G_row, - 1, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_T, + n_rows, + n_cols, + &one, + X, + n_rows, + X_row, + 1, + &zero, + G_row, + 1, + stream)); } else if (G0 != U) { // Copy the new column of G0 into U, because the factorization works in // place. @@ -342,34 +349,36 @@ void calcW0(const raft::handle_t& handle, // First we calculate x by solving equation U.T x = sign_A. raft::copy(ws, sign, n_active, stream); math_t alpha = 1; - RAFT_CUBLAS_TRY(raft::linalg::cublastrsm(handle.get_cublas_handle(), - CUBLAS_SIDE_LEFT, - fillmode, - CUBLAS_OP_T, - CUBLAS_DIAG_NON_UNIT, - n_active, - 1, - &alpha, - U, - ld_U, - ws, - ld_U, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(handle.get_cublas_handle(), + CUBLAS_SIDE_LEFT, + fillmode, + CUBLAS_OP_T, + CUBLAS_DIAG_NON_UNIT, + n_active, + 1, + &alpha, + U, + ld_U, + ws, + ld_U, + stream)); // ws stores x, the solution of U.T x = sign_A. Now we solve U * ws = x - RAFT_CUBLAS_TRY(raft::linalg::cublastrsm(handle.get_cublas_handle(), - CUBLAS_SIDE_LEFT, - fillmode, - CUBLAS_OP_N, - CUBLAS_DIAG_NON_UNIT, - n_active, - 1, - &alpha, - U, - ld_U, - ws, - ld_U, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(handle.get_cublas_handle(), + CUBLAS_SIDE_LEFT, + fillmode, + CUBLAS_OP_N, + CUBLAS_DIAG_NON_UNIT, + n_active, + 1, + &alpha, + U, + ld_U, + ws, + ld_U, + stream)); // Now ws = G0^(-1) sign_A = S GA^{-1} 1_A. } @@ -513,19 +522,20 @@ LarsFitStatus calcEquiangularVec(const raft::handle_t& handle, // Calculate u_eq only in the case if the Gram matrix is not stored. math_t one = 1; math_t zero = 0; - RAFT_CUBLAS_TRY(raft::linalg::cublasgemv(handle.get_cublas_handle(), - CUBLAS_OP_N, - n_rows, - n_active, - &one, - X, - ld_X, - ws, - 1, - &zero, - u_eq, - 1, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_rows, + n_active, + &one, + X, + ld_X, + ws, + 1, + &zero, + u_eq, + 1, + stream)); } return LarsFitStatus::kOk; } @@ -601,37 +611,39 @@ void calcMaxStep(const raft::handle_t& handle, // Calculate a = X.T[:,n_active:] * u (2.11) math_t one = 1; math_t zero = 0; - RAFT_CUBLAS_TRY(raft::linalg::cublasgemv(handle.get_cublas_handle(), - CUBLAS_OP_T, - n_rows, - n_inactive, - &one, - X + n_active * ld_X, - ld_X, - u, - 1, - &zero, - a_vec, - 1, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_T, + n_rows, + n_inactive, + &one, + X + n_active * ld_X, + ld_X, + u, + 1, + &zero, + a_vec, + 1, + stream)); } else { // Calculate a = X.T[:,n_A:] * u = X.T[:, n_A:] * X[:,:n_A] * ws // = G[n_A:,:n_A] * ws (2.11) math_t one = 1; math_t zero = 0; - RAFT_CUBLAS_TRY(raft::linalg::cublasgemv(handle.get_cublas_handle(), - CUBLAS_OP_N, - n_inactive, - n_active, - &one, - G + n_active, - ld_G, - ws, - 1, - &zero, - a_vec, - 1, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_inactive, + n_active, + &one, + G + n_active, + ld_G, + ws, + 1, + &zero, + a_vec, + 1, + stream)); } const math_t tiny = std::numeric_limits::min(); const math_t huge = std::numeric_limits::max(); @@ -719,19 +731,20 @@ void larsInit(const raft::handle_t& handle, math_t one = 1; math_t zero = 0; // Set initial correlation to X.T * y - RAFT_CUBLAS_TRY(raft::linalg::cublasgemv(handle.get_cublas_handle(), - CUBLAS_OP_T, - n_rows, - n_cols, - &one, - X, - ld_X, - y, - 1, - &zero, - cor.data(), - 1, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_T, + n_rows, + n_cols, + &one, + X, + ld_X, + y, + 1, + &zero, + cor.data(), + 1, + stream)); if (coef_path) { RAFT_CUDA_TRY( cudaMemsetAsync(coef_path, 0, sizeof(math_t) * (*max_iter + 1) * (*max_iter), stream)); @@ -1110,19 +1123,20 @@ void larsPredict(const raft::handle_t& handle, thrust::device_ptr pred_ptr(preds); thrust::fill(execution_policy, pred_ptr, pred_ptr + n_rows, intercept); math_t one = 1; - RAFT_CUBLAS_TRY(raft::linalg::cublasgemv(handle.get_cublas_handle(), - CUBLAS_OP_N, - n_rows, - n_active, - &one, - X, - ld_X, - beta, - 1, - &one, - preds, - 1, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_rows, + n_active, + &one, + X, + ld_X, + beta, + 1, + &one, + preds, + 1, + stream)); } }; // namespace Lars }; // namespace Solver diff --git a/cpp/src/solver/sgd.cuh b/cpp/src/solver/sgd.cuh index 7dc5f9d278..a3b77fbdd4 100644 --- a/cpp/src/solver/sgd.cuh +++ b/cpp/src/solver/sgd.cuh @@ -25,13 +25,12 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include #include #include diff --git a/cpp/src/svm/kernelcache.cuh b/cpp/src/svm/kernelcache.cuh index f99efa223c..d1a011df49 100644 --- a/cpp/src/svm/kernelcache.cuh +++ b/cpp/src/svm/kernelcache.cuh @@ -25,7 +25,7 @@ #include #include -#include +#include #include #include #include diff --git a/cpp/src/svm/linear.cu b/cpp/src/svm/linear.cu index eab6141af5..e06d87a2fc 100644 --- a/cpp/src/svm/linear.cu +++ b/cpp/src/svm/linear.cu @@ -26,13 +26,12 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include #include #include diff --git a/cpp/src/svm/results.cuh b/cpp/src/svm/results.cuh index 7c280bb224..85123d224c 100644 --- a/cpp/src/svm/results.cuh +++ b/cpp/src/svm/results.cuh @@ -26,10 +26,9 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include #include #include #include diff --git a/cpp/src/svm/smosolver.cuh b/cpp/src/svm/smosolver.cuh index 6e5367a8bf..f8b52c73a2 100644 --- a/cpp/src/svm/smosolver.cuh +++ b/cpp/src/svm/smosolver.cuh @@ -22,9 +22,10 @@ #include #include -#include -#include -#include +// #TODO: Replace with public header when ready +#include +#include +#include #include #include @@ -44,9 +45,8 @@ #include #include #include -#include -#include -#include +#include +#include #include "results.cuh" @@ -218,35 +218,37 @@ class SmoSolver { { // multipliers used in the equation : f = 1*cachtile * delta_alpha + 1*f math_t one = 1; - RAFT_CUBLAS_TRY(raft::linalg::cublasgemv(handle.get_cublas_handle(), - CUBLAS_OP_N, - n_rows, - n_ws, - &one, - cacheTile, - n_rows, - delta_alpha, - 1, - &one, - f, - 1, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_rows, + n_ws, + &one, + cacheTile, + n_rows, + delta_alpha, + 1, + &one, + f, + 1, + stream)); if (svmType == EPSILON_SVR) { // SVR has doubled the number of trainig vectors and we need to update // alpha for both batches individually - RAFT_CUBLAS_TRY(raft::linalg::cublasgemv(handle.get_cublas_handle(), - CUBLAS_OP_N, - n_rows, - n_ws, - &one, - cacheTile, - n_rows, - delta_alpha, - 1, - &one, - f + n_rows, - 1, - stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_rows, + n_ws, + &one, + cacheTile, + n_rows, + delta_alpha, + 1, + &one, + f + n_rows, + 1, + stream)); } } diff --git a/cpp/src/svm/svc.cu b/cpp/src/svm/svc.cu index 4c2c18951f..ea3d7032f6 100644 --- a/cpp/src/svm/svc.cu +++ b/cpp/src/svm/svc.cu @@ -22,8 +22,7 @@ #include #include