-
Notifications
You must be signed in to change notification settings - Fork 547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add linear layer and ffn config to enable TransformerEngine layers (with FP8) #432
Conversation
92dbf7e
to
23ff7b9
Compare
I'm excited that this PR is being worked on, mainly because I've been extending the |
While |
LGTM |
This PR adds a config for Linear Layers and FFN modules which allows the use TransformerEngine's te.Linear and te.LayerNormMLP modules (which have fp8 with amp.fp8).
+ I did a little cleanup
This PR is built on top of #271
In the future, this'll also allow us to add and prototype other linear layers and ffn blocks. Furthermore it enables us to configure TP/SP for the MLP block, in the
build_ffn
util fn.AMP FP8 training gets results which are nearly identical to AMP BF16:

but has faster runtime.
Furthermore
ffn_config_defaults: ffn_type: te_ln_mlp
allows us to use transformer engine's LayerNormMLP layer which has SP and TP support if configured correctly.