-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
padding the length of input for vit_attention #45506
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -310,6 +349,11 @@ int QkvToContextPluginDynamic::enqueue( | |||
// input[0], (B, S, 3 * N * H, 1, 1) | |||
int batch = input_dims.d[0]; | |||
int seq_len = input_dims.d[1]; | |||
int real_seq_len = seq_len; | |||
if (input_desc[0].type == nvinfer1::DataType::kHALF) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释下fp16需要pading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释下fp16需要pading
好的
__global__ void reset_qk_bias(T *input, int real_seq_len, int seq_len) { | ||
if (threadIdx.x < seq_len) { | ||
int id = threadIdx.x + blockIdx.x * seq_len; | ||
input[id] = threadIdx.x >= real_seq_len ? (T)-1e20f : (T)0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1e20f 注意低精度下的表示能力
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) { | ||
qk_bias = reinterpret_cast<float *>(workspace); | ||
auto size = batch * head_number_ * seq_len * seq_len; | ||
cudaMemset(qk_bias, 0, sizeof(float) * size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memsetasync下面几个调用都一样
@@ -373,6 +423,35 @@ int QkvToContextPluginDynamic::enqueue( | |||
} else if (input_type == nvinfer1::DataType::kHALF) { | |||
#ifdef TRT_PLUGIN_FP16_AVALIABLE | |||
VLOG(1) << "TRT Plugin DataType selected. QkvToContext-->fp16"; | |||
int *padding_offset = nullptr; | |||
half *padding_input = nullptr; | |||
framework::Tensor padding_offset_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成workspace或者成员变量避免显存分配
0, | ||
sizeof(half) * batch * seq_len * 3 * head_number_ * head_size_); | ||
|
||
set_padding_offset<<<1, 1, 0, stream>>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以再提升下并发
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以再提升下并发
好的
@@ -1105,6 +1113,9 @@ def generate_trt_nodes_num(): | |||
self.trt_param.precision = paddle_infer.PrecisionType.Half | |||
yield self.create_inference_config(), generate_trt_nodes_num(), (1e-3, | |||
1e-3) | |||
self.trt_param.precision = paddle_infer.PrecisionType.Float32 | |||
yield self.create_inference_config(), generate_trt_nodes_num(), (1e-3, | |||
1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp32的精度应该可以高点
@@ -342,6 +427,12 @@ int QkvToContextPluginDynamic::enqueue( | |||
head_number_); | |||
qk_bias = temp_qk_bias; | |||
} | |||
// fake qk_bias | |||
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config的时候就可以确定不用每次enque判断.下面几个也一样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config的时候就可以确定不用每次enque判断.下面几个也一样
好的,和下面memset的统一放到configure处理
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) { | ||
qk_bias = reinterpret_cast<float *>(workspace); | ||
auto size = batch * head_number_ * seq_len * seq_len; | ||
cudaMemset(qk_bias, 0, sizeof(float) * size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
async接口
PR types
Others
PR changes
Others
Describe
当attention的输入length不是8的整数倍时,fp16的性能很差,这里对multihead plugin的输入进行padding,对于vit_384模型,batch=1时,时间由13.5ms降低到10.5ms