Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Extend SCF loopUnrollByFactor to return the result loops #114573

Merged
merged 4 commits into from
Nov 4, 2024

Conversation

htyu
Copy link
Contributor

@htyu htyu commented Nov 1, 2024

There is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in triton-lang/triton#5027 . I'm changing the result type of loopUnrollByFactor for that.

@llvmbot
Copy link
Member

llvmbot commented Nov 1, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Hongtao Yu (htyu)

Changes

There is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in triton-lang/triton#5027 . I'm changing the result type of loopUnrollByFactor for that.


Full diff: https://github.com/llvm/llvm-project/pull/114573.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+7-5)
  • (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+4-2)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+12-6)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 4001ba3fc84c9d..eda64ea69f81d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,11 +111,13 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
 void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
                            ArrayRef<std::vector<unsigned>> combinedDimensions);
 
-/// Unrolls this for operation by the specified unroll factor. Returns failure
-/// if the loop cannot be unrolled either due to restrictions or due to invalid
-/// unroll factors. Requires positive loop bounds and step. If specified,
-/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
-LogicalResult loopUnrollByFactor(
+/// Unrolls this for operation by the specified unroll factor. Returns the
+/// unrolled main loop and the eplilog loop in sequence, if the loop is
+/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
+/// either due to restrictions or due to invalid unroll factors. Requires
+/// positive loop bounds and step. If specified, annotates the Ops in each
+/// unrolled iteration by applying `annotateFn`.
+SmallVector<scf::ForOp> loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 551411bb147653..c84cb13f8b6bb2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -353,8 +353,10 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
                                     transform::ApplyToEachResultList &results,
                                     transform::TransformState &state) {
   LogicalResult result(failure());
-  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
-    result = loopUnrollByFactor(scfFor, getFactor());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
+    auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
+    result = resultLoops.empty() ? failure() : success();
+  }
   else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
     result = loopUnrollByFactor(affineFor, getFactor());
   else
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 43fcc595af0f7e..8394ac47888100 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -372,15 +372,17 @@ static void generateUnrolledLoop(
   loopBodyBlock->getTerminator()->setOperands(lastYielded);
 }
 
-/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
-LogicalResult mlir::loopUnrollByFactor(
+/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
+/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
+/// vector.
+SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
   assert(unrollFactor > 0 && "expected positive unroll factor");
 
   // Return if the loop body is empty.
   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
-    return success();
+    return {forOp};
 
   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -401,8 +403,8 @@ LogicalResult mlir::loopUnrollByFactor(
     if (unrollFactor == 1) {
       if (*constTripCount == 1 &&
           failed(forOp.promoteIfSingleIteration(rewriter)))
-        return failure();
-      return success();
+        return {};
+      return {forOp};
     }
 
     int64_t tripCountEvenMultiple =
@@ -450,6 +452,9 @@ LogicalResult mlir::loopUnrollByFactor(
         boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
   }
 
+  SmallVector<scf::ForOp, 2> resultLoops;
+  resultLoops.push_back(forOp);
+
   // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
   if (generateEpilogueLoop) {
     OpBuilder epilogueBuilder(forOp->getContext());
@@ -468,6 +473,7 @@ LogicalResult mlir::loopUnrollByFactor(
     epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
                                epilogueForOp.getInitArgs().size(), results);
     (void)epilogueForOp.promoteIfSingleIteration(rewriter);
+    resultLoops.push_back(epilogueForOp);
   }
 
   // Create unrolled loop.
@@ -490,7 +496,7 @@ LogicalResult mlir::loopUnrollByFactor(
       annotateFn, iterArgs, yieldedValues);
   // Promote the loop body up if this has turned into a single iteration loop.
   (void)forOp.promoteIfSingleIteration(rewriter);
-  return success();
+  return resultLoops;
 }
 
 /// Check if bounds of all inner loops are defined outside of `forOp`

Copy link

github-actions bot commented Nov 1, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, although I don't think you have addressed Mahesh's comment. Please make sure to address it before merging

@htyu htyu merged commit fa57c7a into llvm:main Nov 4, 2024
8 checks passed
htyu added a commit to triton-lang/triton that referenced this pull request Nov 5, 2024
#5064)

Bumping llvm to include a loop unroller fix:
llvm/llvm-project#114573. This is needed for
subsequent loop unroller upstreaming work.
PhilippRados pushed a commit to PhilippRados/llvm-project that referenced this pull request Nov 6, 2024
…#114573)

There is a need of accessing the resulted epilog loop from the SC loop
unroller. It'd clean and convenient to get that directly from the loop
unroller instead of rescanning the whole function, as discussed in
triton-lang/triton#5027 . I'm changing the
result type of `loopUnrollByFactor` for that.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
triton-lang#5064)

Bumping llvm to include a loop unroller fix:
llvm/llvm-project#114573. This is needed for
subsequent loop unroller upstreaming work.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
triton-lang#5064)

Bumping llvm to include a loop unroller fix:
llvm/llvm-project#114573. This is needed for
subsequent loop unroller upstreaming work.
jataylo pushed a commit to jataylo/triton that referenced this pull request Dec 4, 2024
triton-lang#5064)

Bumping llvm to include a loop unroller fix:
llvm/llvm-project#114573. This is needed for
subsequent loop unroller upstreaming work.

(cherry picked from commit 3c296ab)
jataylo pushed a commit to jataylo/triton that referenced this pull request Dec 5, 2024
triton-lang#5064)

Bumping llvm to include a loop unroller fix:
llvm/llvm-project#114573. This is needed for
subsequent loop unroller upstreaming work.

(cherry picked from commit 3c296ab)
jataylo pushed a commit to jataylo/triton that referenced this pull request Dec 6, 2024
triton-lang#5064)

Bumping llvm to include a loop unroller fix:
llvm/llvm-project#114573. This is needed for
subsequent loop unroller upstreaming work.

(cherry picked from commit 3c296ab)
jataylo pushed a commit to jataylo/triton that referenced this pull request Dec 11, 2024
triton-lang#5064)

Bumping llvm to include a loop unroller fix:
llvm/llvm-project#114573. This is needed for
subsequent loop unroller upstreaming work.

(cherry picked from commit 3c296ab)
jataylo pushed a commit to jataylo/triton that referenced this pull request Dec 12, 2024
triton-lang#5064)

Bumping llvm to include a loop unroller fix:
llvm/llvm-project#114573. This is needed for
subsequent loop unroller upstreaming work.

(cherry picked from commit 3c296ab)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants