Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nvFuser has a faster RMSNorm fusion definition than thunder's RMSNorm decomposition #1582

Open
mruberry opened this issue Dec 23, 2024 · 3 comments

Comments

@mruberry
Copy link
Collaborator

mruberry commented Dec 23, 2024

See NVIDIA/Fuser#3629 for details

cc @apaz-cli

@tfogal
Copy link
Collaborator

tfogal commented Jan 10, 2025

@kevinstephano is RMSNorm used in our Q4 models?

@mruberry
Copy link
Collaborator Author

@kevinstephano is RMSNorm used in our Q4 models?

fyi @kiya00, it would be a good feature for the report tooling to be able to describe all the operators seen, which are going to thunder and which are sent to torch.compile or eager, and of those that go to thunder whether they took a fallback path or were executed by eager, torch.compile, nvfuser, etc. Questions like this could then be answered by querying that list of operators, too. Not the highest priority feature, but something to note for the future

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Jan 14, 2025

is RMSNorm used in our Q4 models?

These are the RMSNorm definition for each model.

class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

https://github.com/huggingface/transformers/blob/a11041ffad285b13d578127cc304b90c2f12ce1f/src/transformers/models/qwen2/modeling_qwen2.py#L208-L225

Same for Phi3, except for the class name -
https://github.com/huggingface/transformers/blob/a11041ffad285b13d578127cc304b90c2f12ce1f/src/transformers/models/phi3/modeling_phi3.py#L224-L241

Same for Mistral, except for the class name -
https://github.com/huggingface/transformers/blob/a11041ffad285b13d578127cc304b90c2f12ce1f/src/transformers/models/mistral/modeling_mistral.py#L200-L218

Sample Script to see the generated traces from these implementation-

import torch.utils.benchmark
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
import torch
import thunder

with torch.device("cuda"):
    rms_norm = Qwen2RMSNorm((3584,), eps=1e-06)
    x = torch.randn(1, 4096, 3584, requires_grad=True, dtype=torch.bfloat16)

rms_norm(x)

jfn = thunder.jit(rms_norm)
o = jfn(x)
print(o.shape)
grad_o = torch.rand_like(o)
traces = thunder.last_traces(jfn)
bwd_traces = thunder.last_backward_traces(jfn)

print(traces[-1])
print(bwd_traces[-1])
Forward Trace
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(hidden_states, t_weight):
  # hidden_states: "cuda:0 bf16[1, 4096, 3584]"
  # t_weight: "cuda:0 f32[3584]"
  [t7, t15] = nvFusion0(hidden_states, t_weight)
    # t0 = prims.convert_element_type(hidden_states, dtypes.float32)  # t0: "cuda:0 f32[1, 4096, 3584]"
    # t1 = prims.pow(t0, 2.0)  # t1: "cuda:0 f32[1, 4096, 3584]"
    # t18 = prims.sum(t1, (2,))  # t18: "cuda:0 f32[1, 4096]"
    # t19 = prims.broadcast_in_dim(t18, [1, 4096, 1], [0, 1])  # t19: "cuda:0 f32[1, 4096, 1]"
    # variance = prims.div(t19, 3584.0)  # variance: "cuda:0 f32[1, 4096, 1]"
    # t6 = prims.add(variance, 1e-06)  # t6: "cuda:0 f32[1, 4096, 1]"
    # t7 = prims.rsqrt(t6)  # t7: "cuda:0 f32[1, 4096, 1]"
    # t23 = prims.broadcast_in_dim(t7, (1, 4096, 3584), (0, 1, 2))  # t23: "cuda:0 f32[1, 4096, 3584]"
    # t9 = prims.mul(t0, t23)  # t9: "cuda:0 f32[1, 4096, 3584]"
    # t26 = prims.broadcast_in_dim(t_weight, (1, 4096, 3584), (2,))  # t26: "cuda:0 f32[1, 4096, 3584]"
    # t15 = prims.mul(t26, t9)  # t15: "cuda:0 f32[1, 4096, 3584]"
  return {'output': (t15,), 'flat_args': [hidden_states, t_weight], 'flat_output': (t15,)}, ((hidden_states, t7, t_weight), ())
Backward Trace
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t5, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  hidden_states, t7, t_weight, = C0
  clear_mutable_collection(C0)
  del C0
  [bw_t60, bw_t84] = nvFusion0(t_weight, t5, hidden_states, t7)
    # bw_t26 = prims.broadcast_in_dim(t_weight, (1, 4096, 3584), (2,))  # bw_t26: "cuda:0 f32[1, 4096, 3584]"
    # bw_t58 = prims.mul(bw_t26, t5)  # bw_t58: "cuda:0 f32[1, 4096, 3584]"
    # t0 = prims.convert_element_type(hidden_states, dtypes.float32)  # t0: "cuda:0 f32[1, 4096, 3584]"
    # bw_t63 = prims.mul(t0, bw_t58)  # bw_t63: "cuda:0 f32[1, 4096, 3584]"
    # bw_t64 = prims.sum(bw_t63, (0, 2))  # bw_t64: "cuda:0 f32[4096]"
    # bw_t65 = prims.broadcast_in_dim(bw_t64, [1, 4096, 1], [1])  # bw_t65: "cuda:0 f32[1, 4096, 1]"
    # bw_t67 = prims.pow(t7, 3.0)  # bw_t67: "cuda:0 f32[1, 4096, 1]"
    # bw_t66 = prims.mul(-0.5, bw_t65)  # bw_t66: "cuda:0 f32[1, 4096, 1]"
    # bw_t68 = prims.mul(bw_t66, bw_t67)  # bw_t68: "cuda:0 f32[1, 4096, 1]"
    # bw_t71 = prims.div(bw_t68, 3584.0)  # bw_t71: "cuda:0 f32[1, 4096, 1]"
    # bw_t72 = prims.sum(bw_t71, (0, 2))  # bw_t72: "cuda:0 f32[4096]"
    # bw_t73 = prims.broadcast_in_dim(bw_t72, [1, 4096], [1])  # bw_t73: "cuda:0 f32[1, 4096]"
    # bw_t75 = prims.broadcast_in_dim(bw_t73, [1, 4096, 1], [0, 1])  # bw_t75: "cuda:0 f32[1, 4096, 1]"
    # t23 = prims.broadcast_in_dim(t7, (1, 4096, 3584), (0, 1, 2))  # t23: "cuda:0 f32[1, 4096, 3584]"
    # bw_t76 = prims.broadcast_in_dim(bw_t75, (1, 4096, 3584), (0, 1, 2))  # bw_t76: "cuda:0 f32[1, 4096, 3584]"
    # t9 = prims.mul(t0, t23)  # t9: "cuda:0 f32[1, 4096, 3584]"
    # bw_t79 = prims.pow(t0, 1.0)  # bw_t79: "cuda:0 f32[1, 4096, 3584]"
    # bw_t78 = prims.mul(bw_t76, 2.0)  # bw_t78: "cuda:0 f32[1, 4096, 3584]"
    # bw_t23 = prims.broadcast_in_dim(t7, (1, 4096, 3584), (0, 1, 2))  # bw_t23: "cuda:0 f32[1, 4096, 3584]"
    # t12 = prims.convert_element_type(t9, dtypes.bfloat16)  # t12: "cuda:0 bf16[1, 4096, 3584]"
    # bw_t80 = prims.mul(bw_t78, bw_t79)  # bw_t80: "cuda:0 f32[1, 4096, 3584]"
    # bw_t62 = prims.mul(bw_t23, bw_t58)  # bw_t62: "cuda:0 f32[1, 4096, 3584]"
    # bw_t27 = prims.convert_element_type(t12, dtypes.float32)  # bw_t27: "cuda:0 f32[1, 4096, 3584]"
    # bw_t83 = prims.add(bw_t62, bw_t80)  # bw_t83: "cuda:0 f32[1, 4096, 3584]"
    # bw_t57 = prims.mul(bw_t27, t5)  # bw_t57: "cuda:0 f32[1, 4096, 3584]"
    # bw_t84 = prims.convert_element_type(bw_t83, dtypes.bfloat16)  # bw_t84: "cuda:0 bf16[1, 4096, 3584]"
    # bw_t60 = prims.sum(bw_t57, (0, 1))  # bw_t60: "cuda:0 f32[3584]"
  del t_weight, t5, hidden_states, t7
  return (bw_t84, bw_t60)

cc: @tfogal @kevinstephano

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants