Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector lowering improvments in GC placement #29015

Merged
merged 1 commit into from
Sep 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 161 additions & 75 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ struct LateLowerGCFrame: public FunctionPass {
NoteUse(S, BBS, V, BBS.UpExposedUses);
}
Value *MaybeExtractUnion(std::pair<Value*,int> Val, Instruction *InsertBefore);
int LiftPhi(State &S, PHINode *Phi);
int LiftSelect(State &S, SelectInst *SI);
void LiftPhi(State &S, PHINode *Phi, SmallVector<int, 16> &PHINumbers);
bool LiftSelect(State &S, SelectInst *SI);
int Number(State &S, Value *V);
std::vector<int> NumberVector(State &S, Value *Vec);
int NumberBase(State &S, Value *V, Value *Base);
Expand Down Expand Up @@ -383,7 +383,10 @@ struct LateLowerGCFrame: public FunctionPass {
};

static unsigned getValueAddrSpace(Value *V) {
return cast<PointerType>(V->getType())->getAddressSpace();
Type *Ty = V->getType();
if (isa<VectorType>(Ty))
Ty = cast<VectorType>(V->getType())->getElementType();
return cast<PointerType>(Ty)->getAddressSpace();
}

static bool isSpecialPtr(Type *Ty) {
Expand Down Expand Up @@ -508,42 +511,108 @@ Value *LateLowerGCFrame::MaybeExtractUnion(std::pair<Value*,int> Val, Instructio
return Val.first;
}

int LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
Value *TrueBase = MaybeExtractUnion(FindBaseValue(S, SI->getTrueValue(), false), SI);
Value *FalseBase = MaybeExtractUnion(FindBaseValue(S, SI->getFalseValue(), false), SI);
if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked)
TrueBase = ConstantPointerNull::get(cast<PointerType>(FalseBase->getType()));
if (getValueAddrSpace(FalseBase) != AddressSpace::Tracked)
FalseBase = ConstantPointerNull::get(cast<PointerType>(TrueBase->getType()));
if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked)
return -1;
Value *SelectBase = SelectInst::Create(SI->getCondition(),
TrueBase, FalseBase, "gclift", SI);
int Number = ++S.MaxPtrNumber;
S.PtrNumbering[SelectBase] = S.AllPtrNumbering[SelectBase] =
S.AllPtrNumbering[SI] = Number;
S.ReversePtrNumbering[Number] = SelectBase;
return Number;
static Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint)
{
Value *Val = S.ReversePtrNumbering[Num];
if (isSpecialPtrVec(Val->getType())) {
const std::vector<int> &AllNums = S.AllVectorNumbering[Val];
unsigned Idx = 0;
for (; Idx < AllNums.size(); ++Idx) {
if ((unsigned)AllNums[Idx] == Num)
break;
}
Val = ExtractElementInst::Create(Val, ConstantInt::get(
Type::getInt32Ty(Val->getContext()), Idx), "", InsertionPoint);
}
return Val;
}

bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
if (isSpecialPtrVec(SI->getType())) {
VectorType *VT = cast<VectorType>(SI->getType());
std::vector<int> TrueNumbers = NumberVector(S, SI->getTrueValue());
std::vector<int> FalseNumbers = NumberVector(S, SI->getFalseValue());
std::vector<int> Numbers;
for (unsigned i = 0; i < VT->getNumElements(); ++i) {
SelectInst *LSI = SelectInst::Create(SI->getCondition(),
TrueNumbers[i] < 0 ?
ConstantPointerNull::get(cast<PointerType>(T_prjlvalue)) :
GetPtrForNumber(S, TrueNumbers[i], SI),
FalseNumbers[i] < 0 ?
ConstantPointerNull::get(cast<PointerType>(T_prjlvalue)) :
GetPtrForNumber(S, FalseNumbers[i], SI),
"gclift", SI);
int Number = ++S.MaxPtrNumber;
Numbers.push_back(Number);
S.PtrNumbering[LSI] = S.AllPtrNumbering[LSI] = Number;
S.ReversePtrNumbering[Number] = LSI;
}
S.AllVectorNumbering[SI] = Numbers;
} else {
Value *TrueBase = MaybeExtractUnion(FindBaseValue(S, SI->getTrueValue(), false), SI);
Value *FalseBase = MaybeExtractUnion(FindBaseValue(S, SI->getFalseValue(), false), SI);
if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked)
TrueBase = ConstantPointerNull::get(cast<PointerType>(FalseBase->getType()));
if (getValueAddrSpace(FalseBase) != AddressSpace::Tracked)
FalseBase = ConstantPointerNull::get(cast<PointerType>(TrueBase->getType()));
if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked)
return false;
Value *SelectBase = SelectInst::Create(SI->getCondition(),
TrueBase, FalseBase, "gclift", SI);
int Number = ++S.MaxPtrNumber;
S.PtrNumbering[SelectBase] = S.AllPtrNumbering[SelectBase] =
S.AllPtrNumbering[SI] = Number;
S.ReversePtrNumbering[Number] = SelectBase;
}
return true;
}

int LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi)
void LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi, SmallVector<int, 16> &PHINumbers)
{
PHINode *lift = PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi);
for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) {
Value *Incoming = Phi->getIncomingValue(i);
Value *Base = MaybeExtractUnion(FindBaseValue(S, Incoming, false),
Phi->getIncomingBlock(i)->getTerminator());
if (getValueAddrSpace(Base) != AddressSpace::Tracked)
Base = ConstantPointerNull::get(cast<PointerType>(T_prjlvalue));
if (Base->getType() != T_prjlvalue)
Base = new BitCastInst(Base, T_prjlvalue, "", Phi->getIncomingBlock(i)->getTerminator());
lift->addIncoming(Base, Phi->getIncomingBlock(i));
if (isSpecialPtrVec(Phi->getType())) {
VectorType *VT = cast<VectorType>(Phi->getType());
std::vector<PHINode *> lifted;
for (unsigned i = 0; i < VT->getNumElements(); ++i) {
lifted.push_back(PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi));
}
for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) {
std::vector<int> Numbers = NumberVector(S, Phi->getIncomingValue(i));
BasicBlock *IncomingBB = Phi->getIncomingBlock(i);
Instruction *Terminator = IncomingBB->getTerminator();
for (unsigned i = 0; i < VT->getNumElements(); ++i) {
if (Numbers[i] < 0)
lifted[i]->addIncoming(ConstantPointerNull::get(cast<PointerType>(T_prjlvalue)), IncomingBB);
else
lifted[i]->addIncoming(GetPtrForNumber(S, Numbers[i], Terminator), IncomingBB);
}
}
std::vector<int> Numbers;
for (unsigned i = 0; i < VT->getNumElements(); ++i) {
int Number = ++S.MaxPtrNumber;
PHINumbers.push_back(Number);
Numbers.push_back(Number);
S.PtrNumbering[lifted[i]] = S.AllPtrNumbering[lifted[i]] = Number;
S.ReversePtrNumbering[Number] = lifted[i];
}
S.AllVectorNumbering[Phi] = Numbers;
} else {
PHINode *lift = PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi);
for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) {
Value *Incoming = Phi->getIncomingValue(i);
Value *Base = MaybeExtractUnion(FindBaseValue(S, Incoming, false),
Phi->getIncomingBlock(i)->getTerminator());
if (getValueAddrSpace(Base) != AddressSpace::Tracked)
Base = ConstantPointerNull::get(cast<PointerType>(T_prjlvalue));
if (Base->getType() != T_prjlvalue)
Base = new BitCastInst(Base, T_prjlvalue, "", Phi->getIncomingBlock(i)->getTerminator());
lift->addIncoming(Base, Phi->getIncomingBlock(i));
}
int Number = ++S.MaxPtrNumber;
PHINumbers.push_back(Number);
S.PtrNumbering[lift] = S.AllPtrNumbering[lift] =
S.AllPtrNumbering[Phi] = Number;
S.ReversePtrNumbering[Number] = lift;
}
int Number = ++S.MaxPtrNumber;
S.PtrNumbering[lift] = S.AllPtrNumbering[lift] =
S.AllPtrNumbering[Phi] = Number;
S.ReversePtrNumbering[Number] = lift;
return Number;
}

int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV)
Expand All @@ -566,12 +635,14 @@ int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV)
// input IR)
Number = -1;
} else if (isa<SelectInst>(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) {
int Number = LiftSelect(S, cast<SelectInst>(CurrentV));
S.AllPtrNumbering[V] = Number;
Number = -1;
if (LiftSelect(S, cast<SelectInst>(CurrentV)))
Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV);
return Number;
} else if (isa<PHINode>(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) {
int Number = LiftPhi(S, cast<PHINode>(CurrentV));
S.AllPtrNumbering[V] = Number;
SmallVector<int, 16> PHINumbers;
LiftPhi(S, cast<PHINode>(CurrentV), PHINumbers);
Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV);
return Number;
} else if (isa<ExtractValueInst>(CurrentV) && !isUnion) {
assert(false && "TODO: Extract");
Expand Down Expand Up @@ -630,15 +701,23 @@ std::vector<int> LateLowerGCFrame::NumberVectorBase(State &S, Value *CurrentV) {
Numbers = NumberVectorBase(S, IEI->getOperand(0));
int ElNumber = Number(S, IEI->getOperand(1));
Numbers[idx] = ElNumber;
} else if (isa<LoadInst>(CurrentV) || isa<CallInst>(CurrentV) || isa<PHINode>(CurrentV)) {
} else if (isa<SelectInst>(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) {
LiftSelect(S, cast<SelectInst>(CurrentV));
Numbers = S.AllVectorNumbering[CurrentV];
} else if (isa<PHINode>(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) {
SmallVector<int, 16> PHINumbers;
LiftPhi(S, cast<PHINode>(CurrentV), PHINumbers);
Numbers = S.AllVectorNumbering[CurrentV];
} else if (isa<LoadInst>(CurrentV) || isa<CallInst>(CurrentV) || isa<PHINode>(CurrentV) ||
isa<SelectInst>(CurrentV)) {
// This is simple, we can just number them sequentially
for (unsigned i = 0; i < cast<VectorType>(CurrentV->getType())->getNumElements(); ++i) {
int Num = ++S.MaxPtrNumber;
Numbers.push_back(Num);
S.ReversePtrNumbering[Num] = CurrentV;
}
} else {
assert(false && "Unexpected vector generating operating");
assert(false && "Unexpected vector generating operation");
}
S.AllVectorNumbering[CurrentV] = Numbers;
return Numbers;
Expand Down Expand Up @@ -1148,40 +1227,63 @@ State LateLowerGCFrame::LocalScan(Function &F) {
NoteOperandUses(S, BBS, I, BBS.UpExposedUsesUnrooted);
} else if (SelectInst *SI = dyn_cast<SelectInst>(&I)) {
// We need to insert an extra select for the GC root
if (!isSpecialPtr(SI->getType()) && !isUnionRep(SI->getType()))
if (!isSpecialPtr(SI->getType()) && !isSpecialPtrVec(SI->getType()) &&
!isUnionRep(SI->getType()))
continue;
if (!isUnionRep(SI->getType()) && getValueAddrSpace(SI) != AddressSpace::Tracked) {
if (S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end())
if (isSpecialPtrVec(SI->getType()) ?
S.AllVectorNumbering.find(SI) != S.AllVectorNumbering.end() :
S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end())
continue;
auto Num = LiftSelect(S, SI);
if (Num < 0)
if (!LiftSelect(S, SI))
continue;
auto SelectBase = cast<SelectInst>(S.ReversePtrNumbering[Num]);
SmallVector<int, 1> RefinedPtr{Number(S, SelectBase->getTrueValue()),
Number(S, SelectBase->getFalseValue())};
S.Refinements[Num] = std::move(RefinedPtr);
if (!isSpecialPtrVec(SI->getType())) {
// TODO: Refinements for vector select
int Num = S.AllPtrNumbering[SI];
if (Num < 0)
continue;
auto SelectBase = cast<SelectInst>(S.ReversePtrNumbering[Num]);
SmallVector<int, 2> RefinedPtr{Number(S, SelectBase->getTrueValue()),
Number(S, SelectBase->getFalseValue())};
S.Refinements[Num] = std::move(RefinedPtr);
}
} else {
SmallVector<int, 1> RefinedPtr{Number(S, SI->getTrueValue()),
Number(S, SI->getFalseValue())};
SmallVector<int, 2> RefinedPtr;
if (!isSpecialPtrVec(SI->getType())) {
RefinedPtr = {
Number(S, SI->getTrueValue()),
Number(S, SI->getFalseValue())
};
}
MaybeNoteDef(S, BBS, SI, BBS.Safepoints, std::move(RefinedPtr));
NoteOperandUses(S, BBS, I, BBS.UpExposedUsesUnrooted);
}
} else if (PHINode *Phi = dyn_cast<PHINode>(&I)) {
if (!isSpecialPtr(Phi->getType()) && !isUnionRep(Phi->getType())) {
if (!isSpecialPtr(Phi->getType()) && !isSpecialPtrVec(Phi->getType()) &&
!isUnionRep(Phi->getType())) {
continue;
}
auto nIncoming = Phi->getNumIncomingValues();
// We need to insert an extra phi for the GC root
if (!isUnionRep(Phi->getType()) && getValueAddrSpace(Phi) != AddressSpace::Tracked) {
if (S.AllPtrNumbering.find(Phi) != S.AllPtrNumbering.end())
if (isSpecialPtrVec(Phi->getType()) ?
S.AllVectorNumbering.find(Phi) != S.AllVectorNumbering.end() :
S.AllPtrNumbering.find(Phi) != S.AllPtrNumbering.end())
continue;
auto Num = LiftPhi(S, Phi);
auto lift = cast<PHINode>(S.ReversePtrNumbering[Num]);
S.Refinements[Num] = GetPHIRefinements(lift, S);
PHINumbers.push_back(Num);
LiftPhi(S, Phi, PHINumbers);
} else {
MaybeNoteDef(S, BBS, Phi, BBS.Safepoints, GetPHIRefinements(Phi, S));
PHINumbers.push_back(Number(S, Phi));
SmallVector<int, 1> PHIRefinements;
if (!isSpecialPtrVec(Phi->getType()))
PHIRefinements = GetPHIRefinements(Phi, S);
MaybeNoteDef(S, BBS, Phi, BBS.Safepoints, std::move(PHIRefinements));
if (isSpecialPtrVec(Phi->getType())) {
// TODO: Vector refinements
std::vector<int> Nums = NumberVector(S, Phi);
for (int Num : Nums)
PHINumbers.push_back(Num);
} else {
PHINumbers.push_back(Number(S, Phi));
}
for (unsigned i = 0; i < nIncoming; ++i) {
BBState &IncomingBBS = S.BBStates[Phi->getIncomingBlock(i)];
NoteUse(S, IncomingBBS, Phi->getIncomingValue(i), IncomingBBS.PhiOuts);
Expand Down Expand Up @@ -1776,22 +1878,6 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S) {
return ChangesMade;
}

static Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint)
{
Value *Val = S.ReversePtrNumbering[Num];
if (isSpecialPtrVec(Val->getType())) {
const std::vector<int> &AllNums = S.AllVectorNumbering[Val];
unsigned Idx = 0;
for (; Idx < AllNums.size(); ++Idx) {
if ((unsigned)AllNums[Idx] == Num)
break;
}
Val = ExtractElementInst::Create(Val, ConstantInt::get(
Type::getInt32Ty(Val->getContext()), Idx), "", InsertionPoint);
}
return Val;
}

static void AddInPredLiveOuts(BasicBlock *BB, BitVector &LiveIn, State &S)
{
bool First = true;
Expand Down
Loading