-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmultiply.h
150 lines (127 loc) · 4.82 KB
/
multiply.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include "CSR.h"
#include <omp.h>
#include <algorithm>
// #include <mkl_spblas.h>
template <bool sortOutput, typename IT>
int MKLSpGEMM_symbolic(const CSR<IT,float> &A, const CSR<IT,float> &B, CSR<IT,float> &C)
{
int request = 1;
int sort = 7; // don't sort anything
int info = 0; // output info flag
mkl_scsrmultcsr((char*)"N", &request, &sort,
&(A.rows), &(A.cols), &(B.cols),
A.values, A.colids, A.rowptr,
B.values, B.colids, B.rowptr,
NULL, NULL, C.rowptr,
NULL, &info);
return info;
}
template <bool sortOutput, typename IT>
void MKLSpGEMM_symbolic(const CSR<IT,double> &A, const CSR<IT,double> &B, CSR<IT,double> &C)
{
int request = 1;
int sort = 7; // don't sort anything
int info = 0; // output info flag
mkl_dcsrmultcsr((char*)"N", &request, &sort,
&(A.rows), &(A.cols), &(B.cols),
A.values, A.colids, A.rowptr,
B.values, B.colids, B.rowptr,
NULL, NULL, C.rowptr,
NULL, &info);
}
template <bool sortOutput, typename IT>
int MKLSpGEMM_numeric(const CSR<IT,float> &A, const CSR<IT,float> &B, CSR<IT,float> &C)
{
int request = 2;
int sort = 7;
int info = 0; // output info flag
if (sortOutput) {
sort = 8; // sort nonzeroes in rows of C, leave A and B alone (they are already sorted)
}
mkl_scsrmultcsr((char*)"N", &request, &sort,
&(A.rows), &(A.cols), &(B.cols),
A.values, A.colids, A.rowptr,
B.values, B.colids, B.rowptr,
C.values, C.colids, C.rowptr,
NULL, &info);
return info;
}
template <bool sortOutput, typename IT>
int MKLSpGEMM_numeric(const CSR<IT,double> &A, const CSR<IT,double> &B, CSR<IT,double> &C)
{
int request = 2;
int sort = 7;
int info = 0; // output info flag
if (sortOutput) {
sort = 8; // sort nonzeroes in rows of C, leave A and B alone (they are already sorted)
}
mkl_dcsrmultcsr((char*)"N", &request, &sort,
&(A.rows), &(A.cols), &(B.cols),
A.values, A.colids, A.rowptr,
B.values, B.colids, B.rowptr,
C.values, C.colids, C.rowptr,
NULL, &info);
return info;
}
template <bool sortOutput, typename IT, typename NT>
void MKLSpGEMM(const CSR<IT,NT> &A, const CSR<IT,NT> &B, CSR<IT,NT> &C)
{
// for request=1, mkl_dcsrmultcsr() computes only values of the array ic of length m + 1,
// the memory for this array must be allocated beforehand. On exit the value
// ic(m+1) - 1 is the actual number of the elements in the arrays c and jc
int info;
if (typeid(IT) != typeid(int)) {
cout << "MKL does not support non-int type indices." << endl;
return;
}
C.rows = A.rows;
C.cols = B.cols;
C.rowptr = my_malloc<IT>(C.rows + 1);
C.zerobased = false;
info = MKLSpGEMM_symbolic<sortOutput, IT>(A, B, C);
if (info != 0) {
cout << "MKL-Count Error: info returned " << info << endl;
assert(info == 0);
}
C.nnz = C.rowptr[A.rows] - 1;
C.colids = my_malloc<IT>(C.nnz);
C.values = my_malloc<NT>(C.nnz);
// for request=2, mkl_dcsrmultcsr() has been called previously with the parameter request=1,
// the output arrays jc and c are allocated in the calling program and they are of the length ic(m+1) - 1 at least.
info = MKLSpGEMM_numeric<sortOutput, IT>(A, B, C);
if (info != 0) {
printf("MKL-Calculation Error: info returned %d\n", info);
assert(info == 0);
}
}
template <typename IT, typename NT>
long long int get_flop(const CSR<IT,NT> & A, const CSR<IT,NT> & B)
{
long long int flops = 0; // total flops (multiplication) needed to generate C
long long int tflops=0; //thread private flops
for (IT i=0; i < A.rows; ++i) { // for all rows of A
long long int locmax = 0;
for (IT j=A.rowptr[i]; j < A.rowptr[i + 1]; ++j) { // For all the nonzeros of the ith column
long long int inner = A.colids[j]; // get the row id of B (or column id of A)
long long int npins = B.rowptr[inner + 1] - B.rowptr[inner]; // get the number of nonzeros in A's corresponding column
locmax += npins;
}
tflops += locmax;
}
flops += tflops;
return (flops * 2);
}
template <typename IT, typename NT>
long long int get_flop(const CSC<IT, NT> &A, const CSR<IT, NT> &B)
{
long long int flops = 0;
#pragma omp parallel for reduction(+ \
: flops)
for (IT i = 0; i < A.cols; ++i)
{
IT colnnz = A.colptr[i + 1] - A.colptr[i];
IT rownnz = B.rowptr[i + 1] - B.rowptr[i];
flops += (colnnz * rownnz);
}
return (flops * 2);
}