Skip to content

Commit

Permalink
feat: improve the precision of the FusedAddRMSNormKernel function (#587)
Browse files Browse the repository at this point in the history
When `sizeof(T) == 2`, the sum of the read `input` and `residual` (float
`x`) is split into two parts, high and low 16 bits, and saved to `input`
and `residual` respectively. Later, `input` and `residual` are read out
and combined to `x`, with the aim of improving the precision of the
subsequent `x * rms_rcp` operation.

Increase precision from 1e-2 to 1e-3.
  • Loading branch information
Abatom authored Nov 6, 2024
1 parent d7300c4 commit c7dc921
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
9 changes: 4 additions & 5 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j] = x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand Down Expand Up @@ -173,17 +174,15 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> residual_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]);
float x = smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j];
input_vec[j] = x * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand All @@ -200,7 +199,7 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
const uint32_t smem_size = (num_warps + d) * sizeof(float);
void* args[] = {&input, &residual, &weight, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
residual_fused = residual.clone()
flashinfer.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)

torch.testing.assert_close(x_fused, x_native, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
Expand Down

0 comments on commit c7dc921

Please sign in to comment.