From 6e779e649aee2ebcdf7594e469dc94da6d544380 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Fri, 1 Nov 2024 09:52:11 -0700 Subject: [PATCH 1/4] [mlir] Extend SCF loopUnrollByFactor to return the result loops --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 12 +++++++----- .../SCF/TransformOps/SCFTransformOps.cpp | 6 ++++-- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 18 ++++++++++++------ 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 4001ba3fc84c9d..eda64ea69f81d1 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -111,11 +111,13 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op); void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef> combinedDimensions); -/// Unrolls this for operation by the specified unroll factor. Returns failure -/// if the loop cannot be unrolled either due to restrictions or due to invalid -/// unroll factors. Requires positive loop bounds and step. If specified, -/// annotates the Ops in each unrolled iteration by applying `annotateFn`. -LogicalResult loopUnrollByFactor( +/// Unrolls this for operation by the specified unroll factor. Returns the +/// unrolled main loop and the eplilog loop in sequence, if the loop is +/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled +/// either due to restrictions or due to invalid unroll factors. Requires +/// positive loop bounds and step. If specified, annotates the Ops in each +/// unrolled iteration by applying `annotateFn`. +SmallVector loopUnrollByFactor( scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn = nullptr); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 551411bb147653..c84cb13f8b6bb2 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -353,8 +353,10 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); - if (scf::ForOp scfFor = dyn_cast(op)) - result = loopUnrollByFactor(scfFor, getFactor()); + if (scf::ForOp scfFor = dyn_cast(op)) { + auto resultLoops = loopUnrollByFactor(scfFor, getFactor()); + result = resultLoops.empty() ? failure() : success(); + } else if (AffineForOp affineFor = dyn_cast(op)) result = loopUnrollByFactor(affineFor, getFactor()); else diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 43fcc595af0f7e..8394ac47888100 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -372,15 +372,17 @@ static void generateUnrolledLoop( loopBodyBlock->getTerminator()->setOperands(lastYielded); } -/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled. -LogicalResult mlir::loopUnrollByFactor( +/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the +/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty +/// vector. +SmallVector mlir::loopUnrollByFactor( scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn) { assert(unrollFactor > 0 && "expected positive unroll factor"); // Return if the loop body is empty. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) - return success(); + return {forOp}; // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. @@ -401,8 +403,8 @@ LogicalResult mlir::loopUnrollByFactor( if (unrollFactor == 1) { if (*constTripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter))) - return failure(); - return success(); + return {}; + return {forOp}; } int64_t tripCountEvenMultiple = @@ -450,6 +452,9 @@ LogicalResult mlir::loopUnrollByFactor( boundsBuilder.create(loc, step, unrollFactorCst); } + SmallVector resultLoops; + resultLoops.push_back(forOp); + // Create epilogue clean up loop starting at 'upperBoundUnrolled'. if (generateEpilogueLoop) { OpBuilder epilogueBuilder(forOp->getContext()); @@ -468,6 +473,7 @@ LogicalResult mlir::loopUnrollByFactor( epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), epilogueForOp.getInitArgs().size(), results); (void)epilogueForOp.promoteIfSingleIteration(rewriter); + resultLoops.push_back(epilogueForOp); } // Create unrolled loop. @@ -490,7 +496,7 @@ LogicalResult mlir::loopUnrollByFactor( annotateFn, iterArgs, yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. (void)forOp.promoteIfSingleIteration(rewriter); - return success(); + return resultLoops; } /// Check if bounds of all inner loops are defined outside of `forOp` From ec4600241bf8f1acac2c403b9c4b2a43a68802f7 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Fri, 1 Nov 2024 11:06:51 -0700 Subject: [PATCH 2/4] make return value more structured. --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 17 +++++++++++------ .../SCF/TransformOps/SCFTransformOps.cpp | 5 ++--- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 17 ++++++++--------- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index eda64ea69f81d1..c3bd6d86864186 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -111,13 +111,18 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op); void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef> combinedDimensions); +struct UnrolledLoopInfo { + scf::ForOp mainLoopOp; + scf::ForOp epilogueLoopOp; +}; + /// Unrolls this for operation by the specified unroll factor. Returns the -/// unrolled main loop and the eplilog loop in sequence, if the loop is -/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled -/// either due to restrictions or due to invalid unroll factors. Requires -/// positive loop bounds and step. If specified, annotates the Ops in each -/// unrolled iteration by applying `annotateFn`. -SmallVector loopUnrollByFactor( +/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise +/// returns a strucutre of null fields if the loop cannot be unrolled either due +/// to restrictions or due to invalid unroll factors. Requires positive loop +/// bounds and step. If specified, annotates the Ops in each unrolled iteration +/// by applying `annotateFn`. +UnrolledLoopInfo loopUnrollByFactor( scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn = nullptr); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index c84cb13f8b6bb2..cefd023c40d96c 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -355,9 +355,8 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, LogicalResult result(failure()); if (scf::ForOp scfFor = dyn_cast(op)) { auto resultLoops = loopUnrollByFactor(scfFor, getFactor()); - result = resultLoops.empty() ? failure() : success(); - } - else if (AffineForOp affineFor = dyn_cast(op)) + result = resultLoops.mainLoopOp ? success() : failure(); + } else if (AffineForOp affineFor = dyn_cast(op)) result = loopUnrollByFactor(affineFor, getFactor()); else return emitSilenceableError() diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 8394ac47888100..a50e90af3af658 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -373,16 +373,15 @@ static void generateUnrolledLoop( } /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the -/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty -/// vector. -SmallVector mlir::loopUnrollByFactor( +/// eplilog loop, if the loop is unrolled. Otherwise return null. +UnrolledLoopInfo mlir::loopUnrollByFactor( scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn) { assert(unrollFactor > 0 && "expected positive unroll factor"); // Return if the loop body is empty. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) - return {forOp}; + return {forOp, nullptr}; // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. @@ -403,8 +402,8 @@ SmallVector mlir::loopUnrollByFactor( if (unrollFactor == 1) { if (*constTripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter))) - return {}; - return {forOp}; + return {nullptr, nullptr}; + return {forOp, nullptr}; } int64_t tripCountEvenMultiple = @@ -452,8 +451,7 @@ SmallVector mlir::loopUnrollByFactor( boundsBuilder.create(loc, step, unrollFactorCst); } - SmallVector resultLoops; - resultLoops.push_back(forOp); + UnrolledLoopInfo resultLoops; // Create epilogue clean up loop starting at 'upperBoundUnrolled'. if (generateEpilogueLoop) { @@ -473,7 +471,7 @@ SmallVector mlir::loopUnrollByFactor( epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), epilogueForOp.getInitArgs().size(), results); (void)epilogueForOp.promoteIfSingleIteration(rewriter); - resultLoops.push_back(epilogueForOp); + resultLoops.epilogueLoopOp = epilogueForOp; } // Create unrolled loop. @@ -496,6 +494,7 @@ SmallVector mlir::loopUnrollByFactor( annotateFn, iterArgs, yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. (void)forOp.promoteIfSingleIteration(rewriter); + resultLoops.mainLoopOp = forOp; return resultLoops; } From a6abc7a1b01d1bb13f74608544ff98be2698612f Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Sun, 3 Nov 2024 12:06:24 -0800 Subject: [PATCH 3/4] Changed return type to FailureOr --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 14 +++++++------- .../SCF/TransformOps/SCFTransformOps.cpp | 7 +++---- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 18 +++++++++--------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index c3bd6d86864186..dfb2e1e6c90aba 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -112,17 +112,17 @@ void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef> combinedDimensions); struct UnrolledLoopInfo { - scf::ForOp mainLoopOp; - scf::ForOp epilogueLoopOp; + scf::ForOp mainLoopOp = nullptr; + scf::ForOp epilogueLoopOp = nullptr; }; /// Unrolls this for operation by the specified unroll factor. Returns the /// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise -/// returns a strucutre of null fields if the loop cannot be unrolled either due -/// to restrictions or due to invalid unroll factors. Requires positive loop -/// bounds and step. If specified, annotates the Ops in each unrolled iteration -/// by applying `annotateFn`. -UnrolledLoopInfo loopUnrollByFactor( +/// returns failure if the loop cannot be unrolled either due to restrictions or +/// due to invalid unroll factors. Requires positive loop bounds and step. If +/// specified, annotates the Ops in each unrolled iteration by applying +/// `annotateFn`. +FailureOr loopUnrollByFactor( scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn = nullptr); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index cefd023c40d96c..551411bb147653 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -353,10 +353,9 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); - if (scf::ForOp scfFor = dyn_cast(op)) { - auto resultLoops = loopUnrollByFactor(scfFor, getFactor()); - result = resultLoops.mainLoopOp ? success() : failure(); - } else if (AffineForOp affineFor = dyn_cast(op)) + if (scf::ForOp scfFor = dyn_cast(op)) + result = loopUnrollByFactor(scfFor, getFactor()); + else if (AffineForOp affineFor = dyn_cast(op)) result = loopUnrollByFactor(affineFor, getFactor()); else return emitSilenceableError() diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index a50e90af3af658..e591ca49bccb8f 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -373,15 +373,15 @@ static void generateUnrolledLoop( } /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the -/// eplilog loop, if the loop is unrolled. Otherwise return null. -UnrolledLoopInfo mlir::loopUnrollByFactor( +/// eplilog loop, if the loop is unrolled. +FailureOr mlir::loopUnrollByFactor( scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn) { assert(unrollFactor > 0 && "expected positive unroll factor"); // Return if the loop body is empty. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) - return {forOp, nullptr}; + return UnrolledLoopInfo{forOp, nullptr}; // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. @@ -402,8 +402,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor( if (unrollFactor == 1) { if (*constTripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter))) - return {nullptr, nullptr}; - return {forOp, nullptr}; + return failure(); + return UnrolledLoopInfo{forOp, nullptr}; } int64_t tripCountEvenMultiple = @@ -470,8 +470,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor( } epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), epilogueForOp.getInitArgs().size(), results); - (void)epilogueForOp.promoteIfSingleIteration(rewriter); - resultLoops.epilogueLoopOp = epilogueForOp; + if (epilogueForOp.promoteIfSingleIteration(rewriter).failed()) + resultLoops.epilogueLoopOp = epilogueForOp; } // Create unrolled loop. @@ -493,8 +493,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor( }, annotateFn, iterArgs, yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. - (void)forOp.promoteIfSingleIteration(rewriter); - resultLoops.mainLoopOp = forOp; + if (forOp.promoteIfSingleIteration(rewriter).failed()) + resultLoops.mainLoopOp = forOp; return resultLoops; } From 0f2c6ec3675a1e5f02961b06f641b6ad145a1aee Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Mon, 4 Nov 2024 10:49:55 -0800 Subject: [PATCH 4/4] Use std::nullopt for field type. --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 4 ++-- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index dfb2e1e6c90aba..02ffa0da7a8b86 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -112,8 +112,8 @@ void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef> combinedDimensions); struct UnrolledLoopInfo { - scf::ForOp mainLoopOp = nullptr; - scf::ForOp epilogueLoopOp = nullptr; + std::optional mainLoopOp = std::nullopt; + std::optional epilogueLoopOp = std::nullopt; }; /// Unrolls this for operation by the specified unroll factor. Returns the diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index e591ca49bccb8f..247311d66ff949 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -381,7 +381,7 @@ FailureOr mlir::loopUnrollByFactor( // Return if the loop body is empty. if (llvm::hasSingleElement(forOp.getBody()->getOperations())) - return UnrolledLoopInfo{forOp, nullptr}; + return UnrolledLoopInfo{forOp, std::nullopt}; // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. @@ -403,7 +403,7 @@ FailureOr mlir::loopUnrollByFactor( if (*constTripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter))) return failure(); - return UnrolledLoopInfo{forOp, nullptr}; + return UnrolledLoopInfo{forOp, std::nullopt}; } int64_t tripCountEvenMultiple =