Skip to content

Commit

Permalink
[LoopUnroll] Do not pipeline epilog loops generated by loop unrolling (
Browse files Browse the repository at this point in the history
…triton-lang#5027)

The epilog loop created by the loop unroller may not be run if the main
unrolled loop covers all original loop iterations, thus pipelining it
non-speculatively may not be beneficial. It can also cause some
correctness issue when combined with the downstream PTXAS optimizer.
  • Loading branch information
htyu authored Nov 5, 2024
1 parent 3c296ab commit d2b8659
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
20 changes: 14 additions & 6 deletions lib/Dialect/Triton/Transforms/LoopUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@

namespace mlir::triton {

static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";

namespace {

class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {

int getUnrollFactorOrDefault(scf::ForOp forOp) {
// Use the attribute attached to the loop if it exists otherwise set the
// factor to 1 to suppress the unrolling.
if (auto factor = forOp->getAttrOfType<IntegerAttr>(
mlir::triton::loopUnrollFactorAttrName))
if (auto factor =
forOp->getAttrOfType<IntegerAttr>(loopUnrollFactorAttrName))
return factor.getInt();
return 1;
}

const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";
const char *pipelineStagesAttrName = "tt.num_stages";

public:
LoopUnrollPass() = default;
LoopUnrollPass(const LoopUnrollPass &) {}
Expand All @@ -49,11 +50,18 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
loops.push_back(forOp);
});

auto ctx = getOperation()->getContext();
for (auto loop : loops) {
auto unrollFactor = getUnrollFactorOrDefault(loop);
loop->removeAttr(mlir::triton::loopUnrollFactorAttrName);
loop->removeAttr(loopUnrollFactorAttrName);
LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop);
(void)loopUnrollByFactor(loop, unrollFactor);
auto resultLoops = loopUnrollByFactor(loop, unrollFactor);
// Do not pipeline the epilog loop.
if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) {
(*resultLoops->epilogueLoopOp)
->setAttr(pipelineStagesAttrName,
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1));
}
}
}
};
Expand Down
1 change: 1 addition & 0 deletions test/Triton/loop-unroll.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr<f32>>, %arg1: i32) {
// CHECK: scf.for
// CHECK: tt.load
// CHECK-NOT: tt.load
// CHECK: tt.num_stages = 1 : i32
%2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>) : i32 {
%3 = tt.load %arg5 : tensor<256x!tt.ptr<f32>>
%4 = arith.addf %arg4, %3 : tensor<256xf32>
Expand Down

0 comments on commit d2b8659

Please sign in to comment.