Skip to content

Commit

Permalink
revise per comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Apr 2, 2024
1 parent 3818df5 commit 5069f4c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
12 changes: 3 additions & 9 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ void mean_stddev(const raft::handle_t& handle,
T weight = T(1) / T(n_samples);
raft::linalg::multiplyScalar(mean_vector, mean_vector, weight, D, stream);
comm.allreduce(mean_vector, mean_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);

// calculate stdev.S
SimpleDenseMat<T> stddev_mat(stddev_vector, 1, D);
Expand Down Expand Up @@ -135,12 +134,6 @@ void mean_stddev(const raft::handle_t& handle,
raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, D, submean_no_neg_op, stream);

raft::linalg::sqrt(stddev_vector, stddev_vector, D, handle.get_stream());

ML::Logger::get().setLevel(6);
auto log_mean = raft::arr2Str(mean_vector, D, "", stream);
CUML_LOG_DEBUG("log_mean: %s", log_mean.c_str());
auto log_stddev = raft::arr2Str(stddev_vector, D, "", stream);
CUML_LOG_DEBUG("log_stddev: %s", log_stddev.c_str());
}

struct inverse_op {
Expand All @@ -164,7 +157,7 @@ struct Standardizer {
rmm::device_uvector<T>& mean_std_buff)
{
int D = X.n;
ASSERT(mean_std_buff.size() == 4 * D, "buff size must be four times the dimension");
ASSERT(mean_std_buff.size() == 4 * D, "mean_std_buff size must be four times the dimension");

auto stream = handle.get_stream();

Expand All @@ -188,7 +181,8 @@ struct Standardizer {
size_t vec_size)
{
int D = X.n;
ASSERT(mean_std_buff.size() == 4 * vec_size, "buff size must be four times the aligned size");
ASSERT(mean_std_buff.size() == 4 * vec_size,
"mean_std_buff size must be four times the aligned size");

auto stream = handle.get_stream();

Expand Down
20 changes: 10 additions & 10 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ void qnFit_impl(const raft::handle_t& handle,
auto X_simple = SimpleDenseMat<T>(X, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR);

rmm::device_uvector<T> mean_std_buff(4 * D, handle.get_stream());
Standardizer<T>* stder = NULL;
if (standardization) stder = new Standardizer(handle, X_simple, n_samples, mean_std_buff);
Standardizer<T>* std_obj = NULL;
if (standardization) std_obj = new Standardizer(handle, X_simple, n_samples, mean_std_buff);

ML::GLM::opg::qn_fit_x_mg(handle,
pams,
Expand All @@ -128,12 +128,12 @@ void qnFit_impl(const raft::handle_t& handle,
n_samples,
rank,
n_ranks,
stder); // ignore sample_weight, svr_eps
std_obj); // ignore sample_weight, svr_eps

if (standardization) {
int n_targets = ML::GLM::detail::qn_is_classification(pams.loss) && C == 2 ? 1 : C;
stder->adapt_model_for_linearFwd(handle, w0, n_targets, D, pams.fit_intercept);
delete stder;
std_obj->adapt_model_for_linearFwd(handle, w0, n_targets, D, pams.fit_intercept);
delete std_obj;
}

return;
Expand Down Expand Up @@ -242,10 +242,10 @@ void qnFitSparse_impl(const raft::handle_t& handle,

size_t vec_size = raft::alignTo<size_t>(sizeof(T) * D, ML::GLM::detail::qn_align);
rmm::device_uvector<T> mean_std_buff(4 * vec_size, handle.get_stream());
Standardizer<T>* stder = NULL;
Standardizer<T>* std_obj = NULL;

if (standardization)
stder = new Standardizer(handle, X_simple, n_samples, mean_std_buff, vec_size);
std_obj = new Standardizer(handle, X_simple, n_samples, mean_std_buff, vec_size);

ML::GLM::opg::qn_fit_x_mg(handle,
pams,
Expand All @@ -258,12 +258,12 @@ void qnFitSparse_impl(const raft::handle_t& handle,
n_samples,
rank,
n_ranks,
stder); // ignore sample_weight, svr_eps
std_obj); // ignore sample_weight, svr_eps

if (standardization) {
int n_targets = ML::GLM::detail::qn_is_classification(pams.loss) && C == 2 ? 1 : C;
stder->adapt_model_for_linearFwd(handle, w0, n_targets, D, pams.fit_intercept);
delete stder;
std_obj->adapt_model_for_linearFwd(handle, w0, n_targets, D, pams.fit_intercept);
delete std_obj;
}

return;
Expand Down
12 changes: 7 additions & 5 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,9 +998,9 @@ def make_classification_with_nnz(
computed_csr = X_da.compute()
assert isinstance(computed_csr, csr_matrix)
assert computed_csr.nnz == nnz and computed_csr.shape == (n_rows, n_cols)
assert array_equal(computed_csr.data, X.data)
assert array_equal(computed_csr.indices, X.indices)
assert array_equal(computed_csr.indptr, X.indptr)
assert array_equal(computed_csr.data, X.data, unit_tol=tolerance)
assert array_equal(computed_csr.indices, X.indices, unit_tol=tolerance)
assert array_equal(computed_csr.indptr, X.indptr, unit_tol=tolerance)

lr_on = cumlLBFGS_dask(standardization=True, verbose=True, **est_params)
lr_on.fit(X_da, y_da)
Expand All @@ -1018,5 +1018,7 @@ def make_classification_with_nnz(
sg = SG(**est_params)
sg.fit(X_scaled, y)

assert array_equal(lron_coef_origin, sg.coef_, tolerance)
assert array_equal(lron_intercept_origin, sg.intercept_, tolerance)
assert array_equal(lron_coef_origin, sg.coef_, unit_tol=tolerance)
assert array_equal(
lron_intercept_origin, sg.intercept_, unit_tol=tolerance
)

0 comments on commit 5069f4c

Please sign in to comment.