Skip to content

Commit

Permalink
Changed return type to FailureOr<UnrolledLoopInfo>
Browse files Browse the repository at this point in the history
  • Loading branch information
htyu committed Nov 3, 2024
1 parent ec46002 commit 61ce324
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
14 changes: 7 additions & 7 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,17 @@ void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);

struct UnrolledLoopInfo {
scf::ForOp mainLoopOp;
scf::ForOp epilogueLoopOp;
scf::ForOp mainLoopOp = nullptr;
scf::ForOp epilogueLoopOp = nullptr;
};

/// Unrolls this for operation by the specified unroll factor. Returns the
/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise
/// returns a strucutre of null fields if the loop cannot be unrolled either due
/// to restrictions or due to invalid unroll factors. Requires positive loop
/// bounds and step. If specified, annotates the Ops in each unrolled iteration
/// by applying `annotateFn`.
UnrolledLoopInfo loopUnrollByFactor(
/// returns failure if the loop cannot be unrolled either due to restrictions or
/// due to invalid unroll factors. Requires positive loop bounds and step. If
/// specified, annotates the Ops in each unrolled iteration by applying
/// `annotateFn`.
FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,7 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
LogicalResult result(failure());
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
result = resultLoops.mainLoopOp ? success() : failure();
result = loopUnrollByFactor(scfFor, getFactor());
} else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
Expand Down
18 changes: 9 additions & 9 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,15 @@ static void generateUnrolledLoop(
}

/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
/// eplilog loop, if the loop is unrolled. Otherwise return null.
UnrolledLoopInfo mlir::loopUnrollByFactor(
/// eplilog loop, if the loop is unrolled.
FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");

// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
return {forOp, nullptr};
return UnrolledLoopInfo{forOp, nullptr};

// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
Expand All @@ -402,8 +402,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
return {nullptr, nullptr};
return {forOp, nullptr};
return failure();
return UnrolledLoopInfo{forOp, nullptr};
}

int64_t tripCountEvenMultiple =
Expand Down Expand Up @@ -470,8 +470,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
}
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
resultLoops.epilogueLoopOp = epilogueForOp;
if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
resultLoops.epilogueLoopOp = epilogueForOp;
}

// Create unrolled loop.
Expand All @@ -493,8 +493,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
},
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
resultLoops.mainLoopOp = forOp;
if (forOp.promoteIfSingleIteration(rewriter).failed())
resultLoops.mainLoopOp = forOp;
return resultLoops;
}

Expand Down

0 comments on commit 61ce324

Please sign in to comment.