[MOE] Try to optimize moe align block size multiblocks cuda kernel #3137
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
This is follow up of PR#2970 which enables efficient mutli-blocks execution.
The codes pass correctness test in many simple configurations (seq_len=16384):
The enbles distributed cusum for a single GPU (cudaLaunchCooperativeKernel) at large scale workload.
To make it possible, note Sum(F(a + b)) != Sum(F(a)) + Sum(b) , where F = floor((a + c - 1) / c) (we can make a = m*c + p, b = n * c + q, where q, p are both integers and ranging from [1, c) to prove this mathmatically ).
Algorithm Structure
The codes in kernel moe_align_block_size_multiblocks_kernel organized in the following manner :
The overal steps are very similar to what in this triton version : PR#2913
Considering ROCM support, the stategy is using Ampere technique to develop this multi-blocks codes, no fancy features from arch above sm90 needed after my simple evaluation.
Modifications
We introduce cooperative_groups control valid since from Volta arch and typical strategy in Ampere arch :
After careful study, I believe 16x16 warp fragment is suitable for distributed cumsum computation. The numbers will benchmarked later.
Correctness
Benchmark for large scale data
(WIP)
Next Steps
Checklist