Skip to content

Commit

Permalink
[InstCombine] Transform (fcmp + fadd + sel) into (fcmp + sel + fadd) (#…
Browse files Browse the repository at this point in the history
…106492)

Transform `fcmp + fadd + sel` into `fcmp + sel + fadd` which enables the
possibility of transforming `fcmp + sel` into `maxnum/minnum`
intrinsics.

Alive2 results:
https://alive2.llvm.org/ce/z/2cmimW
https://alive2.llvm.org/ce/z/Qh9ZJt
https://alive2.llvm.org/ce/z/vtLj3R
  • Loading branch information
rajatbajpai authored Nov 11, 2024
1 parent b816c26 commit ef2d6da
Show file tree
Hide file tree
Showing 3 changed files with 745 additions and 0 deletions.
14 changes: 14 additions & 0 deletions llvm/include/llvm/IR/FMF.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ class FastMathFlags {

/// Print fast-math flags to \p O.
void print(raw_ostream &O) const;

/// Intersect rewrite-based flags
static inline FastMathFlags intersectRewrite(FastMathFlags LHS,
FastMathFlags RHS) {
const unsigned RewriteMask =
AllowReassoc | AllowReciprocal | AllowContract | ApproxFunc;
return FastMathFlags(RewriteMask & LHS.Flags & RHS.Flags);
}

/// Union value flags
static inline FastMathFlags unionValue(FastMathFlags LHS, FastMathFlags RHS) {
const unsigned ValueMask = NoNaNs | NoInfs | NoSignedZeros;
return FastMathFlags(ValueMask & (LHS.Flags | RHS.Flags));
}
};

inline FastMathFlags operator|(FastMathFlags LHS, FastMathFlags RHS) {
Expand Down
57 changes: 57 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3645,6 +3645,60 @@ static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected,
return false;
}

// This transformation enables the possibility of transforming fcmp + sel into
// a fmaxnum/fminnum intrinsic.
static Value *foldSelectIntoAddConstant(SelectInst &SI,
InstCombiner::BuilderTy &Builder) {
// Do this transformation only when select instruction gives NaN and NSZ
// guarantee.
auto *SIFOp = dyn_cast<FPMathOperator>(&SI);
if (!SIFOp || !SIFOp->hasNoSignedZeros() || !SIFOp->hasNoNaNs())
return nullptr;

// select((fcmp Pred, X, 0), (fadd X, C), C)
// => fadd((select (fcmp Pred, X, 0), X, 0), C)
//
// Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE
Instruction *FAdd;
Constant *C;
Value *X, *Z;
CmpInst::Predicate Pred;

// Note: OneUse check for `Cmp` is necessary because it makes sure that other
// InstCombine folds don't undo this transformation and cause an infinite
// loop. Furthermore, it could also increase the operation count.
if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
m_OneUse(m_Instruction(FAdd)), m_Constant(C))) ||
match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
m_Constant(C), m_OneUse(m_Instruction(FAdd))))) {
// Only these relational predicates can be transformed into maxnum/minnum
// intrinsic.
if (!CmpInst::isRelational(Pred) || !match(Z, m_AnyZeroFP()))
return nullptr;

if (!match(FAdd, m_FAdd(m_Specific(X), m_Specific(C))))
return nullptr;

Value *NewSelect = Builder.CreateSelect(SI.getCondition(), X, Z, "", &SI);
NewSelect->takeName(&SI);

Value *NewFAdd = Builder.CreateFAdd(NewSelect, C);
NewFAdd->takeName(FAdd);

// Propagate FastMath flags
FastMathFlags SelectFMF = SI.getFastMathFlags();
FastMathFlags FAddFMF = FAdd->getFastMathFlags();
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(SelectFMF, FAddFMF) |
FastMathFlags::unionValue(SelectFMF, FAddFMF);
cast<Instruction>(NewFAdd)->setFastMathFlags(NewFMF);
cast<Instruction>(NewSelect)->setFastMathFlags(NewFMF);

return NewFAdd;
}

return nullptr;
}

Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Expand Down Expand Up @@ -4041,6 +4095,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Value *V = foldRoundUpIntegerWithPow2Alignment(SI, Builder))
return replaceInstUsesWith(SI, V);

if (Value *V = foldSelectIntoAddConstant(SI, Builder))
return replaceInstUsesWith(SI, V);

// select(mask, mload(,,mask,0), 0) -> mload(,,mask,0)
// Load inst is intentionally not checked for hasOneUse()
if (match(FalseVal, m_Zero()) &&
Expand Down
Loading

0 comments on commit ef2d6da

Please sign in to comment.