-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Input upcast is missing in Thunder's implementation of torch.nn.functional.rms_norm #1713
Comments
Great catch! PyTorch does upcast (introduced in pytorch/pytorch#134106) and so should we: |
Could be related to #1678 cc: @riccardofelluga |
Do any of our Q4--Q1 models use RMSNorm in this way s.t. we need to prioritize this? |
All models use the RMS norm, but none use it through |
Should Thunder upcast the inputs to fp32 even for bf16 inputs? The linked PRs and issues from PyTorch suggest that the upcast is done primarily for fp16 to avoid overflow. |
Update! Thanks @t-vi for the link to the cpp implementation and @IvanYashchuk for the suggestion! I've tested and wrote #1751 to do all the computation in the exact way that torch is doing under the hood. The resulting trace with nvfuser executor is as following(note that torch executor is suffering from the same problem even tho I haven't pasted the trace here): # Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_2):
# t_0: "cuda:0 bf16[1, 2, 3]"
# t_2: "cuda:0 bf16[1, 2, 3]"
[t15] = nvFusion0(t_0, t_2)
# t4 = prims.convert_element_type(t_0, dtypes.float32) # t4: "cuda:0 f32[1, 2, 3]"
# t5 = prims.mul(t4, t4) # t5: "cuda:0 f32[1, 2, 3]"
# t6 = prims.sum(t5, (0, 1, 2)) # t6: "cuda:0 f32[]"
# t7 = prims.broadcast_in_dim(t6, [1, 1, 1], []) # t7: "cuda:0 f32[1, 1, 1]"
# t8 = prims.div(t7, 6.0) # t8: "cuda:0 f32[1, 1, 1]"
# t9 = prims.add(t8, 0.5) # t9: "cuda:0 f32[1, 1, 1]"
# t10 = prims.rsqrt(t9) # t10: "cuda:0 f32[1, 1, 1]"
# t11 = prims.broadcast_in_dim(t10, (1, 2, 3), (0, 1, 2)) # t11: "cuda:0 f32[1, 2, 3]"
# t12 = prims.mul(t4, t11) # t12: "cuda:0 f32[1, 2, 3]"
# t13 = prims.convert_element_type(t_2, dtypes.float32) # t13: "cuda:0 f32[1, 2, 3]"
# t14 = prims.mul(t12, t13) # t14: "cuda:0 f32[1, 2, 3]"
# t15 = prims.convert_element_type(t14, dtypes.bfloat16) # t15: "cuda:0 bf16[1, 2, 3]"
return (t15,) As you can see the castings are kept at minimum but the issue still stands(t_0 is the input tensor and t_2 are the weights). In the CI test I've purposefully made the tolerance 0.0 to check the magnitude of the mismatch. Also on A6000 Ada or H100 the mismatch does not always show up. More investigation is needed |
Update pt. 2! It looks like the cause of the mismatch is the multiplication with the Updating the patch makes it work for torch executor on both cuda and cpu, but nvfuser seems to be still upcasting to fp32, probably due to nvFuser upcasting(big question here, this needs to be investigated) |
🐛 Bug
PyTorch has recently added a new operation
torch.nn.functional.rms_norm
and Thunder support was added in #1390.The current implementation in Thunder is different from the one in LitGPT.
In LitGPT input tensor
a
is upcasted to fp32 and then the result is downcasted to the input's dtype: https://github.com/Lightning-AI/litgpt/blob/1d93671269de04fc96a63a795e137de5b13dc99b/litgpt/model.py#L819Another example from Hugging Face implementation of Qwen 2 where the input is also upcasted: https://github.com/huggingface/transformers/blob/ec7afad60909dd97d998c1f14681812d69a15728/src/transformers/models/qwen2/modeling_qwen2.py#L220
Does PyTorch upcast the inputs in its implementation?
Should the inputs be upcasted in Thunder?
The text was updated successfully, but these errors were encountered: