Skip to content

Commit

Permalink
[BACKEND] Add missing precondition in optimize acc init (#5184)
Browse files Browse the repository at this point in the history
We need scalar select to be able to do this optimization.
  • Loading branch information
ThomasRaoux authored Nov 18, 2024
1 parent 1fc3269 commit c76b342
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ std::optional<std::pair<Operation *, int>> findZeroInitOp(Value accUse,
return std::nullopt;
}
if (auto selOp = dyn_cast<arith::SelectOp>(defOp)) {
if (!selOp.getCondition().getType().isInteger(1))
return std::nullopt;
if (isConstantZeroTensor(selOp.getTrueValue()) ||
isConstantZeroTensor(selOp.getFalseValue())) {
return std::make_pair(selOp, 0);
Expand Down
16 changes: 16 additions & 0 deletions test/TritonGPU/accumulator-init.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -348,4 +348,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
}
tt.return %17 : tensor<128x16xf32, #mma1>
}

// If the condition is a tensor skip the optimization.
// CHECK-LABEL: @negative_sel_tensor
// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc
tt.func @negative_sel_tensor(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> {
%c0_i32 = arith.constant 0 : i32
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
%c1_i32 = arith.constant 1 : i32
%c8_i32 = arith.constant 8 : i32
%17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 {
%acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1>
%acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1>
scf.yield %acc: tensor<128x16xf32, #mma1>
}
tt.return %17 : tensor<128x16xf32, #mma1>
}
}

0 comments on commit c76b342

Please sign in to comment.