Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input upcast is missing in Thunder's implementation of torch.nn.functional.rms_norm #1713

Open
IvanYashchuk opened this issue Jan 29, 2025 · 7 comments
Assignees

Comments

@IvanYashchuk
Copy link
Collaborator

🐛 Bug

PyTorch has recently added a new operation torch.nn.functional.rms_norm and Thunder support was added in #1390.

The current implementation in Thunder is different from the one in LitGPT.
In LitGPT input tensor a is upcasted to fp32 and then the result is downcasted to the input's dtype: https://github.com/Lightning-AI/litgpt/blob/1d93671269de04fc96a63a795e137de5b13dc99b/litgpt/model.py#L819
Another example from Hugging Face implementation of Qwen 2 where the input is also upcasted: https://github.com/huggingface/transformers/blob/ec7afad60909dd97d998c1f14681812d69a15728/src/transformers/models/qwen2/modeling_qwen2.py#L220

Does PyTorch upcast the inputs in its implementation?
Should the inputs be upcasted in Thunder?

@t-vi
Copy link
Collaborator

t-vi commented Jan 29, 2025

@kshitij12345
Copy link
Collaborator

Could be related to #1678

cc: @riccardofelluga

@tfogal
Copy link
Collaborator

tfogal commented Jan 29, 2025

Do any of our Q4--Q1 models use RMSNorm in this way s.t. we need to prioritize this?

@IvanYashchuk
Copy link
Collaborator Author

All models use the RMS norm, but none use it through torch.nn.functional.rms_norm.

@IvanYashchuk
Copy link
Collaborator Author

Should Thunder upcast the inputs to fp32 even for bf16 inputs? The linked PRs and issues from PyTorch suggest that the upcast is done primarily for fp16 to avoid overflow.

@riccardofelluga riccardofelluga self-assigned this Feb 5, 2025
@riccardofelluga
Copy link
Collaborator

Update! Thanks @t-vi for the link to the cpp implementation and @IvanYashchuk for the suggestion! I've tested and wrote #1751 to do all the computation in the exact way that torch is doing under the hood. The resulting trace with nvfuser executor is as following(note that torch executor is suffering from the same problem even tho I haven't pasted the trace here):

# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_2):
  # t_0: "cuda:0 bf16[1, 2, 3]"
  # t_2: "cuda:0 bf16[1, 2, 3]"
  [t15] = nvFusion0(t_0, t_2)
    # t4 = prims.convert_element_type(t_0, dtypes.float32)  # t4: "cuda:0 f32[1, 2, 3]"
    # t5 = prims.mul(t4, t4)  # t5: "cuda:0 f32[1, 2, 3]"
    # t6 = prims.sum(t5, (0, 1, 2))  # t6: "cuda:0 f32[]"
    # t7 = prims.broadcast_in_dim(t6, [1, 1, 1], [])  # t7: "cuda:0 f32[1, 1, 1]"
    # t8 = prims.div(t7, 6.0)  # t8: "cuda:0 f32[1, 1, 1]"
    # t9 = prims.add(t8, 0.5)  # t9: "cuda:0 f32[1, 1, 1]"
    # t10 = prims.rsqrt(t9)  # t10: "cuda:0 f32[1, 1, 1]"
    # t11 = prims.broadcast_in_dim(t10, (1, 2, 3), (0, 1, 2))  # t11: "cuda:0 f32[1, 2, 3]"
    # t12 = prims.mul(t4, t11)  # t12: "cuda:0 f32[1, 2, 3]"
    # t13 = prims.convert_element_type(t_2, dtypes.float32)  # t13: "cuda:0 f32[1, 2, 3]"
    # t14 = prims.mul(t12, t13)  # t14: "cuda:0 f32[1, 2, 3]"
    # t15 = prims.convert_element_type(t14, dtypes.bfloat16)  # t15: "cuda:0 bf16[1, 2, 3]"
  return (t15,)

As you can see the castings are kept at minimum but the issue still stands(t_0 is the input tensor and t_2 are the weights). In the CI test I've purposefully made the tolerance 0.0 to check the magnitude of the mismatch.

Also on A6000 Ada or H100 the mismatch does not always show up. More investigation is needed

@riccardofelluga
Copy link
Collaborator

riccardofelluga commented Feb 7, 2025

Update pt. 2! It looks like the cause of the mismatch is the multiplication with the weight tensor which in torch is done in bf16 and instead my patch and thunder earlier are doing in fp32. From pytorch/pytorch#134106

https://github.com/pytorch/pytorch/blob/501c5972f02cc1902f0a060f709b6a2a4ebeb102/aten/src/ATen/native/layer_norm.cpp#L318

Updating the patch makes it work for torch executor on both cuda and cpu, but nvfuser seems to be still upcasting to fp32, probably due to nvFuser upcasting(big question here, this needs to be investigated)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants