Skip to content

Commit

Permalink
Fixes a range inference overflow with util.align. (#18808)
Browse files Browse the repository at this point in the history
* In the present state we were folding a non-analyzable util.align lhs
to a constant zero because the next power of two of the maximal range is
zero.
* Detects overflow and will not infer a range.
* Fixes some issues with a RHS of zero that were discovered when writing
tests for this case (which isn't really valid but was asserting the
compiler).

Signed-off-by: Stella Laurenzo <[email protected]>
  • Loading branch information
stellaraccident authored Oct 17, 2024
1 parent 8da6ba2 commit 929a7da
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ static bool isAlignedTo(Value value, Value alignment) {
APInt staticAlignment;
bool hasStaticAlignment =
matchPattern(alignment, m_ConstantInt(&staticAlignment));
if (hasStaticValue && hasStaticAlignment) {
if (hasStaticValue && hasStaticAlignment && !staticAlignment.isZero()) {
// If this value is itself a multiple of the alignment then we can fold.
if (staticValue.urem(staticAlignment).isZero()) {
return true; // value % alignment == 0
Expand Down
18 changes: 13 additions & 5 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ void AlignOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
auto constantAlignment = argRanges[1].getConstantValue();
// Note that for non constant alignment, there may still be something we
// want to infer, but this is left for the future.
if (constantAlignment) {
if (constantAlignment && !constantAlignment->isZero()) {
// We can align the range directly.
// (value + (alignment - 1)) & ~(alignment - 1)
// https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
Expand All @@ -1119,11 +1119,19 @@ void AlignOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
APInt one(constantAlignment->getBitWidth(), 1);
APInt alignmentM1 = *constantAlignment - one;
APInt alignmentM1Inv = ~alignmentM1;
auto align = [&](APInt value) -> APInt {
return (value + alignmentM1) & alignmentM1Inv;
auto align = [&](APInt value, bool &invalid) -> APInt {
APInt aligned = (value + alignmentM1) & alignmentM1Inv;
// Detect overflow, which commonly happens at max range.
if (aligned.ult(value))
invalid = true;
return aligned;
};
setResultRange(getResult(),
ConstantIntRanges::fromUnsigned(align(umin), align(umax)));
bool invalid = false;
auto alignedUmin = align(umin, invalid);
auto alignedUmax = align(umax, invalid);
if (!invalid)
setResultRange(getResult(),
ConstantIntRanges::fromUnsigned(alignedUmin, alignedUmax));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,16 @@ struct RemUIDivisibilityByConstant : public OpRewritePattern<arith::RemUIOp> {
return failure();

uint64_t rhsValue = rhsConstant.getZExtValue();
if (rhsValue > 0 && lhsDiv.udiv() > 0 && lhsDiv.udiv() % rhsValue != 0)
return rewriter.notifyMatchFailure(op, "rhs does not divide lhs");
if (rhsValue > 0 && lhsDiv.udiv() > 0) {
if (lhsDiv.udiv() % rhsValue != 0)
return rewriter.notifyMatchFailure(op, "rhs does not divide lhs");

rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, rewriter.getZeroAttr(op.getResult().getType()));
return success();
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, rewriter.getZeroAttr(op.getResult().getType()));
return success();
}

return failure();
}

DataFlowSolver &solver;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,34 @@ util.func @util_align_bounds_div(%arg0 : index, %arg1 : index) -> index, index,
// CHECK: util.return %[[ALIGN]], %[[ZERO]], %[[REM128]], %[[TRUE]], %[[FALSE]]
util.return %2, %rem64, %rem128, %in_bounds, %out_bounds : index, index, index, i1, i1
}

// -----
// Unbounded lhs of util.align technically has a range that extends to the max
// value of the bit width. Attempting to align this overflows (to zero). If not
// caught, this will most likely lead the optimizer to conclude that the
// aligned result is a constant zero. This code is verified by checking for
// overflow generally and should handle this case.
// CHECK-LABEL: @util_align_overflow
util.func @util_align_overflow(%arg0 : i64) -> i64 {
%c64 = arith.constant 64 : i64
// CHECK: util.align
%0 = util.align %arg0, %c64 : i64
util.return %0 : i64
}

// -----
// Aligning to an alignment of zero doesn't make a lot of sense but it isn't
// numerically an error. We don't fold or optimize this case and we verify
// it as such (and that other division by zero errors don't come up).
// CHECK-LABEL: @util_align_zero
util.func @util_align_zero(%arg0 : i64) -> i64 {
%c0 = arith.constant 0 : i64
%c16 = arith.constant 16 : i64
%assume = util.assume.int %arg0<umin=0, umax=15> : i64
%c128 = arith.constant 128 : i64
// CHECK: util.align
// CHECK: arith.remui
%0 = util.align %assume, %c0 : i64
%rem16 = arith.remui %0, %c16 : i64
util.return %rem16 : i64
}

0 comments on commit 929a7da

Please sign in to comment.