Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: modify group-gemm stage number #497

Merged
merged 1 commit into from
Sep 13, 2024

Conversation

jeejeelee
Copy link
Contributor

The current group-gemm configuration raises the following error on NVIDIA 3090 :

RuntimeError: cutlass group_gemm.initialize failed: Error Internal

Modify the stage of group-gemm to 4, reduce the size of dynamic smem, so that it can be called on GPUs like the 3090.

Additionally, I also did a simple comparison on the A800. Modifying the stage to 4 can still slightly improve the performance of group-gemm.

Refer to: https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemm_grouped_sm80.cu

@@ -85,7 +85,7 @@ cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType*
cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
cutlass::epilogue::thread::LinearCombination<DType, 8, float, float>, // Epilogue
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // Swizzling Operator
8 // Stages
4 // Stages
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjust the size based on different shared memory instead of directly changing 8 to 4. cc @yzh119

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value 4 was chosen based on the configuration in cutlass, which should be compatible with most scenarios. 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example of the cutlass you provided is the SM80, not the SM90.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @jeejeelee thanks for the PR!

I'll merge this at the moment, some todo items:

  1. add benchmarks for group gemm
  2. select different configurations according to input shapes and cuda arch.
  3. add sm90 cutlass group gemm (I have a wip branch but don't have time to work on it at the moment, it will be great if some people in the community can take it over).

@yzh119 yzh119 merged commit 52dab1d into flashinfer-ai:main Sep 13, 2024
@jeejeelee jeejeelee deleted the modify-group-gemm branch September 14, 2024 03:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants