Skip to content

Commit

Permalink
Support int64 index type in MG sparse LogisticRegression (#5962)
Browse files Browse the repository at this point in the history
Authors:
  - Jinfeng Li (https://github.com/lijinf2)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5962
  • Loading branch information
lijinf2 authored Jul 28, 2024
1 parent 3320895 commit a8fda19
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 123 deletions.
8 changes: 4 additions & 4 deletions cpp/include/cuml/linear_model/qn_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ void qnFit(raft::handle_t& handle,
* @param[out] f: host pointer holding the final objective value
* @param[out] num_iters: host pointer holding the actual number of iterations taken
*/
template <typename T>
template <typename T, typename I>
void qnFitSparse(raft::handle_t& handle,
std::vector<Matrix::Data<T>*>& input_values,
int* input_cols,
int* input_row_ids,
int X_nnz,
I* input_cols,
I* input_row_ids,
I X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<T>*>& labels,
T* coef,
Expand Down
38 changes: 20 additions & 18 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ void mean_stddev(const raft::handle_t& handle,
raft::linalg::sqrt(stddev_vector, stddev_vector, D, handle.get_stream());
}

template <typename T>
SimpleSparseMat<T> get_sub_mat(const raft::handle_t& handle,
SimpleSparseMat<T> mat,
int start,
int end,
rmm::device_uvector<int>& buff_row_ids)
template <typename T, typename I = int>
SimpleSparseMat<T, I> get_sub_mat(const raft::handle_t& handle,
SimpleSparseMat<T, I> mat,
int start,
int end,
rmm::device_uvector<I>& buff_row_ids)
{
end = end <= mat.m ? end : mat.m;
int n_rows = end - start;
Expand All @@ -97,25 +97,25 @@ SimpleSparseMat<T> get_sub_mat(const raft::handle_t& handle,
"the size of buff_row_ids should be at least end - start + 1");
raft::copy(buff_row_ids.data(), mat.row_ids + start, n_rows + 1, stream);

int idx;
I idx;
raft::copy(&idx, buff_row_ids.data(), 1, stream);
raft::resource::sync_stream(handle);

auto subtract_op = [idx] __device__(const int a) { return a - idx; };
auto subtract_op = [idx] __device__(const I a) { return a - idx; };
raft::linalg::unaryOp(buff_row_ids.data(), buff_row_ids.data(), n_rows + 1, subtract_op, stream);

int nnz;
I nnz;
raft::copy(&nnz, buff_row_ids.data() + n_rows, 1, stream);
raft::resource::sync_stream(handle);

SimpleSparseMat<T> res(
SimpleSparseMat<T, I> res(
mat.values + idx, mat.cols + idx, buff_row_ids.data(), nnz, n_rows, n_cols);
return res;
}

template <typename T>
template <typename T, typename I = int>
void mean(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
const SimpleSparseMat<T, I>& X,
size_t n_samples,
T* mean_vector)
{
Expand All @@ -125,7 +125,7 @@ void mean(const raft::handle_t& handle,
auto& comm = handle.get_comms();

int chunk_size = 500000; // split matrix by rows for better numeric precision
rmm::device_uvector<int> buff_row_ids(chunk_size + 1, stream);
rmm::device_uvector<I> buff_row_ids(chunk_size + 1, stream);

rmm::device_uvector<T> ones(chunk_size, stream);
SimpleVec<T> ones_vec(ones.data(), chunk_size);
Expand All @@ -140,7 +140,7 @@ void mean(const raft::handle_t& handle,

for (int i = 0; i < X.m; i += chunk_size) {
// get X[i:i + chunk_size]
SimpleSparseMat<T> X_sub = get_sub_mat(handle, X, i, i + chunk_size, buff_row_ids);
SimpleSparseMat<T, I> X_sub = get_sub_mat(handle, X, i, i + chunk_size, buff_row_ids);
SimpleDenseMat<T> ones_mat(ones.data(), 1, X_sub.m);

X_sub.gemmb(handle, 1., ones_mat, false, false, 0., buff_D_mat, stream);
Expand All @@ -153,9 +153,9 @@ void mean(const raft::handle_t& handle,
comm.sync_stream(stream);
}

template <typename T>
template <typename T, typename I = int>
void mean_stddev(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
const SimpleSparseMat<T, I>& X,
size_t n_samples,
T* mean_vector,
T* stddev_vector)
Expand All @@ -170,7 +170,8 @@ void mean_stddev(const raft::handle_t& handle,
auto square_op = [] __device__(const T a) { return a * a; };
raft::linalg::unaryOp(X_values_squared.data(), X_values_squared.data(), X.nnz, square_op, stream);

auto X_squared = SimpleSparseMat<T>(X_values_squared.data(), X.cols, X.row_ids, X.nnz, X.m, X.n);
auto X_squared =
SimpleSparseMat<T, I>(X_values_squared.data(), X.cols, X.row_ids, X.nnz, X.m, X.n);

mean(handle, X_squared, n_samples, stddev_vector);

Expand Down Expand Up @@ -227,8 +228,9 @@ struct Standardizer {
raft::linalg::binaryOp(scaled_mean.data, std_inv.data, mean.data, D, raft::mul_op(), stream);
}

template <typename I = int>
Standardizer(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
const SimpleSparseMat<T, I>& X,
size_t n_samples,
rmm::device_uvector<T>& mean_std_buff,
size_t vec_size)
Expand Down
32 changes: 16 additions & 16 deletions cpp/src/glm/qn/simple_mat/sparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,25 @@ namespace ML {
*
* However, when the data comes from the outside, we cannot guarantee that.
*/
template <typename T>
template <typename T, typename I = int>
struct SimpleSparseMat : SimpleMat<T> {
typedef SimpleMat<T> Super;
T* values;
int* cols;
int* row_ids;
int nnz;
I* cols;
I* row_ids;
I nnz;

SimpleSparseMat() : Super(0, 0), values(nullptr), cols(nullptr), row_ids(nullptr), nnz(0) {}

SimpleSparseMat(T* values, int* cols, int* row_ids, int nnz, int m, int n)
SimpleSparseMat(T* values, I* cols, I* row_ids, I nnz, int m, int n)
: Super(m, n), values(values), cols(cols), row_ids(row_ids), nnz(nnz)
{
check_csr(*this, 0);
}

void print(std::ostream& oss) const override { oss << (*this) << std::endl; }

void operator=(const SimpleSparseMat<T>& other) = delete;
void operator=(const SimpleSparseMat<T, I>& other) = delete;

inline void gemmb(const raft::handle_t& handle,
const T alpha,
Expand All @@ -73,9 +73,9 @@ struct SimpleSparseMat : SimpleMat<T> {
SimpleDenseMat<T>& C,
cudaStream_t stream) const override
{
const SimpleSparseMat<T>& B = *this;
int kA = A.n;
int kB = B.m;
const SimpleSparseMat<T, I>& B = *this;
int kA = A.n;
int kB = B.m;

if (transA) {
ASSERT(A.n == C.m, "GEMM invalid dims: m");
Expand Down Expand Up @@ -167,26 +167,26 @@ struct SimpleSparseMat : SimpleMat<T> {
}
};

template <typename T>
inline void check_csr(const SimpleSparseMat<T>& mat, cudaStream_t stream)
template <typename T, typename I = int>
inline void check_csr(const SimpleSparseMat<T, I>& mat, cudaStream_t stream)
{
int row_ids_nnz;
I row_ids_nnz;
raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream);
raft::interruptible::synchronize(stream);
ASSERT(row_ids_nnz == mat.nnz,
"SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and "
"the last element must be equal nnz.");
}

template <typename T>
std::ostream& operator<<(std::ostream& os, const SimpleSparseMat<T>& mat)
template <typename T, typename I = int>
std::ostream& operator<<(std::ostream& os, const SimpleSparseMat<T, I>& mat)
{
check_csr(mat, 0);
os << "SimpleSparseMat (CSR)"
<< "\n";
std::vector<T> values(mat.nnz);
std::vector<int> cols(mat.nnz);
std::vector<int> row_ids(mat.m + 1);
std::vector<I> cols(mat.nnz);
std::vector<I> row_ids(mat.m + 1);
raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default);
raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default);
raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default);
Expand Down
126 changes: 77 additions & 49 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ void qnFitSparse_impl(const raft::handle_t& handle,
int rank,
int n_ranks)
{
auto X_simple = SimpleSparseMat<T>(X_values, X_cols, X_row_ids, X_nnz, N, D);
auto X_simple = SimpleSparseMat<T, I>(X_values, X_cols, X_row_ids, X_nnz, N, D);

size_t vec_size = raft::alignTo<size_t>(sizeof(T) * D, ML::GLM::detail::qn_align);
rmm::device_uvector<T> mean_std_buff(4 * vec_size, handle.get_stream());
Expand Down Expand Up @@ -303,12 +303,12 @@ void qnFitSparse_impl(const raft::handle_t& handle,
return;
}

template <typename T>
template <typename T, typename I = int>
void qnFitSparse(raft::handle_t& handle,
std::vector<Matrix::Data<T>*>& input_values,
int* input_cols,
int* input_row_ids,
int X_nnz,
I* input_cols,
I* input_row_ids,
I X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<T>*>& labels,
T* coef,
Expand All @@ -324,52 +324,80 @@ void qnFitSparse(raft::handle_t& handle,
auto data_input_values = input_values[0];
auto data_y = labels[0];

qnFitSparse_impl<T, int>(handle,
pams,
data_input_values->ptr,
input_cols,
input_row_ids,
X_nnz,
standardization,
data_y->ptr,
input_desc.totalElementsOwnedBy(input_desc.rank),
input_desc.N,
n_classes,
coef,
f,
num_iters,
input_desc.M,
input_desc.rank,
input_desc.uniqueRanks().size());
qnFitSparse_impl(handle,
pams,
data_input_values->ptr,
input_cols,
input_row_ids,
X_nnz,
standardization,
data_y->ptr,
input_desc.totalElementsOwnedBy(input_desc.rank),
input_desc.N,
n_classes,
coef,
f,
num_iters,
input_desc.M,
input_desc.rank,
input_desc.uniqueRanks().size());
}

template void qnFitSparse(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_values,
int* input_cols,
int* input_row_ids,
int X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels,
float* coef,
const qn_params& pams,
bool standardization,
int n_classes,
float* f,
int* num_iters);

template void qnFitSparse(raft::handle_t& handle,
std::vector<Matrix::Data<double>*>& input_values,
int* input_cols,
int* input_row_ids,
int X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<double>*>& labels,
double* coef,
const qn_params& pams,
bool standardization,
int n_classes,
double* f,
int* num_iters);
template void qnFitSparse<float, int>(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_values,
int* input_cols,
int* input_row_ids,
int X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels,
float* coef,
const qn_params& pams,
bool standardization,
int n_classes,
float* f,
int* num_iters);

template void qnFitSparse<double, int>(raft::handle_t& handle,
std::vector<Matrix::Data<double>*>& input_values,
int* input_cols,
int* input_row_ids,
int X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<double>*>& labels,
double* coef,
const qn_params& pams,
bool standardization,
int n_classes,
double* f,
int* num_iters);

template void qnFitSparse<float, int64_t>(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_values,
int64_t* input_cols,
int64_t* input_row_ids,
int64_t X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels,
float* coef,
const qn_params& pams,
bool standardization,
int n_classes,
float* f,
int* num_iters);

template void qnFitSparse<double, int64_t>(raft::handle_t& handle,
std::vector<Matrix::Data<double>*>& input_values,
int64_t* input_cols,
int64_t* input_row_ids,
int64_t X_nnz,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<double>*>& labels,
double* coef,
const qn_params& pams,
bool standardization,
int n_classes,
double* f,
int* num_iters);

}; // namespace opg
}; // namespace GLM
Expand Down
9 changes: 7 additions & 2 deletions python/cuml/cuml/dask/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -183,11 +183,16 @@ def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank):
inp_X = scipy.sparse.vstack([X for X, _ in data])

elif cupyx.scipy.sparse.isspmatrix(data[0][0]):
total_nnz = sum([X.nnz for X, _ in data])
if total_nnz > np.iinfo(np.int32).max:
raise ValueError(
f"please use scipy csr_matrix because cupyx uses int32 index dtype that does not support {total_nnz} non-zero values of a partition"
)
inp_X = cupyx.scipy.sparse.vstack([X for X, _ in data])

else:
raise ValueError(
"input matrix must be dense, scipy sparse, or cupy sparse"
"input matrix must be dense, scipy sparse, or cupyx sparse"
)

inp_y = concatenate([y for _, y in data])
Expand Down
Loading

0 comments on commit a8fda19

Please sign in to comment.