-
Notifications
You must be signed in to change notification settings - Fork 197
/
svd.cuh
398 lines (355 loc) · 14.4 KB
/
svd.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "eig.cuh"
#include "gemm.cuh"
#include "transpose.h"
#include <raft/common/nvtx.hpp>
#include <raft/cuda_utils.cuh>
#include <raft/cudart_utils.h>
#include <raft/handle.hpp>
#include <raft/linalg/cublas_wrappers.h>
#include <raft/linalg/cusolver_wrappers.h>
#include <raft/matrix/math.hpp>
#include <raft/matrix/matrix.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>
namespace raft {
namespace linalg {
/**
* @brief singular value decomposition (SVD) on the column major float type
* input matrix using QR method
* @param handle: raft handle
* @param in: input matrix
* @param n_rows: number rows of input matrix
* @param n_cols: number columns of input matrix
* @param sing_vals: singular values of input matrix
* @param left_sing_vecs: left singular values of input matrix
* @param right_sing_vecs: right singular values of input matrix
* @param trans_right: transpose right vectors or not
* @param gen_left_vec: generate left eig vector. Not activated.
* @param gen_right_vec: generate right eig vector. Not activated.
* @param stream cuda stream
*/
// TODO: activate gen_left_vec and gen_right_vec options
// TODO: couldn't template this function due to cusolverDnSgesvd and
// cusolverSnSgesvd. Check if there is any other way.
template <typename T>
void svdQR(const raft::handle_t& handle,
T* in,
int n_rows,
int n_cols,
T* sing_vals,
T* left_sing_vecs,
T* right_sing_vecs,
bool trans_right,
bool gen_left_vec,
bool gen_right_vec,
cudaStream_t stream)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"raft::linalg::svdQR(%d, %d)", n_rows, n_cols);
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cublasHandle_t cublasH = handle.get_cublas_handle();
#if CUDART_VERSION >= 10010 && CUDART_VERSION < 11000
// 46340: sqrt of max int value
ASSERT(n_rows <= 46340,
"svd solver is not supported for the data that has more than 46340 "
"samples (rows) "
"if you are using CUDA version <11. Please use other solvers such as "
"eig if it is available.");
#endif
const int m = n_rows;
const int n = n_cols;
rmm::device_scalar<int> devInfo(stream);
T* d_rwork = nullptr;
int lwork = 0;
RAFT_CUSOLVER_TRY(cusolverDngesvd_bufferSize<T>(cusolverH, n_rows, n_cols, &lwork));
rmm::device_uvector<T> d_work(lwork, stream);
char jobu = 'S';
char jobvt = 'A';
if (!gen_left_vec) {
char new_u = 'N';
strcpy(&jobu, &new_u);
}
if (!gen_right_vec) {
char new_vt = 'N';
strcpy(&jobvt, &new_vt);
}
RAFT_CUSOLVER_TRY(cusolverDngesvd(cusolverH,
jobu,
jobvt,
m,
n,
in,
m,
sing_vals,
left_sing_vecs,
m,
right_sing_vecs,
n,
d_work.data(),
lwork,
d_rwork,
devInfo.data(),
stream));
// Transpose the right singular vector back
if (trans_right) raft::linalg::transpose(right_sing_vecs, n_cols, stream);
RAFT_CUDA_TRY(cudaGetLastError());
int dev_info;
raft::update_host(&dev_info, devInfo.data(), 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
ASSERT(dev_info == 0,
"svd.cuh: svd couldn't converge to a solution. "
"This usually occurs when some of the features do not vary enough.");
}
template <typename T>
void svdEig(const raft::handle_t& handle,
T* in,
int n_rows,
int n_cols,
T* S,
T* U,
T* V,
bool gen_left_vec,
cudaStream_t stream)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"raft::linalg::svdEig(%d, %d)", n_rows, n_cols);
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cublasHandle_t cublasH = handle.get_cublas_handle();
int len = n_cols * n_cols;
rmm::device_uvector<T> in_cross_mult(len, stream);
T alpha = T(1);
T beta = T(0);
raft::linalg::gemm(handle,
in,
n_rows,
n_cols,
in,
in_cross_mult.data(),
n_cols,
n_cols,
CUBLAS_OP_T,
CUBLAS_OP_N,
alpha,
beta,
stream);
eigDC(handle, in_cross_mult.data(), n_cols, n_cols, V, S, stream);
raft::matrix::colReverse(V, n_cols, n_cols, stream);
raft::matrix::rowReverse(S, n_cols, 1, stream);
raft::matrix::seqRoot(S, S, alpha, n_cols, stream, true);
if (gen_left_vec) {
raft::linalg::gemm(handle,
in,
n_rows,
n_cols,
V,
U,
n_rows,
n_cols,
CUBLAS_OP_N,
CUBLAS_OP_N,
alpha,
beta,
stream);
raft::matrix::matrixVectorBinaryDivSkipZero(U, S, n_rows, n_cols, false, true, stream);
}
}
/**
* @brief on the column major input matrix using Jacobi method
* @param handle: raft handle
* @param in: input matrix
* @param n_rows: number rows of input matrix
* @param n_cols: number columns of input matrix
* @param sing_vals: singular values of input matrix
* @param left_sing_vecs: left singular vectors of input matrix
* @param right_sing_vecs: right singular vectors of input matrix
* @param gen_left_vec: generate left eig vector. Not activated.
* @param gen_right_vec: generate right eig vector. Not activated.
* @param tol: error tolerance for the jacobi method. Algorithm stops when the
* error is below tol
* @param max_sweeps: number of sweeps in the Jacobi algorithm. The more the better
* accuracy.
* @param stream cuda stream
*/
template <typename math_t>
void svdJacobi(const raft::handle_t& handle,
math_t* in,
int n_rows,
int n_cols,
math_t* sing_vals,
math_t* left_sing_vecs,
math_t* right_sing_vecs,
bool gen_left_vec,
bool gen_right_vec,
math_t tol,
int max_sweeps,
cudaStream_t stream)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"raft::linalg::svdJacobi(%d, %d)", n_rows, n_cols);
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
gesvdjInfo_t gesvdj_params = NULL;
RAFT_CUSOLVER_TRY(cusolverDnCreateGesvdjInfo(&gesvdj_params));
RAFT_CUSOLVER_TRY(cusolverDnXgesvdjSetTolerance(gesvdj_params, tol));
RAFT_CUSOLVER_TRY(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, max_sweeps));
int m = n_rows;
int n = n_cols;
rmm::device_scalar<int> devInfo(stream);
int lwork = 0;
int econ = 1;
RAFT_CUSOLVER_TRY(raft::linalg::cusolverDngesvdj_bufferSize(cusolverH,
CUSOLVER_EIG_MODE_VECTOR,
econ,
m,
n,
in,
m,
sing_vals,
left_sing_vecs,
m,
right_sing_vecs,
n,
&lwork,
gesvdj_params));
rmm::device_uvector<math_t> d_work(lwork, stream);
RAFT_CUSOLVER_TRY(raft::linalg::cusolverDngesvdj(cusolverH,
CUSOLVER_EIG_MODE_VECTOR,
econ,
m,
n,
in,
m,
sing_vals,
left_sing_vecs,
m,
right_sing_vecs,
n,
d_work.data(),
lwork,
devInfo.data(),
gesvdj_params,
stream));
RAFT_CUSOLVER_TRY(cusolverDnDestroyGesvdjInfo(gesvdj_params));
}
/**
* @brief reconstruct a matrix use left and right singular vectors and
* singular values
* @param handle: raft handle
* @param U: left singular vectors of size n_rows x k
* @param S: square matrix with singular values on its diagonal, k x k
* @param V: right singular vectors of size n_cols x k
* @param out: reconstructed matrix to be returned
* @param n_rows: number rows of output matrix
* @param n_cols: number columns of output matrix
* @param k: number of singular values
* @param stream cuda stream
*/
template <typename math_t>
void svdReconstruction(const raft::handle_t& handle,
math_t* U,
math_t* S,
math_t* V,
math_t* out,
int n_rows,
int n_cols,
int k,
cudaStream_t stream)
{
const math_t alpha = 1.0, beta = 0.0;
rmm::device_uvector<math_t> SVT(k * n_cols, stream);
raft::linalg::gemm(
handle, S, k, k, V, SVT.data(), k, n_cols, CUBLAS_OP_N, CUBLAS_OP_T, alpha, beta, stream);
raft::linalg::gemm(handle,
U,
n_rows,
k,
SVT.data(),
out,
n_rows,
n_cols,
CUBLAS_OP_N,
CUBLAS_OP_N,
alpha,
beta,
stream);
}
/**
* @brief reconstruct a matrix use left and right singular vectors and
* singular values
* @param handle: raft handle
* @param A_d: input matrix
* @param U: left singular vectors of size n_rows x k
* @param S_vec: singular values as a vector
* @param V: right singular vectors of size n_cols x k
* @param n_rows: number rows of output matrix
* @param n_cols: number columns of output matrix
* @param k: number of singular values to be computed, 1.0 for normal SVD
* @param tol: tolerance for the evaluation
* @param stream cuda stream
*/
template <typename math_t>
bool evaluateSVDByL2Norm(const raft::handle_t& handle,
math_t* A_d,
math_t* U,
math_t* S_vec,
math_t* V,
int n_rows,
int n_cols,
int k,
math_t tol,
cudaStream_t stream)
{
cublasHandle_t cublasH = handle.get_cublas_handle();
int m = n_rows, n = n_cols;
// form product matrix
rmm::device_uvector<math_t> P_d(m * n, stream);
rmm::device_uvector<math_t> S_mat(k * k, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(P_d.data(), 0, sizeof(math_t) * m * n, stream));
RAFT_CUDA_TRY(cudaMemsetAsync(S_mat.data(), 0, sizeof(math_t) * k * k, stream));
raft::matrix::initializeDiagonalMatrix(S_vec, S_mat.data(), k, k, stream);
svdReconstruction(handle, U, S_mat.data(), V, P_d.data(), m, n, k, stream);
// get norms of each
math_t normA = raft::matrix::getL2Norm(handle, A_d, m * n, stream);
math_t normU = raft::matrix::getL2Norm(handle, U, m * k, stream);
math_t normS = raft::matrix::getL2Norm(handle, S_mat.data(), k * k, stream);
math_t normV = raft::matrix::getL2Norm(handle, V, n * k, stream);
math_t normP = raft::matrix::getL2Norm(handle, P_d.data(), m * n, stream);
// calculate percent error
const math_t alpha = 1.0, beta = -1.0;
rmm::device_uvector<math_t> A_minus_P(m * n, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(A_minus_P.data(), 0, sizeof(math_t) * m * n, stream));
RAFT_CUBLAS_TRY(raft::linalg::cublasgeam(cublasH,
CUBLAS_OP_N,
CUBLAS_OP_N,
m,
n,
&alpha,
A_d,
m,
&beta,
P_d.data(),
m,
A_minus_P.data(),
m,
stream));
math_t norm_A_minus_P = raft::matrix::getL2Norm(handle, A_minus_P.data(), m * n, stream);
math_t percent_error = 100.0 * norm_A_minus_P / normA;
return (percent_error / 100.0 < tol);
}
}; // end namespace linalg
}; // end namespace raft