-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve the precision of the FusedAddRMSNormKernel function #587
Conversation
def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x.to(orig_dtype) * weight
return x, residual If the function is modified as follows, the output result of the def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps) * weight.to(orig_dtype)
return x.to(orig_dtype), residual |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice contribution, thank you @Abatom !
Left some comments for discussion.
It's better to add the benchmark result for the new one @Abatom |
@zhyncs we haven't set up a standard benchmark for normalization kernels so I think we can leave it for further work. One interesting feature to have in flashinfer is to add benchmarking class that returns bandwidth and FLOP utilization like proton. Ideally we can port nvbench to python but I don't have a concrete idea about the amount of work. |
@yzh119 The shared memory has already been used in place of global memory, and an global memory read has also been reduced. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I think this PR is ready to be merged.
Brief note (to remind myself what this PR is doing): keep residual in fp32 in shared memory to increase the numerical accuracy of rmsnorm.
When
sizeof(T) == 2
, the sum of the readinput
andresidual
(floatx
) is split into two parts, high and low 16 bits, and saved toinput
andresidual
respectively. Later,input
andresidual
are read out and combined tox
, with the aim of improving the precision of the subsequentx * rms_rcp
operation.Increase precision from 1e-2 to 1e-3.