Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-node-multi-gpu Logistic Regression in C++ #5477

Merged
merged 12 commits into from
Jul 24, 2023
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ if(BUILD_CUML_CPP_LIBRARY)
src/glm/ols_mg.cu
src/glm/preprocess_mg.cu
src/glm/ridge_mg.cu
src/glm/qn_mg.cu
src/kmeans/kmeans_mg.cu
src/knn/knn_mg.cu
src/knn/knn_classify_mg.cu
Expand Down
56 changes: 56 additions & 0 deletions cpp/include/cuml/linear_model/qn_mg.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuda_runtime.h>
#include <cuml/common/logger.hpp>
#include <cuml/linear_model/qn.h>
#include <raft/core/comms.hpp>

#include <cumlprims/opg/matrix/data.hpp>
#include <cumlprims/opg/matrix/part_descriptor.hpp>
using namespace MLCommon;

namespace ML {
namespace GLM {
namespace opg {

/**
* @brief performs MNMG fit operation for the logistic regression using quasi newton methods
* @param[in] handle: the internal cuml handle object
* @param[in] input_data: vector holding all partitions for that rank
* @param[in] input_desc: PartDescriptor object for the input
* @param[in] labels: labels data
* @param[out] coef: learned coefficients
* @param[in] pams: model parameters
* @param[in] X_col_major: true if X is stored column-major
* @param[in] n_classes: number of outputs (number of classes or `1` for regression)
* @param[out] f: host pointer holding the final objective value
* @param[out] num_iters: host pointer holding the actual number of iterations taken
*/
void qnFit(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_data,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels,
float* coef,
const qn_params& pams,
bool X_col_major,
int n_classes,
float* f,
int* num_iters);

}; // namespace opg
}; // namespace GLM
}; // namespace ML
166 changes: 166 additions & 0 deletions cpp/src/glm/qn/glm_base_mg.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/core/comms.hpp>
lijinf2 marked this conversation as resolved.
Show resolved Hide resolved
#include <raft/core/handle.hpp>
#include <raft/linalg/multiply.cuh>
#include <raft/util/cudart_utils.hpp>

#include <glm/qn/glm_base.cuh>
#include <glm/qn/glm_logistic.cuh>
#include <glm/qn/glm_regularizer.cuh>
#include <glm/qn/qn_solvers.cuh>
#include <glm/qn/qn_util.cuh>

namespace ML {
namespace GLM {
namespace opg {
template <typename T>
// multi-gpu version of linearBwd
inline void linearBwdMG(const raft::handle_t& handle,
SimpleDenseMat<T>& G,
const SimpleMat<T>& X,
const SimpleDenseMat<T>& dZ,
bool setZero,
const int64_t n_samples,
const int n_ranks)
{
cudaStream_t stream = handle.get_stream();
// Backward pass:
// - compute G <- dZ * X.T
// - for bias: Gb = mean(dZ, 1)

const bool has_bias = X.n != G.n;
const int D = X.n;
const T beta = setZero ? T(0) : T(1);

if (has_bias) {
SimpleVec<T> Gbias;
SimpleDenseMat<T> Gweights;

col_ref(G, Gbias, D);

col_slice(G, Gweights, 0, D);

// TODO can this be fused somehow?
Gweights.assign_gemm(handle, 1.0 / n_samples, dZ, false, X, false, beta / n_ranks, stream);

raft::stats::mean(Gbias.data, dZ.data, dZ.m, dZ.n, false, true, stream);
T bias_factor = 1.0 * dZ.n / n_samples;
raft::linalg::multiplyScalar(Gbias.data, Gbias.data, bias_factor, dZ.m, stream);

} else {
CUML_LOG_DEBUG("has bias not enabled");
G.assign_gemm(handle, 1.0 / n_samples, dZ, false, X, false, beta / n_ranks, stream);
}
}

/**
* @brief Aggregates local gradient vectors and loss values from local training data. This
* class is the multi-node-multi-gpu version of GLMWithData.
*
* The implementation overrides existing GLMWithData::() function. The purpose is to
* aggregate local gradient vectors and loss values from distributed X, y, where X represents the
* input vectors and y represents labels.
*
* GLMWithData::() currently invokes three functions: linearFwd, getLossAndDz and linearBwd.
* linearFwd multiplies local input vectors with the coefficient vector (i.e. coef_), so does not
* require communication. getLossAndDz calculates local loss so requires allreduce to obtain a
* global loss. linearBwd calculates local gradient vector so requires allreduce to obtain a
* global gradient vector. The global loss and the global gradient vector will be used in
* min_lbfgs to update coefficient. The update runs individually on every GPU and when finished,
* all GPUs have the same value of coefficient.
*/
template <typename T, class GLMObjective>
struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a doxygen here just to help future developers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added, please check!

const raft::handle_t* handle_p;
int rank;
int64_t n_samples;
int n_ranks;

GLMWithDataMG(raft::handle_t const& handle,
int rank,
int n_ranks,
int64_t n_samples,
GLMObjective* obj,
const SimpleMat<T>& X,
const SimpleVec<T>& y,
SimpleDenseMat<T>& Z)
: ML::GLM::detail::GLMWithData<T, GLMObjective>(obj, X, y, Z)
{
this->handle_p = &handle;
this->rank = rank;
this->n_ranks = n_ranks;
this->n_samples = n_samples;
}

inline T operator()(const SimpleVec<T>& wFlat,
SimpleVec<T>& gradFlat,
T* dev_scalar,
cudaStream_t stream)
{
SimpleDenseMat<T> W(wFlat.data, this->C, this->dims);
SimpleDenseMat<T> G(gradFlat.data, this->C, this->dims);
SimpleVec<T> lossVal(dev_scalar, 1);

// apply regularization
auto regularizer_obj = this->objective;
auto lossFunc = regularizer_obj->loss;
auto reg = regularizer_obj->reg;
G.fill(0, stream);
reg->reg_grad(dev_scalar, G, W, lossFunc->fit_intercept, stream);
float reg_host;
raft::update_host(&reg_host, dev_scalar, 1, stream);
// note: avoid syncing here because there's a sync before reg_host is used.

// apply linearFwd, getLossAndDz, linearBwd
ML::GLM::detail::linearFwd(
lossFunc->handle, *(this->Z), *(this->X), W); // linear part: forward pass

raft::comms::comms_t const& communicator = raft::resource::get_comms(*(this->handle_p));

lossFunc->getLossAndDZ(dev_scalar, *(this->Z), *(this->y), stream); // loss specific part

// normalize local loss before allreduce sum
T factor = 1.0 * (*this->y).len / this->n_samples;
raft::linalg::multiplyScalar(dev_scalar, dev_scalar, factor, 1, stream);

communicator.allreduce(dev_scalar, dev_scalar, 1, raft::comms::op_t::SUM, stream);
communicator.sync_stream(stream);

linearBwdMG(lossFunc->handle,
G,
*(this->X),
*(this->Z),
false,
n_samples,
n_ranks); // linear part: backward pass

communicator.allreduce(G.data, G.data, this->C * this->dims, raft::comms::op_t::SUM, stream);
communicator.sync_stream(stream);

float loss_host;
raft::update_host(&loss_host, dev_scalar, 1, stream);
raft::resource::sync_stream(*(this->handle_p));
loss_host += reg_host;
lossVal.fill(loss_host + reg_host, stream);

return loss_host;
}
};
}; // namespace opg
}; // namespace GLM
}; // namespace ML
157 changes: 157 additions & 0 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "qn/glm_logistic.cuh"
#include "qn/glm_regularizer.cuh"
#include "qn/qn_util.cuh"
#include "qn/simple_mat/dense.hpp"
#include <cuml/common/logger.hpp>
#include <cuml/linear_model/qn.h>
#include <cuml/linear_model/qn_mg.hpp>
#include <raft/core/comms.hpp>
#include <raft/core/error.hpp>
#include <raft/core/handle.hpp>
#include <raft/util/cudart_utils.hpp>
using namespace MLCommon;

#include "qn/glm_base_mg.cuh"

#include <cuda_runtime.h>

namespace ML {
namespace GLM {
namespace opg {

template <typename T>
void qnFit_impl(const raft::handle_t& handle,
const qn_params& pams,
T* X,
bool X_col_major,
T* y,
size_t N,
size_t D,
size_t C,
T* w0,
T* f,
int* num_iters,
size_t n_samples,
int rank,
int n_ranks)
{
switch (pams.loss) {
case QN_LOSS_LOGISTIC: {
RAFT_EXPECTS(
C == 2,
"qn_mg.cu: only the LOGISTIC loss is supported currently. The number of classes must be 2");
} break;
default: {
RAFT_EXPECTS(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss);
}
}

cudaStream_t stream = raft::resource::get_cuda_stream(handle);
auto X_simple = SimpleDenseMat<T>(X, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR);
auto y_simple = SimpleVec<T>(y, N);
SimpleVec<T> coef_simple(w0, D + pams.fit_intercept);

ML::GLM::detail::LBFGSParam<T> opt_param(pams);

// prepare regularizer regularizer_obj
ML::GLM::detail::LogisticLoss<T> loss_func(handle, D, pams.fit_intercept);
T l2 = pams.penalty_l2;
lijinf2 marked this conversation as resolved.
Show resolved Hide resolved
if (pams.penalty_normalized) {
l2 /= n_samples; // l2 /= 1/X.m
}
ML::GLM::detail::Tikhonov<T> reg(l2);
ML::GLM::detail::RegularizedGLM<T, ML::GLM::detail::LogisticLoss<T>, decltype(reg)>
regularizer_obj(&loss_func, &reg);

// prepare GLMWithDataMG
int n_targets = C == 2 ? 1 : C;
rmm::device_uvector<T> tmp(n_targets * N, stream);
SimpleDenseMat<T> Z(tmp.data(), n_targets, N);
auto obj_function =
GLMWithDataMG(handle, rank, n_ranks, n_samples, &regularizer_obj, X_simple, y_simple, Z);

// prepare temporary variables fx, k, workspace
float fx = -1;
int k = -1;
rmm::device_uvector<float> tmp_workspace(lbfgs_workspace_size(opt_param, coef_simple.len),
stream);
SimpleVec<float> workspace(tmp_workspace.data(), tmp_workspace.size());

// call min_lbfgs
min_lbfgs(opt_param, obj_function, coef_simple, fx, &k, workspace, stream, 5);
}

template <typename T>
void qnFit_impl(raft::handle_t& handle,
std::vector<Matrix::Data<T>*>& input_data,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<T>*>& labels,
T* coef,
const qn_params& pams,
bool X_col_major,
int n_classes,
T* f,
int* num_iters)
{
RAFT_EXPECTS(input_data.size() == 1,
"qn_mg.cu currently does not accept more than one input matrix");
RAFT_EXPECTS(labels.size() == input_data.size(), "labels size does not equal to input_data size");

auto data_X = input_data[0];
auto data_y = labels[0];

size_t n_samples = 0;
for (auto p : input_desc.partsToRanks) {
n_samples += p->size;
}

qnFit_impl<T>(handle,
pams,
data_X->ptr,
X_col_major,
data_y->ptr,
input_desc.totalElementsOwnedBy(input_desc.rank),
input_desc.N,
n_classes,
coef,
f,
num_iters,
input_desc.M,
input_desc.rank,
input_desc.uniqueRanks().size());
}

void qnFit(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_data,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels,
float* coef,
const qn_params& pams,
bool X_col_major,
int n_classes,
float* f,
int* num_iters)
{
qnFit_impl<float>(
handle, input_data, input_desc, labels, coef, pams, X_col_major, n_classes, f, num_iters);
}

}; // namespace opg
}; // namespace GLM
}; // namespace ML
Loading