-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathtinyblas.cu
981 lines (909 loc) · 42.8 KB
/
tinyblas.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tinyblas.h"
//
//
// ██████╗ ██╗ █████╗ ██████╗
// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝
// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗
// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║
// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║
// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// In this file you'll find GPU subroutines implementing general matrix
// multiplication, that are API compatible with NVIDIA's cuBLAS library
// and implement similar tricks[1] for performance.
//
// [1] S. Boehm, ‘How to Optimize a CUDA Matmul Kernel for cuBLAS-like
// Performance’, 2022. [Online]. Available:
// https://siboehm.com/articles/22/CUDA-MMM. [Accessed:
// 05-Mar-2024].
#include <algorithm>
#include <cstdlib>
#include <type_traits>
#ifndef __HIP__
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#define __shfl_down(var, srcLane, warpSize) __shfl_down_sync(-1u, var, srcLane, warpSize)
#else
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#define cudaSuccess hipSuccess
#define cudaStream_t hipStream_t
#define cudaGetLastError hipGetLastError
#endif
#define WARPSIZE 32
#define THREAD_COUNT ((BM * BN) / (TM * TN))
#define KERNEL __launch_bounds__(THREAD_COUNT)
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
#define IGNORE_BETA 1
#define IGNORE_ALPHA 2
#define ASSUME_A_OP_N 4
#define ASSUME_B_OP_T 8
#define ASSUME_M_SAFE 16
#define ASSUME_N_SAFE 32
#define ASSUME_K_SAFE 64
#define ASSUME_A_OP_T 128
#define ASSUME_B_OP_N 256
struct tinyblasContext {
cudaStream_t stream;
};
inline bool isone(float x) {
return x == 1;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// tinyBLAS specialized matrix vector product kernel
__forceinline__ __device__ float warpSum(float x) {
for (int i = WARPSIZE / 2; i > 0; i /= 2)
x += __shfl_down(x, i, WARPSIZE);
return x;
}
template <typename WORD, typename SRC>
__device__ __forceinline__ void madd(WORD *tally, WORD *kahan, SRC a, SRC b) {
WORD x = a;
WORD y = b;
WORD z = x * y - *kahan;
WORD t = *tally + z;
*kahan = (t - *tally) - z;
*tally = t;
}
template <typename WORD, typename SRC, typename DST>
static __device__ void matvec(int m, int k, const SRC *A, int lda, const SRC *B, DST *C) {
WORD Ct[WARPSIZE] = {0};
WORD Ce[WARPSIZE] = {0};
int i = blockIdx.y * WARPSIZE;
for (int l = threadIdx.x; l < k; l += WARPSIZE)
for (int j = 0; j < WARPSIZE; ++j)
madd(&Ct[j], &Ce[j], A[lda * (i + j) + l], B[l]);
for (int j = 0; j < WARPSIZE; ++j) {
WORD c = warpSum(Ct[j]);
if (!threadIdx.x)
C[i + j] = c;
}
}
template <typename WORD, typename SRC, typename DST>
static __global__ __launch_bounds__(WARPSIZE) void matvec_entry(int m, int k, const SRC *A, int lda,
const SRC *B, DST *C) {
matvec<WORD>(m, k, A, lda, B, C);
}
template <typename WORD, typename SRC, typename DST>
static tinyblasStatus_t matvec_launch(tinyblasHandle_t handle, int m, int k, const SRC *A, int lda,
const SRC *B, DST *C) {
dim3 blocks(WARPSIZE, m / WARPSIZE);
matvec_entry<WORD><<<blocks, WARPSIZE, 0, handle->stream>>>(m, k, A, lda, B, C);
if (cudaGetLastError() != cudaSuccess)
return TINYBLAS_STATUS_EXECUTION_FAILED;
return TINYBLAS_STATUS_SUCCESS;
}
template <typename WORD>
static bool can_use_matvec(tinyblasOperation_t aT, tinyblasOperation_t bT, int m, int n, int k,
WORD alpha, WORD beta) {
return n == 1 && k >= 4096 && aT && !bT && //
!(m % WARPSIZE) && !(k % WARPSIZE) && //
isone(alpha) && !beta;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// tinyBLAS block tiling outer product GEMM kernel
template <int CONFIG, int BM, int BN, int TM, int TN, typename WORD, typename SRC, typename DST>
static __device__ void matmul_block2d(tinyblasOperation_t transa, tinyblasOperation_t transb, int m,
int n, int k, WORD alpha, const SRC *A, int lda, const SRC *B,
int ldb, WORD beta, DST *C, int ldc) {
constexpr int BK = THREAD_COUNT;
static_assert(BM % TM == 0, "can't divide work for threads");
static_assert(BN % TN == 0, "can't divide work for threads");
static_assert(BM > 0 && BN > 0 && BK > 0 && TM > 0 && TN > 0,
"one of the constexpr configuration values was non-positive");
static_assert((BK * BM * sizeof(SRC)) + (BK * BN * sizeof(SRC)) <= 65536,
"you're almost almost certainly using too much shared memory");
constexpr bool msafe = !!(CONFIG & ASSUME_M_SAFE);
constexpr bool nsafe = !!(CONFIG & ASSUME_N_SAFE);
constexpr bool ksafe = !!(CONFIG & ASSUME_K_SAFE);
const int th = threadIdx.x;
const int ii = blockIdx.x * BM;
const int jj = blockIdx.y * BN;
const int ti = th / (BN / TN) * TM;
const int tj = th % (BN / TN) * TN;
__shared__ SRC As[BK * BM];
__shared__ SRC Bs[BK * BN];
WORD At[TM];
WORD Bt[TN];
WORD Ct[TM * TN] = {0};
if (CONFIG & ASSUME_A_OP_T)
transa = TINYBLAS_OP_T;
if (CONFIG & ASSUME_A_OP_N)
transa = TINYBLAS_OP_N;
if (CONFIG & ASSUME_B_OP_N)
transb = TINYBLAS_OP_N;
if (CONFIG & ASSUME_B_OP_T)
transb = TINYBLAS_OP_T;
for (int ll = 0; ll < k; ll += BK) {
if (!ksafe || !msafe)
for (int i = 0; i < BM; ++i)
As[BM * th + i] = 0;
for (int i = 0; i < BM && (ll + th < k || ksafe) && (ii + i < m || msafe); ++i)
As[BM * th + i] = A[transa ? lda * (ii + i) + (ll + th) : lda * (ll + th) + (ii + i)];
if (!ksafe || !nsafe)
for (int j = 0; j < BN; ++j)
Bs[BN * th + j] = 0;
for (int j = 0; j < BN && (ll + th < k || ksafe) && (jj + j < n || nsafe); ++j)
Bs[BN * th + j] = B[transb ? ldb * (ll + th) + (jj + j) : ldb * (jj + j) + (ll + th)];
__syncthreads();
for (int l = 0; l < BK; ++l) {
for (int j = 0; j < TM; ++j)
At[j] = As[BM * l + ti + j];
for (int h = 0; h < TN; ++h)
Bt[h] = Bs[BN * l + tj + h];
for (int j = 0; j < TM; ++j)
for (int h = 0; h < TN; ++h)
Ct[TN * j + h] += At[j] * Bt[h];
}
__syncthreads();
}
for (int j = 0; j < TN && (jj + tj + j < n || nsafe); ++j)
for (int i = 0; i < TM && (ii + ti + i < m || msafe); ++i) {
WORD r, d = Ct[TN * i + j];
if ((CONFIG & IGNORE_BETA) || !beta) {
if (CONFIG & IGNORE_ALPHA)
r = d;
else
r = alpha * d;
} else {
WORD c = C[ldc * (jj + tj + j) + (ii + ti + i)];
if (CONFIG & IGNORE_ALPHA)
r = beta * c + d;
else
r = alpha * d + beta * c;
}
C[ldc * (jj + tj + j) + (ii + ti + i)] = r;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// tinyBLAS warp block tiling outer product GEMM kernel
template <int CONFIG, int BM, int BN, int BK, int VE, int WM, int WN, int WNI, int TM, int TN,
int TT, typename WORD, typename SRC, typename DST>
static __device__ void matmul_warp2d(tinyblasOperation_t aT, //
tinyblasOperation_t bT, //
int m, int n, int k, WORD alpha, //
const SRC *A, int lda, //
const SRC *B, int ldb, WORD beta, //
DST *C, int ldc) {
const SRC zero = 0;
const int warpIdx = threadIdx.x / WARPSIZE;
const int warpCol = warpIdx % (BN / WN);
const int warpRow = warpIdx / (BN / WN);
constexpr int WARPS = TT / WARPSIZE;
constexpr int WMI = (WM * WN) / (WARPSIZE * TM * TN * WNI);
constexpr int WSUBM = WM / WMI;
constexpr int WSUBN = WN / WNI;
constexpr bool msafe = !!(CONFIG & ASSUME_M_SAFE);
constexpr bool nsafe = !!(CONFIG & ASSUME_N_SAFE);
constexpr bool ksafe = !!(CONFIG & ASSUME_K_SAFE);
const int threadIdxInWarp = threadIdx.x % WARPSIZE;
const int threadColInWarp = threadIdxInWarp % (WSUBN / TN);
const int threadRowInWarp = threadIdxInWarp / (WSUBN / TN);
// want to tune these magic numbers?
// use llamafile/pick_a_warp_kernel.c
static_assert(!(BN % WN) && !(BM % WM), "");
static_assert(!(WM % WMI) && !(WN % WNI), "");
static_assert((BN / WN) * (BM / WM) == WARPS, "");
static_assert((WM * WN) % (WARPSIZE * TM * TN * WNI) == 0, "");
static_assert((BM * BK) % (VE * TT) == 0, "");
static_assert((BN * BK) % (VE * TT) == 0, "");
static_assert(BK % VE == 0, "");
static_assert(BN % VE == 0, "");
__shared__ SRC As[BK * BM];
__shared__ SRC Bs[BK * BN];
WORD Ar[WMI * TM] = {0};
WORD Br[WNI * TN] = {0};
WORD Ct[WMI * TM * WNI * TN] = {0};
if (CONFIG & ASSUME_A_OP_T)
aT = TINYBLAS_OP_T;
if (CONFIG & ASSUME_A_OP_N)
aT = TINYBLAS_OP_N;
if (CONFIG & ASSUME_B_OP_N)
bT = TINYBLAS_OP_N;
if (CONFIG & ASSUME_B_OP_T)
bT = TINYBLAS_OP_T;
for (int ll = 0; ll < k; ll += BK) {
for (int h = 0; h < BM; h += (TT * VE) / BK)
for (int v = 0; v < VE; ++v) {
int l = ll + threadIdx.x % (BK / VE) * VE + v;
int i = blockIdx.y * BM + threadIdx.x / (BK / VE) + h;
As[BM * (threadIdx.x % (BK / VE) * VE + v) + (threadIdx.x / (BK / VE) + h)] =
(((i < m || msafe) && //
(l < k || ksafe))
? A[aT ? lda * l + i : lda * i + l]
: zero);
}
for (int h = 0; h < BK; h += TT / (BN / VE))
for (int v = 0; v < VE; ++v) {
int l = ll + threadIdx.x / (BN / VE) + h;
int j = blockIdx.x * BN + threadIdx.x % (BN / VE) * VE + v;
Bs[BN * (threadIdx.x / (BN / VE) + h) + (threadIdx.x % (BN / VE) * VE + v)] =
(((j < n || nsafe) && //
(l < k || ksafe))
? B[bT ? ldb * j + l : ldb * l + j]
: zero);
}
__syncthreads();
for (int l = 0; l < BK; ++l) {
for (int ii = 0; ii < WMI; ++ii)
for (int i = 0; i < TM; ++i)
Ar[TM * ii + i] =
As[BM * l + WM * warpRow + WSUBM * ii + TM * threadRowInWarp + i];
for (int jj = 0; jj < WNI; ++jj)
for (int j = 0; j < TN; ++j)
Br[TN * jj + j] =
Bs[BN * l + WN * warpCol + WSUBN * jj + TN * threadColInWarp + j];
for (int ii = 0; ii < WMI; ++ii)
for (int jj = 0; jj < WNI; ++jj)
for (int i = 0; i < TM; ++i)
for (int j = 0; j < TN; ++j)
Ct[(WNI * TN) * (TM * ii + i) + (TN * jj) + j] +=
Ar[TM * ii + i] * Br[TN * jj + j];
}
__syncthreads();
}
for (int ii = 0; ii < WMI; ++ii)
for (int jj = 0; jj < WNI; ++jj)
for (int i = 0; i < TM; i += 1)
for (int j = 0; j < TN; j += 1) {
int row = (BM * blockIdx.y + WM * warpRow) + (WSUBM * ii) +
(threadRowInWarp * TM + i);
int col = (BN * blockIdx.x + WN * warpCol) + (WSUBN * jj) +
(threadColInWarp * TN + j);
if ((row < m || msafe) && (col < n || nsafe)) {
WORD r, d = Ct[(WNI * TN) * (TM * ii + i) + (TN * jj + j)];
if ((CONFIG & IGNORE_BETA) || !beta) {
if (CONFIG & IGNORE_ALPHA)
r = d;
else
r = alpha * d;
} else {
WORD c = C[ldc * row + col];
if (CONFIG & IGNORE_ALPHA)
r = beta * c + d;
else
r = alpha * d + beta * c;
}
C[ldc * row + col] = r;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// tinyBLAS canonical cuBLAS-like interface
/**
* Creates new tinyBLAS handle.
*
* Before calling tinyBLAS GEMM functions a handle must first be
* created, using this function. It should be freed later, using
* tinyblasDestroy(). After a handle is created the caller needs
* tinyblasSetStream() to specify the CUDA stream.
*
* @param out_handle receives pointer to newly created handle
* @return TINYBLAS_STATUS_SUCCESS on success otherwise error
*/
tinyblasStatus_t tinyblasCreate(tinyblasHandle_t *out_handle) {
tinyblasHandle_t handle;
if ((handle = (tinyblasHandle_t)malloc(sizeof(struct tinyblasContext)))) {
*out_handle = handle;
return TINYBLAS_STATUS_SUCCESS;
} else {
return TINYBLAS_STATUS_ALLOC_FAILED;
}
}
/**
* Destroys tinyBLAS handle.
*
* @param handle is pointer to handle created by tinyblasCreate()
* @return TINYBLAS_STATUS_SUCCESS on success otherwise error
*/
tinyblasStatus_t tinyblasDestroy(tinyblasHandle_t handle) {
free(handle);
return TINYBLAS_STATUS_SUCCESS;
}
/**
* Associates CUDA handle with tinyBLAS handle.
*
* The provided stream will be used when tinyBLAS launches kernels.
*
* @param handle is pointer to handle created by tinyblasCreate()
* @param stream is pointer to stream created by cudaStreamCreate()
* @return TINYBLAS_STATUS_SUCCESS on success otherwise error
*/
tinyblasStatus_t tinyblasSetStream(tinyblasHandle_t handle, void *stream) {
handle->stream = (cudaStream_t)stream;
return TINYBLAS_STATUS_SUCCESS;
}
/**
* Gets CUDA stream associated with tinyBLAS handle.
*
* @param handle is pointer to handle created by tinyblasCreate()
* @param out_stream receives pointer to any cudaStream_t object
* @return TINYBLAS_STATUS_SUCCESS on success otherwise error
*/
tinyblasStatus_t tinyblasGetStream(tinyblasHandle_t handle, void **out_stream) {
*out_stream = handle->stream;
return TINYBLAS_STATUS_SUCCESS;
}
/**
* Returns string describing tinyBLAS status code.
*/
const char *tinyblasGetStatusString(tinyblasStatus_t err) {
switch (err) {
case TINYBLAS_STATUS_SUCCESS:
return "Success";
case TINYBLAS_STATUS_ALLOC_FAILED:
return "Alloc failed";
case TINYBLAS_STATUS_INVALID_VALUE:
return "Invalid value";
case TINYBLAS_STATUS_NOT_SUPPORTED:
return "Not supported";
case TINYBLAS_STATUS_EXECUTION_FAILED:
return "Execution failed";
case TINYBLAS_STATUS_DIMENSION_OVERLAP:
return "Dimension overlap";
case TINYBLAS_STATUS_DIMENSION_OVERFLOW:
return "Dimension overflow";
default:
return "Unknown error";
}
}
/**
* Performs single-precision general matrix multiplication.
*
* This is a column major GEMM subroutine for computing C = α*A*B + β*C.
*
* @param handle was created by tinyblasCreate()
* @param transa if `A` should be transposed
* @param transb if `B` should be transposed
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param alpha points to scalar that's multiplied against input
* @param A is input array of first matrix
* @param lda is row stride of `A`
* @param B is input array of second matrix
* @param ldb is row stride of `B`
* @param beta points to scalar that's multiplied against the existing
* output, but this multiplication only happens if beta is nonzero
* @param C is input/output array of output matrix
* @param ldc is row stride of `C`
*/
tinyblasStatus_t tinyblasSgemm(tinyblasHandle_t handle, tinyblasOperation_t transa,
tinyblasOperation_t transb, int m, int n, int k, const float *alpha,
const float *A, int lda, const float *B, int ldb, const float *beta,
float *C, int ldc) {
return tinyblasGemmEx(handle, transa, transb, m, n, k, alpha, A, TINYBLAS_R_32F, lda, B,
TINYBLAS_R_32F, ldb, beta, C, TINYBLAS_R_32F, ldc, TINYBLAS_COMPUTE_32F,
TINYBLAS_GEMM_DEFAULT);
}
template <int CONFIG, int BM, int BN, int BK, int VE, int WM, int WN, int WNI, int TM, int TN,
int TT, typename WORD, typename SRC, typename DST>
static __global__ void __launch_bounds__(TT) tinyblasGE_entry(tinyblasOperation_t aT, //
tinyblasOperation_t bT, //
int m, int n, int k, WORD alpha, //
const SRC *A, int lda, //
const SRC *B, int ldb, //
WORD beta, DST *C, int ldc) {
matmul_warp2d<CONFIG, BM, BN, BK, VE, WM, WN, WNI, TM, TN, TT>(aT, bT, m, n, k, alpha, A, lda,
B, ldb, beta, C, ldc);
}
template <int BM, int BN, int BK, int VE, int WM, int WN, int WNI, int TM, int TN, int TT,
typename WORD, typename SRC, typename DST>
static tinyblasStatus_t tinyblasGE_launcher(tinyblasHandle_t handle, tinyblasOperation_t aT,
tinyblasOperation_t bT, int m, int n, int k, WORD alpha,
const SRC *A, int lda, const SRC *B, int ldb, WORD beta,
DST *C, int ldc) {
dim3 blocks(CEIL_DIV(n, BN), CEIL_DIV(m, BM));
if ((!beta && //
isone(alpha) && //
n % BN == 0 && //
k % BK == 0 && //
aT == TINYBLAS_OP_N && //
bT == TINYBLAS_OP_T)) {
constexpr int CONFIG = IGNORE_BETA | IGNORE_ALPHA | ASSUME_A_OP_N | ASSUME_B_OP_T |
ASSUME_N_SAFE | ASSUME_K_SAFE;
tinyblasGE_entry<CONFIG, BM, BN, BK, VE, WM, WN, WNI, TM, TN, TT>
<<<blocks, TT, 0, handle->stream>>>(aT, bT, m, n, k, alpha, A, lda, B, ldb, beta, C,
ldc);
} else {
tinyblasGE_entry<0, BM, BN, BK, VE, WM, WN, WNI, TM, TN, TT>
<<<blocks, TT, 0, handle->stream>>>(aT, bT, m, n, k, alpha, A, lda, B, ldb, beta, C,
ldc);
}
if (cudaGetLastError() != cudaSuccess)
return TINYBLAS_STATUS_EXECUTION_FAILED;
return TINYBLAS_STATUS_SUCCESS;
}
template <typename WORD, typename SRC, typename DST>
tinyblasStatus_t tinyblasGE_launch(tinyblasHandle_t handle, tinyblasOperation_t aT,
tinyblasOperation_t bT, int m, int n, int k, WORD alpha,
const SRC *A, int lda, const SRC *B, int ldb, WORD beta, DST *C,
int ldc) {
if (can_use_matvec(aT, bT, m, n, k, alpha, beta))
return matvec_launch<WORD>(handle, m, k, A, lda, B, C);
constexpr int TT = 256;
constexpr int BM = 128;
constexpr int BN = 64;
constexpr int BK = 64;
constexpr int VE = 16;
constexpr int WM = 32;
constexpr int WN = 32;
constexpr int WNI = 1;
constexpr int TM = 8;
constexpr int TN = 4;
return tinyblasGE_launcher<BM, BN, BK, VE, WM, WN, WNI, TM, TN, TT>(
handle, bT, aT, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc);
}
/**
* Performs extended general matrix multiplication.
*
* This is a column major GEMM subroutine for computing C = α*A*B + β*C.
*
* @param handle was created by tinyblasCreate()
* @param transa if `A` should be transposed
* @param transb if `B` should be transposed
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param alpha points to scalar that's multiplied against input
* @param A is input array of first matrix
* @param Atype is data type of `C`
* @param lda is row stride of `A`
* @param B is input array of second matrix
* @param Btype is data type of `C`
* @param ldb is row stride of `B`
* @param beta points to scalar that's multiplied against the existing
* output, but this multiplication only happens if beta is nonzero
* @param C is input/output array of output matrix
* @param Ctype is data type of `C`
* @param ldc is row stride of `C`
* @param computeType is data type of `alpha`, `beta`, and dot product
* @param algo specifies algorithm to use
*/
tinyblasStatus_t tinyblasGemmEx(tinyblasHandle_t handle, //
tinyblasOperation_t transa, //
tinyblasOperation_t transb, //
int m, int n, int k, //
const void *alpha, //
const void *A, tinyblasDataType_t Atype, int lda, //
const void *B, tinyblasDataType_t Btype, int ldb, //
const void *beta, //
void *C, tinyblasDataType_t Ctype, int ldc, //
tinyblasComputeType_t computeType, //
tinyblasGemmAlgo_t algo) {
if (m < 0 || n < 0 || k < 0)
return TINYBLAS_STATUS_INVALID_VALUE;
if (lda < std::max(1, transa ? k : m))
return TINYBLAS_STATUS_INVALID_VALUE;
if (ldb < std::max(1, transb ? n : k))
return TINYBLAS_STATUS_INVALID_VALUE;
if (ldc < std::max(1, m))
return TINYBLAS_STATUS_INVALID_VALUE;
if (1ll * lda * ((transa ? k : m) - 1) + ((transa ? m : k) - 1) > INT_MAX)
return TINYBLAS_STATUS_DIMENSION_OVERFLOW;
if (1ll * ldb * ((transb ? n : k) - 1) + ((transb ? k : n) - 1) > INT_MAX)
return TINYBLAS_STATUS_DIMENSION_OVERFLOW;
if (1ll * ldc * (n - 1) + (m - 1) > INT_MAX)
return TINYBLAS_STATUS_DIMENSION_OVERFLOW;
if (algo != TINYBLAS_GEMM_DEFAULT)
return TINYBLAS_STATUS_INVALID_VALUE;
if (Atype != Btype)
return TINYBLAS_STATUS_NOT_SUPPORTED;
switch (Atype) {
case TINYBLAS_R_16F:
switch (Ctype) {
case TINYBLAS_R_16F:
switch (computeType) {
case TINYBLAS_COMPUTE_16F:
return tinyblasGE_launch(
handle, transa, transb, m, n, k, (float)*(const half *)alpha, (const half *)A,
lda, (const half *)B, ldb, (float)*(const half *)beta, (half *)C, ldc);
case TINYBLAS_COMPUTE_32F:
return tinyblasGE_launch(handle, transa, transb, m, n, k, *(const float *)alpha,
(const half *)A, lda, (const half *)B, ldb,
*(const float *)beta, (half *)C, ldc);
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
case TINYBLAS_R_32F:
switch (computeType) {
case TINYBLAS_COMPUTE_16F:
return TINYBLAS_STATUS_NOT_SUPPORTED;
case TINYBLAS_COMPUTE_32F:
return tinyblasGE_launch(handle, transa, transb, m, n, k, *(const float *)alpha,
(const half *)A, lda, (const half *)B, ldb,
*(const float *)beta, (float *)C, ldc);
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
case TINYBLAS_R_32F:
switch (Ctype) {
case TINYBLAS_R_16F:
return TINYBLAS_STATUS_NOT_SUPPORTED;
case TINYBLAS_R_32F:
switch (computeType) {
case TINYBLAS_COMPUTE_16F:
return TINYBLAS_STATUS_NOT_SUPPORTED;
case TINYBLAS_COMPUTE_32F:
return tinyblasGE_launch(handle, transa, transb, m, n, k, *(const float *)alpha,
(const float *)A, lda, (const float *)B, ldb,
*(const float *)beta, (float *)C, ldc);
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
}
template <typename WORD, typename SRC, typename DST>
static __global__ __launch_bounds__(WARPSIZE) void matvecGBE_entry(int m, int k, //
const SRC *const A[], int lda,
const SRC *const B[],
DST *const C[]) {
matvec<WORD>(m, k, A[blockIdx.z], lda, B[blockIdx.z], C[blockIdx.z]);
}
template <int BM, int BN, int TM, int TN, typename WORD, typename SRC, typename DST>
static __global__ void KERNEL tinyblasGBE_entry(tinyblasOperation_t transa,
tinyblasOperation_t transb, int m, int n, int k,
WORD alpha, const SRC *const Aarray[], int lda,
const SRC *const Barray[], int ldb, WORD beta,
DST *const Carray[], int ldc, int batchCount) {
matmul_block2d<0, BM, BN, TM, TN>(transa, transb, m, n, k, alpha, Aarray[blockIdx.z], lda,
Barray[blockIdx.z], ldb, beta, Carray[blockIdx.z], ldc);
}
template <typename WORD, typename SRC, typename DST>
static tinyblasStatus_t tinyblasGBE_launch(tinyblasHandle_t handle, tinyblasOperation_t transa,
tinyblasOperation_t transb, int m, int n, int k,
WORD alpha, const SRC *const *Aarray, int lda,
const SRC *const *Barray, int ldb, WORD beta,
DST *const *Carray, int ldc, int batchCount) {
if (can_use_matvec(transa, transb, m, n, k, alpha, beta)) {
dim3 blocks(WARPSIZE, m / WARPSIZE, batchCount);
matvecGBE_entry<WORD>
<<<blocks, WARPSIZE, 0, handle->stream>>>(m, k, Aarray, lda, Barray, Carray);
} else {
constexpr int BM = 16;
constexpr int BN = 16;
constexpr int TM = 4;
constexpr int TN = 4;
dim3 blocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), batchCount);
tinyblasGBE_entry<BM, BN, TM, TN><<<blocks, THREAD_COUNT, 0, handle->stream>>>(
transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc,
batchCount);
}
if (cudaGetLastError() != cudaSuccess)
return TINYBLAS_STATUS_EXECUTION_FAILED;
return TINYBLAS_STATUS_SUCCESS;
}
/**
* Multiplies matrices.
*
* This is a column major GEMM subroutine for computing C = α*A*B + β*C.
*
* @param handle was created by tinyblasCreate()
* @param transa if `A` should be transposed
* @param transb if `B` should be transposed
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param alpha points to scalar that's multiplied against input
* @param A is input array of device memory pointing to first matrices
* @param Atype is data type of `C`
* @param lda is row stride of `A`
* @param B is input array of device memory pointing to second matrices
* @param Btype is data type of `C`
* @param ldb is row stride of `B`
* @param beta points to scalar that's multiplied against the existing
* output, but this multiplication only happens if beta is nonzero
* @param C is input/output array of output matrices
* @param Ctype is data type of `C`
* @param ldc is row stride of `C`
* @param batchCount is number of elements in `A`, `B`, and `C`
* @param computeType is data type of `alpha`, `beta`, and dot product
* @param algo specifies algorithm to use
*/
tinyblasStatus_t tinyblasGemmBatchedEx(tinyblasHandle_t handle, tinyblasOperation_t transa,
tinyblasOperation_t transb, int m, int n, int k,
const void *alpha, const void *const Aarray[],
tinyblasDataType_t Atype, int lda,
const void *const Barray[], tinyblasDataType_t Btype,
int ldb, const void *beta, void *const Carray[],
tinyblasDataType_t Ctype, int ldc, int batchCount,
tinyblasComputeType_t computeType, tinyblasGemmAlgo_t algo) {
if (m < 0 || n < 0 || k < 0)
return TINYBLAS_STATUS_INVALID_VALUE;
if (lda < std::max(1, transa ? k : m))
return TINYBLAS_STATUS_INVALID_VALUE;
if (ldb < std::max(1, transb ? n : k))
return TINYBLAS_STATUS_INVALID_VALUE;
if (ldc < std::max(1, m))
return TINYBLAS_STATUS_INVALID_VALUE;
if (1ll * lda * ((transa ? k : m) - 1) + ((transa ? m : k) - 1) > INT_MAX)
return TINYBLAS_STATUS_DIMENSION_OVERFLOW;
if (1ll * ldb * ((transb ? n : k) - 1) + ((transb ? k : n) - 1) > INT_MAX)
return TINYBLAS_STATUS_DIMENSION_OVERFLOW;
if (1ll * ldc * (n - 1) + (m - 1) > INT_MAX)
return TINYBLAS_STATUS_DIMENSION_OVERFLOW;
if (algo != TINYBLAS_GEMM_DEFAULT)
return TINYBLAS_STATUS_INVALID_VALUE;
if (Atype != Btype)
return TINYBLAS_STATUS_NOT_SUPPORTED;
switch (Atype) {
case TINYBLAS_R_16F:
switch (Ctype) {
case TINYBLAS_R_16F:
switch (computeType) {
case TINYBLAS_COMPUTE_16F:
return tinyblasGBE_launch(
handle, transa, transb, m, n, k, (float)*(const half *)alpha,
(const half *const *)Aarray, lda, (const half *const *)Barray, ldb,
(float)*(const half *)beta, (half *const *)Carray, ldc, batchCount);
case TINYBLAS_COMPUTE_32F:
return tinyblasGBE_launch(handle, transa, transb, m, n, k, *(const float *)alpha,
(const half *const *)Aarray, lda,
(const half *const *)Barray, ldb, *(const float *)beta,
(half *const *)Carray, ldc, batchCount);
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
case TINYBLAS_R_32F:
switch (computeType) {
case TINYBLAS_COMPUTE_16F:
return TINYBLAS_STATUS_NOT_SUPPORTED;
case TINYBLAS_COMPUTE_32F:
return tinyblasGBE_launch(handle, transa, transb, m, n, k, *(const float *)alpha,
(const half *const *)Aarray, lda,
(const half *const *)Barray, ldb, *(const float *)beta,
(float *const *)Carray, ldc, batchCount);
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
case TINYBLAS_R_32F:
switch (Ctype) {
case TINYBLAS_R_16F:
return TINYBLAS_STATUS_NOT_SUPPORTED;
case TINYBLAS_R_32F:
switch (computeType) {
case TINYBLAS_COMPUTE_16F:
return TINYBLAS_STATUS_NOT_SUPPORTED;
case TINYBLAS_COMPUTE_32F:
return tinyblasGBE_launch(handle, transa, transb, m, n, k, *(const float *)alpha,
(const float *const *)Aarray, lda,
(const float *const *)Barray, ldb, *(const float *)beta,
(float *const *)Carray, ldc, batchCount);
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
}
template <typename WORD, typename SRC, typename DST>
static __global__ __launch_bounds__(WARPSIZE) void matvecGSBE_entry(int m, int k, const SRC *A,
int lda, long long strideA,
const SRC *B, long long strideB,
DST *C, long long strideC) {
matvec<WORD>(m, k, A + blockIdx.z * strideA, lda, B + blockIdx.z * strideB,
C + blockIdx.z * strideC);
}
template <int CONFIG, int BM, int BN, int TM, int TN, typename SRC, typename DST, typename WORD>
static __global__ void KERNEL tinyblasGSBE_entry(tinyblasOperation_t transa,
tinyblasOperation_t transb, int m, int n, int k,
WORD alpha, const SRC *A, int lda,
long long strideA, const SRC *B, int ldb,
long long strideB, WORD beta, DST *C, int ldc,
long long strideC, int batchCount) {
matmul_block2d<CONFIG, BM, BN, TM, TN>(transa, transb, m, n, k, alpha, A + strideA * blockIdx.z,
lda, B + strideB * blockIdx.z, ldb, beta,
C + strideC * blockIdx.z, ldc);
}
template <typename WORD, typename SRC, typename DST>
static tinyblasStatus_t tinyblasGSBE_launch(tinyblasHandle_t handle, tinyblasOperation_t transa,
tinyblasOperation_t transb, int m, int n, int k,
WORD alpha, const SRC *A, int lda, long long strideA,
const SRC *B, int ldb, long long strideB, WORD beta,
DST *C, int ldc, long long strideC, int batchCount) {
if (can_use_matvec(transa, transb, m, n, k, alpha, beta)) {
dim3 blocks(WARPSIZE, m / WARPSIZE, batchCount);
matvecGSBE_entry<WORD><<<blocks, WARPSIZE, 0, handle->stream>>>(m, k, A, lda, strideA, B,
strideB, C, strideC);
} else {
constexpr int BM = 16;
constexpr int BN = 16;
constexpr int TM = 4;
constexpr int TN = 4;
constexpr int BK = THREAD_COUNT;
dim3 blocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), batchCount);
if ((!beta && //
isone(alpha) && //
m % BM == 0 && //
k % BK == 0 && //
transa == TINYBLAS_OP_T && //
transb == TINYBLAS_OP_N)) {
constexpr int CONFIG = IGNORE_BETA | IGNORE_ALPHA | ASSUME_A_OP_T | ASSUME_B_OP_N |
ASSUME_M_SAFE | ASSUME_K_SAFE;
tinyblasGSBE_entry<CONFIG, BM, BN, TM, TN><<<blocks, THREAD_COUNT, 0, handle->stream>>>(
transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc,
strideC, batchCount);
} else {
tinyblasGSBE_entry<0, BM, BN, TM, TN><<<blocks, THREAD_COUNT, 0, handle->stream>>>(
transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc,
strideC, batchCount);
}
}
if (cudaGetLastError() != cudaSuccess)
return TINYBLAS_STATUS_EXECUTION_FAILED;
return TINYBLAS_STATUS_SUCCESS;
}
/**
* Multiplies matrices.
*
* This is a column major GEMM subroutine for computing C = α*A*B + β*C.
*
* @param handle was created by tinyblasCreate()
* @param transa if `A` should be transposed
* @param transb if `B` should be transposed
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param alpha points to scalar that's multiplied against input
* @param A is input array of first matrices
* @param Atype is data type of `A`
* @param lda is row stride of `A`
* @param strideA is distance between matrices in `A`
* @param B is input array of second matrices
* @param Btype is data type of `B`
* @param ldb is row stride of `B`
* @param strideB is distance between matrices in `B`
* @param beta points to scalar that's multiplied against the existing
* output, but this multiplication only happens if beta is nonzero
* @param C is input/output array of output matrices
* @param Ctype is data type of `C`
* @param ldc is row stride of `C`
* @param strideC is distance between matrices in `C`, which must not overlap
* @param batchCount is number of matrices to multiply
* @param computeType is data type of `alpha`, `beta`, and dot product
* @param algo specifies algorithm to use
*/
tinyblasStatus_t tinyblasGemmStridedBatchedEx(tinyblasHandle_t handle, //
tinyblasOperation_t transa, //
tinyblasOperation_t transb, //
int m, int n, int k, //
const void *alpha, //
const void *A, tinyblasDataType_t Atype, int lda,
long long strideA, //
const void *B, tinyblasDataType_t Btype, int ldb,
long long strideB, //
const void *beta, //
void *C, tinyblasDataType_t Ctype, int ldc,
long long strideC, //
int batchCount, //
tinyblasComputeType_t computeType, //
tinyblasGemmAlgo_t algo) {
if (m < 0 || n < 0 || k < 0)
return TINYBLAS_STATUS_INVALID_VALUE;
if (lda < std::max(1, transa ? k : m))
return TINYBLAS_STATUS_INVALID_VALUE;
if (ldb < std::max(1, transb ? n : k))
return TINYBLAS_STATUS_INVALID_VALUE;
if (ldc < std::max(1, m))
return TINYBLAS_STATUS_INVALID_VALUE;
if (std::max(0ll, strideC) < std::min(1ll * ldc * n, strideC * 2))
return TINYBLAS_STATUS_DIMENSION_OVERLAP;
if (1ll * lda * ((transa ? k : m) - 1) + ((transa ? m : k) - 1) > INT_MAX)
return TINYBLAS_STATUS_DIMENSION_OVERFLOW;
if (1ll * ldb * ((transb ? n : k) - 1) + ((transb ? k : n) - 1) > INT_MAX)
return TINYBLAS_STATUS_DIMENSION_OVERFLOW;
if (1ll * ldc * (n - 1) + (m - 1) > INT_MAX)
return TINYBLAS_STATUS_DIMENSION_OVERFLOW;
if (algo != TINYBLAS_GEMM_DEFAULT)
return TINYBLAS_STATUS_INVALID_VALUE;
if (Atype != Btype)
return TINYBLAS_STATUS_NOT_SUPPORTED;
switch (Atype) {
case TINYBLAS_R_16F:
switch (Ctype) {
case TINYBLAS_R_16F:
switch (computeType) {
case TINYBLAS_COMPUTE_16F:
return tinyblasGSBE_launch(
handle, transa, transb, m, n, k, (float)*(const half *)alpha, (const half *)A,
lda, strideA, (const half *)B, ldb, strideB, (float)*(const half *)beta,
(half *)C, ldc, strideC, batchCount);
case TINYBLAS_COMPUTE_32F:
return tinyblasGSBE_launch(handle, transa, transb, m, n, k, *(const float *)alpha,
(const half *)A, lda, strideA, (const half *)B, ldb,
strideB, *(const float *)beta, (half *)C, ldc, strideC,
batchCount);
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
case TINYBLAS_R_32F:
switch (computeType) {
case TINYBLAS_COMPUTE_16F:
return TINYBLAS_STATUS_NOT_SUPPORTED;
case TINYBLAS_COMPUTE_32F:
return tinyblasGSBE_launch(handle, transa, transb, m, n, k, *(const float *)alpha,
(const half *)A, lda, strideA, (const half *)B, ldb,
strideB, *(const float *)beta, (float *)C, ldc, strideC,
batchCount);
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
case TINYBLAS_R_32F:
switch (Ctype) {
case TINYBLAS_R_16F:
return TINYBLAS_STATUS_NOT_SUPPORTED;
case TINYBLAS_R_32F:
switch (computeType) {
case TINYBLAS_COMPUTE_16F:
return TINYBLAS_STATUS_NOT_SUPPORTED;
case TINYBLAS_COMPUTE_32F:
return tinyblasGSBE_launch(handle, transa, transb, m, n, k, *(const float *)alpha,
(const float *)A, lda, strideA, (const float *)B, ldb,
strideB, *(const float *)beta, (float *)C, ldc, strideC,
batchCount);
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
default:
return TINYBLAS_STATUS_INVALID_VALUE;
}
}