-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support (Bias)SkipLayerNormalization fusion in GPT2 #13988
Conversation
…crosoft/onnxruntime into hari/skip_layer_normalization
@@ -537,7 +537,14 @@ def fuse_gpt2(self, layernorm, add_before_layernorm, input_name_to_nodes, output | |||
if two_gather is None: | |||
return False | |||
|
|||
# If the add_before_layernorm node is an Add node, then the add_output output is the first index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It is better to also update the comment at the beginning of this function. That comment only contains first case but not the second.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I donm't see any other comment at the beginnign of this function. Are you referring to the args description of is_embedding_sum_needed()
before this by any chance ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Streamlined some logic around optional_embedding_sum_output
and add_output
which both need the comment. It should be easier to understand the code now.
### Description Add support of ONNX conversion of GPT-2 for two stages: * Stage 1 is the initial stage that has empty past state. * Stage 2 has non-empty past state and sequence_length is 1. Add a parameter --stage to specify such stage. For stage 1, we will enable mask_index for Attention so that we can use fused attention in CUDA. Other changes: (1) use int32 inputs as default (otherwise, there is error in inference) (2) update gpt2_parity to include SkipLayerNormalization (see #13988) and EmbedLayerNormalization (3) get all environment variables that might impact GPT-2 latency in benchmark_gpt2 ### Motivation and Context To test fused attention for GPT-2 model for #13953.
### Description Add support of ONNX conversion of GPT-2 for two stages: * Stage 1 is the initial stage that has empty past state. * Stage 2 has non-empty past state and sequence_length is 1. Add a parameter --stage to specify such stage. For stage 1, we will enable mask_index for Attention so that we can use fused attention in CUDA. Other changes: (1) use int32 inputs as default (otherwise, there is error in inference) (2) update gpt2_parity to include SkipLayerNormalization (see microsoft#13988) and EmbedLayerNormalization (3) get all environment variables that might impact GPT-2 latency in benchmark_gpt2 ### Motivation and Context To test fused attention for GPT-2 model for microsoft#13953.
### Description 1. SkipLayerNormalization has a new output (#13988) and the symbolic shape inference script needs corresponding updates 2. The greedy sampling op (#13426) shouldn't re-use the logits buffer as its corresponding kernel doesn't seem to support it yet. ### Motivation and Context Fix some transformer issues
Description
The GPT2 model has a slightly different SkipLayerNormalization pattern from BERT in that the residual to the next layer is not the output of LayerNorm like in BERT but the added residual and the input fed to the LayerNorm node thus making the Add and LayerNorm to not be fused as there is a consumer of the intermediate output of Add. This also means that any Add after MatMuls feeding into SkipLayerNortmalization won't be fused as well (as the pre-requisite to that fusion is that SkipLayerNormalization be fused first). This change adds support for this variant of SkipLayerNormalization (SLN) by adjusting the schema of SLN to have an optional output which will be the output of the addition of the residual and the input.
TODO: Add kernel test
Motivation and Context
Improve fusion coverage for GPT2 and help improve its perf and any language model using it