Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the precision of the FusedAddRMSNormKernel function #587

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j] = x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand Down Expand Up @@ -173,17 +174,15 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> residual_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]);
float x = smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j];
input_vec[j] = x * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand All @@ -200,7 +199,7 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
const uint32_t smem_size = (num_warps + d) * sizeof(float);
void* args[] = {&input, &residual, &weight, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
residual_fused = residual.clone()
flashinfer.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)

torch.testing.assert_close(x_fused, x_native, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
Expand Down