-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: reduce the read and write of shared memory in the FusedAddRMSNormKernel #592
Conversation
@@ -133,6 +133,8 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res | |||
input_vec.fill(0.f); | |||
vec_t<T, VEC_SIZE> residual_vec; | |||
residual_vec.fill(0.f); | |||
vec_t<float, VEC_SIZE> x_vec; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote this kernel in August #419, and you can actually use https://pytorch.org/docs/stable/benchmark_utils.html to add a benchmark. This way, you can know whether there is a performance improvement before and after the changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote this kernel in August #419, and you can actually use https://pytorch.org/docs/stable/benchmark_utils.html to add a benchmark. This way, you can know whether there is a performance improvement before and after the changes.
Okay, I'll look into this, but I've analyzed this PR using Nsign Compute and found that the performance is about the same as the code before the precision improvement(#587).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like this file would be great
https://github.com/flashinfer-ai/flashinfer/tree/main/benchmarks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay,I'll try to write a benchmark test like this.
On A800Original#python3 benchmarks/bench_fused_add_rmsnorm.py
batch_size: 1, hidden_size: 111, dtype: float16 , latency: 9us, throughput: 0.127GB/s
batch_size: 1, hidden_size: 500, dtype: float16 , latency: 9us, throughput: 0.556GB/s
batch_size: 1, hidden_size: 1024, dtype: float16 , latency: 8us, throughput: 1.212GB/s
batch_size: 1, hidden_size: 3072, dtype: float16 , latency: 9us, throughput: 3.593GB/s
batch_size: 1, hidden_size: 4096, dtype: float16 , latency: 8us, throughput: 5.004GB/s
batch_size: 1, hidden_size: 8192, dtype: float16 , latency: 9us, throughput: 8.731GB/s
---
batch_size: 19, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 2.138GB/s
batch_size: 19, hidden_size: 500, dtype: float16 , latency: 8us, throughput: 10.063GB/s
batch_size: 19, hidden_size: 1024, dtype: float16 , latency: 8us, throughput: 20.380GB/s
batch_size: 19, hidden_size: 3072, dtype: float16 , latency: 9us, throughput: 54.253GB/s
batch_size: 19, hidden_size: 4096, dtype: float16 , latency: 9us, throughput: 71.412GB/s
batch_size: 19, hidden_size: 8192, dtype: float16 , latency: 9us, throughput: 136.994GB/s
---
batch_size: 99, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 11.466GB/s
batch_size: 99, hidden_size: 500, dtype: float16 , latency: 9us, throughput: 45.655GB/s
batch_size: 99, hidden_size: 1024, dtype: float16 , latency: 8us, throughput: 98.419GB/s
batch_size: 99, hidden_size: 3072, dtype: float16 , latency: 9us, throughput: 261.573GB/s
batch_size: 99, hidden_size: 4096, dtype: float16 , latency: 10us, throughput: 337.045GB/s
batch_size: 99, hidden_size: 8192, dtype: float16 , latency: 12us, throughput: 533.967GB/s
---
batch_size: 989, hidden_size: 111, dtype: float16 , latency: 9us, throughput: 97.215GB/s
batch_size: 989, hidden_size: 500, dtype: float16 , latency: 10us, throughput: 392.762GB/s
batch_size: 989, hidden_size: 1024, dtype: float16 , latency: 12us, throughput: 670.266GB/s
batch_size: 989, hidden_size: 3072, dtype: float16 , latency: 24us, throughput: 1025.302GB/s
batch_size: 989, hidden_size: 4096, dtype: float16 , latency: 28us, throughput: 1155.931GB/s
batch_size: 989, hidden_size: 8192, dtype: float16 , latency: 48us, throughput: 1351.178GB/s
--- v1#python3 benchmarks/bench_fused_add_rmsnorm.py
batch_size: 1, hidden_size: 111, dtype: float16 , latency: 9us, throughput: 0.124GB/s
batch_size: 1, hidden_size: 500, dtype: float16 , latency: 8us, throughput: 0.592GB/s
batch_size: 1, hidden_size: 1024, dtype: float16 , latency: 7us, throughput: 1.369GB/s
batch_size: 1, hidden_size: 3072, dtype: float16 , latency: 8us, throughput: 3.775GB/s
batch_size: 1, hidden_size: 4096, dtype: float16 , latency: 8us, throughput: 4.885GB/s
batch_size: 1, hidden_size: 8192, dtype: float16 , latency: 9us, throughput: 8.759GB/s
---
batch_size: 19, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 2.244GB/s
batch_size: 19, hidden_size: 500, dtype: float16 , latency: 8us, throughput: 9.345GB/s
batch_size: 19, hidden_size: 1024, dtype: float16 , latency: 8us, throughput: 20.047GB/s
batch_size: 19, hidden_size: 3072, dtype: float16 , latency: 8us, throughput: 56.730GB/s
batch_size: 19, hidden_size: 4096, dtype: float16 , latency: 9us, throughput: 73.456GB/s
batch_size: 19, hidden_size: 8192, dtype: float16 , latency: 10us, throughput: 129.034GB/s
---
batch_size: 99, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 11.182GB/s
batch_size: 99, hidden_size: 500, dtype: float16 , latency: 9us, throughput: 45.244GB/s
batch_size: 99, hidden_size: 1024, dtype: float16 , latency: 8us, throughput: 97.719GB/s
batch_size: 99, hidden_size: 3072, dtype: float16 , latency: 10us, throughput: 255.074GB/s
batch_size: 99, hidden_size: 4096, dtype: float16 , latency: 10us, throughput: 309.801GB/s
batch_size: 99, hidden_size: 8192, dtype: float16 , latency: 13us, throughput: 498.308GB/s
---
batch_size: 989, hidden_size: 111, dtype: float16 , latency: 9us, throughput: 100.372GB/s
batch_size: 989, hidden_size: 500, dtype: float16 , latency: 11us, throughput: 346.132GB/s
batch_size: 989, hidden_size: 1024, dtype: float16 , latency: 14us, throughput: 586.412GB/s
batch_size: 989, hidden_size: 3072, dtype: float16 , latency: 25us, throughput: 971.120GB/s
batch_size: 989, hidden_size: 4096, dtype: float16 , latency: 30us, throughput: 1094.190GB/s
batch_size: 989, hidden_size: 8192, dtype: float16 , latency: 50us, throughput: 1295.277GB/s
--- v2: Improve the precision of the FusedAddRMSNormKernel function #587#python3 benchmarks/bench_fused_add_rmsnorm.py
batch_size: 1, hidden_size: 111, dtype: float16 , latency: 9us, throughput: 0.127GB/s
batch_size: 1, hidden_size: 500, dtype: float16 , latency: 9us, throughput: 0.560GB/s
batch_size: 1, hidden_size: 1024, dtype: float16 , latency: 9us, throughput: 1.121GB/s
batch_size: 1, hidden_size: 3072, dtype: float16 , latency: 9us, throughput: 3.494GB/s
batch_size: 1, hidden_size: 4096, dtype: float16 , latency: 9us, throughput: 4.588GB/s
batch_size: 1, hidden_size: 8192, dtype: float16 , latency: 11us, throughput: 7.656GB/s
---
batch_size: 19, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 2.120GB/s
batch_size: 19, hidden_size: 500, dtype: float16 , latency: 8us, throughput: 9.461GB/s
batch_size: 19, hidden_size: 1024, dtype: float16 , latency: 8us, throughput: 18.881GB/s
batch_size: 19, hidden_size: 3072, dtype: float16 , latency: 9us, throughput: 54.199GB/s
batch_size: 19, hidden_size: 4096, dtype: float16 , latency: 10us, throughput: 64.051GB/s
batch_size: 19, hidden_size: 8192, dtype: float16 , latency: 12us, throughput: 106.248GB/s
---
batch_size: 99, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 11.392GB/s
batch_size: 99, hidden_size: 500, dtype: float16 , latency: 8us, throughput: 49.162GB/s
batch_size: 99, hidden_size: 1024, dtype: float16 , latency: 9us, throughput: 92.030GB/s
batch_size: 99, hidden_size: 3072, dtype: float16 , latency: 10us, throughput: 235.945GB/s
batch_size: 99, hidden_size: 4096, dtype: float16 , latency: 11us, throughput: 303.909GB/s
batch_size: 99, hidden_size: 8192, dtype: float16 , latency: 13us, throughput: 483.708GB/s
---
batch_size: 989, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 103.845GB/s
batch_size: 989, hidden_size: 500, dtype: float16 , latency: 10us, throughput: 382.028GB/s
batch_size: 989, hidden_size: 1024, dtype: float16 , latency: 14us, throughput: 563.498GB/s
batch_size: 989, hidden_size: 3072, dtype: float16 , latency: 26us, throughput: 927.479GB/s
batch_size: 989, hidden_size: 4096, dtype: float16 , latency: 32us, throughput: 1009.666GB/s
batch_size: 989, hidden_size: 8192, dtype: float16 , latency: 54us, throughput: 1207.263GB/s
--- v3: This PR#python3 benchmarks/bench_fused_add_rmsnorm.py
batch_size: 1, hidden_size: 111, dtype: float16 , latency: 9us, throughput: 0.126GB/s
batch_size: 1, hidden_size: 500, dtype: float16 , latency: 9us, throughput: 0.584GB/s
batch_size: 1, hidden_size: 1024, dtype: float16 , latency: 9us, throughput: 1.123GB/s
batch_size: 1, hidden_size: 3072, dtype: float16 , latency: 9us, throughput: 3.578GB/s
batch_size: 1, hidden_size: 4096, dtype: float16 , latency: 9us, throughput: 4.644GB/s
batch_size: 1, hidden_size: 8192, dtype: float16 , latency: 10us, throughput: 8.458GB/s
---
batch_size: 19, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 2.136GB/s
batch_size: 19, hidden_size: 500, dtype: float16 , latency: 8us, throughput: 9.398GB/s
batch_size: 19, hidden_size: 1024, dtype: float16 , latency: 8us, throughput: 19.087GB/s
batch_size: 19, hidden_size: 3072, dtype: float16 , latency: 9us, throughput: 54.074GB/s
batch_size: 19, hidden_size: 4096, dtype: float16 , latency: 9us, throughput: 70.508GB/s
batch_size: 19, hidden_size: 8192, dtype: float16 , latency: 10us, throughput: 124.474GB/s
---
batch_size: 99, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 10.703GB/s
batch_size: 99, hidden_size: 500, dtype: float16 , latency: 8us, throughput: 48.536GB/s
batch_size: 99, hidden_size: 1024, dtype: float16 , latency: 9us, throughput: 93.118GB/s
batch_size: 99, hidden_size: 3072, dtype: float16 , latency: 10us, throughput: 247.778GB/s
batch_size: 99, hidden_size: 4096, dtype: float16 , latency: 10us, throughput: 316.723GB/s
batch_size: 99, hidden_size: 8192, dtype: float16 , latency: 12us, throughput: 543.516GB/s
---
batch_size: 989, hidden_size: 111, dtype: float16 , latency: 8us, throughput: 104.044GB/s
batch_size: 989, hidden_size: 500, dtype: float16 , latency: 11us, throughput: 369.634GB/s
batch_size: 989, hidden_size: 1024, dtype: float16 , latency: 12us, throughput: 658.708GB/s
batch_size: 989, hidden_size: 3072, dtype: float16 , latency: 24us, throughput: 1003.270GB/s
batch_size: 989, hidden_size: 4096, dtype: float16 , latency: 29us, throughput: 1127.658GB/s
batch_size: 989, hidden_size: 8192, dtype: float16 , latency: 49us, throughput: 1317.664GB/s
--- This PR has essentially the same performance as the original, slightly lower than the original, but for the sake of precision, it can be improved by an order of magnitude (from 1e-2 to 1e-3), which I believe is acceptable. |
I tested #587 , and if dtype=torch.bfloat16, 1e-3 is too strict |
@Abatom Could you clarify the main purpose of this PR, since it doesn't appear to focus on performance improvements? I'm also curious about any other potential benefits, as the accuracy improvements don't seem significant. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The achieved bandwidth might have been underestimated because we didn't count memory access to residual
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, updated results on H100:
---
batch_size: 1, hidden_size: 111, dtype: float16 , latency: 6us, throughput: 0.180GB/s
batch_size: 1, hidden_size: 500, dtype: float16 , latency: 6us, throughput: 0.850GB/s
batch_size: 1, hidden_size: 1024, dtype: float16 , latency: 6us, throughput: 1.706GB/s
batch_size: 1, hidden_size: 3072, dtype: float16 , latency: 6us, throughput: 4.962GB/s
batch_size: 1, hidden_size: 4096, dtype: float16 , latency: 6us, throughput: 6.540GB/s
batch_size: 1, hidden_size: 8192, dtype: float16 , latency: 7us, throughput: 11.960GB/s
---
batch_size: 19, hidden_size: 111, dtype: float16 , latency: 6us, throughput: 2.968GB/s
batch_size: 19, hidden_size: 500, dtype: float16 , latency: 6us, throughput: 13.038GB/s
batch_size: 19, hidden_size: 1024, dtype: float16 , latency: 6us, throughput: 26.839GB/s
batch_size: 19, hidden_size: 3072, dtype: float16 , latency: 6us, throughput: 76.400GB/s
batch_size: 19, hidden_size: 4096, dtype: float16 , latency: 6us, throughput: 98.126GB/s
batch_size: 19, hidden_size: 8192, dtype: float16 , latency: 7us, throughput: 182.439GB/s
---
batch_size: 99, hidden_size: 111, dtype: float16 , latency: 6us, throughput: 15.074GB/s
batch_size: 99, hidden_size: 500, dtype: float16 , latency: 6us, throughput: 65.834GB/s
batch_size: 99, hidden_size: 1024, dtype: float16 , latency: 6us, throughput: 131.731GB/s
batch_size: 99, hidden_size: 3072, dtype: float16 , latency: 7us, throughput: 364.698GB/s
batch_size: 99, hidden_size: 4096, dtype: float16 , latency: 7us, throughput: 463.593GB/s
batch_size: 99, hidden_size: 8192, dtype: float16 , latency: 8us, throughput: 769.438GB/s
---
batch_size: 989, hidden_size: 111, dtype: float16 , latency: 6us, throughput: 137.252GB/s
batch_size: 989, hidden_size: 500, dtype: float16 , latency: 7us, throughput: 547.651GB/s
batch_size: 989, hidden_size: 1024, dtype: float16 , latency: 8us, throughput: 964.887GB/s
batch_size: 989, hidden_size: 3072, dtype: float16 , latency: 14us, throughput: 1701.064GB/s
batch_size: 989, hidden_size: 4096, dtype: float16 , latency: 17us, throughput: 1914.022GB/s
batch_size: 989, hidden_size: 8192, dtype: float16 , latency: 29us, throughput: 2242.014GB/s
---
…apes (#636) This PR fixes the issue #634, which is brought by #592 . If we want to use 16-bytes vectorized read/write, we need to confirm the address is aligned to 16 bytes. When `num_warps` is not a multiple of 4 (4*sizeof(float) = 16), the address of `smem + num_warps` might not align to 16 bytes. We can fix this by shifting the start offset of vectorized read/write to `smem + ceil_div(num_warps, 4) * 4` to force the alignment. cc @ovowei @Abatom
Use
vec_t<float, VEC_SIZE> x_vec
to reduce the number of read and write operations to shared memory.