From fee58673984da1b56f0e83c840d4ec1cba0ed28e Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 22 Dec 2022 20:56:54 +0100 Subject: [PATCH 1/2] Add `detail` namespace to linear model --- cpp/include/cuml/linear_model/glm.hpp | 14 +- cpp/src/glm/glm.cu | 201 +++++++++++--------------- cpp/src/glm/ols.cuh | 16 +- cpp/src/glm/qn/glm_base.cuh | 13 +- cpp/src/glm/qn/qn.cuh | 88 +++++------ cpp/src/glm/qn/qn_linesearch.cuh | 3 +- cpp/src/glm/qn/qn_solvers.cuh | 5 +- cpp/src/glm/qn/qn_util.cuh | 3 +- cpp/src/glm/ridge.cuh | 5 +- cpp/src/svm/linear.cu | 6 +- cpp/test/sg/quasi_newton.cu | 23 +-- 11 files changed, 163 insertions(+), 214 deletions(-) diff --git a/cpp/include/cuml/linear_model/glm.hpp b/cpp/include/cuml/linear_model/glm.hpp index 912b07cc26..4520996181 100644 --- a/cpp/include/cuml/linear_model/glm.hpp +++ b/cpp/include/cuml/linear_model/glm.hpp @@ -149,7 +149,9 @@ void gemmPredict(const raft::handle_t& handle, * overwritten by final result. * @param f host pointer holding the final objective value * @param num_iters host pointer holding the actual number of iterations taken - * @param sample_weight + * @param sample_weight device pointer to sample weight vector of length n_rows (nullptr + for uniform weights) + * @param svr_eps epsilon parameter for svr */ template void qnFit(const raft::handle_t& cuml_handle, @@ -163,7 +165,8 @@ void qnFit(const raft::handle_t& cuml_handle, T* w0, T* f, int* num_iters, - T* sample_weight = nullptr); + T* sample_weight = nullptr, + T svr_eps = 0); /** * @brief Fit a GLM using quasi newton methods. @@ -183,7 +186,9 @@ void qnFit(const raft::handle_t& cuml_handle, * overwritten by final result. * @param f host pointer holding the final objective value * @param num_iters host pointer holding the actual number of iterations taken - * @param sample_weight + * @param sample_weight device pointer to sample weight vector of length n_rows (nullptr + for uniform weights) + * @param svr_eps epsilon parameter for svr */ template void qnFitSparse(const raft::handle_t& cuml_handle, @@ -199,7 +204,8 @@ void qnFitSparse(const raft::handle_t& cuml_handle, T* w0, T* f, int* num_iters, - T* sample_weight = nullptr); + T* sample_weight = nullptr, + T svr_eps = 0); /** * @brief Obtain the confidence scores of samples diff --git a/cpp/src/glm/glm.cu b/cpp/src/glm/glm.cu index 0fb990c3c6..a7a828f8c8 100644 --- a/cpp/src/glm/glm.cu +++ b/cpp/src/glm/glm.cu @@ -38,18 +38,17 @@ void olsFit(const raft::handle_t& handle, int algo, float* sample_weight) { - olsFit(handle, - input, - n_rows, - n_cols, - labels, - coef, - intercept, - fit_intercept, - normalize, - handle.get_stream(), - algo, - sample_weight); + detail::olsFit(handle, + input, + n_rows, + n_cols, + labels, + coef, + intercept, + fit_intercept, + normalize, + algo, + sample_weight); } void olsFit(const raft::handle_t& handle, @@ -64,18 +63,17 @@ void olsFit(const raft::handle_t& handle, int algo, double* sample_weight) { - olsFit(handle, - input, - n_rows, - n_cols, - labels, - coef, - intercept, - fit_intercept, - normalize, - handle.get_stream(), - algo, - sample_weight); + detail::olsFit(handle, + input, + n_rows, + n_cols, + labels, + coef, + intercept, + fit_intercept, + normalize, + algo, + sample_weight); } void gemmPredict(const raft::handle_t& handle, @@ -86,7 +84,7 @@ void gemmPredict(const raft::handle_t& handle, float intercept, float* preds) { - gemmPredict(handle, input, n_rows, n_cols, coef, intercept, preds, handle.get_stream()); + detail::gemmPredict(handle, input, n_rows, n_cols, coef, intercept, preds); } void gemmPredict(const raft::handle_t& handle, @@ -97,7 +95,7 @@ void gemmPredict(const raft::handle_t& handle, double intercept, double* preds) { - gemmPredict(handle, input, n_rows, n_cols, coef, intercept, preds, handle.get_stream()); + detail::gemmPredict(handle, input, n_rows, n_cols, coef, intercept, preds); } void ridgeFit(const raft::handle_t& handle, @@ -114,20 +112,19 @@ void ridgeFit(const raft::handle_t& handle, int algo, float* sample_weight) { - ridgeFit(handle, - input, - n_rows, - n_cols, - labels, - alpha, - n_alpha, - coef, - intercept, - fit_intercept, - normalize, - handle.get_stream(), - algo, - sample_weight); + detail::ridgeFit(handle, + input, + n_rows, + n_cols, + labels, + alpha, + n_alpha, + coef, + intercept, + fit_intercept, + normalize, + algo, + sample_weight); } void ridgeFit(const raft::handle_t& handle, @@ -144,20 +141,19 @@ void ridgeFit(const raft::handle_t& handle, int algo, double* sample_weight) { - ridgeFit(handle, - input, - n_rows, - n_cols, - labels, - alpha, - n_alpha, - coef, - intercept, - fit_intercept, - normalize, - handle.get_stream(), - algo, - sample_weight); + detail::ridgeFit(handle, + input, + n_rows, + n_cols, + labels, + alpha, + n_alpha, + coef, + intercept, + fit_intercept, + normalize, + algo, + sample_weight); } template @@ -172,21 +168,11 @@ void qnFit(const raft::handle_t& cuml_handle, T* w0, T* f, int* num_iters, - T* sample_weight) + T* sample_weight, + T svr_eps) { - qnFit(cuml_handle, - pams, - X, - X_col_major, - y, - N, - D, - C, - w0, - f, - num_iters, - cuml_handle.get_stream(), - sample_weight); + detail::qnFit( + cuml_handle, pams, X, X_col_major, y, N, D, C, w0, f, num_iters, sample_weight, svr_eps); } template void qnFit(const raft::handle_t&, @@ -200,7 +186,8 @@ template void qnFit(const raft::handle_t&, float*, float*, int*, - float*); + float*, + float); template void qnFit(const raft::handle_t&, const qn_params&, double*, @@ -212,7 +199,8 @@ template void qnFit(const raft::handle_t&, double*, double*, int*, - double*); + double*, + double); template void qnFitSparse(const raft::handle_t& cuml_handle, @@ -228,23 +216,24 @@ void qnFitSparse(const raft::handle_t& cuml_handle, T* w0, T* f, int* num_iters, - T* sample_weight) + T* sample_weight, + T svr_eps) { - qnFitSparse(cuml_handle, - pams, - X_values, - X_cols, - X_row_ids, - X_nnz, - y, - N, - D, - C, - w0, - f, - num_iters, - cuml_handle.get_stream(), - sample_weight); + detail::qnFitSparse(cuml_handle, + pams, + X_values, + X_cols, + X_row_ids, + X_nnz, + y, + N, + D, + C, + w0, + f, + num_iters, + sample_weight, + svr_eps); } template void qnFitSparse(const raft::handle_t&, @@ -260,7 +249,8 @@ template void qnFitSparse(const raft::handle_t&, float*, float*, int*, - float*); + float*, + float); template void qnFitSparse(const raft::handle_t&, const qn_params&, double*, @@ -274,7 +264,8 @@ template void qnFitSparse(const raft::handle_t&, double*, double*, int*, - double*); + double*, + double); template void qnDecisionFunction(const raft::handle_t& cuml_handle, @@ -287,8 +278,7 @@ void qnDecisionFunction(const raft::handle_t& cuml_handle, T* params, T* scores) { - qnDecisionFunction( - cuml_handle, pams, X, X_col_major, N, D, C, params, scores, cuml_handle.get_stream()); + detail::qnDecisionFunction(cuml_handle, pams, X, X_col_major, N, D, C, params, scores); } template void qnDecisionFunction( @@ -309,18 +299,8 @@ void qnDecisionFunctionSparse(const raft::handle_t& cuml_handle, T* params, T* scores) { - qnDecisionFunctionSparse(cuml_handle, - pams, - X_values, - X_cols, - X_row_ids, - X_nnz, - N, - D, - C, - params, - scores, - cuml_handle.get_stream()); + detail::qnDecisionFunctionSparse( + cuml_handle, pams, X_values, X_cols, X_row_ids, X_nnz, N, D, C, params, scores); } template void qnDecisionFunctionSparse( @@ -348,8 +328,7 @@ void qnPredict(const raft::handle_t& cuml_handle, T* params, T* scores) { - qnPredict( - cuml_handle, pams, X, X_col_major, N, D, C, params, scores, cuml_handle.get_stream()); + detail::qnPredict(cuml_handle, pams, X, X_col_major, N, D, C, params, scores); } template void qnPredict( @@ -370,18 +349,8 @@ void qnPredictSparse(const raft::handle_t& cuml_handle, T* params, T* preds) { - qnPredictSparse(cuml_handle, - pams, - X_values, - X_cols, - X_row_ids, - X_nnz, - N, - D, - C, - params, - preds, - cuml_handle.get_stream()); + detail::qnPredictSparse( + cuml_handle, pams, X_values, X_cols, X_row_ids, X_nnz, N, D, C, params, preds); } template void qnPredictSparse( diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index f316e9eb9a..77878fa615 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -36,6 +36,7 @@ namespace ML { namespace GLM { +namespace detail { /** * @brief fit an ordinary least squares model @@ -48,7 +49,6 @@ namespace GLM { * @param intercept host pointer to hold the solution for bias term of size 1 * @param fit_intercept if true, fit intercept * @param normalize if true, normalize data to zero mean, unit variance - * @param stream cuda stream * @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2: * QR-decomposition) * @param sample_weight device pointer to sample weight vector of length n_rows (nullptr for uniform @@ -64,10 +64,10 @@ void olsFit(const raft::handle_t& handle, math_t* intercept, bool fit_intercept, bool normalize, - cudaStream_t stream, int algo = 0, math_t* sample_weight = nullptr) { + cudaStream_t stream = handle.get_stream(); auto cublas_handle = handle.get_cublas_handle(); auto cusolver_handle = handle.get_cusolver_dn_handle(); @@ -166,7 +166,6 @@ void olsFit(const raft::handle_t& handle, * @param coef coefficients of the model * @param intercept bias term of the model * @param preds device pointer to store predictions of size n_rows - * @param stream cuda stream */ template void gemmPredict(const raft::handle_t& handle, @@ -175,14 +174,14 @@ void gemmPredict(const raft::handle_t& handle, size_t n_cols, const math_t* coef, math_t intercept, - math_t* preds, - cudaStream_t stream) + math_t* preds) { ASSERT(n_cols > 0, "gemmPredict: number of columns cannot be less than one"); ASSERT(n_rows > 0, "gemmPredict: number of rows cannot be less than one"); - math_t alpha = math_t(1); - math_t beta = math_t(0); + cudaStream_t stream = handle.get_stream(); + math_t alpha = math_t(1); + math_t beta = math_t(0); raft::linalg::gemm(handle, input, n_rows, @@ -199,7 +198,6 @@ void gemmPredict(const raft::handle_t& handle, if (intercept != math_t(0)) raft::linalg::addScalar(preds, preds, intercept, n_rows, stream); } - +}; // namespace detail }; // namespace GLM }; // namespace ML -// end namespace ML diff --git a/cpp/src/glm/qn/glm_base.cuh b/cpp/src/glm/qn/glm_base.cuh index 2f669ba59f..c9756bdfde 100644 --- a/cpp/src/glm/qn/glm_base.cuh +++ b/cpp/src/glm/qn/glm_base.cuh @@ -37,9 +37,9 @@ template inline void linearFwd(const raft::handle_t& handle, SimpleDenseMat& Z, const SimpleMat& X, - const SimpleDenseMat& W, - cudaStream_t stream) + const SimpleDenseMat& W) { + cudaStream_t stream = handle.get_stream(); // Forward pass: compute Z <- W * X.T + bias const bool has_bias = X.n != W.n; const int D = X.n; @@ -66,9 +66,9 @@ inline void linearBwd(const raft::handle_t& handle, SimpleDenseMat& G, const SimpleMat& X, const SimpleDenseMat& dZ, - bool setZero, - cudaStream_t stream) + bool setZero) { + cudaStream_t stream = handle.get_stream(); // Backward pass: // - compute G <- dZ * X.T // - for bias: Gb = mean(dZ, 1) @@ -193,10 +193,9 @@ struct GLMBase : GLMDims { { Loss* loss = static_cast(this); // static polymorphism - linearFwd(handle, Zb, Xb, W, stream); // linear part: forward pass + linearFwd(handle, Zb, Xb, W); // linear part: forward pass loss->getLossAndDZ(loss_val, Zb, yb, stream); // loss specific part - linearBwd(handle, G, Xb, Zb, initGradZero, - stream); // linear part: backward pass + linearBwd(handle, G, Xb, Zb, initGradZero); // linear part: backward pass } }; diff --git a/cpp/src/glm/qn/qn.cuh b/cpp/src/glm/qn/qn.cuh index 6fa67b653f..a96da1a0af 100644 --- a/cpp/src/glm/qn/qn.cuh +++ b/cpp/src/glm/qn/qn.cuh @@ -32,6 +32,7 @@ namespace ML { namespace GLM { +namespace detail { template int qn_fit(const raft::handle_t& handle, @@ -42,9 +43,9 @@ int qn_fit(const raft::handle_t& handle, SimpleDenseMat& Z, T* w0_data, // initial value and result T* fx, - int* num_iters, - cudaStream_t stream) + int* num_iters) { + cudaStream_t stream = handle.get_stream(); LBFGSParam opt_param(pams); SimpleVec w0(w0_data, loss.n_param); @@ -59,14 +60,14 @@ int qn_fit(const raft::handle_t& handle, if (l2 == 0) { GLMWithData lossWith(&loss, X, y, Z); - return qn_minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param, stream, pams.verbose); + return qn_minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param, pams.verbose); } else { Tikhonov reg(l2); RegularizedGLM obj(&loss, ®); GLMWithData lossWith(&obj, X, y, Z); - return qn_minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param, stream, pams.verbose); + return qn_minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param, pams.verbose); } } @@ -79,7 +80,6 @@ inline void qn_fit_x(const raft::handle_t& handle, T* w0_data, T* f, int* num_iters, - cudaStream_t stream, T* sample_weight = nullptr, T svr_eps = 0) { @@ -95,9 +95,10 @@ inline void qn_fit_x(const raft::handle_t& handle, Dimensionality of w0 depends on loss, so we initialize it later. */ - int N = X.m; - int D = X.n; - int n_targets = qn_is_classification(pams.loss) && C == 2 ? 1 : C; + cudaStream_t stream = handle.get_stream(); + int N = X.m; + int D = X.n; + int n_targets = qn_is_classification(pams.loss) && C == 2 ? 1 : C; rmm::device_uvector tmp(n_targets * N, stream); SimpleDenseMat Z(tmp.data(), n_targets, N); SimpleVec y(y_data, N); @@ -107,49 +108,49 @@ inline void qn_fit_x(const raft::handle_t& handle, ASSERT(C == 2, "qn.h: logistic loss invalid C"); LogisticLoss loss(handle, D, pams.fit_intercept); if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters); } break; case QN_LOSS_SQUARED: { ASSERT(C == 1, "qn.h: squared loss invalid C"); SquaredLoss loss(handle, D, pams.fit_intercept); if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters); } break; case QN_LOSS_SOFTMAX: { ASSERT(C > 2, "qn.h: softmax invalid C"); Softmax loss(handle, D, C, pams.fit_intercept); if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters); } break; case QN_LOSS_SVC_L1: { ASSERT(C == 2, "qn.h: SVC-L1 loss invalid C"); SVCL1Loss loss(handle, D, pams.fit_intercept); if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters); } break; case QN_LOSS_SVC_L2: { ASSERT(C == 2, "qn.h: SVC-L2 loss invalid C"); SVCL2Loss loss(handle, D, pams.fit_intercept); if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters); } break; case QN_LOSS_SVR_L1: { ASSERT(C == 1, "qn.h: SVR-L1 loss invalid C"); SVRL1Loss loss(handle, D, pams.fit_intercept, svr_eps); if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters); } break; case QN_LOSS_SVR_L2: { ASSERT(C == 1, "qn.h: SVR-L2 loss invalid C"); SVRL2Loss loss(handle, D, pams.fit_intercept, svr_eps); if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters); } break; case QN_LOSS_ABS: { ASSERT(C == 1, "qn.h: abs loss (L1) invalid C"); AbsLoss loss(handle, D, pams.fit_intercept); if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters); } break; default: { ASSERT(false, "qn.h: unknown loss function type (id = %d).", pams.loss); @@ -169,12 +170,11 @@ void qnFit(const raft::handle_t& handle, T* w0_data, T* f, int* num_iters, - cudaStream_t stream, T* sample_weight = nullptr, T svr_eps = 0) { SimpleDenseMat X(X_data, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR); - qn_fit_x(handle, pams, X, y_data, C, w0_data, f, num_iters, stream, sample_weight, svr_eps); + qn_fit_x(handle, pams, X, y_data, C, w0_data, f, num_iters, sample_weight, svr_eps); } template @@ -191,22 +191,16 @@ void qnFitSparse(const raft::handle_t& handle, T* w0_data, T* f, int* num_iters, - cudaStream_t stream, T* sample_weight = nullptr, T svr_eps = 0) { SimpleSparseMat X(X_values, X_cols, X_row_ids, X_nnz, N, D); - qn_fit_x(handle, pams, X, y_data, C, w0_data, f, num_iters, stream, sample_weight, svr_eps); + qn_fit_x(handle, pams, X, y_data, C, w0_data, f, num_iters, sample_weight, svr_eps); } template -void qn_decision_function(const raft::handle_t& handle, - const qn_params& pams, - SimpleMat& X, - int C, - T* params, - T* scores, - cudaStream_t stream) +void qn_decision_function( + const raft::handle_t& handle, const qn_params& pams, SimpleMat& X, int C, T* params, T* scores) { // NOTE: While gtests pass X as row-major, and python API passes X as // col-major, no extensive testing has been done to ensure that @@ -215,7 +209,7 @@ void qn_decision_function(const raft::handle_t& handle, GLMDims dims(n_targets, X.n, pams.fit_intercept); SimpleDenseMat W(params, n_targets, dims.dims); SimpleDenseMat Z(scores, n_targets, X.m); - linearFwd(handle, Z, X, W, stream); + linearFwd(handle, Z, X, W); } template @@ -227,11 +221,10 @@ void qnDecisionFunction(const raft::handle_t& handle, int D, int C, T* params, - T* scores, - cudaStream_t stream) + T* scores) { SimpleDenseMat X(Xptr, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR); - qn_decision_function(handle, pams, X, C, params, scores, stream); + qn_decision_function(handle, pams, X, C, params, scores); } template @@ -245,26 +238,21 @@ void qnDecisionFunctionSparse(const raft::handle_t& handle, int D, int C, T* params, - T* scores, - cudaStream_t stream) + T* scores) { SimpleSparseMat X(X_values, X_cols, X_row_ids, X_nnz, N, D); - qn_decision_function(handle, pams, X, C, params, scores, stream); + qn_decision_function(handle, pams, X, C, params, scores); } template -void qn_predict(const raft::handle_t& handle, - const qn_params& pams, - SimpleMat& X, - int C, - T* params, - T* preds, - cudaStream_t stream) +void qn_predict( + const raft::handle_t& handle, const qn_params& pams, SimpleMat& X, int C, T* params, T* preds) { - bool is_class = qn_is_classification(pams.loss); - int n_targets = is_class && C == 2 ? 1 : C; + cudaStream_t stream = handle.get_stream(); + bool is_class = qn_is_classification(pams.loss); + int n_targets = is_class && C == 2 ? 1 : C; rmm::device_uvector scores(n_targets * X.m, stream); - qn_decision_function(handle, pams, X, C, params, scores.data(), stream); + qn_decision_function(handle, pams, X, C, params, scores.data()); SimpleDenseMat Z(scores.data(), n_targets, X.m); SimpleDenseMat P(preds, 1, X.m); @@ -289,11 +277,10 @@ void qnPredict(const raft::handle_t& handle, int D, int C, T* params, - T* preds, - cudaStream_t stream) + T* preds) { SimpleDenseMat X(Xptr, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR); - qn_predict(handle, pams, X, C, params, preds, stream); + qn_predict(handle, pams, X, C, params, preds); } template @@ -307,12 +294,11 @@ void qnPredictSparse(const raft::handle_t& handle, int D, int C, T* params, - T* preds, - cudaStream_t stream) + T* preds) { SimpleSparseMat X(X_values, X_cols, X_row_ids, X_nnz, N, D); - qn_predict(handle, pams, X, C, params, preds, stream); + qn_predict(handle, pams, X, C, params, preds); } - +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/qn_linesearch.cuh b/cpp/src/glm/qn/qn_linesearch.cuh index b4a6c84e76..0bc794755e 100644 --- a/cpp/src/glm/qn/qn_linesearch.cuh +++ b/cpp/src/glm/qn/qn_linesearch.cuh @@ -23,6 +23,7 @@ namespace ML { namespace GLM { +namespace detail { template struct LSProjectedStep { @@ -205,6 +206,6 @@ LINE_SEARCH_RETCODE ls_backtrack_projected(const LBFGSParam& param, } return LS_MAX_ITERS_REACHED; } - +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/qn_solvers.cuh b/cpp/src/glm/qn/qn_solvers.cuh index d52bdffeeb..f36a62b0c7 100644 --- a/cpp/src/glm/qn/qn_solvers.cuh +++ b/cpp/src/glm/qn/qn_solvers.cuh @@ -49,6 +49,7 @@ namespace ML { namespace GLM { +namespace detail { // TODO better way to deal with alignment? Smaller aligne possible? constexpr size_t qn_align = 256; @@ -415,10 +416,10 @@ inline int qn_minimize(const raft::handle_t& handle, LossFunction& loss, const T l1, const LBFGSParam& opt_param, - cudaStream_t stream, const int verbosity = 0) { // TODO should the worksapce allocation happen outside? + cudaStream_t stream = handle.get_stream(); OPT_RETCODE ret; if (l1 == 0.0) { rmm::device_uvector tmp(lbfgs_workspace_size(opt_param, x.len), stream); @@ -466,6 +467,6 @@ inline int qn_minimize(const raft::handle_t& handle, } return ret; } - +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/qn_util.cuh b/cpp/src/glm/qn/qn_util.cuh index 43174f91ef..55f594e17c 100644 --- a/cpp/src/glm/qn/qn_util.cuh +++ b/cpp/src/glm/qn/qn_util.cuh @@ -24,6 +24,7 @@ namespace ML { namespace GLM { +namespace detail { enum LINE_SEARCH_ALGORITHM { LBFGS_LS_BT_ARMIJO = 1, @@ -268,6 +269,6 @@ struct op_pseudo_grad { HDI T operator()(const T x, const T dlossx) const { return get_pseudo_grad(x, dlossx, l1); } }; - +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/ridge.cuh b/cpp/src/glm/ridge.cuh index 3df7b7f2f6..92926c1809 100644 --- a/cpp/src/glm/ridge.cuh +++ b/cpp/src/glm/ridge.cuh @@ -35,6 +35,7 @@ namespace ML { namespace GLM { +namespace detail { template void ridgeSolve(const raft::handle_t& handle, @@ -162,10 +163,10 @@ void ridgeFit(const raft::handle_t& handle, math_t* intercept, bool fit_intercept, bool normalize, - cudaStream_t stream, int algo = 0, math_t* sample_weight = nullptr) { + cudaStream_t stream = handle.get_stream(); auto cublas_handle = handle.get_cublas_handle(); auto cusolver_handle = handle.get_cusolver_dn_handle(); @@ -246,6 +247,6 @@ void ridgeFit(const raft::handle_t& handle, *intercept = math_t(0); } } - +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/svm/linear.cu b/cpp/src/svm/linear.cu index a6bb7d2fab..f83f9e3b66 100644 --- a/cpp/src/svm/linear.cu +++ b/cpp/src/svm/linear.cu @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -46,9 +47,6 @@ #include #include -#include -#include - #include namespace ML { @@ -478,7 +476,6 @@ LinearSVMModel LinearSVMModel::fit(const raft::handle_t& handle, wi, &target, &num_iters, - worker.stream, (T*)sampleWeight, T(params.epsilon)); @@ -500,7 +497,6 @@ LinearSVMModel LinearSVMModel::fit(const raft::handle_t& handle, psi, &target, &num_iters, - worker.stream, (T*)sampleWeight); } if (parallel) handle.sync_stream_pool(); diff --git a/cpp/test/sg/quasi_newton.cu b/cpp/test/sg/quasi_newton.cu index 3edc054d76..4057616c85 100644 --- a/cpp/test/sg/quasi_newton.cu +++ b/cpp/test/sg/quasi_newton.cu @@ -121,7 +121,7 @@ T run(const raft::handle_t& handle, T fx; - qn_fit(handle, pams, loss, X, y, z, w, &fx, &num_iters, stream); + detail::qn_fit(handle, pams, loss, X, y, z, w, &fx, &num_iters); return fx; } @@ -555,7 +555,7 @@ TEST_F(QuasiNewtonTest, predict) pams.loss = QN_LOSS_LOGISTIC; pams.fit_intercept = false; - qnPredict(handle, pams, Xdev->data, false, N, D, 2, w.data, preds.data, stream); + qnPredict(handle, pams, Xdev->data, false, N, D, 2, w.data, preds.data); raft::update_host(&preds_host[0], preds.data, preds.len, stream); handle.sync_stream(stream); @@ -565,7 +565,7 @@ TEST_F(QuasiNewtonTest, predict) pams.loss = QN_LOSS_SQUARED; pams.fit_intercept = false; - qnPredict(handle, pams, Xdev->data, false, N, D, 1, w.data, preds.data, stream); + qnPredict(handle, pams, Xdev->data, false, N, D, 1, w.data, preds.data); raft::update_host(&preds_host[0], preds.data, preds.len, stream); handle.sync_stream(stream); @@ -591,7 +591,7 @@ TEST_F(QuasiNewtonTest, predict_softmax) qn_params pams; pams.loss = QN_LOSS_SOFTMAX; pams.fit_intercept = false; - qnPredict(handle, pams, Xdev->data, false, N, D, C, w.data, preds.data, stream); + qnPredict(handle, pams, Xdev->data, false, N, D, C, w.data, preds.data); raft::update_host(&preds_host[0], preds.data, preds.len, stream); handle.sync_stream(stream); @@ -664,16 +664,8 @@ TEST_F(QuasiNewtonTest, dense_vs_sparse_logistic) f_sparse = run(handle, loss, X_sparse, *ydev, l1, l2, w0_sparse.data, z_sparse, 0, stream); ASSERT_TRUE(compApprox(f_dense, f_sparse)); - qnPredict(handle, - pams, - Xdev->data, - Xdev->ord == COL_MAJOR, - N, - D, - C, - w0_dense.data, - preds_dense.data, - stream); + qnPredict( + handle, pams, Xdev->data, Xdev->ord == COL_MAJOR, N, D, C, w0_dense.data, preds_dense.data); qnPredictSparse(handle, pams, X_sparse.values, @@ -684,8 +676,7 @@ TEST_F(QuasiNewtonTest, dense_vs_sparse_logistic) D, C, w0_sparse.data, - preds_sparse.data, - stream); + preds_sparse.data); raft::update_host(&preds_dense_host[0], preds_dense.data, preds_dense.len, stream); raft::update_host(&preds_sparse_host[0], preds_sparse.data, preds_sparse.len, stream); From 1b8ac7de3d3ba1f19b1f915fd7856516342768f9 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 22 Dec 2022 21:19:08 +0100 Subject: [PATCH 2/2] Add `detail` to GLM functions --- cpp/src/glm/qn/glm_base.cuh | 3 ++- cpp/src/glm/qn/glm_linear.cuh | 3 ++- cpp/src/glm/qn/glm_logistic.cuh | 2 ++ cpp/src/glm/qn/glm_regularizer.cuh | 2 ++ cpp/src/glm/qn/glm_softmax.cuh | 3 ++- cpp/src/glm/qn/glm_svm.cuh | 3 ++- cpp/test/sg/quasi_newton.cu | 4 ++++ 7 files changed, 16 insertions(+), 4 deletions(-) diff --git a/cpp/src/glm/qn/glm_base.cuh b/cpp/src/glm/qn/glm_base.cuh index c9756bdfde..55a147031d 100644 --- a/cpp/src/glm/qn/glm_base.cuh +++ b/cpp/src/glm/qn/glm_base.cuh @@ -32,6 +32,7 @@ namespace ML { namespace GLM { +namespace detail { template inline void linearFwd(const raft::handle_t& handle, @@ -242,6 +243,6 @@ struct GLMWithData : GLMDims { return objective->gradNorm(grad, dev_scalar, stream); } }; - +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/glm_linear.cuh b/cpp/src/glm/qn/glm_linear.cuh index 1fb024abed..09de8eb6fe 100644 --- a/cpp/src/glm/qn/glm_linear.cuh +++ b/cpp/src/glm/qn/glm_linear.cuh @@ -23,6 +23,7 @@ namespace ML { namespace GLM { +namespace detail { template struct SquaredLoss : GLMBase> { @@ -76,6 +77,6 @@ struct AbsLoss : GLMBase> { return nrm1(grad, dev_scalar, stream); } }; - +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/glm_logistic.cuh b/cpp/src/glm/qn/glm_logistic.cuh index 267589507b..70edf11aca 100644 --- a/cpp/src/glm/qn/glm_logistic.cuh +++ b/cpp/src/glm/qn/glm_logistic.cuh @@ -23,6 +23,7 @@ namespace ML { namespace GLM { +namespace detail { template struct LogisticLoss : GLMBase> { @@ -63,5 +64,6 @@ struct LogisticLoss : GLMBase> { return nrmMax(grad, dev_scalar, stream); } }; +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/glm_regularizer.cuh b/cpp/src/glm/qn/glm_regularizer.cuh index 5385f8fa26..6f88ccc78d 100644 --- a/cpp/src/glm/qn/glm_regularizer.cuh +++ b/cpp/src/glm/qn/glm_regularizer.cuh @@ -25,6 +25,7 @@ namespace ML { namespace GLM { +namespace detail { template struct Tikhonov { @@ -91,5 +92,6 @@ struct RegularizedGLM : GLMDims { return loss->gradNorm(grad, dev_scalar, stream); } }; +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/glm_softmax.cuh b/cpp/src/glm/qn/glm_softmax.cuh index 066e330a64..762e62e0a9 100644 --- a/cpp/src/glm/qn/glm_softmax.cuh +++ b/cpp/src/glm/qn/glm_softmax.cuh @@ -23,6 +23,7 @@ namespace ML { namespace GLM { +namespace detail { using raft::ceildiv; using raft::myExp; using raft::myLog; @@ -194,6 +195,6 @@ struct Softmax : GLMBase> { return nrmMax(grad, dev_scalar, stream); } }; - +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/glm_svm.cuh b/cpp/src/glm/qn/glm_svm.cuh index 1d2bc3623e..8f81a42a74 100644 --- a/cpp/src/glm/qn/glm_svm.cuh +++ b/cpp/src/glm/qn/glm_svm.cuh @@ -23,6 +23,7 @@ namespace ML { namespace GLM { +namespace detail { template struct SVCL1Loss : GLMBase> { @@ -153,6 +154,6 @@ struct SVRL2Loss : GLMBase> { return squaredNorm(grad, dev_scalar, stream) * 0.5; } }; - +}; // namespace detail }; // namespace GLM }; // namespace ML diff --git a/cpp/test/sg/quasi_newton.cu b/cpp/test/sg/quasi_newton.cu index 4057616c85..f8939b0dbb 100644 --- a/cpp/test/sg/quasi_newton.cu +++ b/cpp/test/sg/quasi_newton.cu @@ -29,6 +29,10 @@ namespace ML { namespace GLM { +using detail::GLMDims; +using detail::LogisticLoss; +using detail::Softmax; +using detail::SquaredLoss; struct QuasiNewtonTest : ::testing::Test { static constexpr int N = 10; static constexpr int D = 2;