-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RMSNorm precision different from HF implementation #1132
Comments
That is correct, both RMSNorm and LayerNorm in TE perform all internal computation in FP32 (and so e.g. TE LayerNorm is equivalent to x = x.to(torch.float32)
y = nn.LayerNorm(x)
y = y.to(torch.bfloat16) The reason for that is to preserve precision of the computation, especially since RMSNorm/LayerNorm weights are typically close to 1. >>> import torch
>>> a = torch.Tensor([0.003]).to(torch.bfloat16)
>>> a
tensor([0.0030], dtype=torch.bfloat16)
>>> a + 1
tensor([1.], dtype=torch.bfloat16) Based on this, I would argue that it is actually HF implementation that is wrong here. |
@ptrendx Thanks for your reply. I totally agree that we should use float32 to do all the calculations, in theory. However, we're not training from scratch. We're continuous training open source models like That's why I believe we should at least provide an option to align the RMSNorm with HF transformers? |
Yeah, I figured that's a probable reason for this ask. Could you open an issue in HF transformers repo as well then? It would be interesting to hear their opinion on the topic and also raise their awareness to, hopefully, align the implementations to the right precisions with new models going forward. I need to think how to expose that option. In the meantime - if you wanted to change TE implementation yourself to do the multiplication in the lower precision you would need to change |
Great, thank you @ptrendx ! I'll try to change the code myself. Besides, here's the issue on HF: huggingface/transformers#33133 |
We just stumbled upon this issue and compared the implementation of the RMSNorm between TransformerEngine and TensorRT-LLM. It looks like TensorRT-LLM does the weight multiplication in lower precision, consistent with the HF transformers implementation. This likely means that a model trained with TransformerEngine will produce (at least slightly) different outputs when inferenced with TensorRT-LLM. I agree with @ptrendx that performing the operation in higher precision sounds sensible but I think it would be useful to have the option to align implementations across Nvidia's stack. |
We noticed there's a tiny implementation difference that makes
transformer_engine.pytorch.module.rmsnorm
and alsoTELayerNormColumnParallelLinear
generate results from HF implementation.And the tiny difference is when the hidden_states are converted back to
![CleanShot 2024-08-23 at 22 40 51@2x](https://private-user-images.githubusercontent.com/552990/360965897-4b7bfbf6-c107-4f4b-9938-02d72565de90.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkzOTg2MDcsIm5iZiI6MTczOTM5ODMwNywicGF0aCI6Ii81NTI5OTAvMzYwOTY1ODk3LTRiN2JmYmY2LWMxMDctNGY0Yi05OTM4LTAyZDcyNTY1ZGU5MC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjEyJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxMlQyMjExNDdaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT04MGIyZTRlYjMzY2Q3YmYxODIwMjE5YTllMThkMGY2MGUyZWRmYTI1YTA2NmUwOWVkMDJiNjFlY2M5YWI4OGIwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.L3axjd4Ss2tOa8kytNpCV7NNMw_eKyTXwn9yz6HuSC4)
bfloat16
. Here's the gap:We wonder if TE could provide an other to match the HF's implementation, which converts hidden_states to bfloat16 before multiply the weights. Thanks.
How to reproduce
Version:
transformer-engine 1.7.0+4e7caa1
Code to reproduce:
First define
HFRMSNorm
with native implementation:The assertion should fail when we run the code with this implementation.
Now, let's change the last line from
return self.weight * hidden_states.to(input_dtype)
toreturn (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
, the assertion should pass.The text was updated successfully, but these errors were encountered: