Skip to content

Commit

Permalink
Add varlen MHA fp16 slen=384 kernels for sm_86
Browse files Browse the repository at this point in the history
1. add varlen mha fp16 slen=384 kernel for sm_86
2. referesh all sm_86 kernels now use NVCC -gencode=arch=compute_86,code=\"sm_86\"
3. use unfused kernel for fixed len s=384 fp16

Signed-off-by: Rajeev Rao <[email protected]>
  • Loading branch information
rajeevsrao committed Apr 12, 2021
1 parent 0fa021a commit 3ea9099
Show file tree
Hide file tree
Showing 19 changed files with 53,684 additions and 88 deletions.
5 changes: 5 additions & 0 deletions plugin/bertQKVToContextPlugin/fused_multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ extern unsigned char fused_multihead_attention_int8_384_64_kernel_sm80_cu_o[];
extern unsigned char fused_multihead_attention_int8_128_64_kernel_sm80_cu_o[];
extern unsigned char fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o[];
extern unsigned char fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o[];
extern unsigned char fused_multihead_attention_fp16_384_64_kernel_sm86_cu_o[];

extern unsigned int fused_multihead_attention_fp16_64_64_kernel_sm75_cu_o_len;
extern unsigned int fused_multihead_attention_fp16_96_64_kernel_sm75_cu_o_len;
Expand All @@ -111,6 +112,7 @@ extern unsigned int fused_multihead_attention_int8_384_64_kernel_sm80_cu_o_len;
extern unsigned int fused_multihead_attention_int8_128_64_kernel_sm80_cu_o_len;
extern unsigned int fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o_len;
extern unsigned int fused_multihead_attention_fp16_384_64_kernel_sm80_cu_o_len;
extern unsigned int fused_multihead_attention_fp16_384_64_kernel_sm86_cu_o_len;

static const struct FusedMultiHeadAttentionKernelMetaInfoV1
{
Expand Down Expand Up @@ -175,6 +177,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV1
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o,
fused_multihead_attention_fp16_128_64_kernel_sm80_cu_o_len, "fused_multihead_attention_fp16_128_64_kernel_sm80",
49152, 128},
{DATA_TYPE_FP16, 384, 64, kSM_86, fused_multihead_attention_fp16_384_64_kernel_sm86_cu_o,
fused_multihead_attention_fp16_384_64_kernel_sm86_cu_o_len, "fused_multihead_attention_fp16_384_64_kernel_sm80",
65536, 256},
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_int8_128_64_kernel_sm80_cu_o,
fused_multihead_attention_int8_128_64_kernel_sm80_cu_o_len, "fused_multihead_attention_int8_128_64_kernel_sm80",
24576, 128},
Expand Down

Large diffs are not rendered by default.

114 changes: 69 additions & 45 deletions plugin/bertQKVToContextPlugin/fused_multihead_attention_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,49 +111,67 @@ struct Fused_multihead_attention_params_v2
////////////////////////////////////////////////////////////////////////////////////////////////////
extern unsigned char fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_64_64_kernel_sm86_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_fp16_96_64_kernel_sm86_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_128_64_kernel_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_192_64_kernel_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_256_64_kernel_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin[];

extern unsigned int fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_256_64_kernel_sm75_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_384_64_kernel_sm75_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_384_64_kernel_sm80_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_64_64_kernel_sm75_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_64_64_kernel_sm86_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_96_64_kernel_sm75_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin_len;
extern unsigned int fused_multihead_attention_v2_fp16_96_64_kernel_sm86_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_128_64_kernel_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_128_64_kernel_sm75_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_192_64_kernel_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_192_64_kernel_sm75_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_256_64_kernel_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_256_64_kernel_sm75_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_384_64_kernel_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len;
extern unsigned int fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len;

static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{
Expand Down Expand Up @@ -348,72 +366,78 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2

// GA10x
// Note: For GA10X keep only kernels whose sharedMemBytes < 100KiB
{DATA_TYPE_FP16, 64, 64, kSM_86, fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin,
fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin_len,
{DATA_TYPE_FP16, 64, 64, kSM_86, fused_multihead_attention_v2_fp16_64_64_kernel_sm86_cubin,
fused_multihead_attention_v2_fp16_64_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_fp16_64_64_kernel_sm80", 32768, 128, 0, false},
{DATA_TYPE_FP16, 96, 64, kSM_86, fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin,
fused_multihead_attention_v2_fp16_96_64_kernel_sm80_cubin_len,
{DATA_TYPE_FP16, 96, 64, kSM_86, fused_multihead_attention_v2_fp16_96_64_kernel_sm86_cubin,
fused_multihead_attention_v2_fp16_96_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_fp16_96_64_kernel_sm80", 49152, 128, 0, false},
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin,
fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len,
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin,
fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_fp16_128_64_kernel_sm80_noloop", 40960, 128, 32, false},
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin,
fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len,
{DATA_TYPE_FP16, 128, 64, kSM_86, fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin,
fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_fp16_128_64_kernel_sm80", 65536, 128, 0, false},
{DATA_TYPE_FP16, 256, 64, kSM_86, fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin,
fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len,
{DATA_TYPE_FP16, 256, 64, kSM_86, fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin,
fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_fp16_256_64_kernel_sm80_noloop", 73728, 128, 32, false},
{DATA_TYPE_FP16, 256, 64, kSM_86, fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin,
fused_multihead_attention_v2_fp16_256_64_kernel_sm80_cubin_len,
{DATA_TYPE_FP16, 256, 64, kSM_86, fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin,
fused_multihead_attention_v2_fp16_256_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_fp16_256_64_kernel_sm80", 73728, 128, 0, false},

{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len,
{DATA_TYPE_FP16, 384, 64, kSM_86, fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin,
fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_fp16_384_64_kernel_sm80_noloop", 65536, 256, 48, false},
{DATA_TYPE_FP16, 384, 64, kSM_86, fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin,
fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_fp16_384_64_kernel_sm80", 65536, 256, 0, false},

{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_128_64_kernel_sm80_interleaved_noloop", 20480, 128, 16, true},
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_128_64_kernel_sm80_noloop", 20480, 128, 16, false},
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_128_64_kernel_sm80_interleaved", 24576, 128, 0, true},
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_128_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_128_64_kernel_sm80", 32768, 128, 0, false},
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_192_64_kernel_sm80_interleaved_noloop", 28672, 128, 32, true},
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_192_64_kernel_sm80_noloop", 28672, 128, 32, false},
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_192_64_kernel_sm80_interleaved", 32768, 128, 0, true},
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_192_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 192, 64, kSM_86, fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_192_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_192_64_kernel_sm80", 32768, 128, 0, false},
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_256_64_kernel_sm80_interleaved_noloop", 36864, 128, 32, true},
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_256_64_kernel_sm80_noloop", 36864, 128, 32, false},
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_256_64_kernel_sm80_interleaved", 36864, 128, 0, true},
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_256_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 256, 64, kSM_86, fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_256_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_256_64_kernel_sm80", 36864, 128, 0, false},
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_384_64_kernel_sm80_interleaved_noloop", 53248, 128, 32, true},
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_384_64_kernel_sm80_noloop", 53248, 128, 32, false},
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_384_64_kernel_sm80_interleaved", 51200, 128, 0, true},
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin,
fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len,
{DATA_TYPE_INT8, 384, 64, kSM_86, fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_384_64_kernel_sm80", 53248, 128, 0, false},
#endif
};
Expand Down
Loading

0 comments on commit 3ea9099

Please sign in to comment.