-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nvFuser has a faster RMSNorm fusion definition than thunder's RMSNorm decomposition #1582
Comments
@kevinstephano is RMSNorm used in our Q4 models? |
fyi @kiya00, it would be a good feature for the report tooling to be able to describe all the operators seen, which are going to thunder and which are sent to torch.compile or eager, and of those that go to thunder whether they took a fallback path or were executed by eager, torch.compile, nvfuser, etc. Questions like this could then be answered by querying that list of operators, too. Not the highest priority feature, but something to note for the future |
These are the RMSNorm definition for each model. class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" Same for Phi3, except for the class name - Same for Mistral, except for the class name - Sample Script to see the generated traces from these implementation- import torch.utils.benchmark
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
import torch
import thunder
with torch.device("cuda"):
rms_norm = Qwen2RMSNorm((3584,), eps=1e-06)
x = torch.randn(1, 4096, 3584, requires_grad=True, dtype=torch.bfloat16)
rms_norm(x)
jfn = thunder.jit(rms_norm)
o = jfn(x)
print(o.shape)
grad_o = torch.rand_like(o)
traces = thunder.last_traces(jfn)
bwd_traces = thunder.last_backward_traces(jfn)
print(traces[-1])
print(bwd_traces[-1])
Forward Trace# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(hidden_states, t_weight):
# hidden_states: "cuda:0 bf16[1, 4096, 3584]"
# t_weight: "cuda:0 f32[3584]"
[t7, t15] = nvFusion0(hidden_states, t_weight)
# t0 = prims.convert_element_type(hidden_states, dtypes.float32) # t0: "cuda:0 f32[1, 4096, 3584]"
# t1 = prims.pow(t0, 2.0) # t1: "cuda:0 f32[1, 4096, 3584]"
# t18 = prims.sum(t1, (2,)) # t18: "cuda:0 f32[1, 4096]"
# t19 = prims.broadcast_in_dim(t18, [1, 4096, 1], [0, 1]) # t19: "cuda:0 f32[1, 4096, 1]"
# variance = prims.div(t19, 3584.0) # variance: "cuda:0 f32[1, 4096, 1]"
# t6 = prims.add(variance, 1e-06) # t6: "cuda:0 f32[1, 4096, 1]"
# t7 = prims.rsqrt(t6) # t7: "cuda:0 f32[1, 4096, 1]"
# t23 = prims.broadcast_in_dim(t7, (1, 4096, 3584), (0, 1, 2)) # t23: "cuda:0 f32[1, 4096, 3584]"
# t9 = prims.mul(t0, t23) # t9: "cuda:0 f32[1, 4096, 3584]"
# t26 = prims.broadcast_in_dim(t_weight, (1, 4096, 3584), (2,)) # t26: "cuda:0 f32[1, 4096, 3584]"
# t15 = prims.mul(t26, t9) # t15: "cuda:0 f32[1, 4096, 3584]"
return {'output': (t15,), 'flat_args': [hidden_states, t_weight], 'flat_output': (t15,)}, ((hidden_states, t7, t_weight), ()) Backward Trace# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, _, = saved_for_backward
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t5, = cotangents
clear_mutable_collection(cotangents)
del cotangents
hidden_states, t7, t_weight, = C0
clear_mutable_collection(C0)
del C0
[bw_t60, bw_t84] = nvFusion0(t_weight, t5, hidden_states, t7)
# bw_t26 = prims.broadcast_in_dim(t_weight, (1, 4096, 3584), (2,)) # bw_t26: "cuda:0 f32[1, 4096, 3584]"
# bw_t58 = prims.mul(bw_t26, t5) # bw_t58: "cuda:0 f32[1, 4096, 3584]"
# t0 = prims.convert_element_type(hidden_states, dtypes.float32) # t0: "cuda:0 f32[1, 4096, 3584]"
# bw_t63 = prims.mul(t0, bw_t58) # bw_t63: "cuda:0 f32[1, 4096, 3584]"
# bw_t64 = prims.sum(bw_t63, (0, 2)) # bw_t64: "cuda:0 f32[4096]"
# bw_t65 = prims.broadcast_in_dim(bw_t64, [1, 4096, 1], [1]) # bw_t65: "cuda:0 f32[1, 4096, 1]"
# bw_t67 = prims.pow(t7, 3.0) # bw_t67: "cuda:0 f32[1, 4096, 1]"
# bw_t66 = prims.mul(-0.5, bw_t65) # bw_t66: "cuda:0 f32[1, 4096, 1]"
# bw_t68 = prims.mul(bw_t66, bw_t67) # bw_t68: "cuda:0 f32[1, 4096, 1]"
# bw_t71 = prims.div(bw_t68, 3584.0) # bw_t71: "cuda:0 f32[1, 4096, 1]"
# bw_t72 = prims.sum(bw_t71, (0, 2)) # bw_t72: "cuda:0 f32[4096]"
# bw_t73 = prims.broadcast_in_dim(bw_t72, [1, 4096], [1]) # bw_t73: "cuda:0 f32[1, 4096]"
# bw_t75 = prims.broadcast_in_dim(bw_t73, [1, 4096, 1], [0, 1]) # bw_t75: "cuda:0 f32[1, 4096, 1]"
# t23 = prims.broadcast_in_dim(t7, (1, 4096, 3584), (0, 1, 2)) # t23: "cuda:0 f32[1, 4096, 3584]"
# bw_t76 = prims.broadcast_in_dim(bw_t75, (1, 4096, 3584), (0, 1, 2)) # bw_t76: "cuda:0 f32[1, 4096, 3584]"
# t9 = prims.mul(t0, t23) # t9: "cuda:0 f32[1, 4096, 3584]"
# bw_t79 = prims.pow(t0, 1.0) # bw_t79: "cuda:0 f32[1, 4096, 3584]"
# bw_t78 = prims.mul(bw_t76, 2.0) # bw_t78: "cuda:0 f32[1, 4096, 3584]"
# bw_t23 = prims.broadcast_in_dim(t7, (1, 4096, 3584), (0, 1, 2)) # bw_t23: "cuda:0 f32[1, 4096, 3584]"
# t12 = prims.convert_element_type(t9, dtypes.bfloat16) # t12: "cuda:0 bf16[1, 4096, 3584]"
# bw_t80 = prims.mul(bw_t78, bw_t79) # bw_t80: "cuda:0 f32[1, 4096, 3584]"
# bw_t62 = prims.mul(bw_t23, bw_t58) # bw_t62: "cuda:0 f32[1, 4096, 3584]"
# bw_t27 = prims.convert_element_type(t12, dtypes.float32) # bw_t27: "cuda:0 f32[1, 4096, 3584]"
# bw_t83 = prims.add(bw_t62, bw_t80) # bw_t83: "cuda:0 f32[1, 4096, 3584]"
# bw_t57 = prims.mul(bw_t27, t5) # bw_t57: "cuda:0 f32[1, 4096, 3584]"
# bw_t84 = prims.convert_element_type(bw_t83, dtypes.bfloat16) # bw_t84: "cuda:0 bf16[1, 4096, 3584]"
# bw_t60 = prims.sum(bw_t57, (0, 1)) # bw_t60: "cuda:0 f32[3584]"
del t_weight, t5, hidden_states, t7
return (bw_t84, bw_t60) |
See NVIDIA/Fuser#3629 for details
cc @apaz-cli
The text was updated successfully, but these errors were encountered: