Skip to content

Commit

Permalink
[SCEVExpander] Remove unnecessary expandCodeForImpl() wrapper (NFC)
Browse files Browse the repository at this point in the history
expandCodeFor() was directly calling expandCodeForImpl(). Drop the
Impl variant.
  • Loading branch information
nikic committed Sep 21, 2023
1 parent b5ff71e commit 32e15ae
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 68 deletions.
24 changes: 3 additions & 21 deletions llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,16 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {

/// Insert code to directly compute the specified SCEV expression into the
/// program. The code is inserted into the specified block.
Value *expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I);
Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I) {
return expandCodeForImpl(SH, Ty, I);
return expandCodeFor(SH, Ty, I->getIterator());
}

/// Insert code to directly compute the specified SCEV expression into the
/// program. The code is inserted into the SCEVExpander's current
/// insertion point. If a type is specified, the result will be expanded to
/// have that type, with a cast if necessary.
Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr) {
return expandCodeForImpl(SH, Ty);
}
Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr);

/// Generates a code sequence that evaluates this predicate. The inserted
/// instructions will be at position \p Loc. The result will be of type i1
Expand Down Expand Up @@ -396,23 +395,6 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
private:
LLVMContext &getContext() const { return SE.getContext(); }

/// Insert code to directly compute the specified SCEV expression into the
/// program. The code is inserted into the SCEVExpander's current
/// insertion point. If a type is specified, the result will be expanded to
/// have that type, with a cast if necessary. If \p Root is true, this
/// indicates that \p SH is the top-level expression to expand passed from
/// an external client call.
Value *expandCodeForImpl(const SCEV *SH, Type *Ty);

/// Insert code to directly compute the specified SCEV expression into the
/// program. The code is inserted into the specified block. If \p
/// Root is true, this indicates that \p SH is the top-level expression to
/// expand passed from an external client call.
Value *expandCodeForImpl(const SCEV *SH, Type *Ty, BasicBlock::iterator I);
Value *expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *I) {
return expandCodeForImpl(SH, Ty, I->getIterator());
}

/// Recursive helper function for isHighCostExpansion.
bool isHighCostExpansionHelper(const SCEVOperand &WorkItem, Loop *L,
const Instruction &At, InstructionCost &Cost,
Expand Down
85 changes: 38 additions & 47 deletions llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Type *Ty, Value *V) {
assert(!isa<Instruction>(V) ||
SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint()));

Value *Idx = expandCodeForImpl(Offset, Ty);
Value *Idx = expandCodeFor(Offset, Ty);

// Fold a GEP with constant operands.
if (Constant *CLHS = dyn_cast<Constant>(V))
Expand Down Expand Up @@ -500,17 +500,18 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
Sum = expandAddToGEP(SE.getAddExpr(NewOps), Ty, Sum);
} else if (Op->isNonConstantNegative()) {
// Instead of doing a negate and add, just do a subtract.
Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty);
Value *W = expandCodeFor(SE.getNegativeSCEV(Op), Ty);
Sum = InsertNoopCastOfTo(Sum, Ty);
Sum = InsertBinop(Instruction::Sub, Sum, W, SCEV::FlagAnyWrap,
/*IsSafeToHoist*/ true);
++I;
} else {
// A simple add.
Value *W = expandCodeForImpl(Op, Ty);
Value *W = expandCodeFor(Op, Ty);
Sum = InsertNoopCastOfTo(Sum, Ty);
// Canonicalize a constant to the RHS.
if (isa<Constant>(Sum)) std::swap(Sum, W);
if (isa<Constant>(Sum))
std::swap(Sum, W);
Sum = InsertBinop(Instruction::Add, Sum, W, S->getNoWrapFlags(),
/*IsSafeToHoist*/ true);
++I;
Expand Down Expand Up @@ -558,7 +559,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {

// Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them
// that are needed into the result.
Value *P = expandCodeForImpl(I->second, Ty);
Value *P = expandCodeFor(I->second, Ty);
Value *Result = nullptr;
if (Exponent & 1)
Result = P;
Expand Down Expand Up @@ -617,7 +618,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
Type *Ty = SE.getEffectiveSCEVType(S->getType());

Value *LHS = expandCodeForImpl(S->getLHS(), Ty);
Value *LHS = expandCodeFor(S->getLHS(), Ty);
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) {
const APInt &RHS = SC->getAPInt();
if (RHS.isPowerOf2())
Expand All @@ -626,7 +627,7 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true);
}

Value *RHS = expandCodeForImpl(S->getRHS(), Ty);
Value *RHS = expandCodeFor(S->getRHS(), Ty);
return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
/*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
}
Expand Down Expand Up @@ -994,9 +995,8 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
// Expand code for the start value into the loop preheader.
assert(L->getLoopPreheader() &&
"Can't expand add recurrences without a loop preheader!");
Value *StartV =
expandCodeForImpl(Normalized->getStart(), ExpandTy,
L->getLoopPreheader()->getTerminator());
Value *StartV = expandCodeFor(Normalized->getStart(), ExpandTy,
L->getLoopPreheader()->getTerminator());

// StartV must have been be inserted into L's preheader to dominate the new
// phi.
Expand All @@ -1015,7 +1015,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
Step = SE.getNegativeSCEV(Step);
// Expand the step somewhere that dominates the loop header.
Value *StepV =
expandCodeForImpl(Step, IntTy, L->getHeader()->getFirstInsertionPt());
expandCodeFor(Step, IntTy, L->getHeader()->getFirstInsertionPt());

// The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if
// we actually do emit an addition. It does not apply if we emit a
Expand Down Expand Up @@ -1173,8 +1173,8 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
{
// Expand the step somewhere that dominates the loop header.
SCEVInsertPointGuard Guard(Builder, this);
StepV = expandCodeForImpl(Step, IntTy,
L->getHeader()->getFirstInsertionPt());
StepV =
expandCodeFor(Step, IntTy, L->getHeader()->getFirstInsertionPt());
}
Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract);
}
Expand All @@ -1193,31 +1193,29 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {

// Invert the result.
if (InvertStep)
Result = Builder.CreateSub(
expandCodeForImpl(Normalized->getStart(), TruncTy), Result);
Result = Builder.CreateSub(expandCodeFor(Normalized->getStart(), TruncTy),
Result);
}

// Re-apply any non-loop-dominating scale.
if (PostLoopScale) {
assert(S->isAffine() && "Can't linearly scale non-affine recurrences.");
Result = InsertNoopCastOfTo(Result, IntTy);
Result = Builder.CreateMul(Result,
expandCodeForImpl(PostLoopScale, IntTy));
Result = Builder.CreateMul(Result, expandCodeFor(PostLoopScale, IntTy));
}

// Re-apply any non-loop-dominating offset.
if (PostLoopOffset) {
if (isa<PointerType>(ExpandTy)) {
if (Result->getType()->isIntegerTy()) {
Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy);
Value *Base = expandCodeFor(PostLoopOffset, ExpandTy);
Result = expandAddToGEP(SE.getUnknown(Result), IntTy, Base);
} else {
Result = expandAddToGEP(PostLoopOffset, IntTy, Result);
}
} else {
Result = InsertNoopCastOfTo(Result, IntTy);
Result = Builder.CreateAdd(
Result, expandCodeForImpl(PostLoopOffset, IntTy));
Result = Builder.CreateAdd(Result, expandCodeFor(PostLoopOffset, IntTy));
}
}

Expand Down Expand Up @@ -1259,8 +1257,8 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
S->getNoWrapFlags(SCEV::FlagNW)));
BasicBlock::iterator NewInsertPt =
findInsertPointAfter(cast<Instruction>(V), &*Builder.GetInsertPoint());
V = expandCodeForImpl(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr,
NewInsertPt);
V = expandCodeFor(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr,
NewInsertPt);
return V;
}

Expand Down Expand Up @@ -1360,33 +1358,29 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
}

Value *SCEVExpander::visitPtrToIntExpr(const SCEVPtrToIntExpr *S) {
Value *V =
expandCodeForImpl(S->getOperand(), S->getOperand()->getType());
Value *V = expandCodeFor(S->getOperand(), S->getOperand()->getType());
return ReuseOrCreateCast(V, S->getType(), CastInst::PtrToInt,
GetOptimalInsertionPointForCastOf(V));
}

Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) {
Type *Ty = SE.getEffectiveSCEVType(S->getType());
Value *V = expandCodeForImpl(
S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType())
);
Value *V = expandCodeFor(S->getOperand(),
SE.getEffectiveSCEVType(S->getOperand()->getType()));
return Builder.CreateTrunc(V, Ty);
}

Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) {
Type *Ty = SE.getEffectiveSCEVType(S->getType());
Value *V = expandCodeForImpl(
S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType())
);
Value *V = expandCodeFor(S->getOperand(),
SE.getEffectiveSCEVType(S->getOperand()->getType()));
return Builder.CreateZExt(V, Ty);
}

Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
Type *Ty = SE.getEffectiveSCEVType(S->getType());
Value *V = expandCodeForImpl(
S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType())
);
Value *V = expandCodeFor(S->getOperand(),
SE.getEffectiveSCEVType(S->getOperand()->getType()));
return Builder.CreateSExt(V, Ty);
}

Expand All @@ -1398,7 +1392,7 @@ Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
if (IsSequential)
LHS = Builder.CreateFreeze(LHS);
for (int i = S->getNumOperands() - 2; i >= 0; --i) {
Value *RHS = expandCodeForImpl(S->getOperand(i), Ty);
Value *RHS = expandCodeFor(S->getOperand(i), Ty);
if (IsSequential && i != 0)
RHS = Builder.CreateFreeze(RHS);
Value *Sel;
Expand Down Expand Up @@ -1439,14 +1433,14 @@ Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
return Builder.CreateVScale(ConstantInt::get(S->getType(), 1));
}

Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty,
BasicBlock::iterator IP) {
Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
BasicBlock::iterator IP) {
setInsertPoint(IP);
Value *V = expandCodeForImpl(SH, Ty);
Value *V = expandCodeFor(SH, Ty);
return V;
}

Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty) {
Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) {
// Expand the code for this SCEV.
Value *V = expand(SH);

Expand Down Expand Up @@ -2103,10 +2097,8 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,

Value *SCEVExpander::expandComparePredicate(const SCEVComparePredicate *Pred,
Instruction *IP) {
Value *Expr0 =
expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP);
Value *Expr1 =
expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP);
Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP);
Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP);

Builder.SetInsertPoint(IP);
auto InvPred = ICmpInst::getInversePredicate(Pred->getPredicate());
Expand Down Expand Up @@ -2140,15 +2132,14 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,

IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits);
Builder.SetInsertPoint(Loc);
Value *TripCountVal = expandCodeForImpl(ExitCount, CountTy, Loc);
Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc);

IntegerType *Ty =
IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy));

Value *StepValue = expandCodeForImpl(Step, Ty, Loc);
Value *NegStepValue =
expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc);
Value *StartValue = expandCodeForImpl(Start, ARTy, Loc);
Value *StepValue = expandCodeFor(Step, Ty, Loc);
Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc);
Value *StartValue = expandCodeFor(Start, ARTy, Loc);

ConstantInt *Zero =
ConstantInt::get(Loc->getContext(), APInt::getZero(DstBits));
Expand Down

0 comments on commit 32e15ae

Please sign in to comment.