Skip to content

Commit

Permalink
feat: modify group-gemm stage number (#497)
Browse files Browse the repository at this point in the history
The current group-gemm configuration raises the following error on
NVIDIA 3090 :
```shell
RuntimeError: cutlass group_gemm.initialize failed: Error Internal
``` 
Modify the stage of group-gemm to 4, reduce the size of dynamic smem, so
that it can be called on GPUs like the 3090.

Additionally, I also did a simple comparison on the A800. Modifying the
stage to 4 can still slightly improve the performance of group-gemm.

Refer to:
https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemm_grouped_sm80.cu
  • Loading branch information
jeejeelee authored Sep 13, 2024
1 parent 2de16b0 commit 52dab1d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion include/flashinfer/group_gemm/wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType*
cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
cutlass::epilogue::thread::LinearCombination<DType, 8, float, float>, // Epilogue
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // Swizzling Operator
8 // Stages
4 // Stages
>::GemmKernel;

using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;
Expand Down

0 comments on commit 52dab1d

Please sign in to comment.