diff --git a/cpp/src/solver/cd.cuh b/cpp/src/solver/cd.cuh index 543688dbce..4b7f511e56 100644 --- a/cpp/src/solver/cd.cuh +++ b/cpp/src/solver/cd.cuh @@ -22,13 +22,14 @@ #include #include #include +#include #include #include #include -// #TODO: Replace with public header when ready -#include +#include #include #include +#include #include #include #include @@ -40,8 +41,54 @@ namespace Solver { using namespace MLCommon; +namespace { + +/** Epoch and iteration -related state. */ +template +struct ConvState { + math_t coef; + math_t coefMax; + math_t diffMax; +}; + /** - * Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver + * Update a single CD coefficient and the corresponding convergence criteria. + * + * @param[inout] coefLoc pointer to the coefficient (arr ptr + column index offset) + * @param[in] squaredLoc pointer to the precomputed data - L2 norm of input for across rows + * @param[inout] convStateLoc pointer to the structure holding the convergence state + * @param[in] l1_alpha L1 regularization coef + */ +template +__global__ void __launch_bounds__(1, 1) cdUpdateCoefKernel(math_t* coefLoc, + const math_t* squaredLoc, + ConvState* convStateLoc, + const math_t l1_alpha) +{ + auto coef = *coefLoc; + auto r = coef > l1_alpha ? coef - l1_alpha : (coef < -l1_alpha ? coef + l1_alpha : 0); + auto squared = *squaredLoc; + r = squared > math_t(1e-5) ? r / squared : math_t(0); + auto diff = raft::myAbs(convStateLoc->coef - r); + if (convStateLoc->diffMax < diff) convStateLoc->diffMax = diff; + auto absv = raft::myAbs(r); + if (convStateLoc->coefMax < absv) convStateLoc->coefMax = absv; + convStateLoc->coef = -r; + *coefLoc = r; +} + +} // namespace + +/** + * Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver. + * + * i.e. finds coefficients that minimize the following loss function: + * + * f(coef) = 1/2 * || labels - input * coef ||^2 + * + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2 + * + alpha * l1_ratio * ||coef||_1 + * + * * @param handle * Reference of raft::handle_t * @param input @@ -96,22 +143,18 @@ void cdFit(const raft::handle_t& handle, math_t tol, cudaStream_t stream) { + raft::common::nvtx::range fun_scope("ML::Solver::cdFit-%d-%d", n_rows, n_cols); ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); ASSERT(loss == ML::loss_funct::SQRD_LOSS, "Parameter loss: Only SQRT_LOSS function is supported for now"); - cublasHandle_t cublas_handle = handle.get_cublas_handle(); - - rmm::device_uvector pred(n_rows, stream); rmm::device_uvector residual(n_rows, stream); rmm::device_uvector squared(n_cols, stream); rmm::device_uvector mu_input(0, stream); rmm::device_uvector mu_labels(0, stream); rmm::device_uvector norm2_input(0, stream); - std::vector h_coef(n_cols, math_t(0)); - if (fit_intercept) { mu_input.resize(n_cols, stream); mu_labels.resize(1, stream); @@ -136,9 +179,11 @@ void cdFit(const raft::handle_t& handle, initShuffle(ri, g); math_t l2_alpha = (1 - l1_ratio) * alpha * n_rows; - alpha = l1_ratio * alpha * n_rows; + math_t l1_alpha = l1_ratio * alpha * n_rows; + // Precompute the residual if (normalize) { + // if we normalized the data during preprocessing, no need to compute the norm again. math_t scalar = math_t(1.0) + l2_alpha; raft::matrix::setValue(squared.data(), squared.data(), scalar, n_cols, stream); } else { @@ -149,57 +194,63 @@ void cdFit(const raft::handle_t& handle, raft::copy(residual.data(), labels, n_rows, stream); + ConvState h_convState; + rmm::device_uvector> convStateBuf(1, stream); + auto convStateLoc = convStateBuf.data(); + + rmm::device_scalar cublas_alpha(1.0, stream); + rmm::device_scalar cublas_beta(0.0, stream); + for (int i = 0; i < epochs; i++) { + raft::common::nvtx::range epoch_scope("ML::Solver::cdFit::epoch-%d", i); if (i > 0 && shuffle) { Solver::shuffle(ri, g); } - math_t coef_max = 0.0; - math_t d_coef_max = 0.0; - math_t coef_prev = 0.0; + RAFT_CUDA_TRY(cudaMemsetAsync(convStateLoc, 0, sizeof(ConvState), stream)); for (int j = 0; j < n_cols; j++) { + raft::common::nvtx::range iter_scope("ML::Solver::cdFit::col-%d", j); int ci = ri[j]; math_t* coef_loc = coef + ci; math_t* squared_loc = squared.data() + ci; math_t* input_col_loc = input + (ci * n_rows); - raft::linalg::multiplyScalar(pred.data(), input_col_loc, h_coef[ci], n_rows, stream); - raft::linalg::add(residual.data(), residual.data(), pred.data(), n_rows, stream); - raft::linalg::gemm(handle, - input_col_loc, - n_rows, - 1, - residual.data(), - coef_loc, - 1, - 1, - CUBLAS_OP_T, - CUBLAS_OP_N, - stream); - - if (l1_ratio > math_t(0.0)) Functions::softThres(coef_loc, coef_loc, alpha, 1, stream); - - raft::linalg::eltwiseDivideCheckZero(coef_loc, coef_loc, squared_loc, 1, stream); - - coef_prev = h_coef[ci]; - raft::update_host(&(h_coef[ci]), coef_loc, 1, stream); - handle.sync_stream(stream); - - math_t diff = abs(coef_prev - h_coef[ci]); - - if (diff > d_coef_max) d_coef_max = diff; - - if (abs(h_coef[ci]) > coef_max) coef_max = abs(h_coef[ci]); - - raft::linalg::multiplyScalar(pred.data(), input_col_loc, h_coef[ci], n_rows, stream); - raft::linalg::subtract(residual.data(), residual.data(), pred.data(), n_rows, stream); + // remember current coef + raft::copy(&(convStateLoc->coef), coef_loc, 1, stream); + // calculate the residual without the contribution from column ci + // residual[:] += coef[ci] * X[:, ci] + raft::linalg::axpy( + handle, n_rows, coef_loc, input_col_loc, 1, residual.data(), 1, stream); + + // coef[ci] = dot(X[:, ci], residual[:]) + raft::linalg::gemv(handle, + false, + 1, + n_rows, + cublas_alpha.data(), + input_col_loc, + 1, + residual.data(), + 1, + cublas_beta.data(), + coef_loc, + 1, + stream); + + // Calculate the new coefficient that minimizes f along coordinate line ci + // coef[ci] = SoftTreshold(dot(X[:, ci], residual[:]), l1_alpha) / dot(X[:, ci], X[:, ci])) + // Also, update the convergence criteria. + cdUpdateCoefKernel<<>>( + coef_loc, squared_loc, convStateLoc, l1_alpha); + RAFT_CUDA_TRY(cudaGetLastError()); + + // Restore the residual using the updated coeffecient + raft::linalg::axpy( + handle, n_rows, &(convStateLoc->coef), input_col_loc, 1, residual.data(), 1, stream); } + raft::update_host(&h_convState, convStateLoc, 1, stream); + handle.sync_stream(stream); - bool flag_continue = true; - if (coef_max == math_t(0)) { flag_continue = false; } - - if ((d_coef_max / coef_max) < tol) { flag_continue = false; } - - if (!flag_continue) { break; } + if (h_convState.coefMax < tol || (h_convState.diffMax / h_convState.coefMax) < tol) break; } if (fit_intercept) {