Skip to content

Commit

Permalink
[InstCombine] Canonicalize more saturated-add variants
Browse files Browse the repository at this point in the history
LLVM is not evaluating X u > C, a, b the same way it evaluates X <= C, b, a.

To fix this, let's move the folds to after the canonicalization of -1 to TrueVal.

Let's allow splat vectors with poison elements to be recognized too!

Finally, for completion, handle the one case that isn't caught by the above checks because it is canonicalized to eq:
X == -1 ? -1 : X + 1 -> uadd.sat(X, 1)

Alive2 Proof:
https://alive2.llvm.org/ce/z/WEcgYH
  • Loading branch information
AZero13 committed Aug 3, 2024
1 parent 16226ed commit 9542f3f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 116 deletions.
49 changes: 41 additions & 8 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,14 +977,7 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
Value *Cmp1 = Cmp->getOperand(1);
ICmpInst::Predicate Pred = Cmp->getPredicate();
Value *X;
const APInt *C, *CmpC;
if (Pred == ICmpInst::ICMP_ULT &&
match(TVal, m_Add(m_Value(X), m_APInt(C))) && X == Cmp0 &&
match(FVal, m_AllOnes()) && match(Cmp1, m_APInt(CmpC)) && *CmpC == ~*C) {
// (X u< ~C) ? (X + C) : -1 --> uadd.sat(X, C)
return Builder.CreateBinaryIntrinsic(
Intrinsic::uadd_sat, X, ConstantInt::get(X->getType(), *C));
}
const APInt *C;

// Match unsigned saturated add of 2 variables with an unnecessary 'not'.
// There are 8 commuted variants.
Expand All @@ -996,6 +989,46 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
if (!match(TVal, m_AllOnes()))
return nullptr;

// uge -1 is canonicalized to eq -1 and requires special handling
// (a == -1) ? -1 : a + 1 -> uadd.sat(a, 1)
if (Pred == ICmpInst::ICMP_EQ) {
if (match(FVal, m_Add(m_Specific(Cmp0), m_One())) &&
match(Cmp1, m_AllOnes())) {
return Builder.CreateBinaryIntrinsic(
Intrinsic::uadd_sat, Cmp0, ConstantInt::get(Cmp0->getType(), 1));
}
return nullptr;
}

if ((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) &&
match(FVal, m_Add(m_Specific(Cmp0), m_APIntAllowPoison(C))) &&
match(Cmp1, m_SpecificIntAllowPoison(~*C))) {
// (X u> ~C) ? -1 : (X + C) --> uadd.sat(X, C)
// (X u>= ~C)? -1 : (X + C) --> uadd.sat(X, C)
return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp0,
ConstantInt::get(Cmp0->getType(), *C));
}

// Negative one does not work here because X u> -1 ? -1, X + -1 is not a
// saturated add.
if (Pred == ICmpInst::ICMP_UGT &&
match(FVal, m_Add(m_Specific(Cmp0), m_APIntAllowPoison(C))) &&
match(Cmp1, m_SpecificIntAllowPoison(~*C - 1)) && !C->isAllOnes()) {
// (X u> ~C - 1) ? -1 : (X + C) --> uadd.sat(X, C)
return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp0,
ConstantInt::get(Cmp0->getType(), *C));
}

// Zero does not work here because X u>= 0 ? -1 : X -> is always -1, which is
// not a saturated add.
if (Pred == ICmpInst::ICMP_UGE &&
match(FVal, m_Add(m_Specific(Cmp0), m_APIntAllowPoison(C))) &&
match(Cmp1, m_SpecificIntAllowPoison(-*C)) && !C->isZero()) {
// (X u >= -C) ? -1 : (X + C) --> uadd.sat(X, C)
return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp0,
ConstantInt::get(Cmp0->getType(), *C));
}

// Canonicalize predicate to less-than or less-or-equal-than.
if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
std::swap(Cmp0, Cmp1);
Expand Down
116 changes: 8 additions & 108 deletions llvm/test/Transforms/InstCombine/saturating-add-sub.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1398,9 +1398,7 @@ define i32 @uadd_sat(i32 %x, i32 %y) {

define i32 @uadd_sat_flipped(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -11
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 9)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp ugt i32 %x, -11
Expand All @@ -1411,9 +1409,7 @@ define i32 @uadd_sat_flipped(i32 %x) {

define i32 @uadd_sat_flipped2(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped2(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -10
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 9)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp ugt i32 %x, -10
Expand Down Expand Up @@ -1450,80 +1446,9 @@ define i32 @uadd_sat_flipped3_neg_no_nuw(i32 %x) {
ret i32 %cond
}

define i32 @uadd_sat_flipped4(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped4(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -10
; CHECK-NEXT: [[ADD:%.*]] = add nuw i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[ADD]], i32 -1
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp uge i32 %x, -9
%add = add nuw i32 %x, 9
%cond = select i1 %cmp, i32 %add, i32 -1
ret i32 %cond
}

; Negative Test

define i32 @uadd_sat_flipped4_neg_no_nuw(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped4_neg_no_nuw(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -10
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[ADD]], i32 -1
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp uge i32 %x, -9
%add = add i32 %x, 9
%cond = select i1 %cmp, i32 %add, i32 -1
ret i32 %cond
}

define i32 @uadd_sat_flipped5(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped5(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -9
; CHECK-NEXT: [[ADD:%.*]] = add nuw i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp uge i32 %x, -8
%add = add nuw i32 %x, 9
%cond = select i1 %cmp, i32 -1, i32 %add
ret i32 %cond
}

; Negative test

define i32 @uadd_sat_flipped5_neg_no_nuw(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped5_neg_no_nuw(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -9
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp uge i32 %x, -8
%add = add i32 %x, 9
%cond = select i1 %cmp, i32 -1, i32 %add
ret i32 %cond
}

define i32 @uadd_sat_flipped6(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped6(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -9
; CHECK-NEXT: [[ADD:%.*]] = add nuw i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp ugt i32 %x, -9
%add = add nuw i32 %x, 9
%cond = select i1 %cmp, i32 -1, i32 %add
ret i32 %cond
}

define i32 @uadd_sat_negative_one(i32 %x) {
; CHECK-LABEL: @uadd_sat_negative_one(
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], -1
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 1
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 1)
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp eq i32 %x, -1
Expand Down Expand Up @@ -1552,21 +1477,6 @@ define i32 @uadd_sat_negative_one_poison_all(i32 %x) {
ret i32 %cond
}

; Negative test

define i32 @uadd_sat_flipped_neg_no_nuw(i32 %x) {
; CHECK-LABEL: @uadd_sat_flipped_neg_no_nuw(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], -9
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 9
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[ADD]]
; CHECK-NEXT: ret i32 [[COND]]
;
%cmp = icmp ugt i32 %x, -9
%add = add i32 %x, 9
%cond = select i1 %cmp, i32 -1, i32 %add
ret i32 %cond
}

define i32 @uadd_sat_poison(i32 %x, i32 %y) {
; CHECK-LABEL: @uadd_sat_poison(
; CHECK-NEXT: ret i32 poison
Expand Down Expand Up @@ -1651,9 +1561,7 @@ define <2 x i8> @uadd_sat_flipped4_vector(<2 x i8> %x) {

define <2 x i8> @uadd_sat_flipped4_poison_vector(<2 x i8> %x) {
; CHECK-LABEL: @uadd_sat_flipped4_poison_vector(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i8> [[X:%.*]], <i8 -10, i8 poison>
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i8> [[X]], <i8 9, i8 9>
; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[CMP]], <2 x i8> [[ADD]], <2 x i8> <i8 -1, i8 -1>
; CHECK-NEXT: [[COND:%.*]] = call <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8> [[X:%.*]], <2 x i8> <i8 9, i8 9>)
; CHECK-NEXT: ret <2 x i8> [[COND]]
;
%cmp = icmp ult <2 x i8> %x, <i8 -10, i8 poison>
Expand All @@ -1664,9 +1572,7 @@ define <2 x i8> @uadd_sat_flipped4_poison_vector(<2 x i8> %x) {

define <2 x i8> @uadd_sat_flipped4_poison_vector_compare(<2 x i8> %x) {
; CHECK-LABEL: @uadd_sat_flipped4_poison_vector_compare(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult <2 x i8> [[X:%.*]], <i8 -10, i8 poison>
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i8> [[X]], <i8 9, i8 poison>
; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[CMP]], <2 x i8> [[ADD]], <2 x i8> <i8 -1, i8 -1>
; CHECK-NEXT: [[COND:%.*]] = call <2 x i8> @llvm.uadd.sat.v2i8(<2 x i8> [[X:%.*]], <2 x i8> <i8 9, i8 9>)
; CHECK-NEXT: ret <2 x i8> [[COND]]
;
%cmp = icmp ult <2 x i8> %x, <i8 -10, i8 poison>
Expand Down Expand Up @@ -2162,9 +2068,7 @@ define i32 @uadd_sat_not_commute_select_uge_commute_add(i32 %x, i32 %y) {

define i32 @uadd_sat_constant(i32 %x) {
; CHECK-LABEL: @uadd_sat_constant(
; CHECK-NEXT: [[A:%.*]] = add i32 [[X:%.*]], 42
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[X]], -43
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 -1, i32 [[A]]
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 42)
; CHECK-NEXT: ret i32 [[R]]
;
%a = add i32 %x, 42
Expand Down Expand Up @@ -2230,9 +2134,7 @@ define i32 @uadd_sat_canon_y_nuw(i32 %x, i32 %y) {

define <4 x i32> @uadd_sat_constant_vec(<4 x i32> %x) {
; CHECK-LABEL: @uadd_sat_constant_vec(
; CHECK-NEXT: [[A:%.*]] = add <4 x i32> [[X:%.*]], <i32 42, i32 42, i32 42, i32 42>
; CHECK-NEXT: [[C:%.*]] = icmp ugt <4 x i32> [[X]], <i32 -43, i32 -43, i32 -43, i32 -43>
; CHECK-NEXT: [[R:%.*]] = select <4 x i1> [[C]], <4 x i32> <i32 -1, i32 -1, i32 -1, i32 -1>, <4 x i32> [[A]]
; CHECK-NEXT: [[R:%.*]] = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> [[X:%.*]], <4 x i32> <i32 42, i32 42, i32 42, i32 42>)
; CHECK-NEXT: ret <4 x i32> [[R]]
;
%a = add <4 x i32> %x, <i32 42, i32 42, i32 42, i32 42>
Expand All @@ -2254,9 +2156,7 @@ define <4 x i32> @uadd_sat_constant_vec_commute(<4 x i32> %x) {

define <4 x i32> @uadd_sat_constant_vec_commute_undefs(<4 x i32> %x) {
; CHECK-LABEL: @uadd_sat_constant_vec_commute_undefs(
; CHECK-NEXT: [[A:%.*]] = add <4 x i32> [[X:%.*]], <i32 42, i32 42, i32 42, i32 poison>
; CHECK-NEXT: [[C:%.*]] = icmp ult <4 x i32> [[X]], <i32 -43, i32 -43, i32 poison, i32 -43>
; CHECK-NEXT: [[R:%.*]] = select <4 x i1> [[C]], <4 x i32> [[A]], <4 x i32> <i32 -1, i32 poison, i32 -1, i32 -1>
; CHECK-NEXT: [[R:%.*]] = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> [[X:%.*]], <4 x i32> <i32 42, i32 42, i32 42, i32 42>)
; CHECK-NEXT: ret <4 x i32> [[R]]
;
%a = add <4 x i32> %x, <i32 42, i32 42, i32 42, i32 poison>
Expand Down

0 comments on commit 9542f3f

Please sign in to comment.