diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp index fb466bf0efd0..4e13877ec404 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp @@ -26,6 +26,115 @@ namespace mlir::iree_compiler::IREE::Util { +//===----------------------------------------------------------------------===// +// util.assume.int +//===----------------------------------------------------------------------===// + +LogicalResult AssumeIntOp::canonicalize(AssumeIntOp op, + PatternRewriter &rewriter) { + bool needsRewrite = false; + ArrayAttr assumptions = op.getAssumptions(); + + // We do a fast check for the canonical form here, making any in-place updates + // we can and signalling needsRewrite=true when the op needs to be updated + // to a new canonical form. + SmallPtrSet seenOperands; + seenOperands.reserve(op.getNumOperands()); + for (auto [idx, operand] : llvm::enumerate(op.getOperands())) { + // Match constant. + if (matchPattern(operand, m_Constant())) { + needsRewrite = true; + rewriter.replaceAllUsesWith(op.getResult(idx), operand); + continue; + } + + // Check for a duplicate. + auto [foundIt, inserted] = seenOperands.insert(operand); + if (!inserted) { + // This should be the non-common path: find the original index number + // and rewrite. + for (auto [seenIdx, seenOperand] : llvm::enumerate(op.getOperands())) { + if (seenOperand == operand) { + needsRewrite = true; + rewriter.replaceAllUsesWith(op.getResult(idx), op.getResult(seenIdx)); + break; + } + } + continue; + } + + // Detect whether assumptions need to be normalized. + ArrayAttr assumptionRow = llvm::cast(assumptions[idx]); + if (assumptionRow.size() > 1) { + bool allAssumptionsSame = true; + for (unsigned i = 1; i < assumptionRow.size(); ++i) { + if (assumptionRow[i] != assumptionRow[0]) { + allAssumptionsSame = false; + break; + } + } + if (allAssumptionsSame) { + needsRewrite = true; + } + } + } + if (!needsRewrite) + return failure(); + + // Need to rewrite the assumption. + auto normalizeAssumptions = [](Attribute row, bool &madeChange) { + auto rowArray = llvm::cast(row); + if (rowArray.size() <= 1) + return rowArray; + + bool allSame = true; + for (unsigned i = 1; i < rowArray.size(); ++i) { + if (rowArray[0] != rowArray[i]) { + allSame = false; + break; + } + } + + if (!allSame) + return rowArray; + + // All entries are the same: compress down to a single column. + madeChange = true; + return ArrayAttr::get(row.getContext(), {rowArray[0]}); + }; + SmallVector newAssumptions; + SmallVector newOperands; + SmallVector retainedResults; + bool madeChange = false; + for (auto [idx, operand] : llvm::enumerate(op.getOperands())) { + // If the result has no uses, do not retain it. + if (op.getResult(idx).use_empty()) { + madeChange = true; + continue; + } + + newAssumptions.push_back( + normalizeAssumptions(assumptions[idx], madeChange)); + newOperands.push_back(operand); + retainedResults.push_back(op.getResult(idx)); + } + + // It is important to avoid canonicalizer looping that if we determined at + // the top that a rewrite was needed, that we actually made a change. + (void)madeChange; + assert(madeChange && "util.assume.int canonicalizer signaled a rewrite was " + "needed but it produced the same op"); + + if (!newOperands.empty()) { + auto newOp = + rewriter.create(op.getLoc(), newOperands, newAssumptions); + rewriter.replaceAllUsesWith(retainedResults, newOp.getResults()); + } + + rewriter.eraseOp(op); + return success(); +} + //===----------------------------------------------------------------------===// // util.null //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 3de051be39dd..1a009f10946e 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -1264,6 +1264,9 @@ LogicalResult AssumeIntOp::verify() { for (auto [index, operandAssumptionsAttr] : llvm::enumerate(allOperandAssumptions)) { auto operandAssumptions = cast(operandAssumptionsAttr); + // We always allow a single row to broadcast to any requested size. + if (operandAssumptions.size() == 1) + continue; if (rank && *rank != operandAssumptions.size()) return emitOpError() << "expected operand #" << index << " to have " << *rank << " assumptions but it has " diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index aaa10da27005..b1c17bdea68f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -518,6 +518,7 @@ def Util_AssumeIntOp : Util_PureOp<"assume.int", [ std::optional getUnionedUnsignedDivisor(unsigned operandIndex); }]; + let hasCanonicalizeMethod = 1; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel index 2df2bfd4ee16..a1c60408471f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel @@ -20,6 +20,7 @@ iree_lit_test_suite( "alignment_ops.mlir", "assignment_folding.mlir", "assignment_ops.mlir", + "assume_folding.mlir", "assume_ops.mlir", "attributes.mlir", "buffer_folding.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt index b6ac5d8e87d4..2dad4d1a449f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt @@ -18,6 +18,7 @@ iree_lit_test_suite( "alignment_ops.mlir" "assignment_folding.mlir" "assignment_ops.mlir" + "assume_folding.mlir" "assume_ops.mlir" "attributes.mlir" "buffer_folding.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_folding.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_folding.mlir new file mode 100644 index 000000000000..ffc2aaeb9a75 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_folding.mlir @@ -0,0 +1,50 @@ +// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s + +// CHECK-LABEL: @already_canonical +util.func public @already_canonical(%arg0 : index) -> index { + // CHECK: util.assume.int + %0 = util.assume.int %arg0 : index + util.return %0 : index +} + +// ----- + +// CHECK-LABEL: @elide_constant_assumption +util.func public @elide_constant_assumption() -> index { + %cst = arith.constant 1 : index + %0 = util.assume.int %cst : index + // CHECK: %[[CST:.*]] = arith.constant 1 : index + // CHECK: util.return %[[CST]] + util.return %0 : index +} + +// ----- +// CHECK-LABEL: @elide_multi_constant_assumption +util.func public @elide_multi_constant_assumption(%arg0 : index, %arg1 : index) -> index, index, index { + %cst = arith.constant 1 : index + // CHECK: %[[CST:.*]] = arith.constant 1 : index + // CHECK: %[[ASSUME:.*]]:2 = util.assume.int + // CHECK-NEXT: %arg0, + // CHECK-NEXT: %arg1 + // CHECK-NEXT: : index, index + %0:3 = util.assume.int %arg0, %cst, %arg1 : index, index, index + // CHECK: util.return %[[ASSUME]]#0, %[[CST]], %[[ASSUME]]#1 + util.return %0#0, %0#1, %0#2 : index, index, index +} + +// ----- +// CHECK-LABEL: @broadcast_duplicate_assumptions +util.func public @broadcast_duplicate_assumptions(%arg0 : index) -> index { + // CHECK: util.assume.int %arg0 + %0 = util.assume.int %arg0[, ] : index + util.return %0 : index +} + +// ----- +// CHECK-LABEL: @dedup_duplicate_operands +util.func public @dedup_duplicate_operands(%arg0 : index) -> index, index { + // CHECK: %[[ASSUME:.*]] = util.assume.int %arg0 : index + %0:2 = util.assume.int %arg0[, ], %arg0 : index, index + // CHECK: util.return %[[ASSUME]], %[[ASSUME]] + util.return %0#0, %0#1 : index, index +} diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir index 2be8dc549226..e9d1708e2e00 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir @@ -1,7 +1,15 @@ // RUN: iree-opt --split-input-file --verify-diagnostics %s util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 { - // expected-error @+1 {{expected operand #1 to have 1 assumptions but it has 2}} + // expected-error @+1 {{expected operand #1 to have 3 assumptions but it has 2}} + %0:2 = util.assume.int %arg0[, , ], %arg1[, ] : index, i64 + util.return %0#0, %0#1 : index, i64 +} + +// ----- + +util.func public @assume.int.multi_operand_broadcast(%arg0 : index, %arg1 : i64) -> index, i64 { + // It is legal to have a mismatched arity if 1. %0:2 = util.assume.int %arg0[], %arg1[, ] : index, i64 util.return %0#0, %0#1 : index, i64 }