Skip to content

Commit

Permalink
[Flow] Improve canonicalizer to remove redundantly returned results. (#…
Browse files Browse the repository at this point in the history
…20065)

When a `flow.dispatch.region` returns the same value multiple times from
within the body, the result can be dropped to return the value only
once.

---------

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Feb 23, 2025
1 parent 3b7d9dd commit 80e3fd0
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 21 deletions.
63 changes: 42 additions & 21 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,35 +599,49 @@ LogicalResult DispatchRegionOp::reifyResultShapes(

/// Canonicalizes a DispatchRegionOp: Drop all unused results. Returns `true`
/// if the IR was modified.
bool dropUnusedDispatchRegionResults(RewriterBase &rewriter,
Flow::DispatchRegionOp regionOp) {
bool dropUnusedAndRedundantDispatchRegionResults(
RewriterBase &rewriter, Flow::DispatchRegionOp regionOp) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(regionOp);
if (!llvm::hasSingleElement(regionOp.getBody())) {
// Bail on case where there are more than one blocks in the dispatch.
return false;
}

// Determine unused results and result types + dynamic dimensions of the new
// op.
llvm::DenseSet<unsigned> unusedResults;
// Determine unused/redunduant results and result types + dynamic dimensions
// of the new op. If the result is redundant, record which result number it is
// redundant with. If the result is dropped, record `std::nullopt` to indicate
// that.
llvm::DenseMap<Value, std::optional<unsigned>> droppedResultValues;
llvm::SetVector<Value> yieldedResultsSet;
SmallVector<Type> resultTypes;
SmallVector<Value> dynamicDims;
unsigned dimOffset = 0;
for (const auto &it : llvm::enumerate(regionOp.getResults())) {
Type type = it.value().getType();

auto returnOp =
cast<Flow::ReturnOp>(regionOp.getBody().front().getTerminator());
for (const auto &[index, value] : llvm::enumerate(regionOp.getResults())) {
Type type = value.getType();
auto shapedType = llvm::dyn_cast<ShapedType>(type);
if (it.value().use_empty()) {
unusedResults.insert(it.index());
OpOperand &yieldedVal = returnOp->getOpOperand(index);
if (value.use_empty()) {
droppedResultValues[value] = std::nullopt;
} else if (yieldedResultsSet.contains(yieldedVal.get())) {
droppedResultValues[value] = yieldedVal.getOperandNumber();
} else {
resultTypes.push_back(type);
ValueRange dims = regionOp.getResultDims().slice(
dimOffset, shapedType.getNumDynamicDims());
dynamicDims.append(dims.begin(), dims.end());
yieldedResultsSet.insert(yieldedVal.get());
}
dimOffset += shapedType.getNumDynamicDims();
}
assert(dimOffset == regionOp.getResultDims().size() &&
"expected that all dynamic dims were processed");

// Nothing to do if all results are used.
if (unusedResults.empty())
if (droppedResultValues.empty())
return false;

// Create new region and move over the body.
Expand All @@ -636,21 +650,27 @@ bool dropUnusedDispatchRegionResults(RewriterBase &rewriter,
newRegionOp.getBody().takeBody(regionOp.getBody());

// Update terminator.
auto returnOp =
cast<Flow::ReturnOp>(newRegionOp.getBody().front().getTerminator());
SmallVector<Value> yieldedValues;
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
if (!unusedResults.contains(it.index()))
yieldedValues.push_back(it.value());
ValueRange yieldedVals = yieldedResultsSet.getArrayRef();
rewriter.modifyOpInPlace(
returnOp, [&]() { returnOp.getOperandsMutable().assign(yieldedValues); });
returnOp, [&]() { returnOp.getOperandsMutable().assign(yieldedVals); });

// Replace all uses of the old op.
SmallVector<Value> replacements(regionOp->getNumResults(), nullptr);
unsigned resultCounter = 0;
for (const auto &it : llvm::enumerate(regionOp.getResults()))
if (!unusedResults.contains(it.index()))
replacements[it.index()] = newRegionOp->getResult(resultCounter++);
llvm::SmallDenseMap<unsigned, unsigned> oldResultNumToNewResultNum;
for (const auto &[index, value] : llvm::enumerate(regionOp.getResults())) {
if (droppedResultValues.contains(value)) {
std::optional<unsigned> resultNumber = droppedResultValues.lookup(value);
if (resultNumber) {
replacements[index] = newRegionOp->getResult(
oldResultNumToNewResultNum[resultNumber.value()]);
}
} else {
oldResultNumToNewResultNum[index] = resultCounter;
replacements[index] = newRegionOp->getResult(resultCounter);
resultCounter++;
}
}
rewriter.replaceOp(regionOp, replacements);

return true;
Expand All @@ -662,7 +682,8 @@ struct DispatchRegionDropUnusedResults

LogicalResult matchAndRewrite(DispatchRegionOp regionOp,
PatternRewriter &rewriter) const final {
return success(dropUnusedDispatchRegionResults(rewriter, regionOp));
return success(
dropUnusedAndRedundantDispatchRegionResults(rewriter, regionOp));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,20 @@ util.func public @dont_fold_not_full_static_insert_into_empty(
// CHECK-LABEL: util.func public @dont_fold_not_full_static_insert_into_empty
// CHECK: %[[INSERT:.+]] = tensor.insert_slice
// CHECK: util.return %[[INSERT]]

// -----

util.func @remove_redundant_results(%arg0 : tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
%c0 = arith.constant 0 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?xf32>
%0:3 = flow.dispatch.region -> (tensor<?xf32>{%d0}, tensor<?xf32>{%d0}, tensor<?xf32>{%d0}) {
flow.return %arg0, %arg0, %arg0 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
}
util.return %0#0, %0#2 : tensor<?xf32>, tensor<?xf32>
}
// CHECK-LABEL: func public @remove_redundant_results
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]]
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -> (tensor<?xf32>{%[[DIM]]}
// CHECK-NEXT: flow.return %[[ARG0]] : tensor<?xf32>
// CHECK: util.return %[[DISPATCH]], %[[DISPATCH]]

0 comments on commit 80e3fd0

Please sign in to comment.