-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathrunner.cu
549 lines (502 loc) · 20 KB
/
runner.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
#include "kernels.cuh"
#include "runner.cuh"
#include <cmath>
#include <cstdio>
#include <fstream>
#include <iomanip>
float get_sec() {
struct timeval time;
gettimeofday(&time, NULL);
return (1e6 * time.tv_sec + time.tv_usec);
}
float cpu_elapsed_time(float &beg, float &end) { return 1.0e-6 * (end - beg); }
void cudaCheck(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line,
cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
};
void CudaDeviceInfo() {
int deviceId;
cudaGetDevice(&deviceId);
cudaDeviceProp props{};
cudaGetDeviceProperties(&props, deviceId);
printf("Device ID: %d\n\
Name: %s\n\
Compute Capability: %d.%d\n\
memoryBusWidth: %d\n\
maxThreadsPerBlock: %d\n\
maxThreadsPerMultiProcessor: %d\n\
maxRegsPerBlock: %d\n\
maxRegsPerMultiProcessor: %d\n\
totalGlobalMem: %zuMB\n\
sharedMemPerBlock: %zuKB\n\
sharedMemPerMultiprocessor: %zuKB\n\
totalConstMem: %zuKB\n\
multiProcessorCount: %d\n\
Warp Size: %d\n",
deviceId, props.name, props.major, props.minor, props.memoryBusWidth,
props.maxThreadsPerBlock, props.maxThreadsPerMultiProcessor,
props.regsPerBlock, props.regsPerMultiprocessor,
props.totalGlobalMem / 1024 / 1024, props.sharedMemPerBlock / 1024,
props.sharedMemPerMultiprocessor / 1024, props.totalConstMem / 1024,
props.multiProcessorCount, props.warpSize);
};
void randomize_matrix(float *mat, int N) {
// NOTICE: Use gettimeofday instead of srand((unsigned)time(NULL)); the time
// precision is too low and the same random number is generated.
struct timeval time {};
gettimeofday(&time, nullptr);
srand(time.tv_usec);
for (int i = 0; i < N; i++) {
float tmp = (float)(rand() % 5) + 0.01 * (rand() % 5);
tmp = (rand() % 2 == 0) ? tmp : tmp * (-1.);
mat[i] = tmp;
}
}
void range_init_matrix(float *mat, int N) {
for (int i = 0; i < N; i++) {
mat[i] = i;
}
}
void zero_init_matrix(float *mat, int N) {
for (int i = 0; i < N; i++) {
mat[i] = 0.0;
}
}
void copy_matrix(const float *src, float *dest, int N) {
int i;
for (i = 0; src + i && dest + i && i < N; i++)
*(dest + i) = *(src + i);
if (i != N)
printf("copy failed at %d while there are %d elements in total.\n", i, N);
}
void print_matrix(const float *A, int M, int N, std::ofstream &fs) {
int i;
fs << std::setprecision(2)
<< std::fixed; // Set floating-point precision and fixed notation
fs << "[";
for (i = 0; i < M * N; i++) {
if ((i + 1) % N == 0)
fs << std::setw(5) << A[i]; // Set field width and write the value
else
fs << std::setw(5) << A[i] << ", ";
if ((i + 1) % N == 0) {
if (i + 1 < M * N)
fs << ";\n";
}
}
fs << "]\n";
}
bool verify_matrix(float *matRef, float *matOut, int N) {
double diff = 0.0;
int i;
for (i = 0; i < N; i++) {
diff = std::fabs(matRef[i] - matOut[i]);
if (diff > 0.01) {
printf("Divergence! Should %5.2f, Is %5.2f (Diff %5.2f) at %d\n",
matRef[i], matOut[i], diff, i);
return false;
}
}
return true;
}
int div_ceil(int numerator, int denominator) {
std::div_t res = std::div(numerator, denominator);
return res.rem ? (res.quot + 1) : res.quot;
}
void runCublasFP32(cublasHandle_t handle, int M, int N, int K, float alpha,
float *A, float *B, float beta, float *C) {
// cuBLAS uses column-major order. So we change the order of our row-major A &
// B, since (B^T*A^T)^T = (A*B)
// This runs cuBLAS in full fp32 mode
cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, B, CUDA_R_32F,
N, A, CUDA_R_32F, K, &beta, C, CUDA_R_32F, N, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
void runCublasBF16(cublasHandle_t handle, int M, int N, int K, float alpha,
float *A, float *B, float beta, float *C) {
// This runs cuBLAS with mixed precision (performing the mul with operands
// downcast to bf16), which is ~4x faster
cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, B, CUDA_R_32F,
N, A, CUDA_R_32F, K, &beta, C, CUDA_R_32F, N,
CUBLAS_COMPUTE_32F_FAST_16BF, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
void runCublasTF32(cublasHandle_t handle, int M, int N, int K, float alpha,
float *A, float *B, float beta, float *C) {
// This runs cuBLAS with mixed precision (performing the mul with operands
// downcast to bf16), which is ~4x faster
cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, B, CUDA_R_32F,
N, A, CUDA_R_32F, K, &beta, C, CUDA_R_32F, N,
CUBLAS_COMPUTE_32F_FAST_TF32, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
void run_sgemm_naive(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
dim3 gridDim(CEIL_DIV(M, 32), CEIL_DIV(N, 32));
dim3 blockDim(32, 32);
sgemm_naive<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
void run_sgemm_coalesce(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
dim3 gridDim(CEIL_DIV(M, 32), CEIL_DIV(N, 32));
dim3 blockDim(32 * 32);
sgemm_global_mem_coalesce<32>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
void run_sgemm_shared_mem_block(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
dim3 gridDim(CEIL_DIV(M, 32), CEIL_DIV(N, 32));
dim3 blockDim(32 * 32);
// L1 cache becomes useless, since we access GMEM only via SMEM, so we carve
// out all of L1 to SMEM. This doesn't currently make a difference, since
// occupancy is limited by reg and thread count, but it's good to do anyway.
cudaFuncSetAttribute(sgemm_shared_mem_block<32>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
sgemm_shared_mem_block<32>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
void runSgemm1DBlocktiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint BM = 64;
const uint BN = 64;
const uint BK = 8;
const uint TM = 8;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / TM);
sgemm1DBlocktiling<BM, BN, BK, TM>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
void runSgemm2DBlocktiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint BK = 8;
const uint TM = 8;
const uint TN = 8;
if (M >= 128 and N >= 128) {
const uint BM = 128;
const uint BN = 128;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemm2DBlocktiling<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
} else {
// this is a hacky solution to the underlying problem
// of not having proper bounds checking in the kernel
const uint BM = 64;
const uint BN = 64;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemm2DBlocktiling<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
}
void runSgemmVectorize(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
const uint BK = 8;
const uint TM = 8;
const uint TN = 8;
if (M >= 128 and N >= 128) {
const uint BM = 128;
const uint BN = 128;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmVectorize<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
} else {
// this is a hacky solution to the underlying problem
// of not having proper bounds checking in the kernel
const uint BM = 64;
const uint BN = 64;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmVectorize<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
}
void runSgemmResolveBankConflicts(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
const uint BK = 8;
const uint TM = 8;
const uint TN = 8;
if (M >= 128 and N >= 128) {
const uint BM = 128;
const uint BN = 128;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmResolveBankConflicts<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
} else {
// this is a hacky solution to the underlying problem
// of not having proper bounds checking in the kernel
const uint BM = 64;
const uint BN = 64;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmResolveBankConflicts<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
}
void runSgemmResolveBankExtraCol(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
const uint BK = 8;
const uint TM = 8;
const uint TN = 8;
if (M >= 128 and N >= 128) {
const uint BM = 128;
const uint BN = 128;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmResolveBankExtraCol<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
} else {
// this is a hacky solution to the underlying problem
// of not having proper bounds checking in the kernel
const uint BM = 64;
const uint BN = 64;
dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
dim3 blockDim((BM * BN) / (TM * TN));
sgemmResolveBankExtraCol<BM, BN, BK, TM, TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
}
void runSgemmAutotuned(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
// A100
// const uint K9_BK = 16;
// const uint K9_TM = 4;
// const uint K9_TN = 4;
// const uint K9_BM = 64;
// const uint K9_BN = 64;
// A6000
const uint K9_BK = 16;
const uint K9_TM = 8;
const uint K9_TN = 8;
const uint K9_BM = 128;
const uint K9_BN = 128;
dim3 blockDim(K9_NUM_THREADS);
static_assert(
(K9_NUM_THREADS * 4) % K9_BK == 0,
"NUM_THREADS*4 must be multiple of K9_BK to avoid quantization issues "
"during GMEM->SMEM tiling (loading only parts of the final row of Bs "
"during each iteraion)");
static_assert(
(K9_NUM_THREADS * 4) % K9_BN == 0,
"NUM_THREADS*4 must be multiple of K9_BN to avoid quantization issues "
"during GMEM->SMEM tiling (loading only parts of the final row of As "
"during each iteration)");
static_assert(
K9_BN % (16 * K9_TN) == 0,
"K9_BN must be a multiple of 16*K9_TN to avoid quantization effects");
static_assert(
K9_BM % (16 * K9_TM) == 0,
"K9_BM must be a multiple of 16*K9_TM to avoid quantization effects");
static_assert((K9_BM * K9_BK) % (4 * K9_NUM_THREADS) == 0,
"K9_BM*K9_BK must be a multiple of 4*256 to vectorize loads");
static_assert((K9_BN * K9_BK) % (4 * K9_NUM_THREADS) == 0,
"K9_BN*K9_BK must be a multiple of 4*256 to vectorize loads");
dim3 gridDim(CEIL_DIV(N, K9_BN), CEIL_DIV(M, K9_BM));
sgemmAutotuned<K9_BM, K9_BN, K9_BK, K9_TM, K9_TN>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
void runSgemmWarptiling(int M, int N, int K, float alpha, float *A, float *B,
float beta, float *C) {
// Settings for A100
// const uint K10_NUM_THREADS = 128;
// const uint K10_BN = 128;
// const uint K10_BM = 64;
// const uint K10_BK = 16;
// const uint K10_WN = 64;
// const uint K10_WM = 32;
// const uint K10_WNITER = 1;
// const uint K10_TN = 4;
// const uint K10_TM = 4;
// Settings for A6000
const uint K10_NUM_THREADS = 128;
const uint K10_BN = 128;
const uint K10_BM = 128;
const uint K10_BK = 16;
const uint K10_WN = 64;
const uint K10_WM = 64;
const uint K10_WNITER = 4;
const uint K10_TN = 4;
const uint K10_TM = 8;
dim3 blockDim(K10_NUM_THREADS);
constexpr uint NUM_WARPS = K10_NUM_THREADS / 32;
// warptile in threadblocktile
static_assert((K10_BN % K10_WN == 0) and (K10_BM % K10_WM == 0));
static_assert((K10_BN / K10_WN) * (K10_BM / K10_WM) == NUM_WARPS);
// threads in warpsubtile
static_assert((K10_WM * K10_WN) % (WARPSIZE * K10_TM * K10_TN * K10_WNITER) ==
0);
constexpr uint K10_WMITER =
(K10_WM * K10_WN) / (32 * K10_TM * K10_TN * K10_WNITER);
// warpsubtile in warptile
static_assert((K10_WM % K10_WMITER == 0) and (K10_WN % K10_WNITER == 0));
static_assert((K10_NUM_THREADS * 4) % K10_BK == 0,
"NUM_THREADS*4 must be multiple of K9_BK to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of Bs during each iteraion)");
static_assert((K10_NUM_THREADS * 4) % K10_BN == 0,
"NUM_THREADS*4 must be multiple of K9_BN to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of As during each iteration)");
static_assert(K10_BN % (16 * K10_TN) == 0,
"BN must be a multiple of 16*TN to avoid quantization effects");
static_assert(K10_BM % (16 * K10_TM) == 0,
"BM must be a multiple of 16*TM to avoid quantization effects");
static_assert((K10_BM * K10_BK) % (4 * K10_NUM_THREADS) == 0,
"BM*BK must be a multiple of 4*256 to vectorize loads");
static_assert((K10_BN * K10_BK) % (4 * K10_NUM_THREADS) == 0,
"BN*BK must be a multiple of 4*256 to vectorize loads");
dim3 gridDim(CEIL_DIV(N, K10_BN), CEIL_DIV(M, K10_BM));
sgemmWarptiling<K10_BM, K10_BN, K10_BK, K10_WM, K10_WN, K10_WNITER, K10_TM,
K10_TN, K10_NUM_THREADS>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
void runSgemmDoubleBuffering(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
// Settings for A100
// const uint K11_NUM_THREADS = 256;
// const uint K11_BN = 128;
// const uint K11_BM = 64;
// const uint K11_BK = 16;
// const uint K11_WN = 32;
// const uint K11_WM = 32;
// const uint K11_WNITER = 2;
// const uint K11_TN = 4;
// const uint K11_TM = 4;
// Settings for A6000
const uint K11_NUM_THREADS = 256;
const uint K11_BN = 256;
const uint K11_BM = 128;
const uint K11_BK = 16;
const uint K11_WN = 32;
const uint K11_WM = 128;
const uint K11_WNITER = 1;
const uint K11_TN = 8;
const uint K11_TM = 8;
dim3 blockDim(K11_NUM_THREADS);
constexpr uint NUM_WARPS = K11_NUM_THREADS / 32;
// warptile in threadblocktile
static_assert((K11_BN % K11_WN == 0) and (K11_BM % K11_WM == 0));
static_assert((K11_BN / K11_WN) * (K11_BM / K11_WM) == NUM_WARPS);
// threads in warpsubtile
static_assert((K11_WM * K11_WN) % (WARPSIZE * K11_TM * K11_TN * K11_WNITER) ==
0);
constexpr uint K11_WMITER =
(K11_WM * K11_WN) / (32 * K11_TM * K11_TN * K11_WNITER);
// warpsubtile in warptile
static_assert((K11_WM % K11_WMITER == 0) and (K11_WN % K11_WNITER == 0));
static_assert((K11_NUM_THREADS / 2 * 4) % K11_BK == 0,
"NUM_THREADS*4 must be multiple of BK to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of Bs during each iteraion)");
static_assert((K11_NUM_THREADS / 2 * 4) % K11_BN == 0,
"NUM_THREADS*4 must be multiple of BN to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of As during each iteration)");
static_assert(K11_BN % (16 * K11_TN) == 0,
"BN must be a multiple of 16*TN to avoid quantization effects");
static_assert(K11_BM % (16 * K11_TM) == 0,
"BM must be a multiple of 16*TM to avoid quantization effects");
static_assert((K11_BM * K11_BK) % (4 * K11_NUM_THREADS / 2) == 0,
"BM*BK must be a multiple of 4*256 to vectorize loads");
static_assert((K11_BN * K11_BK) % (4 * K11_NUM_THREADS / 2) == 0,
"BN*BK must be a multiple of 4*256 to vectorize loads");
dim3 gridDim(CEIL_DIV(N, K11_BN), CEIL_DIV(M, K11_BM));
sgemmDoubleBuffering<K11_BM, K11_BN, K11_BK, K11_WM, K11_WN, K11_WNITER,
K11_TM, K11_TN, K11_NUM_THREADS>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
void runSgemmDoubleBuffering2(int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C) {
// Settings for A6000
const uint K12_NUM_THREADS = 128;
const uint K12_BN = 128;
const uint K12_BM = 128;
const uint K12_BK = 16;
const uint K12_WN = 64;
const uint K12_WM = 64;
const uint K12_WNITER = 4;
const uint K12_TN = 4;
const uint K12_TM = 8;
dim3 blockDim(K12_NUM_THREADS);
constexpr uint NUM_WARPS = K12_NUM_THREADS / 32;
// warptile in threadblocktile
static_assert((K12_BN % K12_WN == 0) and (K12_BM % K12_WM == 0));
static_assert((K12_BN / K12_WN) * (K12_BM / K12_WM) == NUM_WARPS);
// threads in warpsubtile
static_assert((K12_WM * K12_WN) % (WARPSIZE * K12_TM * K12_TN * K12_WNITER) ==
0);
constexpr uint K12_WMITER =
(K12_WM * K12_WN) / (32 * K12_TM * K12_TN * K12_WNITER);
// warpsubtile in warptile
static_assert((K12_WM % K12_WMITER == 0) and (K12_WN % K12_WNITER == 0));
static_assert((K12_NUM_THREADS * 4) % K12_BK == 0,
"NUM_THREADS*4 must be multiple of K9_BK to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of Bs during each iteraion)");
static_assert((K12_NUM_THREADS * 4) % K12_BN == 0,
"NUM_THREADS*4 must be multiple of K9_BN to avoid quantization "
"issues during GMEM->SMEM tiling (loading only parts of the "
"final row of As during each iteration)");
static_assert(K12_BN % (16 * K12_TN) == 0,
"BN must be a multiple of 16*TN to avoid quantization effects");
static_assert(K12_BM % (16 * K12_TM) == 0,
"BM must be a multiple of 16*TM to avoid quantization effects");
static_assert((K12_BM * K12_BK) % (4 * K12_NUM_THREADS) == 0,
"BM*BK must be a multiple of 4*256 to vectorize loads");
static_assert((K12_BN * K12_BK) % (4 * K12_NUM_THREADS) == 0,
"BN*BK must be a multiple of 4*256 to vectorize loads");
dim3 gridDim(CEIL_DIV(N, K12_BN), CEIL_DIV(M, K12_BM));
runSgemmDoubleBuffering2<K12_BM, K12_BN, K12_BK, K12_WM, K12_WN, K12_WNITER,
K12_TM, K12_TN, K12_NUM_THREADS>
<<<gridDim, blockDim>>>(M, N, K, alpha, A, B, beta, C);
}
void run_kernel(int kernel_num, int M, int N, int K, float alpha, float *A,
float *B, float beta, float *C, cublasHandle_t handle) {
switch (kernel_num) {
case 0:
runCublasFP32(handle, M, N, K, alpha, A, B, beta, C);
break;
case 1:
run_sgemm_naive(M, N, K, alpha, A, B, beta, C);
break;
case 2:
run_sgemm_coalesce(M, N, K, alpha, A, B, beta, C);
break;
case 3:
run_sgemm_shared_mem_block(M, N, K, alpha, A, B, beta, C);
break;
case 4:
runSgemm1DBlocktiling(M, N, K, alpha, A, B, beta, C);
break;
case 5:
runSgemm2DBlocktiling(M, N, K, alpha, A, B, beta, C);
break;
case 6:
runSgemmVectorize(M, N, K, alpha, A, B, beta, C);
break;
case 7:
runSgemmResolveBankConflicts(M, N, K, alpha, A, B, beta, C);
break;
case 8:
runSgemmResolveBankExtraCol(M, N, K, alpha, A, B, beta, C);
break;
case 9:
runSgemmAutotuned(M, N, K, alpha, A, B, beta, C);
break;
case 10:
runSgemmWarptiling(M, N, K, alpha, A, B, beta, C);
break;
case 11:
runSgemmDoubleBuffering(M, N, K, alpha, A, B, beta, C);
break;
case 12:
runSgemmDoubleBuffering2(M, N, K, alpha, A, B, beta, C);
break;
default:
throw std::invalid_argument("Unknown kernel number");
}
}