Skip to content

Commit

Permalink
[InstCombine] Enable saturated add canonicalization in more variants
Browse files Browse the repository at this point in the history
LLVM is not evaluating X u > C, a, b the same way it evaluates X <= C, b, a.

To fix this, let's move the folds to after the canonicalization of -1 to TrueVal.

Alive2 Proof:
https://alive2.llvm.org/ce/z/8QbZfx
  • Loading branch information
AZero13 committed Jul 22, 2024
1 parent c83ade4 commit 55f372d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 32 deletions.
33 changes: 25 additions & 8 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,14 +977,7 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
Value *Cmp1 = Cmp->getOperand(1);
ICmpInst::Predicate Pred = Cmp->getPredicate();
Value *X;
const APInt *C, *CmpC;
if (Pred == ICmpInst::ICMP_ULT &&
match(TVal, m_Add(m_Value(X), m_APInt(C))) && X == Cmp0 &&
match(FVal, m_AllOnes()) && match(Cmp1, m_APInt(CmpC)) && *CmpC == ~*C) {
// (X u< ~C) ? (X + C) : -1 --> uadd.sat(X, C)
return Builder.CreateBinaryIntrinsic(
Intrinsic::uadd_sat, X, ConstantInt::get(X->getType(), *C));
}
const APInt *C;

// Match unsigned saturated add of 2 variables with an unnecessary 'not'.
// There are 8 commuted variants.
Expand All @@ -996,6 +989,30 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
if (!match(TVal, m_AllOnes()))
return nullptr;

if ((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) &&
match(FVal, m_Add(m_Specific(Cmp0), m_APInt(C))) &&
match(Cmp1, m_SpecificInt(~*C))) {
// (X u> ~C) ? -1 : (X + C) --> uadd.sat(X, C)
return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp0,
ConstantInt::get(Cmp0->getType(), *C));
}

if (Pred == ICmpInst::ICMP_UGE &&
match(FVal, m_Add(m_Specific(Cmp0), m_APInt(C))) &&
match(Cmp1, m_SpecificInt(-*C))) {
// (X u> -C) ? -1 : (X + C) --> uadd.sat(X, C)
return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp0,
ConstantInt::get(Cmp0->getType(), *C));
}

if (Pred == ICmpInst::ICMP_UGT &&
match(FVal, m_Add(m_Specific(Cmp0), m_APInt(C))) &&
match(Cmp1, m_SpecificInt(~*C - 1))) {
// (X u> ~C - 1) ? -1 : (X + C) --> uadd.sat(X, C)
return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp0,
ConstantInt::get(Cmp0->getType(), *C));
}

// Canonicalize predicate to less-than or less-or-equal-than.
if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
std::swap(Cmp0, Cmp1);
Expand Down
32 changes: 8 additions & 24 deletions llvm/test/Transforms/InstCombine/saturating-add-sub.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1398,9 +1398,7 @@ define i32 @uadd_sat(i32 %x, i32 %y) {

define i32 @uadd_sat_flipped(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -11
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 9)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp uge i32 %x, -10
Expand All @@ -1411,9 +1409,7 @@ define i32 @uadd_sat_flipped(i32 %x) {

define i32 @uadd_sat_flipped2(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped2(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -10
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 9)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp ugt i32 %x, -10
Expand All @@ -1424,9 +1420,7 @@ define i32 @uadd_sat_flipped2(i32 %x) {

define i32 @uadd_sat_flipped3(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped3(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -11
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 9)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp ugt i32 %x, -11
Expand All @@ -1437,9 +1431,7 @@ define i32 @uadd_sat_flipped3(i32 %x) {

define i32 @uadd_sat_flipped4(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped4(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -10
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 9)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp uge i32 %x, -9
Expand All @@ -1450,9 +1442,7 @@ define i32 @uadd_sat_flipped4(i32 %x) {

define i32 @uadd_sat_flipped5(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped5(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], -9
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[ADD]], i32 -1
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 9)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp ult i32 %x, -9
Expand All @@ -1474,9 +1464,7 @@ define i32 @uadd_sat_flipped6(i32 %x) {

define i32 @uadd_sat_flipped7(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped7(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], -9
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[ADD]], i32 -1
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 9)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp ule i32 %x, -10
Expand Down Expand Up @@ -1957,9 +1945,7 @@ define i32 @uadd_sat_not_commute_select_uge_commute_add(i32 %x, i32 %y) {

define i32 @uadd_sat_constant(i32 %x) {
; CHECK-LABEL: @uadd_sat_constant(
; CHECK-NEXT: [[A:%.*]] = add i32 [[X:%.*]], 42
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[X]], -43
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 -1, i32 [[A]]
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 42)
; CHECK-NEXT: ret i32 [[R]]
;
%a = add i32 %x, 42
Expand Down Expand Up @@ -2025,9 +2011,7 @@ define i32 @uadd_sat_canon_y_nuw(i32 %x, i32 %y) {

define <4 x i32> @uadd_sat_constant_vec(<4 x i32> %x) {
; CHECK-LABEL: @uadd_sat_constant_vec(
; CHECK-NEXT: [[A:%.*]] = add <4 x i32> [[X:%.*]], <i32 42, i32 42, i32 42, i32 42>
; CHECK-NEXT: [[C:%.*]] = icmp ugt <4 x i32> [[X]], <i32 -43, i32 -43, i32 -43, i32 -43>
; CHECK-NEXT: [[R:%.*]] = select <4 x i1> [[C]], <4 x i32> <i32 -1, i32 -1, i32 -1, i32 -1>, <4 x i32> [[A]]
; CHECK-NEXT: [[R:%.*]] = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> [[X:%.*]], <4 x i32> <i32 42, i32 42, i32 42, i32 42>)
; CHECK-NEXT: ret <4 x i32> [[R]]
;
%a = add <4 x i32> %x, <i32 42, i32 42, i32 42, i32 42>
Expand Down

0 comments on commit 55f372d

Please sign in to comment.