-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Extend SCF loopUnrollByFactor to return the result loops #114573
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Hongtao Yu (htyu) ChangesThere is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in triton-lang/triton#5027 . I'm changing the result type of Full diff: https://github.com/llvm/llvm-project/pull/114573.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 4001ba3fc84c9d..eda64ea69f81d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,11 +111,13 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
-/// Unrolls this for operation by the specified unroll factor. Returns failure
-/// if the loop cannot be unrolled either due to restrictions or due to invalid
-/// unroll factors. Requires positive loop bounds and step. If specified,
-/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
-LogicalResult loopUnrollByFactor(
+/// Unrolls this for operation by the specified unroll factor. Returns the
+/// unrolled main loop and the eplilog loop in sequence, if the loop is
+/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
+/// either due to restrictions or due to invalid unroll factors. Requires
+/// positive loop bounds and step. If specified, annotates the Ops in each
+/// unrolled iteration by applying `annotateFn`.
+SmallVector<scf::ForOp> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 551411bb147653..c84cb13f8b6bb2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -353,8 +353,10 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LogicalResult result(failure());
- if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
- result = loopUnrollByFactor(scfFor, getFactor());
+ if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
+ auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
+ result = resultLoops.empty() ? failure() : success();
+ }
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 43fcc595af0f7e..8394ac47888100 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -372,15 +372,17 @@ static void generateUnrolledLoop(
loopBodyBlock->getTerminator()->setOperands(lastYielded);
}
-/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
-LogicalResult mlir::loopUnrollByFactor(
+/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
+/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
+/// vector.
+SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
- return success();
+ return {forOp};
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -401,8 +403,8 @@ LogicalResult mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
- return failure();
- return success();
+ return {};
+ return {forOp};
}
int64_t tripCountEvenMultiple =
@@ -450,6 +452,9 @@ LogicalResult mlir::loopUnrollByFactor(
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
}
+ SmallVector<scf::ForOp, 2> resultLoops;
+ resultLoops.push_back(forOp);
+
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
if (generateEpilogueLoop) {
OpBuilder epilogueBuilder(forOp->getContext());
@@ -468,6 +473,7 @@ LogicalResult mlir::loopUnrollByFactor(
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
+ resultLoops.push_back(epilogueForOp);
}
// Create unrolled loop.
@@ -490,7 +496,7 @@ LogicalResult mlir::loopUnrollByFactor(
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
- return success();
+ return resultLoops;
}
/// Check if bounds of all inner loops are defined outside of `forOp`
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, although I don't think you have addressed Mahesh's comment. Please make sure to address it before merging
#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work.
…#114573) There is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in triton-lang/triton#5027 . I'm changing the result type of `loopUnrollByFactor` for that.
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work.
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work.
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
There is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in triton-lang/triton#5027 . I'm changing the result type of
loopUnrollByFactor
for that.