Skip to content

Commit

Permalink
support mean var calculation in chunks to avoid precision loss of add…
Browse files Browse the repository at this point in the history
…ing one to a large number
  • Loading branch information
lijinf2 committed Apr 2, 2024
1 parent 5069f4c commit a44edb9
Showing 1 changed file with 78 additions and 25 deletions.
103 changes: 78 additions & 25 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/linalg/divide.cuh>
#include <raft/linalg/multiply.cuh>
#include <raft/linalg/sqrt.cuh>
#include <raft/linalg/subtract.cuh>
#include <raft/matrix/math.hpp>
#include <raft/sparse/op/row_op.cuh>
#include <raft/stats/stddev.cuh>
Expand Down Expand Up @@ -80,60 +81,112 @@ void mean_stddev(const raft::handle_t& handle,
}

template <typename T>
void mean_stddev(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
size_t n_samples,
T* mean_vector,
T* stddev_vector)
SimpleSparseMat<T> get_sub_mat(const raft::handle_t& handle,
SimpleSparseMat<T> mat,
int start,
int end,
rmm::device_uvector<int>& buff_row_ids)
{
end = end <= mat.m ? end : mat.m;
int n_rows = end - start;
int n_cols = mat.n;
auto stream = handle.get_stream();

RAFT_EXPECTS(start < end, "start index must be smaller than end index");
RAFT_EXPECTS(buff_row_ids.size() >= n_rows + 1,
"the size of buff_row_ids should be at least end - start + 1");
raft::copy(buff_row_ids.data(), mat.row_ids + start, n_rows + 1, stream);

int idx;
raft::copy(&idx, buff_row_ids.data(), 1, stream);
raft::resource::sync_stream(handle);

auto subtract_op = [idx] __device__(const int a) { return a - idx; };
raft::linalg::unaryOp(buff_row_ids.data(), buff_row_ids.data(), n_rows + 1, subtract_op, stream);

int nnz;
raft::copy(&nnz, buff_row_ids.data() + n_rows, 1, stream);
raft::resource::sync_stream(handle);

SimpleSparseMat<T> res(
mat.values + idx, mat.cols + idx, buff_row_ids.data(), nnz, n_rows, n_cols);
return res;
}

template <typename T>
void mean(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
size_t n_samples,
T* mean_vector)
{
int D = X.n;
int num_rows = X.m;
auto stream = handle.get_stream();
auto& comm = handle.get_comms();
SimpleDenseMat<T> mean_mat(mean_vector, 1, D);

// calculate mean
rmm::device_uvector<T> ones(num_rows, stream);
SimpleVec<T> ones_vec(ones.data(), num_rows);
int chunk_size = 500000; // split matrix by rows for better numeric precision
rmm::device_uvector<int> buff_row_ids(chunk_size + 1, stream);

rmm::device_uvector<T> ones(chunk_size, stream);
SimpleVec<T> ones_vec(ones.data(), chunk_size);
ones_vec.fill(1.0, stream);

SimpleDenseMat<T> ones_mat(ones.data(), 1, num_rows);
X.gemmb(handle, 1., ones_mat, false, false, 0., mean_mat, stream);
rmm::device_uvector<T> buff_D(D, stream);
SimpleDenseMat<T> buff_D_mat(buff_D.data(), 1, D);

// calculate mean
SimpleDenseMat<T> mean_mat(mean_vector, 1, D);
mean_mat.fill(0., stream);

for (int i = 0; i < X.m; i += chunk_size) {
// get X[i:i + chunk_size]
SimpleSparseMat<T> X_sub = get_sub_mat(handle, X, i, i + chunk_size, buff_row_ids);
SimpleDenseMat<T> ones_mat(ones.data(), 1, X_sub.m);

X_sub.gemmb(handle, 1., ones_mat, false, false, 0., buff_D_mat, stream);
raft::linalg::binaryOp(mean_vector, mean_vector, buff_D_mat.data, D, raft::add_op(), stream);
}

T weight = T(1) / T(n_samples);
raft::linalg::multiplyScalar(mean_vector, mean_vector, weight, D, stream);
comm.allreduce(mean_vector, mean_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);
}

// calculate stdev.S
SimpleDenseMat<T> stddev_mat(stddev_vector, 1, D);
template <typename T>
void mean_stddev(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
size_t n_samples,
T* mean_vector,
T* stddev_vector)
{
auto stream = handle.get_stream();
int D = X.n;
mean(handle, X, n_samples, mean_vector);

// calculate stdev.S
rmm::device_uvector<T> X_values_squared(X.nnz, stream);
raft::copy(X_values_squared.data(), X.values, X.nnz, stream);
auto square_op = [] __device__(const T a) { return a * a; };
raft::linalg::unaryOp(X_values_squared.data(), X_values_squared.data(), X.nnz, square_op, stream);

auto X_squared =
SimpleSparseMat<T>(X_values_squared.data(), X.cols, X.row_ids, X.nnz, num_rows, D);
X_squared.gemmb(handle, T(1.), ones_mat, false, false, T(0.), stddev_mat, stream);
auto X_squared = SimpleSparseMat<T>(X_values_squared.data(), X.cols, X.row_ids, X.nnz, X.m, X.n);

weight = n_samples < 1 ? T(0) : T(1) / T(n_samples - 1);
raft::linalg::multiplyScalar(stddev_vector, stddev_vector, weight, D, stream);
comm.allreduce(stddev_vector, stddev_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);
mean(handle, X_squared, n_samples, stddev_vector);

weight = n_samples * weight;
T weight = n_samples / T(n_samples - 1);
auto submean_no_neg_op = [weight] __device__(const T a, const T b) -> T {
T res = a - b * b * weight;
T res = weight * (a - b * b);
if (res < 0) {
// return sum(x^2) / (n - 1) if negative variance (due to precision loss of floating point
// arithmetic)
res = a;
res = weight * a;
}
return res;
};
raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, D, submean_no_neg_op, stream);
raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, X.n, submean_no_neg_op, stream);

raft::linalg::sqrt(stddev_vector, stddev_vector, D, handle.get_stream());
raft::linalg::sqrt(stddev_vector, stddev_vector, X.n, handle.get_stream());
}

struct inverse_op {
Expand Down

0 comments on commit a44edb9

Please sign in to comment.