Skip to content

Commit

Permalink
[LateGCLowering] Fix skipped Select lifting (#35387)
Browse files Browse the repository at this point in the history
Fixes #35341
  • Loading branch information
Keno authored and staticfloat committed Apr 21, 2020
1 parent 48231a0 commit 116e2ef
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 65 deletions.
121 changes: 56 additions & 65 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ struct LateLowerGCFrame: public FunctionPass, private JuliaPassContext {
}

void LiftPhi(State &S, PHINode *Phi);
bool LiftSelect(State &S, SelectInst *SI);
void LiftSelect(State &S, SelectInst *SI);
Value *MaybeExtractScalar(State &S, std::pair<Value*,int> ValExpr, Instruction *InsertBefore);
std::vector<Value*> MaybeExtractVector(State &S, Value *BaseVec, Instruction *InsertBefore);
Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertBefore);
Expand Down Expand Up @@ -600,12 +600,12 @@ Value *LateLowerGCFrame::GetPtrForNumber(State &S, unsigned Num, Instruction *In
return MaybeExtractScalar(S, std::make_pair(Val, Idx), InsertBefore);
}

bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
void LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
if (isa<PointerType>(SI->getType()) ?
S.AllPtrNumbering.count(SI) :
S.AllCompositeNumbering.count(SI)) {
// already visited here--nothing to do
return true;
return;
}
std::vector<int> Numbers;
unsigned NumRoots = 1;
Expand All @@ -617,68 +617,60 @@ bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
// find the base root for the arguments
Value *TrueBase = MaybeExtractScalar(S, FindBaseValue(S, SI->getTrueValue(), false), SI);
Value *FalseBase = MaybeExtractScalar(S, FindBaseValue(S, SI->getFalseValue(), false), SI);
Value *V_null = ConstantPointerNull::get(cast<PointerType>(T_prjlvalue));
bool didsplit = false;
if (TrueBase != V_null && FalseBase != V_null) {
std::vector<Value*> TrueBases;
std::vector<Value*> FalseBases;
if (!isa<PointerType>(TrueBase->getType())) {
TrueBases = MaybeExtractVector(S, TrueBase, SI);
assert(TrueBases.size() == Numbers.size());
NumRoots = TrueBases.size();
}
if (!isa<PointerType>(FalseBase->getType())) {
FalseBases = MaybeExtractVector(S, FalseBase, SI);
assert(FalseBases.size() == Numbers.size());
NumRoots = FalseBases.size();
}
if (isa<PointerType>(SI->getType()) ?
S.AllPtrNumbering.count(SI) :
S.AllCompositeNumbering.count(SI)) {
// MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode)
return true;
}
// need to handle each element (may just be one scalar)
for (unsigned i = 0; i < NumRoots; ++i) {
Value *TrueElem;
if (isa<PointerType>(TrueBase->getType()))
TrueElem = TrueBase;
else
TrueElem = TrueBases[i];
Value *FalseElem;
if (isa<PointerType>(FalseBase->getType()))
FalseElem = FalseBase;
else
FalseElem = FalseBases[i];
if (TrueElem != V_null || FalseElem != V_null) {
Value *Cond = SI->getCondition();
if (isa<VectorType>(Cond->getType())) {
Cond = ExtractElementInst::Create(Cond,
ConstantInt::get(Type::getInt32Ty(Cond->getContext()), i),
"", SI);
}
SelectInst *SelectBase = SelectInst::Create(Cond, TrueElem, FalseElem, "gclift", SI);
int Number = ++S.MaxPtrNumber;
S.AllPtrNumbering[SelectBase] = Number;
S.ReversePtrNumbering[Number] = SelectBase;
if (isa<PointerType>(SI->getType()))
S.AllPtrNumbering[SI] = Number;
else
Numbers[i] = Number;
didsplit = true;
}
}
if (isa<VectorType>(SI->getType()) && NumRoots != Numbers.size()) {
// broadcast the scalar root number to fill the vector
assert(NumRoots == 1);
int Number = Numbers[0];
Numbers.resize(0);
Numbers.resize(SI->getType()->getVectorNumElements(), Number);
}
std::vector<Value*> TrueBases;
std::vector<Value*> FalseBases;
if (!isa<PointerType>(TrueBase->getType())) {
TrueBases = MaybeExtractVector(S, TrueBase, SI);
assert(TrueBases.size() == Numbers.size());
NumRoots = TrueBases.size();
}
if (!isa<PointerType>(FalseBase->getType())) {
FalseBases = MaybeExtractVector(S, FalseBase, SI);
assert(FalseBases.size() == Numbers.size());
NumRoots = FalseBases.size();
}
if (isa<PointerType>(SI->getType()) ?
S.AllPtrNumbering.count(SI) :
S.AllCompositeNumbering.count(SI)) {
// MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode)
return;
}
// need to handle each element (may just be one scalar)
for (unsigned i = 0; i < NumRoots; ++i) {
Value *TrueElem;
if (isa<PointerType>(TrueBase->getType()))
TrueElem = TrueBase;
else
TrueElem = TrueBases[i];
Value *FalseElem;
if (isa<PointerType>(FalseBase->getType()))
FalseElem = FalseBase;
else
FalseElem = FalseBases[i];
Value *Cond = SI->getCondition();
if (isa<VectorType>(Cond->getType())) {
Cond = ExtractElementInst::Create(Cond,
ConstantInt::get(Type::getInt32Ty(Cond->getContext()), i),
"", SI);
}
SelectInst *SelectBase = SelectInst::Create(Cond, TrueElem, FalseElem, "gclift", SI);
int Number = ++S.MaxPtrNumber;
S.AllPtrNumbering[SelectBase] = Number;
S.ReversePtrNumbering[Number] = SelectBase;
if (isa<PointerType>(SI->getType()))
S.AllPtrNumbering[SI] = Number;
else
Numbers[i] = Number;
}
if (isa<VectorType>(SI->getType()) && NumRoots != Numbers.size()) {
// broadcast the scalar root number to fill the vector
assert(NumRoots == 1);
int Number = Numbers[0];
Numbers.resize(0);
Numbers.resize(SI->getType()->getVectorNumElements(), Number);
}
if (!isa<PointerType>(SI->getType()))
S.AllCompositeNumbering[SI] = Numbers;
return didsplit;
}

void LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi) {
Expand Down Expand Up @@ -754,9 +746,8 @@ int LateLowerGCFrame::NumberBase(State &S, Value *CurrentV)
// input IR)
Number = -1;
} else if (isa<SelectInst>(CurrentV) && !isTrackedValue(CurrentV)) {
Number = -1;
if (LiftSelect(S, cast<SelectInst>(CurrentV))) // lifting a scalar pointer (if necessary)
Number = S.AllPtrNumbering.at(CurrentV);
LiftSelect(S, cast<SelectInst>(CurrentV));
Number = S.AllPtrNumbering.at(CurrentV);
return Number;
} else if (isa<PHINode>(CurrentV) && !isTrackedValue(CurrentV)) {
LiftPhi(S, cast<PHINode>(CurrentV));
Expand Down
16 changes: 16 additions & 0 deletions test/llvmpasses/gcroots.ll
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,22 @@ top:
ret i8 %val
}

define i8 @lost_select_decayed(i1 %arg1) {
; CHECK-LABEL: @lost_select_decayed
; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
; CHECK: [[GEP0:%.*]] = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
; CHECK: store %jl_value_t addrspace(10)* [[SOMETHING:%.*]], %jl_value_t addrspace(10)** [[GEP0]]
top:
%ptls = call %jl_value_t*** @julia.ptls_states()
%obj1 = call %jl_value_t addrspace(10) *@alloc()
%decayed = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11)*
%selected = select i1 %arg1, %jl_value_t addrspace(11)* null, %jl_value_t addrspace(11)* %decayed
%casted = bitcast %jl_value_t addrspace(11)* %selected to i8 addrspace(11)*
call void @jl_safepoint()
%val = load i8, i8 addrspace(11)* %casted
ret i8 %val
}

!0 = !{!"jtbaa"}
!1 = !{!"jtbaa_const", !0, i64 0}
!2 = !{!1, !1, i64 0, i64 1}

0 comments on commit 116e2ef

Please sign in to comment.