-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change contents of ScalarEvolution from private to protected #83052
Open
skewballfox
wants to merge
15
commits into
llvm:main
Choose a base branch
from
skewballfox:scev_protect
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+147
−36
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
eea887c
mainly pushing to switch machines
skewballfox e47436b
added AssumeLoopExits bool to SE, lifting MustExit code into SE
skewballfox f55e361
added MustExitcode for computeExitLimit
skewballfox 8e85c06
added enzyme mustExit code to computeExitLimitFromSingleExitSwitch
skewballfox 3f378b5
add enzyme must exit code to computeExitLimitFromCondImpl
skewballfox 14a0c6c
implemented enzyme must exit code in computeExitLimitFromICmp
skewballfox abb0ab4
add Enzyme changes to SE howManyLessThans
skewballfox c1d83de
fixed issue in howManyLessThans where conditions were incorrectly dep…
skewballfox 66ab0c3
incorporating changes from code review
skewballfox 9b57191
removed unrelated change
skewballfox 5776793
moved mustexit code to other computeExitLimitFromICmp definition
skewballfox bdabce8
reran git clang-format HEAD~1
skewballfox a9c9251
removed redundant binOp code from CondImpl
skewballfox 7597a83
implenting requested changes
skewballfox 16d19f6
forgot to remove include for deleted file
skewballfox File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -509,6 +509,8 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) { | |
return S; | ||
} | ||
|
||
void ScalarEvolution::setAssumeLoopExits() { this->AssumeLoopFinite = true; } | ||
|
||
SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, | ||
const SCEV *op, Type *ty) | ||
: SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} | ||
|
@@ -7413,7 +7415,8 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { | |
// A mustprogress loop without side effects must be finite. | ||
// TODO: The check used here is very conservative. It's only *specific* | ||
// side effects which are well defined in infinite loops. | ||
return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); | ||
return AssumeLoopFinite || isFinite(L) || | ||
(isMustProgress(L) && loopHasNoSideEffects(L)); | ||
} | ||
|
||
const SCEV *ScalarEvolution::createSCEVIter(Value *V) { | ||
|
@@ -8828,6 +8831,26 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, | |
ScalarEvolution::ExitLimit | ||
ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, | ||
bool AllowPredicates) { | ||
if (AssumeLoopFinite) { | ||
SmallVector<BasicBlock *, 8> ExitingBlocks; | ||
L->getExitingBlocks(ExitingBlocks); | ||
for (auto &ExitingBlock : ExitingBlocks) { | ||
BasicBlock *Exit = nullptr; | ||
for (auto *SBB : successors(ExitingBlock)) { | ||
if (!L->contains(SBB)) { | ||
if (GuaranteedUnreachable.count(SBB)) | ||
continue; | ||
Exit = SBB; | ||
break; | ||
} | ||
} | ||
if (!Exit) | ||
ExitingBlock = nullptr; | ||
} | ||
ExitingBlocks.erase( | ||
std::remove(ExitingBlocks.begin(), ExitingBlocks.end(), nullptr), | ||
ExitingBlocks.end()); | ||
} | ||
assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); | ||
// If our exiting block does not dominate the latch, then its connection with | ||
// loop's exit limit may be far from trivial. | ||
|
@@ -8853,6 +8876,8 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, | |
BasicBlock *Exit = nullptr; | ||
for (auto *SBB : successors(ExitingBlock)) | ||
if (!L->contains(SBB)) { | ||
if (AssumeLoopFinite and GuaranteedUnreachable.count(SBB)) | ||
continue; | ||
if (Exit) // Multiple exit successors. | ||
return getCouldNotCompute(); | ||
Exit = SBB; | ||
|
@@ -8923,6 +8948,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( | |
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( | ||
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, | ||
bool ControlsOnlyExit, bool AllowPredicates) { | ||
|
||
// Handle BinOp conditions (And, Or). | ||
if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( | ||
Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) | ||
|
@@ -8950,6 +8976,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( | |
if (ExitIfTrue == !CI->getZExtValue()) | ||
// The backedge is always taken. | ||
return getCouldNotCompute(); | ||
|
||
// The backedge is never taken. | ||
return getZero(CI->getType()); | ||
} | ||
|
@@ -8961,9 +8988,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( | |
const APInt *C; | ||
if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) && | ||
match(WO->getRHS(), m_APInt(C))) { | ||
ConstantRange NWR = | ||
ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C, | ||
WO->getNoWrapKind()); | ||
ConstantRange NWR = ConstantRange::makeExactNoWrapRegion( | ||
WO->getBinaryOp(), *C, WO->getNoWrapKind()); | ||
CmpInst::Predicate Pred; | ||
APInt NewRHSC, Offset; | ||
NWR.getEquivalentICmp(Pred, NewRHSC, Offset); | ||
|
@@ -9019,13 +9045,15 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( | |
const SCEV *SymbolicMaxBECount = getCouldNotCompute(); | ||
if (EitherMayExit) { | ||
bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond); | ||
|
||
// Both conditions must be same for the loop to continue executing. | ||
// Choose the less conservative count. | ||
if (EL0.ExactNotTaken != getCouldNotCompute() && | ||
EL1.ExactNotTaken != getCouldNotCompute()) { | ||
BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken, | ||
UseSequentialUMin); | ||
} | ||
|
||
if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) | ||
ConstantMaxBECount = EL1.ConstantMaxNotTaken; | ||
else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) | ||
|
@@ -9045,6 +9073,12 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( | |
// For now, be conservative. | ||
if (EL0.ExactNotTaken == EL1.ExactNotTaken) | ||
BECount = EL0.ExactNotTaken; | ||
// This was executed in Enzyme's must exit code under the | ||
// logic for when the binary op was OR | ||
if (AssumeLoopFinite && !IsAnd) { | ||
if (EL0.ExactNotTaken == EL1.ExactNotTaken) | ||
ConstantMaxBECount = EL0.ExactNotTaken; | ||
} | ||
} | ||
|
||
// There are cases (e.g. PR26207) where computeExitLimitFromCond is able | ||
|
@@ -9053,12 +9087,14 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( | |
// and | ||
// EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and | ||
// EL1.ConstantMaxNotTaken to not. | ||
if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) && | ||
!isa<SCEVCouldNotCompute>(BECount)) | ||
ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); | ||
if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount)) | ||
SymbolicMaxBECount = | ||
isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount; | ||
if (!AssumeLoopFinite || !IsAnd) { // should skip if assume exits and OR | ||
if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) && | ||
!isa<SCEVCouldNotCompute>(BECount)) | ||
ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); | ||
if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount)) | ||
SymbolicMaxBECount = | ||
isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount; | ||
} | ||
return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, | ||
{ &EL0.Predicates, &EL1.Predicates }); | ||
} | ||
|
@@ -9082,8 +9118,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( | |
if (EL.hasAnyInfo()) | ||
return EL; | ||
|
||
auto *ExhaustiveCount = | ||
computeExitCountExhaustively(L, ExitCond, ExitIfTrue); | ||
auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue); | ||
|
||
if (!isa<SCEVCouldNotCompute>(ExhaustiveCount)) | ||
return ExhaustiveCount; | ||
|
@@ -9094,7 +9129,31 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( | |
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( | ||
const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, | ||
bool ControlsOnlyExit, bool AllowPredicates) { | ||
|
||
if (AssumeLoopFinite) { | ||
#define PROP_PHI(LHS) \ | ||
if (auto un = dyn_cast<SCEVUnknown>(LHS)) { \ | ||
if (auto pn = dyn_cast_or_null<PHINode>(un->getValue())) { \ | ||
const SCEV *sc = nullptr; \ | ||
bool failed = false; \ | ||
for (auto &a : pn->incoming_values()) { \ | ||
auto subsc = getSCEV(a); \ | ||
if (sc == nullptr) { \ | ||
sc = subsc; \ | ||
continue; \ | ||
} \ | ||
if (subsc != sc) { \ | ||
failed = true; \ | ||
break; \ | ||
} \ | ||
} \ | ||
if (!failed) { \ | ||
LHS = sc; \ | ||
} \ | ||
} \ | ||
} | ||
PROP_PHI(LHS) | ||
PROP_PHI(RHS) | ||
} | ||
// Try to evaluate any dependencies out of the loop. | ||
LHS = getSCEVAtScope(LHS, L); | ||
RHS = getSCEVAtScope(RHS, L); | ||
|
@@ -9107,6 +9166,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( | |
Pred = ICmpInst::getSwappedPredicate(Pred); | ||
} | ||
|
||
// was not present in Enzyme code, the last condition is true if | ||
// AssumeLoopExits is true | ||
// will the first two checks cause enzyme to fail? | ||
bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) && | ||
loopIsFiniteByAssumption(L); | ||
// Simplify the operands before analyzing them. | ||
|
@@ -9184,15 +9246,19 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( | |
if (EL.hasAnyInfo()) return EL; | ||
break; | ||
} | ||
|
||
case ICmpInst::ICMP_SLE: | ||
case ICmpInst::ICMP_ULE: | ||
// Since the loop is finite, an invariant RHS cannot include the boundary | ||
// value, otherwise it would loop forever. | ||
if (!EnableFiniteLoopControl || !ControllingFiniteLoop || | ||
!isLoopInvariant(RHS, L)) | ||
break; | ||
RHS = getAddExpr(getOne(RHS->getType()), RHS); | ||
if (!AssumeLoopFinite) { | ||
// Since the loop is finite, an invariant RHS cannot include the boundary | ||
// value, otherwise it would loop forever. | ||
if (!EnableFiniteLoopControl || !ControllingFiniteLoop || | ||
!isLoopInvariant(RHS, L)) | ||
break; | ||
RHS = getAddExpr(getOne(RHS->getType()), RHS); | ||
} | ||
[[fallthrough]]; | ||
|
||
case ICmpInst::ICMP_SLT: | ||
case ICmpInst::ICMP_ULT: { // while (X < Y) | ||
bool IsSigned = ICmpInst::isSigned(Pred); | ||
|
@@ -9204,16 +9270,33 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( | |
} | ||
case ICmpInst::ICMP_SGE: | ||
case ICmpInst::ICMP_UGE: | ||
// Since the loop is finite, an invariant RHS cannot include the boundary | ||
// value, otherwise it would loop forever. | ||
if (!EnableFiniteLoopControl || !ControllingFiniteLoop || | ||
!isLoopInvariant(RHS, L)) | ||
break; | ||
RHS = getAddExpr(getMinusOne(RHS->getType()), RHS); | ||
if (!AssumeLoopFinite) { | ||
// Since the loop is finite, an invariant RHS cannot include the boundary | ||
// value, otherwise it would loop forever. | ||
if (!EnableFiniteLoopControl || !ControllingFiniteLoop || | ||
!isLoopInvariant(RHS, L)) | ||
break; | ||
RHS = getAddExpr(getMinusOne(RHS->getType()), RHS); | ||
} | ||
[[fallthrough]]; | ||
case ICmpInst::ICMP_SGT: | ||
case ICmpInst::ICMP_UGT: { // while (X > Y) | ||
bool IsSigned = ICmpInst::isSigned(Pred); | ||
if (AssumeLoopFinite) { | ||
if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) { | ||
if (!isa<IntegerType>(RHS->getType())) | ||
break; | ||
SmallVector<const SCEV *, 2> sv = { | ||
RHS, getConstant( | ||
ConstantInt::get(cast<IntegerType>(RHS->getType()), -1))}; | ||
// Since this is not an infinite loop by induction, RHS cannot be | ||
// int_min/uint_min Therefore subtracting 1 does not wrap. | ||
if (IsSigned) | ||
RHS = getAddExpr(sv, SCEV::FlagNSW); | ||
else | ||
RHS = getAddExpr(sv, SCEV::FlagNUW); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like duplicated code from above (the EnableFiniteLoopControl bit). |
||
} | ||
} | ||
ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, | ||
AllowPredicates); | ||
if (EL.hasAnyInfo()) | ||
|
@@ -9238,8 +9321,14 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, | |
if (Switch->getDefaultDest() == ExitingBlock) | ||
return getCouldNotCompute(); | ||
|
||
assert(L->contains(Switch->getDefaultDest()) && | ||
"Default case must not exit the loop!"); | ||
// if not using enzyme executes by default | ||
// if using enzyme and the code is guaranteed unreachable, | ||
// the default destination doesn't matter | ||
if (!AssumeLoopFinite || | ||
!GuaranteedUnreachable.count(Switch->getDefaultDest())) { | ||
assert(L->contains(Switch->getDefaultDest()) && | ||
"Default case must not exit the loop!"); | ||
} | ||
const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); | ||
const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); | ||
|
||
|
@@ -12752,9 +12841,9 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, | |
// If RHS <=u Limit, then there must exist a value V in the sequence | ||
// defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and | ||
// V <=u UINT_MAX. Thus, we must exit the loop before unsigned | ||
// overflow occurs. This limit also implies that a signed comparison | ||
// (in the wide bitwidth) is equivalent to an unsigned comparison as | ||
// the high bits on both sides must be zero. | ||
// overflow occurs. This limit also implies that a signed | ||
// comparison (in the wide bitwidth) is equivalent to an unsigned | ||
// comparison as the high bits on both sides must be zero. | ||
APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this)); | ||
APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1); | ||
Limit = Limit.zext(OuterBitWidth); | ||
|
@@ -12765,6 +12854,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, | |
Flags = setFlags(Flags, SCEV::FlagNUW); | ||
|
||
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags); | ||
|
||
if (AR->hasNoUnsignedWrap()) { | ||
// Emulate what getZeroExtendExpr would have done during construction | ||
// if we'd been able to infer the fact just above at that time. | ||
|
@@ -12848,6 +12938,13 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, | |
!loopHasNoAbnormalExits(L)) | ||
return getCouldNotCompute(); | ||
|
||
// This bailout is protecting the logic in computeMaxBECountForLT which | ||
// has not yet been sufficiently auditted or tested with negative strides. | ||
// We used to filter out all known-non-positive cases here, we're in the | ||
// process of being less restrictive bit by bit. | ||
if (AssumeLoopFinite && IsSigned && isKnownNonPositive(Stride)) | ||
return getCouldNotCompute(); | ||
|
||
if (!isKnownNonZero(Stride)) { | ||
// If we have a step of zero, and RHS isn't invariant in L, we don't know | ||
// if it might eventually be greater than start and if so, on which | ||
|
@@ -12977,13 +13074,20 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, | |
if (!BECount) { | ||
auto canProveRHSGreaterThanEqualStart = [&]() { | ||
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; | ||
const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); | ||
const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); | ||
|
||
if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) || | ||
isKnownPredicate(CondGE, GuardedRHS, GuardedStart)) | ||
if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart)) { | ||
return true; | ||
|
||
} | ||
// In the Enzyme MustExitScalarEvolutionCode, this check was missing | ||
// I do not have enough context to know if these two checks should be | ||
// mutually Exclusive. If they aren't then this bool check is unnecessary | ||
if (!AssumeLoopFinite) { | ||
const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); | ||
const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); | ||
|
||
if (isKnownPredicate(CondGE, GuardedRHS, GuardedStart)) | ||
return true; | ||
} | ||
// (RHS > Start - 1) implies RHS >= Start. | ||
// * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if | ||
// "Start - 1" doesn't overflow. | ||
|
@@ -13120,7 +13224,10 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, | |
if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) && | ||
!isa<SCEVCouldNotCompute>(BECount)) | ||
ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); | ||
|
||
if (AssumeLoopFinite) { | ||
return ExitLimit(BECount, ConstantMaxBECount, ConstantMaxBECount, MaxOrZero, | ||
Predicates); | ||
} | ||
const SCEV *SymbolicMaxBECount = | ||
isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount; | ||
return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero, | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assignment to RHS should be outside the "if"?