Skip to content

Commit

Permalink
Add util.assume.int folder. (#18805)
Browse files Browse the repository at this point in the history
Signed-off-by: Stella Laurenzo <[email protected]>
  • Loading branch information
stellaraccident authored Oct 17, 2024
1 parent 1500641 commit 8da6ba2
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 1 deletion.
109 changes: 109 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,115 @@

namespace mlir::iree_compiler::IREE::Util {

//===----------------------------------------------------------------------===//
// util.assume.int
//===----------------------------------------------------------------------===//

LogicalResult AssumeIntOp::canonicalize(AssumeIntOp op,
PatternRewriter &rewriter) {
bool needsRewrite = false;
ArrayAttr assumptions = op.getAssumptions();

// We do a fast check for the canonical form here, making any in-place updates
// we can and signalling needsRewrite=true when the op needs to be updated
// to a new canonical form.
SmallPtrSet<Value, 4> seenOperands;
seenOperands.reserve(op.getNumOperands());
for (auto [idx, operand] : llvm::enumerate(op.getOperands())) {
// Match constant.
if (matchPattern(operand, m_Constant())) {
needsRewrite = true;
rewriter.replaceAllUsesWith(op.getResult(idx), operand);
continue;
}

// Check for a duplicate.
auto [foundIt, inserted] = seenOperands.insert(operand);
if (!inserted) {
// This should be the non-common path: find the original index number
// and rewrite.
for (auto [seenIdx, seenOperand] : llvm::enumerate(op.getOperands())) {
if (seenOperand == operand) {
needsRewrite = true;
rewriter.replaceAllUsesWith(op.getResult(idx), op.getResult(seenIdx));
break;
}
}
continue;
}

// Detect whether assumptions need to be normalized.
ArrayAttr assumptionRow = llvm::cast<ArrayAttr>(assumptions[idx]);
if (assumptionRow.size() > 1) {
bool allAssumptionsSame = true;
for (unsigned i = 1; i < assumptionRow.size(); ++i) {
if (assumptionRow[i] != assumptionRow[0]) {
allAssumptionsSame = false;
break;
}
}
if (allAssumptionsSame) {
needsRewrite = true;
}
}
}
if (!needsRewrite)
return failure();

// Need to rewrite the assumption.
auto normalizeAssumptions = [](Attribute row, bool &madeChange) {
auto rowArray = llvm::cast<ArrayAttr>(row);
if (rowArray.size() <= 1)
return rowArray;

bool allSame = true;
for (unsigned i = 1; i < rowArray.size(); ++i) {
if (rowArray[0] != rowArray[i]) {
allSame = false;
break;
}
}

if (!allSame)
return rowArray;

// All entries are the same: compress down to a single column.
madeChange = true;
return ArrayAttr::get(row.getContext(), {rowArray[0]});
};
SmallVector<ArrayAttr> newAssumptions;
SmallVector<Value> newOperands;
SmallVector<Value> retainedResults;
bool madeChange = false;
for (auto [idx, operand] : llvm::enumerate(op.getOperands())) {
// If the result has no uses, do not retain it.
if (op.getResult(idx).use_empty()) {
madeChange = true;
continue;
}

newAssumptions.push_back(
normalizeAssumptions(assumptions[idx], madeChange));
newOperands.push_back(operand);
retainedResults.push_back(op.getResult(idx));
}

// It is important to avoid canonicalizer looping that if we determined at
// the top that a rewrite was needed, that we actually made a change.
(void)madeChange;
assert(madeChange && "util.assume.int canonicalizer signaled a rewrite was "
"needed but it produced the same op");

if (!newOperands.empty()) {
auto newOp =
rewriter.create<AssumeIntOp>(op.getLoc(), newOperands, newAssumptions);
rewriter.replaceAllUsesWith(retainedResults, newOp.getResults());
}

rewriter.eraseOp(op);
return success();
}

//===----------------------------------------------------------------------===//
// util.null
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,9 @@ LogicalResult AssumeIntOp::verify() {
for (auto [index, operandAssumptionsAttr] :
llvm::enumerate(allOperandAssumptions)) {
auto operandAssumptions = cast<ArrayAttr>(operandAssumptionsAttr);
// We always allow a single row to broadcast to any requested size.
if (operandAssumptions.size() == 1)
continue;
if (rank && *rank != operandAssumptions.size())
return emitOpError() << "expected operand #" << index << " to have "
<< *rank << " assumptions but it has "
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ def Util_AssumeIntOp : Util_PureOp<"assume.int", [
std::optional<uint64_t> getUnionedUnsignedDivisor(unsigned operandIndex);
}];

let hasCanonicalizeMethod = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
"alignment_ops.mlir",
"assignment_folding.mlir",
"assignment_ops.mlir",
"assume_folding.mlir",
"assume_ops.mlir",
"attributes.mlir",
"buffer_folding.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_lit_test_suite(
"alignment_ops.mlir"
"assignment_folding.mlir"
"assignment_ops.mlir"
"assume_folding.mlir"
"assume_ops.mlir"
"attributes.mlir"
"buffer_folding.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s

// CHECK-LABEL: @already_canonical
util.func public @already_canonical(%arg0 : index) -> index {
// CHECK: util.assume.int
%0 = util.assume.int %arg0<umin=0> : index
util.return %0 : index
}

// -----

// CHECK-LABEL: @elide_constant_assumption
util.func public @elide_constant_assumption() -> index {
%cst = arith.constant 1 : index
%0 = util.assume.int %cst<umin=0> : index
// CHECK: %[[CST:.*]] = arith.constant 1 : index
// CHECK: util.return %[[CST]]
util.return %0 : index
}

// -----
// CHECK-LABEL: @elide_multi_constant_assumption
util.func public @elide_multi_constant_assumption(%arg0 : index, %arg1 : index) -> index, index, index {
%cst = arith.constant 1 : index
// CHECK: %[[CST:.*]] = arith.constant 1 : index
// CHECK: %[[ASSUME:.*]]:2 = util.assume.int
// CHECK-NEXT: %arg0<udiv = 2>,
// CHECK-NEXT: %arg1<udiv = 4>
// CHECK-NEXT: : index, index
%0:3 = util.assume.int %arg0<udiv=2>, %cst<umin=0>, %arg1<udiv=4> : index, index, index
// CHECK: util.return %[[ASSUME]]#0, %[[CST]], %[[ASSUME]]#1
util.return %0#0, %0#1, %0#2 : index, index, index
}

// -----
// CHECK-LABEL: @broadcast_duplicate_assumptions
util.func public @broadcast_duplicate_assumptions(%arg0 : index) -> index {
// CHECK: util.assume.int %arg0<umin = 0>
%0 = util.assume.int %arg0[<umin=0>, <umin=0>] : index
util.return %0 : index
}

// -----
// CHECK-LABEL: @dedup_duplicate_operands
util.func public @dedup_duplicate_operands(%arg0 : index) -> index, index {
// CHECK: %[[ASSUME:.*]] = util.assume.int %arg0<umax = 2> : index
%0:2 = util.assume.int %arg0[<umax=2>, <umax=2>], %arg0<umin=0> : index, index
// CHECK: util.return %[[ASSUME]], %[[ASSUME]]
util.return %0#0, %0#1 : index, index
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
// RUN: iree-opt --split-input-file --verify-diagnostics %s

util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 {
// expected-error @+1 {{expected operand #1 to have 1 assumptions but it has 2}}
// expected-error @+1 {{expected operand #1 to have 3 assumptions but it has 2}}
%0:2 = util.assume.int %arg0[<umin=0>, <umax=2>, <udiv=16>], %arg1[<umax=10>, <udiv=6>] : index, i64
util.return %0#0, %0#1 : index, i64
}

// -----

util.func public @assume.int.multi_operand_broadcast(%arg0 : index, %arg1 : i64) -> index, i64 {
// It is legal to have a mismatched arity if 1.
%0:2 = util.assume.int %arg0[<umin=0>], %arg1[<umax=10>, <udiv=6>] : index, i64
util.return %0#0, %0#1 : index, i64
}
Expand Down

0 comments on commit 8da6ba2

Please sign in to comment.