Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Jun 26, 2023
1 parent b60b3e4 commit 5a9cb6c
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 107 deletions.
46 changes: 1 addition & 45 deletions cpp/include/cuml/linear_model/qn_mg.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
#include <raft/core/comms.hpp>
// #include <raft/core/handle.hpp>
// #include <raft/core/device_mdarray.hpp>
// #include <raft/util/cudart_utils.hpp>
// #include <raft/comms/std_comms.hpp>
#include <cuml/common/logger.hpp>
#include <cuml/linear_model/qn.h> // to use qn_params

#include <cuml/linear_model/qn.h>
#include <cuda_runtime.h>

// #include <vector>
#include <iostream>

namespace ML {
namespace GLM {
namespace opg {
Expand All @@ -35,42 +27,6 @@ void qnFit(const raft::handle_t &handle,
int rank,
int n_ranks);

/**
* @brief Fit a GLM using quasi newton methods.
*
* @param cuml_handle reference to raft::handle_t object
* @param params model parameters
* @param X device pointer to a contiguous feature matrix of dimension [N, D]
* @param X_col_major true if X is stored column-major
* @param y device pointer to label vector of length N
* @param N number of examples
* @param D number of features
* @param C number of outputs (number of classes or `1` for regression)
* @param w0 device pointer of size (D + (fit_intercept ? 1 : 0)) * C with initial point,
* overwritten by final result.
* @param f host pointer holding the final objective value
* @param num_iters host pointer holding the actual number of iterations taken
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr
for uniform weights)
* @param svr_eps epsilon parameter for svr
*/
/*
template <typename T, typename I = int>
void qnFit(const raft::handle_t& cuml_handle,
const qn_params& params,
float* X,
bool X_col_major,
float* y,
int N,
int D,
int C,
float* w0,
float* f,
int* num_iters,
float* sample_weight = nullptr,
T svr_eps = 0);
*/

}; // namespace opg
}; // namespace GLM
}; // namespace ML
Expand Down
17 changes: 0 additions & 17 deletions cpp/src/glm/qn/glm_base_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <raft/core/handle.hpp>
#include <raft/linalg/multiply.cuh>

// the following are for PCA
#include "glm/qn/glm_base.cuh"
#include "glm/qn/glm_logistic.cuh"
#include "glm/qn/qn_util.cuh"
Expand Down Expand Up @@ -47,8 +46,6 @@ inline void linearBwdMG(const raft::handle_t& handle,
const SimpleDenseMat<T>* X_simple_p = (const SimpleDenseMat<T>*)(&X);
Gweights.assign_gemm(handle, 1.0 / n_samples, dZ, false, X, false, beta / n_ranks, stream);

// Stats::opg::mean(handle, mu_data, input_data, input_desc, streams, n_streams);

raft::stats::mean(Gbias.data, dZ.data, dZ.m, dZ.n, false, true, stream);
T bias_factor = 1.0 * dZ.n / n_samples;
raft::linalg::multiplyScalar(Gbias.data, Gbias.data, bias_factor, dZ.m, stream);
Expand Down Expand Up @@ -95,21 +92,8 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
raft::update_host(&reg_host, dev_scalar, 1, stream);

// apply linearFwd, getLossAndDz, linearBwd
/*
inline void loss_grad(T* loss_val,
Mat& G,
const Mat& W,
const SimpleMat<T>& Xb,
const Vec& yb,
Mat& Zb,
cudaStream_t stream,
bool initGradZero = true)
*/
// call linearFwd, linearBwd, getLossAndDz
ML::GLM::detail::linearFwd(lossFunc->handle, *(this->Z), *(this->X), W); // linear part: forward pass

//raft::interruptible::synchronize(stream);
// raft::comms::comms_t const& communicator = raft::resource::get_comms(handle);
raft::comms::comms_t const& communicator = raft::resource::get_comms(*(this->handle_p));

lossFunc->getLossAndDZ(dev_scalar, *(this->Z), *(this->y), stream); // loss specific part
Expand All @@ -126,7 +110,6 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
raft::resource::sync_stream(*(this->handle_p));
raft::interruptible::synchronize(stream);

//linearBwd(lossFunc->handle, G, *(this->X), *(this->Z), false); // linear part: backward pass
linearBwdMG(lossFunc->handle, G, *(this->X), *(this->Z), false, n_samples, n_ranks); // linear part: backward pass
raft::interruptible::synchronize(stream);

Expand Down
46 changes: 1 addition & 45 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
#include <raft/core/comms.hpp>
#include <raft/util/cudart_utils.hpp>
#include <raft/core/handle.hpp>

#include <cuml/common/logger.hpp>

//#include "glm/qn/glm_base.cuh"
//#include "glm/qn/glm_logistic.cuh"
//#include "glm/qn/qn_util.cuh"
//#include "glm/qn/qn_solvers.cuh"
//#include "glm/qn/glm_regularizer.cuh"

#include <cuml/linear_model/qn.h> // to use qn_params
#include <cuml/linear_model/qn.h>
#include "qn/simple_mat/dense.hpp"
#include "qn/qn_util.cuh"
#include "qn/glm_logistic.cuh"
Expand All @@ -19,8 +11,6 @@
#include "qn/glm_base_mg.cuh"

#include <cuda_runtime.h>
#include <iostream>


namespace ML {
namespace GLM {
Expand Down Expand Up @@ -127,37 +117,3 @@ void qnFit(const raft::handle_t &handle,
}; // namespace OPG
}; // namespace GLM
}; // namespace ML

// #include <raft/core/device_mdarray.hpp>
//#include <iostream>
//#include <vector>

//#include <cuml/linear_model/qn_mg.hpp>

/*
namespace ML {
namespace GLM {
}; // namespace GLM
}; // namespace ML
*/

/*
template <typename T>
void qnFit(const raft::handle_t& handle,
const qn_params& pams,
T* X_data,
bool X_col_major,
T* y_data,
int N,
int D,
int C,
T* w0_data,
T* f,
int* num_iters,
T* sample_weight = nullptr,
T svr_eps = 0)
{
}
*/
1 change: 1 addition & 0 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class LogisticRegressionMG(LogisticRegression):
coef_size, dtype=self.dtype, order='C')

def fit(self, X, y, rank, n_ranks, n_samples, n_classes, convert_dtype=False) -> "LogisticRegressionMG":
assert (n_classes == 2):

cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()

Expand Down

0 comments on commit 5a9cb6c

Please sign in to comment.