-
Notifications
You must be signed in to change notification settings - Fork 267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: correct casts in RMSNorm to match references #92
Conversation
I aim to exactly match the references in terms of dtypes. As you can see, in order to do so (especially on the backward passes), the complexity is considerably increased. We also need to store some buffers that were stored previously in lower precision in fp32 now. I am not sure if the tradeoff is worth it (especially in the backward pass, it is more complex), and if it would be better to just not follow the references one-to-one in terms of casting (and storing the inv_rms cached buffer). Let me know what parts you think are worth it to keep close to the reference. |
Would something like kahan summation help with the lower precision accumulation? |
Hadn't heard of it before. Just did a quick search, and if I am not mistaken should be implementable in triton with the generic reduce function. However, given that the relative difference is usually very low (and accum is done in fp32), it might not be worth the hassle. Definitely an interesting idea though. |
you can verify the perf using benchmark/ |
Benchmarks on NVIDIA L4: Forward Speed Benchmark This branch
Master
Backward Speed Benchmark This branch
Master
Full Speed Benchmark This branch
Master
Full Memory Benchmark This branch
Master
|
Just benchmarked it. As expected, the memory usage/speed is slightly worse with this changes. The memory is due to storing the cached rms norm in float32, and speed due to computing/reducing in float32. On Gemma, it will be even more memory usage since weight grads will be stored in fp32 (and then maybe casted back to lower precision). All of this is needed to match the reference.
I'd go with the second one if you are willing to have the complexity in the kernels. What do you think? |
i prefer the third. exactness is a deal breaker |
run all tests again, if pass we can merge |
Done. Also made float16 a bit less tight (to the same level as bfloat) since a very low percentage of elements was failing sometimes. |
@davidgonmar can you resolve the conflicts? Thanks |
Code logic LGTM! Some minor nit comments, the remaining issue is triton 2.3.0 compatibility (the way we use flag doesn't work there) and test env var |
LGTM! checking why the CI is failing |
There seems to be some issue with the GPU CI setup, temporarily making it optional. @davidgonmar could you help rebasing on the latest main branch and we're good to go! |
done! |
Summary
Aims to fix #89.
Details
Does the casts to float32 at the correct places to match the Gemma and Llama references. Does so both in the forward and backward passes.
Also modified the tests for RMSNorm with tighter tolerances + fp16 tests.
Testing Done
Ran tests for convergence and RMSNorm.
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence