diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp index dd9b4ad139f5..2111d6241c6a 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -65,6 +65,8 @@ std::optional> findZeroInitOp(Value accUse, return std::nullopt; } if (auto selOp = dyn_cast(defOp)) { + if (!selOp.getCondition().getType().isInteger(1)) + return std::nullopt; if (isConstantZeroTensor(selOp.getTrueValue()) || isConstantZeroTensor(selOp.getFalseValue())) { return std::make_pair(selOp, 0); diff --git a/test/TritonGPU/accumulator-init.mlir b/test/TritonGPU/accumulator-init.mlir index 72ef11dcafd7..2f56d0e97511 100644 --- a/test/TritonGPU/accumulator-init.mlir +++ b/test/TritonGPU/accumulator-init.mlir @@ -348,4 +348,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } tt.return %17 : tensor<128x16xf32, #mma1> } + +// If the condition is a tensor skip the optimization. +// CHECK-LABEL: @negative_sel_tensor +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @negative_sel_tensor(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } }