This repository has been archived by the owner on Jul 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathkgemm_nn.hpp
212 lines (168 loc) · 5.37 KB
/
kgemm_nn.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#ifndef KGEMM_NN_HPP
#define KGEMM_NN_HPP 1
#include "kroncommon.hpp"
// -----------------------
// NotransA and TransB case
// C = alpha*A*(B) + beta *C
// -----------------------
template<typename T>
DEVICE_FUNCTION
void kgemm_nn( int const mm, int const nn, int const kk,
T const alpha,
T const * const A_, int const ldA,
T const * const B_, int const ldB,
T const beta,
T * C_, int const ldC)
{
#ifdef USE_LAMBDA
auto min = []( int const x, int const y) {
return( (x < y) ? x : y );
};
auto max = []( int const x, int const y) {
return( (x > y) ? x : y );
};
#else
#define min(x,y) (((x) < (y)) ? (x) : (y) )
#define max(x,y) (((x) > (y)) ? (x) : (y) )
#endif
int constexpr nb = 2*32;
#ifdef USE_GPU
// ---------------------------
// use matlab 1 based indexing
// ---------------------------
int constexpr warpsize = 32;
int const nthreads = blockDim.x;
expect( blockDim.y == 1);
expect( blockDim.z == 1);
expect( (nthreads % warpsize) == 0);
// -----------------------------------------
// reorganize threads as nx_threads by ny_threads
// -----------------------------------------
int const nx_threads = warpsize;
int const ny_threads = max(1,nthreads/nx_threads);
int const ix_start = ( threadIdx.x % nx_threads ) + 1;
int const iy_start = (threadIdx.x/nx_threads) + 1;
int const ix_size = nx_threads;
int const iy_size = ny_threads;
int const ij_start = threadIdx.x + 1;
int const ij_size = nthreads;
#else
int const ix_start = 1;
int const ix_size = 1;
int const iy_start = 1;
int const iy_size = 1;
int const ij_start = 1;
int const ij_size = 1;
#endif
expect( ix_start >= 1);
expect( iy_start >= 1);
expect( ix_size >= 1 );
expect( iy_size >= 1 );
// ------------------------------------
// commonly nn is large, but kk, nn are small
//
// consider increasing nb for more effective
// use of shared cache
//
// ------------------------------------
#ifdef USE_LAMBDA
auto A = [&] (int const ia,
int const ja) -> T const & {
return( A_[ indx2f(ia,ja,ldA) ] );
};
auto B = [&] (int const ib,
int const jb) -> T const & {
return( B_[ indx2f(ib,jb,ldB) ] );
};
auto C = [&] (int const ic,
int const jc) -> T& {
return( C_[ indx2f(ic,jc,ldC) ] );
};
#else
#define A(ia,ja) A_[indx2f(ia,ja,ldA)]
#define B(ib,jb) B_[indx2f(ib,jb,ldB)]
#define C(ic,jc) C_[indx2f(ic,jc,ldC)]
#endif
for(int istart=1; istart <= mm; istart += nb) {
int const iend = min( mm, istart + nb-1);
int const isize = iend - istart + 1;
for(int jstart=1; jstart <= nn; jstart += nb) {
int const jend = min(nn, jstart + nb-1);
int const jsize = jend - jstart + 1;
SYNCTHREADS;
// ---------------------------
// perform matrix calculations
// ---------------------------
// for(int j=iy_start; j <= jsize; j += iy_size)
// for(int i=ix_start; i <= isize; i += ix_size) {
for(int ij0 = ij_start-1; ij0 < (isize*jsize); ij0 += ij_size ) {
int const i = (ij0 % isize) + 1;
int const j = ((ij0 - (i-1))/isize) + 1;
int const ia = (istart-1) + i;
int const jb = (jstart-1) + j;
auto const inc_A = ldA;
auto const inc_B = 1;
T cij = 0;
bool constexpr use_pointer = true;
if (use_pointer) {
int k = 1;
T const * Ap = &(A(ia,k));
T const * Bp = &(B(k,jb));
#define case_code(kk) { \
for(k=0; k < kk; k++) { \
cij += (*Ap) * (*Bp); \
Ap += inc_A; \
Bp += inc_B; \
}; \
break; \
}
switch(kk) {
case 1: case_code(1)
case 2: case_code(2)
case 3: case_code(3)
case 4: case_code(4)
case 5: case_code(5)
case 6: case_code(6)
case 7: case_code(7)
case 8: case_code(8)
default:
#ifdef USE_GPU
#pragma unroll
#endif
for(k=0; k < kk; k++) {
cij += (*Ap) * (*Bp);
Ap += inc_A;
Bp += inc_B;
};
};
}
else {
for(int k=1; k <= kk; k++) {
cij += A( ia, k) * B( k, jb);
};
};
// ------------------
// store results to C
// ------------------
int const ic = ia;
int const jc = jb;
T const alpha_cij = alpha * cij;
if (beta == 1) {
atomicAdd( &(C(ic,jc)), alpha_cij );
}
else if (beta == 0) {
C(ic,jc) = alpha_cij;
}
else {
C(ic,jc) = beta * C(ic,jc) + alpha_cij;
};
};
}; // end istart
}; // end jstart
}
#undef min
#undef max
#undef A
#undef B
#undef C
#endif