-
Notifications
You must be signed in to change notification settings - Fork 550
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEA] Support no regularization in MNMG LogisticRegression (#5558)
Also adopted the code structure of the SG class to prepare for future PRs. This PR depends on and has included [PR 5567](#5567) Authors: - Jinfeng Li (https://github.com/lijinf2) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: #5558
- Loading branch information
Showing
7 changed files
with
184 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "glm_base_mg.cuh" | ||
#include <glm/qn/glm_logistic.cuh> | ||
#include <glm/qn/glm_regularizer.cuh> | ||
#include <glm/qn/glm_softmax.cuh> | ||
#include <glm/qn/glm_svm.cuh> | ||
#include <glm/qn/qn_solvers.cuh> | ||
#include <glm/qn/qn_util.cuh> | ||
|
||
#include <cuml/linear_model/qn.h> | ||
#include <rmm/device_uvector.hpp> | ||
|
||
namespace ML { | ||
namespace GLM { | ||
namespace opg { | ||
using namespace ML::GLM::detail; | ||
|
||
template <typename T, typename LossFunction> | ||
int qn_fit_mg(const raft::handle_t& handle, | ||
const qn_params& pams, | ||
LossFunction& loss, | ||
const SimpleMat<T>& X, | ||
const SimpleVec<T>& y, | ||
SimpleDenseMat<T>& Z, | ||
T* w0_data, // initial value and result | ||
T* fx, | ||
int* num_iters, | ||
size_t n_samples, | ||
int rank, | ||
int n_ranks) | ||
{ | ||
cudaStream_t stream = handle.get_stream(); | ||
LBFGSParam<T> opt_param(pams); | ||
SimpleVec<T> w0(w0_data, loss.n_param); | ||
|
||
// Scale the regularization strength with the number of samples. | ||
T l1 = 0; | ||
T l2 = pams.penalty_l2; | ||
if (pams.penalty_normalized) { l2 /= n_samples; } | ||
|
||
ML::GLM::detail::Tikhonov<T> reg(l2); | ||
ML::GLM::detail::RegularizedGLM<T, LossFunction, decltype(reg)> regularizer_obj(&loss, ®); | ||
|
||
auto obj_function = GLMWithDataMG(handle, rank, n_ranks, n_samples, ®ularizer_obj, X, y, Z); | ||
return ML::GLM::detail::qn_minimize( | ||
handle, w0, fx, num_iters, obj_function, l1, opt_param, pams.verbose); | ||
} | ||
|
||
template <typename T> | ||
inline void qn_fit_x_mg(const raft::handle_t& handle, | ||
const qn_params& pams, | ||
SimpleMat<T>& X, | ||
T* y_data, | ||
int C, | ||
T* w0_data, | ||
T* f, | ||
int* num_iters, | ||
int64_t n_samples, | ||
int rank, | ||
int n_ranks, | ||
T* sample_weight = nullptr, | ||
T svr_eps = 0) | ||
{ | ||
/* | ||
NB: | ||
N - number of data rows | ||
D - number of data columns (features) | ||
C - number of output classes | ||
X in R^[N, D] | ||
w in R^[D, C] | ||
y in {0, 1}^[N, C] or {cat}^N | ||
Dimensionality of w0 depends on loss, so we initialize it later. | ||
*/ | ||
cudaStream_t stream = handle.get_stream(); | ||
int N = X.m; | ||
int D = X.n; | ||
int n_targets = ML::GLM::detail::qn_is_classification(pams.loss) && C == 2 ? 1 : C; | ||
rmm::device_uvector<T> tmp(n_targets * N, stream); | ||
SimpleDenseMat<T> Z(tmp.data(), n_targets, N); | ||
SimpleVec<T> y(y_data, N); | ||
|
||
switch (pams.loss) { | ||
case QN_LOSS_LOGISTIC: { | ||
ASSERT(C == 2, "qn_mg.cuh: logistic loss invalid C"); | ||
ML::GLM::detail::LogisticLoss<T> loss(handle, D, pams.fit_intercept); | ||
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>( | ||
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks); | ||
} break; | ||
default: { | ||
ASSERT(false, "qn_mg.cuh: unknown loss function type (id = %d).", pams.loss); | ||
} | ||
} | ||
} | ||
|
||
}; // namespace opg | ||
}; // namespace GLM | ||
}; // namespace ML |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters