diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp index 4e13877ec404..7153080d87af 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp @@ -505,7 +505,7 @@ static bool isAlignedTo(Value value, Value alignment) { APInt staticAlignment; bool hasStaticAlignment = matchPattern(alignment, m_ConstantInt(&staticAlignment)); - if (hasStaticValue && hasStaticAlignment) { + if (hasStaticValue && hasStaticAlignment && !staticAlignment.isZero()) { // If this value is itself a multiple of the alignment then we can fold. if (staticValue.urem(staticAlignment).isZero()) { return true; // value % alignment == 0 diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 1a009f10946e..01d0e4229645 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -1110,7 +1110,7 @@ void AlignOp::inferResultRanges(ArrayRef argRanges, auto constantAlignment = argRanges[1].getConstantValue(); // Note that for non constant alignment, there may still be something we // want to infer, but this is left for the future. - if (constantAlignment) { + if (constantAlignment && !constantAlignment->isZero()) { // We can align the range directly. // (value + (alignment - 1)) & ~(alignment - 1) // https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding @@ -1119,11 +1119,19 @@ void AlignOp::inferResultRanges(ArrayRef argRanges, APInt one(constantAlignment->getBitWidth(), 1); APInt alignmentM1 = *constantAlignment - one; APInt alignmentM1Inv = ~alignmentM1; - auto align = [&](APInt value) -> APInt { - return (value + alignmentM1) & alignmentM1Inv; + auto align = [&](APInt value, bool &invalid) -> APInt { + APInt aligned = (value + alignmentM1) & alignmentM1Inv; + // Detect overflow, which commonly happens at max range. + if (aligned.ult(value)) + invalid = true; + return aligned; }; - setResultRange(getResult(), - ConstantIntRanges::fromUnsigned(align(umin), align(umax))); + bool invalid = false; + auto alignedUmin = align(umin, invalid); + auto alignedUmax = align(umax, invalid); + if (!invalid) + setResultRange(getResult(), + ConstantIntRanges::fromUnsigned(alignedUmin, alignedUmax)); } } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp index 022beaac439b..1049f3950bc3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp @@ -216,12 +216,16 @@ struct RemUIDivisibilityByConstant : public OpRewritePattern { return failure(); uint64_t rhsValue = rhsConstant.getZExtValue(); - if (rhsValue > 0 && lhsDiv.udiv() > 0 && lhsDiv.udiv() % rhsValue != 0) - return rewriter.notifyMatchFailure(op, "rhs does not divide lhs"); + if (rhsValue > 0 && lhsDiv.udiv() > 0) { + if (lhsDiv.udiv() % rhsValue != 0) + return rewriter.notifyMatchFailure(op, "rhs does not divide lhs"); - rewriter.replaceOpWithNewOp( - op, rewriter.getZeroAttr(op.getResult().getType())); - return success(); + rewriter.replaceOpWithNewOp( + op, rewriter.getZeroAttr(op.getResult().getType())); + return success(); + } + + return failure(); } DataFlowSolver &solver; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir index 41b304a89c1f..1924f423ef66 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir @@ -462,3 +462,34 @@ util.func @util_align_bounds_div(%arg0 : index, %arg1 : index) -> index, index, // CHECK: util.return %[[ALIGN]], %[[ZERO]], %[[REM128]], %[[TRUE]], %[[FALSE]] util.return %2, %rem64, %rem128, %in_bounds, %out_bounds : index, index, index, i1, i1 } + +// ----- +// Unbounded lhs of util.align technically has a range that extends to the max +// value of the bit width. Attempting to align this overflows (to zero). If not +// caught, this will most likely lead the optimizer to conclude that the +// aligned result is a constant zero. This code is verified by checking for +// overflow generally and should handle this case. +// CHECK-LABEL: @util_align_overflow +util.func @util_align_overflow(%arg0 : i64) -> i64 { + %c64 = arith.constant 64 : i64 + // CHECK: util.align + %0 = util.align %arg0, %c64 : i64 + util.return %0 : i64 +} + +// ----- +// Aligning to an alignment of zero doesn't make a lot of sense but it isn't +// numerically an error. We don't fold or optimize this case and we verify +// it as such (and that other division by zero errors don't come up). +// CHECK-LABEL: @util_align_zero +util.func @util_align_zero(%arg0 : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %c16 = arith.constant 16 : i64 + %assume = util.assume.int %arg0 : i64 + %c128 = arith.constant 128 : i64 + // CHECK: util.align + // CHECK: arith.remui + %0 = util.align %assume, %c0 : i64 + %rem16 = arith.remui %0, %c16 : i64 + util.return %rem16 : i64 +}