Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the precision of the FusedAddRMSNormKernel function #587

Merged
merged 4 commits into from
Nov 6, 2024

Conversation

Abatom
Copy link
Contributor

@Abatom Abatom commented Nov 6, 2024

When sizeof(T) == 2, the sum of the read input and residual (float x) is split into two parts, high and low 16 bits, and saved to input and residual respectively. Later, input and residual are read out and combined to x, with the aim of improving the precision of the subsequent x * rms_rcp operation.

Increase precision from 1e-2 to 1e-3.

@Abatom
Copy link
Contributor Author

Abatom commented Nov 6, 2024

def fused_add_rms_norm(x, residual, weight, eps):
    orig_dtype = x.dtype
    x = x.to(torch.float32)
    x = x + residual.to(torch.float32)
    residual = x.to(orig_dtype)

    variance = x.pow(2).mean(dim=-1, keepdim=True)
    x = x * torch.rsqrt(variance + eps)
    x = x.to(orig_dtype) * weight
    return x, residual

If the function is modified as follows, the output result of the fused_add_rms_norm function will be almost the same as that of FusedAddRMSNormKernel, the precision can reach 1e-20.

def fused_add_rms_norm(x, residual, weight, eps):
    orig_dtype = x.dtype
    x = x.to(torch.float32)
    x = x + residual.to(torch.float32)
    residual = x.to(orig_dtype)

    variance = x.pow(2).mean(dim=-1, keepdim=True)
    x = x * torch.rsqrt(variance + eps) * weight.to(orig_dtype)
    return x.to(orig_dtype), residual

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice contribution, thank you @Abatom !
Left some comments for discussion.

include/flashinfer/norm.cuh Outdated Show resolved Hide resolved
@zhyncs
Copy link
Member

zhyncs commented Nov 6, 2024

It's better to add the benchmark result for the new one @Abatom

@yzh119
Copy link
Collaborator

yzh119 commented Nov 6, 2024

@zhyncs we haven't set up a standard benchmark for normalization kernels so I think we can leave it for further work.

One interesting feature to have in flashinfer is to add benchmarking class that returns bandwidth and FLOP utilization like proton. Ideally we can port nvbench to python but I don't have a concrete idea about the amount of work.

@Abatom
Copy link
Contributor Author

Abatom commented Nov 6, 2024

@yzh119 The shared memory has already been used in place of global memory, and an global memory read has also been reduced.

@Abatom Abatom requested a review from yzh119 November 6, 2024 11:21
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I think this PR is ready to be merged.

Brief note (to remind myself what this PR is doing): keep residual in fp32 in shared memory to increase the numerical accuracy of rmsnorm.

@yzh119 yzh119 merged commit c7dc921 into flashinfer-ai:main Nov 6, 2024
yzh119 added a commit that referenced this pull request Nov 24, 2024
gemma-style rmsnorm kernels (introduced in #477 ) are similar to
original rmsnorm kernel, and we should use the same kernel for them.
This PR cleans up duplicate code and unifies the kernels for gemma-style
and original rmsnorm kernels.

The precision improvements
(#587,
#592) are kept in this
PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants