Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How about the torch.compile in TransformerEngine ? #1241

Open
south-ocean opened this issue Oct 11, 2024 · 4 comments
Open

How about the torch.compile in TransformerEngine ? #1241

south-ocean opened this issue Oct 11, 2024 · 4 comments
Labels
question Further information is requested

Comments

@south-ocean
Copy link

In PyTorch, we know that Torch.Compile will bring us a lot of benefits, and the TransformerEngine also brings performance improvements through strategies such as Transformer fusion optimization, so does the Transformer Engine also support Torch.compile? Is there any documentation on whether it is possible to get better benefits by using torch.compile in TE mode compared to non-TE mode?
Do you have suggestions for me to use torch.compile in TransformerEngine?

In llama2, we found that torch.compile can make better profits on rmsnorm and swiglu, but in TE, it is not possible to directly add torch.compile to rmsnorm and swig;u, is there any good way?

@timmoon10 timmoon10 added the question Further information is requested label Oct 12, 2024
@timmoon10
Copy link
Collaborator

We have used torch.compile to fuse some operations like bias+GeLU in LayerNormMLP (see bias_gelu_fused_). However, we have not yet done serious work applying torch.compile to FP8 kernels since we're not sure how well they can accommodate the extra logic for FP8 scaling factors and absmax reductions. It's something we've kept in mind though, especially as a means to work around the CPU overheads of running PyTorch in eager mode.

For the moment we manually identify fusion opportunities and incorporate them into our modules, e.g. LayerNormLinear might call a LayerNorm kernel that outputs in FP8. For more flexibility, we are experimenting with a modular operation-based API that can automatically identify some of these fusion opportunites. I believe Lightning Thunder has also been working on automatic kernel fusion with TE.

@south-ocean
Copy link
Author

yeah, But now i am runing on bfloat16. Through add torch.compile to rmsnorm and swiglu for llama2-7b in legacy mode, I can get more benefit than te, leagacy mode through torch.compile can brings 10% performance improvements than te, So can i compile the benefit with te and torch.compile, I think it can be more fast.

@MaciejBalaNV
Copy link

@timmoon10
Please also consider the following: It's quite popular to use torch.compile on the entire model and fuse all inefficient operations on tensors. Moving all such operations into small compilable functions so that we don't have to compile the entire model, is often not practical and not easy to maintain. However, currently TE modules do not work with torch.compile and a graph break is introduced at every TE module usage. In most of the cases it nullifies all possible performance benefits from torch.compile. Registering TE modules and custom kernels with torch.compile so that they do not introduce graph breaks would be a huge improvement. FlashAttention repository did this recently: Dao-AILab/flash-attention#1139

@south-ocean
Copy link
Author

south-ocean commented Oct 23, 2024

@timmoon10 Yeah, I have now found that the performance of using TE does not exceed the benefits of non-TE+torch.compile in llama2-7b, except for FP8 support, the functions for FA, TE and non-TE calls are the same and the linear layer is also called by the blaslt. So for the rest of the parts, although te did some fusion, will the benefits of TE exceed the improvement brought by torch.compile? Do you have any suggestions? How we're taking performance even further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants