Skip to content

Commit

Permalink
[InstCombine] Fold fcmp pred sqrt(X), 0.0 -> fcmp pred2 X, 0.0 (llv…
Browse files Browse the repository at this point in the history
…m#101626)

Proof (Please run alive-tv with larger smt-to):
https://alive2.llvm.org/ce/z/-aqixk
FMF propagation: https://alive2.llvm.org/ce/z/zyKK_p

```
sqrt(X) < 0.0 --> false
sqrt(X) u>= 0.0 --> true
sqrt(X) u< 0.0 --> X u< 0.0
sqrt(X) u<= 0.0 --> X u<= 0.0
sqrt(X) > 0.0 --> X > 0.0
sqrt(X) >= 0.0 --> X >= 0.0
sqrt(X) == 0.0 --> X == 0.0
sqrt(X) u!= 0.0 --> X u!= 0.0
sqrt(X) <= 0.0 --> X == 0.0
sqrt(X) u> 0.0 --> X u!= 0.0
sqrt(X) u== 0.0 --> X u<= 0.0
sqrt(X) != 0.0 --> X > 0.0
!isnan(sqrt(X)) --> X >= 0.0
isnan(sqrt(X)) --> X u< 0.0
```

In most cases, `sqrt` cannot be eliminated since it has multiple uses.
But this patch will break data dependencies and allow optimizer to sink
expensive `sqrt` calls into successor blocks.
  • Loading branch information
dtcxzyw authored Aug 3, 2024
1 parent ea18a40 commit 8bd9ade
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 3 deletions.
64 changes: 64 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7980,6 +7980,67 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
}
}

/// Optimize sqrt(X) compared with zero.
static Instruction *foldSqrtWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
Value *X;
if (!match(I.getOperand(0), m_Sqrt(m_Value(X))))
return nullptr;

if (!match(I.getOperand(1), m_PosZeroFP()))
return nullptr;

auto ReplacePredAndOp0 = [&](FCmpInst::Predicate P) {
I.setPredicate(P);
return IC.replaceOperand(I, 0, X);
};

// Clear ninf flag if sqrt doesn't have it.
if (!cast<Instruction>(I.getOperand(0))->hasNoInfs())
I.setHasNoInfs(false);

switch (I.getPredicate()) {
case FCmpInst::FCMP_OLT:
case FCmpInst::FCMP_UGE:
// sqrt(X) < 0.0 --> false
// sqrt(X) u>= 0.0 --> true
llvm_unreachable("fcmp should have simplified");
case FCmpInst::FCMP_ULT:
case FCmpInst::FCMP_ULE:
case FCmpInst::FCMP_OGT:
case FCmpInst::FCMP_OGE:
case FCmpInst::FCMP_OEQ:
case FCmpInst::FCMP_UNE:
// sqrt(X) u< 0.0 --> X u< 0.0
// sqrt(X) u<= 0.0 --> X u<= 0.0
// sqrt(X) > 0.0 --> X > 0.0
// sqrt(X) >= 0.0 --> X >= 0.0
// sqrt(X) == 0.0 --> X == 0.0
// sqrt(X) u!= 0.0 --> X u!= 0.0
return IC.replaceOperand(I, 0, X);

case FCmpInst::FCMP_OLE:
// sqrt(X) <= 0.0 --> X == 0.0
return ReplacePredAndOp0(FCmpInst::FCMP_OEQ);
case FCmpInst::FCMP_UGT:
// sqrt(X) u> 0.0 --> X u!= 0.0
return ReplacePredAndOp0(FCmpInst::FCMP_UNE);
case FCmpInst::FCMP_UEQ:
// sqrt(X) u== 0.0 --> X u<= 0.0
return ReplacePredAndOp0(FCmpInst::FCMP_ULE);
case FCmpInst::FCMP_ONE:
// sqrt(X) != 0.0 --> X > 0.0
return ReplacePredAndOp0(FCmpInst::FCMP_OGT);
case FCmpInst::FCMP_ORD:
// !isnan(sqrt(X)) --> X >= 0.0
return ReplacePredAndOp0(FCmpInst::FCMP_OGE);
case FCmpInst::FCMP_UNO:
// isnan(sqrt(X)) --> X u< 0.0
return ReplacePredAndOp0(FCmpInst::FCMP_ULT);
default:
llvm_unreachable("Unexpected predicate!");
}
}

static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) {
CmpInst::Predicate Pred = I.getPredicate();
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Expand Down Expand Up @@ -8247,6 +8308,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
if (Instruction *R = foldFabsWithFcmpZero(I, *this))
return R;

if (Instruction *R = foldSqrtWithFcmpZero(I, *this))
return R;

if (match(Op0, m_FNeg(m_Value(X)))) {
// fcmp pred (fneg X), C --> fcmp swap(pred) X, -C
Constant *C;
Expand Down
233 changes: 233 additions & 0 deletions llvm/test/Transforms/InstCombine/fcmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2117,3 +2117,236 @@ define <8 x i1> @fcmp_ogt_fsub_const_vec_denormal_preserve-sign(<8 x float> %x,
%cmp = fcmp ogt <8 x float> %fs, zeroinitializer
ret <8 x i1> %cmp
}

define i1 @fcmp_sqrt_zero_olt(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_olt(
; CHECK-NEXT: ret i1 false
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp olt half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ult(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ult(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ult half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ult_fmf(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ult_fmf(
; CHECK-NEXT: [[CMP:%.*]] = fcmp nsz ult half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ninf nsz ult half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ult_fmf_sqrt_ninf(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ult_fmf_sqrt_ninf(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ninf nsz ult half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call ninf half @llvm.sqrt.f16(half %x)
%cmp = fcmp ninf nsz ult half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ult_nzero(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ult_nzero(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ult half %sqrt, -0.0
ret i1 %cmp
}

define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%sqrt = call <2 x half> @llvm.sqrt.v2f16(<2 x half> %x)
%cmp = fcmp ult <2 x half> %sqrt, zeroinitializer
ret <2 x i1> %cmp
}

define <2 x i1> @fcmp_sqrt_zero_ult_vec_mixed_zero(<2 x half> %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec_mixed_zero(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%sqrt = call <2 x half> @llvm.sqrt.v2f16(<2 x half> %x)
%cmp = fcmp ult <2 x half> %sqrt, <half 0.0, half -0.0>
ret <2 x i1> %cmp
}

define i1 @fcmp_sqrt_zero_ole(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ole(
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ole half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ule(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ule(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ule half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ogt(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ogt(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ogt half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ugt(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ugt(
; CHECK-NEXT: [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ugt half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_oge(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_oge(
; CHECK-NEXT: [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp oge half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_uge(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_uge(
; CHECK-NEXT: ret i1 true
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp uge half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_oeq(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_oeq(
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp oeq half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ueq(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ueq(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ueq half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_one(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_one(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp one half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_une(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_une(
; CHECK-NEXT: [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp une half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ord(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ord(
; CHECK-NEXT: [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ord half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_uno(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_uno(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp uno half %sqrt, 0.0
ret i1 %cmp
}

; Make sure that ninf is cleared.
define i1 @fcmp_sqrt_zero_uno_fmf(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_uno_fmf(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ninf uno half %sqrt, 0.0
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_uno_fmf_sqrt_ninf(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_uno_fmf_sqrt_ninf(
; CHECK-NEXT: [[CMP:%.*]] = fcmp ninf ult half [[X:%.*]], 0xH0000
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call ninf half @llvm.sqrt.f16(half %x)
%cmp = fcmp ninf uno half %sqrt, 0.0
ret i1 %cmp
}

; negative tests

define i1 @fcmp_sqrt_zero_ult_var(half %x, half %y) {
; CHECK-LABEL: @fcmp_sqrt_zero_ult_var(
; CHECK-NEXT: [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[SQRT]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ult half %sqrt, %y
ret i1 %cmp
}

define i1 @fcmp_sqrt_zero_ult_nonzero(half %x) {
; CHECK-LABEL: @fcmp_sqrt_zero_ult_nonzero(
; CHECK-NEXT: [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[SQRT]], 0xH3C00
; CHECK-NEXT: ret i1 [[CMP]]
;
%sqrt = call half @llvm.sqrt.f16(half %x)
%cmp = fcmp ult half %sqrt, 1.000000e+00
ret i1 %cmp
}
4 changes: 1 addition & 3 deletions llvm/test/Transforms/InstCombine/known-never-nan.ll
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@

define i1 @fabs_sqrt_src_maybe_nan(double %arg0, double %arg1) {
; CHECK-LABEL: @fabs_sqrt_src_maybe_nan(
; CHECK-NEXT: [[FABS:%.*]] = call double @llvm.fabs.f64(double [[ARG0:%.*]])
; CHECK-NEXT: [[OP:%.*]] = call double @llvm.sqrt.f64(double [[FABS]])
; CHECK-NEXT: [[TMP:%.*]] = fcmp ord double [[OP]], 0.000000e+00
; CHECK-NEXT: [[TMP:%.*]] = fcmp ord double [[ARG0:%.*]], 0.000000e+00
; CHECK-NEXT: ret i1 [[TMP]]
;
%fabs = call double @llvm.fabs.f64(double %arg0)
Expand Down

0 comments on commit 8bd9ade

Please sign in to comment.