Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MOE] Try to optimize moe align block size multiblocks cuda kernel #3137

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

yiakwy-xpu-ml-framework-team
Copy link
Contributor

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team commented Jan 26, 2025

Motivation

This is follow up of PR#2970 which enables efficient mutli-blocks execution.

The codes pass correctness test in many simple configurations (seq_len=16384):

截屏2025-01-26 11 05 29

The enbles distributed cusum for a single GPU (cudaLaunchCooperativeKernel) at large scale workload.

To make it possible, note Sum(F(a + b)) != Sum(F(a)) + Sum(b) , where F = floor((a + c - 1) / c) (we can make a = m*c + p, b = n * c + q, where q, p are both integers and ranging from [1, c) to prove this mathmatically ).

Algorithm Structure

The codes in kernel moe_align_block_size_multiblocks_kernel organized in the following manner :

// stage 1: compute local shared_counts
...
// stage 2: compute local unaligned cumsum (since bad property of 'F' function mentioned above) using 16x16 fragments in many warps and cache them to tokens_cnt 
...
    __threadfence_system();
    grid.sync();
// stage 3: compute global unaligned cumsum using our newly introduced distributed cumsum algorithm in https://github.com/sgl-project/sglang/pull/2970
{
   ...
    __threadfence_system();
    grid.sync();
}
// stage 4: compute global aligned cumsum and store back to cumsum_ptr
...
// stage 5: compute expert_ids, sorted_ids
...

The overal steps are very similar to what in this triton version : PR#2913

Considering ROCM support, the stategy is using Ampere technique to develop this multi-blocks codes, no fancy features from arch above sm90 needed after my simple evaluation.

Modifications

We introduce cooperative_groups control valid since from Volta arch and typical strategy in Ampere arch :

#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif // USE_ROCM

After careful study, I believe 16x16 warp fragment is suitable for distributed cumsum computation. The numbers will benchmarked later.

Correctness

截屏2025-01-26 11 05 29

Benchmark for large scale data

(WIP)

Next Steps

  • Verify in MI300X @HaiShaw
  • refactor codes to make them simple and nicer :
    • using flashinfer vec_t to to 128 bit copy for int32_t (int16_t may also possible)
    • cutlass as replacement

Checklist

}

// NOTE (yiakwy) : step 2, loop tail
if (tid == active_threads - 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove your balance threads optimization? The current version of the code is too complex, and premature optimization is the root of all evil. I suggest removing this part first to expose the core modifications. Additionally, the kernel for a single block should be retained and the current code should be enabled when the number of tokens is greater than or equal to 32768, unless this multi-block kernel outperforms in all cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep , enabling mulit-blocks upon existing kernel implementation is on the way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants