Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

warp_group_dot lowering crashes for specific instruction shape #5102

Closed
gflegar opened this issue Nov 8, 2024 · 3 comments · Fixed by #5105
Closed

warp_group_dot lowering crashes for specific instruction shape #5102

gflegar opened this issue Nov 8, 2024 · 3 comments · Fixed by #5105

Comments

@gflegar
Copy link
Collaborator

gflegar commented Nov 8, 2024

Somewhere between 68aa962 and 8aedb5e the following IR started to crash:

// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm --debug-only=dialect-conversion

#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 32]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#mma = #mma0  // this fails
// #mma = #mma1  // this works
#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 20608 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @mma_crash(
    %arg0: tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>,
    %arg1: !tt.memdesc<64x64xbf16, #shared2, #triton_gpu.shared_memory>,
    %arg2: tensor<64x64xf32, #mma>) {
    %out = triton_nvidia_gpu.warp_group_dot %arg0, %arg1, %arg2 {isAsync = true}
      : tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      * !tt.memdesc<64x64xbf16, #shared2, #triton_gpu.shared_memory>
      -> tensor<64x64xf32, #mma>
    tt.return
  }
}

(Note that on earlier Triton versions we need to replace kWidth = 2 with kWidth = 0 in the example, since this parameter changed from having to be 0 to having to be non-0.)

It seems to be related to the instruction shape, since it works with instrShape = [16, 64, 16], but fails with instrShape = [16, 64, 32].

This is the stack trace:

assert.h assertion failed at llvm/include/llvm/ADT/SmallVector.h:295 in const_reference llvm::SmallVectorTemplateCommon<mlir::Value>::operator[](size_type) const [T = mlir::Value]: idx < size()
*** Check failure stack trace: ***
    @     0x555d376ae474  __assert_fail
    @     0x555d358e2258  loadReg()
    @     0x555d358e3327  convertDot()
    @     0x555d358e4ea2  convertWGMMA()
    @     0x555d358d78ed  (anonymous namespace)::WarpGroupDotOpConversion::matchAndRewrite()
    @     0x555d358d75d0  mlir::ConvertOpToLLVMPattern<>::matchAndRewrite()
    @     0x555d370de26f  mlir::ConversionPattern::matchAndRewrite()
    @     0x555d3717ec3c  llvm::function_ref<>::callback_fn<>()
    @     0x555d3717bfe9  mlir::PatternApplicator::matchAndRewrite()
    @     0x555d370df201  (anonymous namespace)::OperationLegalizer::legalize()
    @     0x555d370de2e9  mlir::OperationConverter::convert()
    @     0x555d370df658  mlir::OperationConverter::convertOperations()
    @     0x555d370e530c  mlir::applyPartialConversion()
    @     0x555d358b30e6  (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation()

My hunch is that it could be related to #5009 ? CC @ggengnv

@gflegar
Copy link
Collaborator Author

gflegar commented Nov 8, 2024

I can confirm that the crash goes away if I bring back this line: cfddb09#diff-c05cf3aed297bf0c5f1296cc40c522b00fb300c7a4340a1f6be5b0bbe2c42039L2048

Though I have no idea at the moment if what we then end up producing is correct at all. But that does seem to strongly imply that #5009 is the culprit.

@ggengnv
Copy link
Contributor

ggengnv commented Nov 8, 2024

It shouldn't be possible for f16 MMA to have instrShape K=32, since instrShapeK is calculated as 256 / elemWidthBits (see mmaVersionToInstrShape()). In your case it should have K=16.
Similarly, the kWidth value is calculated as 32 / elemWidthBits.

As long as the above relationships are true, the current logic in getTotalElemsPerThreadForOperand would be correct. So AFAIK the code you provided shouldn't ever be generated by Triton.

But code quality wise I think it makes sense to have getTotalElemsPerThreadForOperand use getRepForOperand, to match the code in SharedToDotOperandMMAv2OrV3 (which would fix the above case). So I created a PR for that here #5105

@gflegar
Copy link
Collaborator Author

gflegar commented Nov 11, 2024

Thanks for looking into this so quickly! Yes, the TTGIR came from an even worse example where the result layout doesn't match the LHS's parent layout. There's a slack thread where I asked about this. Somehow Triton got into an edge case when trying to optimize an internal implementation of flash attention we have.

The first step here would be to implement a better verifier for warp_group_dot / dot_op layout, which would make it easier to root cause the offending optimization pass, and then we can fix it. I was planning to try getting some time to do that in the next couple of weeks (though I'm also happy to get help if anyone else has spare cycles).

lezcano pushed a commit that referenced this issue Nov 13, 2024
Fixes #5102

The logic in `getTotalElemsPerThreadForOperand` should now directly
match that in `SharedToDotOperandMMAv2OrV3`
Luosuu pushed a commit to Luosuu/triton that referenced this issue Nov 13, 2024
…n-lang#5105)

Fixes triton-lang#5102

The logic in `getTotalElemsPerThreadForOperand` should now directly
match that in `SharedToDotOperandMMAv2OrV3`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants