Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Transform (fcmp + fadd + sel) into (fcmp + sel + fadd) #106492

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions llvm/include/llvm/IR/FMF.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ class FastMathFlags {

/// Print fast-math flags to \p O.
void print(raw_ostream &O) const;

/// Intersect rewrite-based flags
static inline FastMathFlags intersectRewrite(FastMathFlags LHS,
FastMathFlags RHS) {
const unsigned RewriteMask =
AllowReassoc | AllowReciprocal | AllowContract | ApproxFunc;
return FastMathFlags(RewriteMask & LHS.Flags & RHS.Flags);
}

/// Union value flags
static inline FastMathFlags unionValue(FastMathFlags LHS, FastMathFlags RHS) {
const unsigned ValueMask = NoNaNs | NoInfs | NoSignedZeros;
return FastMathFlags(ValueMask & (LHS.Flags | RHS.Flags));
}
};

inline FastMathFlags operator|(FastMathFlags LHS, FastMathFlags RHS) {
Expand Down
57 changes: 57 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3640,6 +3640,60 @@ static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected,
return false;
}

// This transformation enables the possibility of transforming fcmp + sel into
// a fmaxnum/fminnum intrinsic.
rajatbajpai marked this conversation as resolved.
Show resolved Hide resolved
static Value *foldSelectIntoAddConstant(SelectInst &SI,
InstCombiner::BuilderTy &Builder) {
// Do this transformation only when select instruction gives NaN and NSZ
// guarantee.
auto *SIFOp = dyn_cast<FPMathOperator>(&SI);
if (!SIFOp || !SIFOp->hasNoSignedZeros() || !SIFOp->hasNoNaNs())
return nullptr;

// select((fcmp Pred, X, 0), (fadd X, C), C)
// => fadd((select (fcmp Pred, X, 0), X, 0), C)
//
// Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE
Instruction *FAdd;
Constant *C;
Value *X, *Z;
CmpInst::Predicate Pred;

// Note: OneUse check for `Cmp` is necessary because it makes sure that other
// InstCombine folds don't undo this transformation and cause an infinite
// loop. Furthermore, it could also increase the operation count.
if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
m_OneUse(m_Instruction(FAdd)), m_Constant(C))) ||
match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
m_Constant(C), m_OneUse(m_Instruction(FAdd))))) {
Comment on lines +3665 to +3668
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should have am m_c_Select that handles this commuted case and returns the swapped predicate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this would make such cases much cleaner. Should we extend the pattern match as part of this change, or should we do it in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate, optional

// Only these relational predicates can be transformed into maxnum/minnum
// intrinsic.
if (!CmpInst::isRelational(Pred) || !match(Z, m_AnyZeroFP()))
return nullptr;

if (!match(FAdd, m_FAdd(m_Specific(X), m_Specific(C))))
return nullptr;
Comment on lines +3674 to +3675
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just match the m_FAdd up originally, instead of matching the temporary m_Instruction above. Same with the fcmp check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestions, but I didn't do so for two reasons:

  1. I believe it will make the original condition a little crowded.
  2. Getting FAdd "fast-math-flags" from both conditions will become complex.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really ought to add variants of the pattern matchers that extract the flags

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually have a pattern matcher implementation somewhere that extracts the flags, the problem I ran into trying to use it is that our helpers for setting fast-math flags on new instructions only support an Instruction source, not a FastMathFlags variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agree. I think these types of pattern matchers will make handling fast-math flags a little easier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have merged the FCmp into the original condition and kept the FAdd as separate because of the flags.


Value *NewSelect = Builder.CreateSelect(SI.getCondition(), X, Z, "", &SI);
NewSelect->takeName(&SI);

Value *NewFAdd = Builder.CreateFAdd(NewSelect, C);
NewFAdd->takeName(FAdd);

// Propagate FastMath flags
FastMathFlags SelectFMF = SI.getFastMathFlags();
FastMathFlags FAddFMF = FAdd->getFastMathFlags();
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(SelectFMF, FAddFMF) |
FastMathFlags::unionValue(SelectFMF, FAddFMF);
cast<Instruction>(NewFAdd)->setFastMathFlags(NewFMF);
cast<Instruction>(NewSelect)->setFastMathFlags(NewFMF);
Comment on lines +3683 to +3689
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next API cleanup should move this into the original Create* above


return NewFAdd;
}

return nullptr;
}

Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Expand Down Expand Up @@ -4036,6 +4090,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Value *V = foldRoundUpIntegerWithPow2Alignment(SI, Builder))
return replaceInstUsesWith(SI, V);

if (Value *V = foldSelectIntoAddConstant(SI, Builder))
return replaceInstUsesWith(SI, V);

// select(mask, mload(,,mask,0), 0) -> mload(,,mask,0)
// Load inst is intentionally not checked for hasOneUse()
if (match(FalseVal, m_Zero()) &&
Expand Down
Loading
Loading