Skip to content

Commit

Permalink
Rewrite CD solver using more BLAS (#4446)
Browse files Browse the repository at this point in the history
Reduce the frequency of device-host data transfers and replace some operations with BLAS axpy/gemv routines.

This brings approximately 1.2x-3x speedup against the previous version (more speedup for smaller problem sizes).

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #4446
  • Loading branch information
achirkin authored Feb 9, 2022
1 parent 1dd32dc commit f3c1544
Showing 1 changed file with 99 additions and 48 deletions.
147 changes: 99 additions & 48 deletions cpp/src/solver/cd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
#include <functions/penalty.cuh>
#include <functions/softThres.cuh>
#include <glm/preprocess.cuh>
#include <raft/common/nvtx.hpp>
#include <raft/cuda_utils.cuh>
#include <raft/cudart_utils.h>
#include <raft/linalg/add.hpp>
// #TODO: Replace with public header when ready
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/axpy.hpp>
#include <raft/linalg/eltwise.hpp>
#include <raft/linalg/gemm.hpp>
#include <raft/linalg/gemv.hpp>
#include <raft/linalg/multiply.hpp>
#include <raft/linalg/subtract.hpp>
#include <raft/linalg/unary_op.hpp>
Expand All @@ -40,8 +41,54 @@ namespace Solver {

using namespace MLCommon;

namespace {

/** Epoch and iteration -related state. */
template <typename math_t>
struct ConvState {
math_t coef;
math_t coefMax;
math_t diffMax;
};

/**
* Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver
* Update a single CD coefficient and the corresponding convergence criteria.
*
* @param[inout] coefLoc pointer to the coefficient (arr ptr + column index offset)
* @param[in] squaredLoc pointer to the precomputed data - L2 norm of input for across rows
* @param[inout] convStateLoc pointer to the structure holding the convergence state
* @param[in] l1_alpha L1 regularization coef
*/
template <typename math_t>
__global__ void __launch_bounds__(1, 1) cdUpdateCoefKernel(math_t* coefLoc,
const math_t* squaredLoc,
ConvState<math_t>* convStateLoc,
const math_t l1_alpha)
{
auto coef = *coefLoc;
auto r = coef > l1_alpha ? coef - l1_alpha : (coef < -l1_alpha ? coef + l1_alpha : 0);
auto squared = *squaredLoc;
r = squared > math_t(1e-5) ? r / squared : math_t(0);
auto diff = raft::myAbs(convStateLoc->coef - r);
if (convStateLoc->diffMax < diff) convStateLoc->diffMax = diff;
auto absv = raft::myAbs(r);
if (convStateLoc->coefMax < absv) convStateLoc->coefMax = absv;
convStateLoc->coef = -r;
*coefLoc = r;
}

} // namespace

/**
* Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver.
*
* i.e. finds coefficients that minimize the following loss function:
*
* f(coef) = 1/2 * || labels - input * coef ||^2
* + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2
* + alpha * l1_ratio * ||coef||_1
*
*
* @param handle
* Reference of raft::handle_t
* @param input
Expand Down Expand Up @@ -96,22 +143,18 @@ void cdFit(const raft::handle_t& handle,
math_t tol,
cudaStream_t stream)
{
raft::common::nvtx::range fun_scope("ML::Solver::cdFit-%d-%d", n_rows, n_cols);
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");
ASSERT(loss == ML::loss_funct::SQRD_LOSS,
"Parameter loss: Only SQRT_LOSS function is supported for now");

cublasHandle_t cublas_handle = handle.get_cublas_handle();

rmm::device_uvector<math_t> pred(n_rows, stream);
rmm::device_uvector<math_t> residual(n_rows, stream);
rmm::device_uvector<math_t> squared(n_cols, stream);
rmm::device_uvector<math_t> mu_input(0, stream);
rmm::device_uvector<math_t> mu_labels(0, stream);
rmm::device_uvector<math_t> norm2_input(0, stream);

std::vector<math_t> h_coef(n_cols, math_t(0));

if (fit_intercept) {
mu_input.resize(n_cols, stream);
mu_labels.resize(1, stream);
Expand All @@ -136,9 +179,11 @@ void cdFit(const raft::handle_t& handle,
initShuffle(ri, g);

math_t l2_alpha = (1 - l1_ratio) * alpha * n_rows;
alpha = l1_ratio * alpha * n_rows;
math_t l1_alpha = l1_ratio * alpha * n_rows;

// Precompute the residual
if (normalize) {
// if we normalized the data during preprocessing, no need to compute the norm again.
math_t scalar = math_t(1.0) + l2_alpha;
raft::matrix::setValue(squared.data(), squared.data(), scalar, n_cols, stream);
} else {
Expand All @@ -149,57 +194,63 @@ void cdFit(const raft::handle_t& handle,

raft::copy(residual.data(), labels, n_rows, stream);

ConvState<math_t> h_convState;
rmm::device_uvector<ConvState<math_t>> convStateBuf(1, stream);
auto convStateLoc = convStateBuf.data();

rmm::device_scalar<math_t> cublas_alpha(1.0, stream);
rmm::device_scalar<math_t> cublas_beta(0.0, stream);

for (int i = 0; i < epochs; i++) {
raft::common::nvtx::range epoch_scope("ML::Solver::cdFit::epoch-%d", i);
if (i > 0 && shuffle) { Solver::shuffle(ri, g); }

math_t coef_max = 0.0;
math_t d_coef_max = 0.0;
math_t coef_prev = 0.0;
RAFT_CUDA_TRY(cudaMemsetAsync(convStateLoc, 0, sizeof(ConvState<math_t>), stream));

for (int j = 0; j < n_cols; j++) {
raft::common::nvtx::range iter_scope("ML::Solver::cdFit::col-%d", j);
int ci = ri[j];
math_t* coef_loc = coef + ci;
math_t* squared_loc = squared.data() + ci;
math_t* input_col_loc = input + (ci * n_rows);

raft::linalg::multiplyScalar(pred.data(), input_col_loc, h_coef[ci], n_rows, stream);
raft::linalg::add(residual.data(), residual.data(), pred.data(), n_rows, stream);
raft::linalg::gemm(handle,
input_col_loc,
n_rows,
1,
residual.data(),
coef_loc,
1,
1,
CUBLAS_OP_T,
CUBLAS_OP_N,
stream);

if (l1_ratio > math_t(0.0)) Functions::softThres(coef_loc, coef_loc, alpha, 1, stream);

raft::linalg::eltwiseDivideCheckZero(coef_loc, coef_loc, squared_loc, 1, stream);

coef_prev = h_coef[ci];
raft::update_host(&(h_coef[ci]), coef_loc, 1, stream);
handle.sync_stream(stream);

math_t diff = abs(coef_prev - h_coef[ci]);

if (diff > d_coef_max) d_coef_max = diff;

if (abs(h_coef[ci]) > coef_max) coef_max = abs(h_coef[ci]);

raft::linalg::multiplyScalar(pred.data(), input_col_loc, h_coef[ci], n_rows, stream);
raft::linalg::subtract(residual.data(), residual.data(), pred.data(), n_rows, stream);
// remember current coef
raft::copy(&(convStateLoc->coef), coef_loc, 1, stream);
// calculate the residual without the contribution from column ci
// residual[:] += coef[ci] * X[:, ci]
raft::linalg::axpy<math_t, true>(
handle, n_rows, coef_loc, input_col_loc, 1, residual.data(), 1, stream);

// coef[ci] = dot(X[:, ci], residual[:])
raft::linalg::gemv<math_t, true>(handle,
false,
1,
n_rows,
cublas_alpha.data(),
input_col_loc,
1,
residual.data(),
1,
cublas_beta.data(),
coef_loc,
1,
stream);

// Calculate the new coefficient that minimizes f along coordinate line ci
// coef[ci] = SoftTreshold(dot(X[:, ci], residual[:]), l1_alpha) / dot(X[:, ci], X[:, ci]))
// Also, update the convergence criteria.
cdUpdateCoefKernel<math_t><<<dim3(1, 1, 1), dim3(1, 1, 1), 0, stream>>>(
coef_loc, squared_loc, convStateLoc, l1_alpha);
RAFT_CUDA_TRY(cudaGetLastError());

// Restore the residual using the updated coeffecient
raft::linalg::axpy<math_t, true>(
handle, n_rows, &(convStateLoc->coef), input_col_loc, 1, residual.data(), 1, stream);
}
raft::update_host(&h_convState, convStateLoc, 1, stream);
handle.sync_stream(stream);

bool flag_continue = true;
if (coef_max == math_t(0)) { flag_continue = false; }

if ((d_coef_max / coef_max) < tol) { flag_continue = false; }

if (!flag_continue) { break; }
if (h_convState.coefMax < tol || (h_convState.diffMax / h_convState.coefMax) < tol) break;
}

if (fit_intercept) {
Expand Down

0 comments on commit f3c1544

Please sign in to comment.