diff --git a/src/llvm-late-gc-lowering.cpp b/src/llvm-late-gc-lowering.cpp index f6efa2c763d33..b3b56bedfeca5 100644 --- a/src/llvm-late-gc-lowering.cpp +++ b/src/llvm-late-gc-lowering.cpp @@ -324,12 +324,15 @@ struct LateLowerGCFrame: public FunctionPass, private JuliaPassContext { void NoteUse(State &S, BBState &BBS, Value *V) { NoteUse(S, BBS, V, BBS.UpExposedUses); } - Value *MaybeExtractUnion(std::pair Val, Instruction *InsertBefore); - void LiftPhi(State &S, PHINode *Phi, SmallVector &PHINumbers); + void LiftPhi(State &S, PHINode *Phi); bool LiftSelect(State &S, SelectInst *SI); + Value *MaybeExtractScalar(State &S, std::pair ValExpr, Instruction *InsertBefore); + std::vector MaybeExtractVector(State &S, Value *BaseVec, Instruction *InsertBefore); + Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint); + int Number(State &S, Value *V); std::vector NumberVector(State &S, Value *Vec); - int NumberBase(State &S, Value *V, Value *Base); + int NumberBase(State &S, Value *Base); std::vector NumberVectorBase(State &S, Value *Base); void NoteOperandUses(State &S, BBState &BBS, User &UI); State LocalScan(Function &F); @@ -354,6 +357,12 @@ static unsigned getValueAddrSpace(Value *V) { return V->getType()->getScalarType()->getPointerAddressSpace(); } +static unsigned isTrackedValue(Value *V) { + PointerType *PT = dyn_cast(V->getType()->getScalarType()); + return PT && PT->getAddressSpace() == AddressSpace::Tracked; +} + + static bool isSpecialPtr(Type *Ty) { PointerType *PTy = dyn_cast(Ty); if (!PTy) @@ -466,121 +475,198 @@ static std::pair FindBaseValue(const State &S, Value *V, bool UseCac return std::make_pair(CurrentV, fld_idx); } -Value *LateLowerGCFrame::MaybeExtractUnion(std::pair Val, Instruction *InsertBefore) { - if (isUnionRep(Val.first->getType())) { - assert(Val.second == -1); - return ExtractValueInst::Create(Val.first, {(unsigned)0}, "", InsertBefore); +Value *LateLowerGCFrame::MaybeExtractScalar(State &S, std::pair ValExpr, Instruction *InsertBefore) { + Value *V = ValExpr.first; + if (isUnionRep(V->getType())) { + assert(ValExpr.second == -1); + V = ExtractValueInst::Create(V, {(unsigned)0}, "", InsertBefore); } - else if (Val.second != -1) { - return ExtractElementInst::Create(Val.first, ConstantInt::get(T_int32, Val.second), - "", InsertBefore); + else if (!isTrackedValue(V)) { + // if V isn't tracked, get the shadow def + int BaseNumber; + if (isa(V->getType())) { + assert(ValExpr.second == -1); + BaseNumber = NumberBase(S, V); + } else if (ValExpr.second != -1) { + auto Numbers = NumberVectorBase(S, V); + BaseNumber = Numbers.size() == 0 ? -1 : Numbers[ValExpr.second]; + } else { + return V; // the user must handle this aggregate instead + } + if (BaseNumber >= 0) + V = GetPtrForNumber(S, BaseNumber, InsertBefore); + else + V = ConstantPointerNull::get(cast(T_prjlvalue)); + } + else if (ValExpr.second != -1) { + V = ExtractElementInst::Create(V, + ConstantInt::get(Type::getInt32Ty(V->getContext()), ValExpr.second), + "", InsertBefore); + } + return V; +} + +std::vector LateLowerGCFrame::MaybeExtractVector(State &S, Value *BaseVec, Instruction *InsertBefore) { + auto Numbers = NumberVectorBase(S, BaseVec); + std::vector V{Numbers.size()}; + Value *V_null = ConstantPointerNull::get(cast(T_prjlvalue)); + for (unsigned i = 0; i < V.size(); ++i) { + if (Numbers[i] >= 0) + V[i] = GetPtrForNumber(S, Numbers[i], InsertBefore); + else + V[i] = V_null; } - return Val.first; + return V; } -static Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint) +Value *LateLowerGCFrame::GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint) { Value *Val = S.ReversePtrNumbering[Num]; - if (isSpecialPtrVec(Val->getType())) { + assert(isTrackedValue(Val) || isUnionRep(Val->getType())); + unsigned Idx = -1; + if (isa(Val->getType())) { const std::vector &AllNums = S.AllVectorNumbering[Val]; - unsigned Idx = 0; - for (; Idx < AllNums.size(); ++Idx) { + for (Idx = 0; Idx < AllNums.size(); ++Idx) { if ((unsigned)AllNums[Idx] == Num) break; } - Val = ExtractElementInst::Create(Val, ConstantInt::get( - Type::getInt32Ty(Val->getContext()), Idx), "", InsertionPoint); + assert(Idx < AllNums.size()); } - return Val; + return MaybeExtractScalar(S, std::make_pair(Val, Idx), InsertionPoint); } bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) { - if (isSpecialPtrVec(SI->getType())) { - VectorType *VT = cast(SI->getType()); - std::vector TrueNumbers = NumberVector(S, SI->getTrueValue()); - std::vector FalseNumbers = NumberVector(S, SI->getFalseValue()); - std::vector Numbers; - for (unsigned i = 0; i < VT->getNumElements(); ++i) { - SelectInst *LSI = SelectInst::Create(SI->getCondition(), - TrueNumbers[i] < 0 ? - ConstantPointerNull::get(cast(T_prjlvalue)) : - GetPtrForNumber(S, TrueNumbers[i], SI), - FalseNumbers[i] < 0 ? - ConstantPointerNull::get(cast(T_prjlvalue)) : - GetPtrForNumber(S, FalseNumbers[i], SI), - "gclift", SI); - int Number = ++S.MaxPtrNumber; - Numbers.push_back(Number); - S.AllPtrNumbering[LSI] = Number; - S.ReversePtrNumbering[Number] = LSI; - } - S.AllVectorNumbering[SI] = Numbers; - } else { - Value *TrueBase = MaybeExtractUnion(FindBaseValue(S, SI->getTrueValue(), false), SI); - Value *FalseBase = MaybeExtractUnion(FindBaseValue(S, SI->getFalseValue(), false), SI); - if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) - TrueBase = ConstantPointerNull::get(cast(FalseBase->getType())); - if (getValueAddrSpace(FalseBase) != AddressSpace::Tracked) - FalseBase = ConstantPointerNull::get(cast(TrueBase->getType())); - if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) - return false; - Value *SelectBase = SelectInst::Create(SI->getCondition(), - TrueBase, FalseBase, "gclift", SI); - int Number = ++S.MaxPtrNumber; - S.AllPtrNumbering[SelectBase] = S.AllPtrNumbering[SI] = Number; - S.ReversePtrNumbering[Number] = SelectBase; + if (isa(SI->getType()) ? + S.AllPtrNumbering.count(SI) : + S.AllVectorNumbering.count(SI)) { + // already visited here--nothing to do + return true; } - return true; -} - -void LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi, SmallVector &PHINumbers) -{ - if (isSpecialPtrVec(Phi->getType())) { - VectorType *VT = cast(Phi->getType()); - std::vector lifted; - for (unsigned i = 0; i < VT->getNumElements(); ++i) { - lifted.push_back(PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi)); + std::vector Numbers; + unsigned NumRoots = 1; + if (isa(SI->getType())) + Numbers.resize(SI->getType()->getVectorNumElements(), -1); + assert(!isTrackedValue(SI)); + // find the base root for the arguments + Value *TrueBase = MaybeExtractScalar(S, FindBaseValue(S, SI->getTrueValue(), false), SI); + Value *FalseBase = MaybeExtractScalar(S, FindBaseValue(S, SI->getFalseValue(), false), SI); + Value *V_null = ConstantPointerNull::get(cast(T_prjlvalue)); + bool didsplit = false; + if (TrueBase != V_null && FalseBase != V_null) { + std::vector TrueBases; + std::vector FalseBases; + if (!isa(TrueBase->getType())) { + TrueBases = MaybeExtractVector(S, TrueBase, SI); + assert(TrueBases.size() == Numbers.size()); + NumRoots = TrueBases.size(); + } + if (!isa(FalseBase->getType())) { + FalseBases = MaybeExtractVector(S, FalseBase, SI); + assert(FalseBases.size() == Numbers.size()); + NumRoots = FalseBases.size(); + } + if (isa(SI->getType()) ? + S.AllPtrNumbering.count(SI) : + S.AllVectorNumbering.count(SI)) { + // MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode) + return true; } - for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) { - std::vector Numbers = NumberVector(S, Phi->getIncomingValue(i)); - BasicBlock *IncomingBB = Phi->getIncomingBlock(i); - Instruction *Terminator = IncomingBB->getTerminator(); - for (unsigned i = 0; i < VT->getNumElements(); ++i) { - if (Numbers[i] < 0) - lifted[i]->addIncoming(ConstantPointerNull::get(cast(T_prjlvalue)), IncomingBB); + // need to handle each element (may just be one scalar) + for (unsigned i = 0; i < NumRoots; ++i) { + Value *TrueElem; + if (isa(TrueBase->getType())) + TrueElem = TrueBase; + else + TrueElem = TrueBases[i]; + Value *FalseElem; + if (isa(FalseBase->getType())) + FalseElem = FalseBase; + else + FalseElem = FalseBases[i]; + if (TrueElem != V_null || FalseElem != V_null) { + Value *Cond = SI->getCondition(); + if (isa(Cond->getType())) { + Cond = ExtractElementInst::Create(Cond, + ConstantInt::get(Type::getInt32Ty(Cond->getContext()), i), + "", SI); + } + SelectInst *SelectBase = SelectInst::Create(Cond, TrueElem, FalseElem, "gclift", SI); + int Number = ++S.MaxPtrNumber; + S.AllPtrNumbering[SelectBase] = Number; + S.ReversePtrNumbering[Number] = SelectBase; + if (isa(SI->getType())) + S.AllPtrNumbering[SI] = Number; else - lifted[i]->addIncoming(GetPtrForNumber(S, Numbers[i], Terminator), IncomingBB); + Numbers[i] = Number; + didsplit = true; } } - std::vector Numbers; - for (unsigned i = 0; i < VT->getNumElements(); ++i) { - int Number = ++S.MaxPtrNumber; - PHINumbers.push_back(Number); - Numbers.push_back(Number); - S.AllPtrNumbering[lifted[i]] = Number; - S.ReversePtrNumbering[Number] = lifted[i]; + if (isa(SI->getType()) && NumRoots != Numbers.size()) { + // broadcast the scalar root number to fill the vector + assert(NumRoots == 1); + int Number = Numbers[0]; + Numbers.resize(0); + Numbers.resize(SI->getType()->getVectorNumElements(), Number); } - S.AllVectorNumbering[Phi] = Numbers; - } else { + } + if (isa(SI->getType())) + S.AllVectorNumbering[SI] = Numbers; + return didsplit; +} + +void LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi) +{ + if (isSpecialPtrVec(Phi->getType()) ? + S.AllVectorNumbering.count(Phi) : + S.AllPtrNumbering.count(Phi)) + return; + // need to handle each element (may just be one scalar) + SmallVector lifted; + std::vector Numbers; + unsigned NumRoots = 1; + if (isa(Phi->getType())) { + NumRoots = Phi->getType()->getVectorNumElements(); + Numbers.resize(NumRoots); + } + for (unsigned i = 0; i < NumRoots; ++i) { PHINode *lift = PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi); - for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) { - Value *Incoming = Phi->getIncomingValue(i); - Value *Base = MaybeExtractUnion(FindBaseValue(S, Incoming, false), - Phi->getIncomingBlock(i)->getTerminator()); - if (getValueAddrSpace(Base) != AddressSpace::Tracked) - Base = ConstantPointerNull::get(cast(T_prjlvalue)); - if (Base->getType() != T_prjlvalue) - Base = new BitCastInst(Base, T_prjlvalue, "", Phi->getIncomingBlock(i)->getTerminator()); - lift->addIncoming(Base, Phi->getIncomingBlock(i)); - } int Number = ++S.MaxPtrNumber; - PHINumbers.push_back(Number); - S.AllPtrNumbering[lift] = S.AllPtrNumbering[Phi] = Number; + S.AllPtrNumbering[lift] = Number; S.ReversePtrNumbering[Number] = lift; + if (!isa(Phi->getType())) + S.AllPtrNumbering[Phi] = Number; + else + Numbers[i] = Number; + lifted.push_back(lift); + } + if (!isa(Phi->getType())) + S.AllVectorNumbering[Phi] = Numbers; + for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) { + Value *Incoming = Phi->getIncomingValue(i); + BasicBlock *IncomingBB = Phi->getIncomingBlock(i); + Instruction *Terminator = IncomingBB->getTerminator(); + Value *Base = MaybeExtractScalar(S, FindBaseValue(S, Incoming, false), Terminator); + std::vector IncomingBases; + if (!isa(Base->getType())) { + IncomingBases = MaybeExtractVector(S, Base, Terminator); + assert(IncomingBases.size() == NumRoots); + } + for (unsigned i = 0; i < NumRoots; ++i) { + PHINode *lift = lifted[i]; + Value *BaseElem; + if (isa(Base->getType())) { + BaseElem = Base; + if (BaseElem->getType() != T_prjlvalue) + BaseElem = new BitCastInst(BaseElem, T_prjlvalue, "", Terminator); + } else { + BaseElem = IncomingBases[i]; + } + lift->addIncoming(BaseElem, IncomingBB); + } } } -int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV) +int LateLowerGCFrame::NumberBase(State &S, Value *CurrentV) { auto it = S.AllPtrNumbering.find(CurrentV); if (it != S.AllPtrNumbering.end()) @@ -590,48 +676,48 @@ int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV) if (isa(CurrentV)) { // Perm rooted Number = -2; - } else if (isa(CurrentV) || - ((isa(CurrentV) || isa(CurrentV)) && - getValueAddrSpace(CurrentV) != AddressSpace::Tracked)) { + } else if (isa(CurrentV) || isa(CurrentV) || + (isa(CurrentV) && !isTrackedValue(CurrentV))) { // We know this is rooted in the parent + // future note: we could chose to exclude argument of type CalleeRooted here Number = -1; } else if (!isSpecialPtr(CurrentV->getType()) && !isUnion) { // Externally rooted somehow hopefully (otherwise there's a bug in the // input IR) Number = -1; - } else if (isa(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { + } else if (isa(CurrentV) && !isUnion && !isTrackedValue(CurrentV)) { Number = -1; - if (LiftSelect(S, cast(CurrentV))) - Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV); + if (LiftSelect(S, cast(CurrentV))) // lifting a scalar pointer + Number = S.AllPtrNumbering.at(CurrentV); return Number; - } else if (isa(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { - SmallVector PHINumbers; - LiftPhi(S, cast(CurrentV), PHINumbers); - Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV); + } else if (isa(CurrentV) && !isUnion && !isTrackedValue(CurrentV)) { + LiftPhi(S, cast(CurrentV)); + Number = S.AllPtrNumbering.at(CurrentV); return Number; } else if (isa(CurrentV) && !isUnion) { assert(false && "TODO: Extract"); abort(); } else { - assert( - (CurrentV->getType()->isPointerTy() && - getValueAddrSpace(CurrentV) == AddressSpace::Tracked) || - isUnion); + assert((CurrentV->getType()->isPointerTy() && isTrackedValue(CurrentV)) || isUnion); Number = ++S.MaxPtrNumber; S.ReversePtrNumbering[Number] = CurrentV; } - S.AllPtrNumbering[CurrentV] = S.AllPtrNumbering[V] = Number; + S.AllPtrNumbering[CurrentV] = Number; return Number; } int LateLowerGCFrame::Number(State &S, Value *V) { assert(isSpecialPtr(V->getType()) || isUnionRep(V->getType())); auto CurrentV = FindBaseValue(S, V); - if (CurrentV.second == -1) - return NumberBase(S, V, CurrentV.first); - auto Numbers = NumberVectorBase(S, CurrentV.first); - auto Number = Numbers.size() == 0 ? -1 : Numbers[CurrentV.second]; - S.AllPtrNumbering[V] = Number; + int Number; + if (CurrentV.second == -1) { + Number = NumberBase(S, CurrentV.first); + } else { + auto Numbers = NumberVectorBase(S, CurrentV.first); + Number = Numbers.size() == 0 ? -1 : Numbers.at(CurrentV.second); + } + if (V != CurrentV.first) + S.AllPtrNumbering[V] = Number; return Number; } @@ -640,10 +726,8 @@ std::vector LateLowerGCFrame::NumberVectorBase(State &S, Value *CurrentV) { if (it != S.AllVectorNumbering.end()) return it->second; std::vector Numbers{}; - if (isa(CurrentV) || - ((isa(CurrentV) || isa(CurrentV) || - isa(CurrentV)) && - getValueAddrSpace(CurrentV) != AddressSpace::Tracked)) { + if (isa(CurrentV) || isa(CurrentV) || + (isa(CurrentV) && !isTrackedValue(CurrentV))) { Numbers.resize(CurrentV->getType()->getVectorNumElements(), -1); } /* We (the frontend) don't insert either of these, but it would be legal - @@ -666,12 +750,11 @@ std::vector LateLowerGCFrame::NumberVectorBase(State &S, Value *CurrentV) { Numbers = NumberVectorBase(S, IEI->getOperand(0)); int ElNumber = Number(S, IEI->getOperand(1)); Numbers[idx] = ElNumber; - } else if (isa(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { + } else if (isa(CurrentV) && !isTrackedValue(CurrentV)) { LiftSelect(S, cast(CurrentV)); - Numbers = S.AllVectorNumbering[CurrentV]; - } else if (isa(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { - SmallVector PHINumbers; - LiftPhi(S, cast(CurrentV), PHINumbers); + Numbers = S.AllVectorNumbering.at(CurrentV); + } else if (isa(CurrentV) && !isTrackedValue(CurrentV)) { + LiftPhi(S, cast(CurrentV)); Numbers = S.AllVectorNumbering[CurrentV]; } else if (isa(CurrentV) || isa(CurrentV) || isa(CurrentV) || isa(CurrentV)) { @@ -682,7 +765,8 @@ std::vector LateLowerGCFrame::NumberVectorBase(State &S, Value *CurrentV) { S.ReversePtrNumbering[Num] = CurrentV; } } else { - assert(false && "Unexpected vector generating operation"); + CurrentV->print(errs()); + llvm_unreachable("Unexpected vector generating operation"); } S.AllVectorNumbering[CurrentV] = Numbers; return Numbers; @@ -695,16 +779,15 @@ std::vector LateLowerGCFrame::NumberVector(State &S, Value *V) { auto CurrentV = FindBaseValue(S, V); assert(CurrentV.second == -1); // E.g. if this is a gep, it's possible for the base to be a single ptr + std::vector Numbers{}; if (isSpecialPtrVec(CurrentV.first->getType())) { - auto Numbers = NumberVectorBase(S, CurrentV.first); - S.AllVectorNumbering[V] = Numbers; - return Numbers; + Numbers = NumberVectorBase(S, CurrentV.first); } else { - std::vector Numbers{}; - Numbers.resize(cast(V->getType())->getNumElements(), - NumberBase(S, V, CurrentV.first)); - return Numbers; + int Number = NumberBase(S, CurrentV.first); + Numbers.resize(V->getType()->getVectorNumElements(), Number); } + S.AllVectorNumbering[V] = Numbers; + return Numbers; } static void MaybeResize(BBState &BBS, unsigned Idx) { @@ -737,8 +820,7 @@ void LateLowerGCFrame::MaybeNoteDef(State &S, BBState &BBS, Value *Def, const st int Num = -1; Type *RT = Def->getType(); if (isSpecialPtr(RT)) { - assert(getValueAddrSpace(Def) == AddressSpace::Tracked && - "Returned value of GC interest, but not tracked?"); + assert(isTrackedValue(Def) && "Returned value of GC interest, but not tracked?"); Num = Number(S, Def); } else if (isUnionRep(RT)) { @@ -1182,30 +1264,11 @@ State LateLowerGCFrame::LocalScan(Function &F) { } NoteOperandUses(S, BBS, I); } else if (SelectInst *SI = dyn_cast(&I)) { - // We need to insert an extra select for the GC root - if (!isSpecialPtr(SI->getType()) && !isSpecialPtrVec(SI->getType()) && - !isUnionRep(SI->getType())) - continue; - if (!isUnionRep(SI->getType()) && getValueAddrSpace(SI) != AddressSpace::Tracked) { - if (isSpecialPtrVec(SI->getType()) ? - S.AllVectorNumbering.find(SI) != S.AllVectorNumbering.end() : - S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end()) - continue; - if (!LiftSelect(S, SI)) - continue; - if (!isSpecialPtrVec(SI->getType())) { - // TODO: Refinements for vector select - int Num = S.AllPtrNumbering[SI]; - if (Num < 0) - continue; - auto SelectBase = cast(S.ReversePtrNumbering[Num]); - SmallVector RefinedPtr{Number(S, SelectBase->getTrueValue()), - Number(S, SelectBase->getFalseValue())}; - S.Refinements[Num] = std::move(RefinedPtr); - } - } else { + if (isUnionRep(SI->getType()) || isTrackedValue(SI)) { + // record the select definition of these values SmallVector RefinedPtr; if (!isSpecialPtrVec(SI->getType())) { + // TODO: Refinements for vector select RefinedPtr = { Number(S, SI->getTrueValue()), Number(S, SI->getFalseValue()) @@ -1213,44 +1276,40 @@ State LateLowerGCFrame::LocalScan(Function &F) { } MaybeNoteDef(S, BBS, SI, BBS.Safepoints, std::move(RefinedPtr)); NoteOperandUses(S, BBS, I); + } else if (isSpecialPtr(SI->getType()) || isSpecialPtrVec(SI->getType())) { + // We need to insert extra selects for the GC roots + LiftSelect(S, SI); } } else if (PHINode *Phi = dyn_cast(&I)) { - if (!isSpecialPtr(Phi->getType()) && !isSpecialPtrVec(Phi->getType()) && - !isUnionRep(Phi->getType())) { - continue; - } - auto nIncoming = Phi->getNumIncomingValues(); - // We need to insert an extra phi for the GC root - if (!isUnionRep(Phi->getType()) && getValueAddrSpace(Phi) != AddressSpace::Tracked) { - if (isSpecialPtrVec(Phi->getType()) ? - S.AllVectorNumbering.find(Phi) != S.AllVectorNumbering.end() : - S.AllPtrNumbering.find(Phi) != S.AllPtrNumbering.end()) - continue; - LiftPhi(S, Phi, PHINumbers); - } else { + if (isUnionRep(Phi->getType()) || isTrackedValue(Phi)) { + // record the phi definition of these values SmallVector PHIRefinements; - if (!isSpecialPtrVec(Phi->getType())) + if (isa(Phi->getType())) PHIRefinements = GetPHIRefinements(Phi, S); MaybeNoteDef(S, BBS, Phi, BBS.Safepoints, std::move(PHIRefinements)); - if (isSpecialPtrVec(Phi->getType())) { + if (!isa(Phi->getType())) { + PHINumbers.push_back(Number(S, Phi)); + } else { // TODO: Vector refinements std::vector Nums = NumberVector(S, Phi); for (int Num : Nums) PHINumbers.push_back(Num); - } else { - PHINumbers.push_back(Number(S, Phi)); } + unsigned nIncoming = Phi->getNumIncomingValues(); for (unsigned i = 0; i < nIncoming; ++i) { BBState &IncomingBBS = S.BBStates[Phi->getIncomingBlock(i)]; NoteUse(S, IncomingBBS, Phi->getIncomingValue(i), IncomingBBS.PhiOuts); } + } else if (isSpecialPtr(Phi->getType()) || isSpecialPtrVec(Phi->getType())) { + // We need to insert extra phis for the GC roots + LiftPhi(S, Phi); } } else if (isa(&I)) { NoteOperandUses(S, BBS, I); } else if (isa(&I)) { NoteOperandUses(S, BBS, I); } else if (auto *ASCI = dyn_cast(&I)) { - if (getValueAddrSpace(ASCI) == AddressSpace::Tracked) { + if (isTrackedValue(ASCI)) { SmallVector RefinedPtr{}; auto origin = ASCI->getPointerOperand()->stripPointerCasts(); if (auto LI = dyn_cast(origin)) { @@ -1906,8 +1965,6 @@ static void AddInPredLiveOuts(BasicBlock *BB, BitVector &LiveIn, State &S) void LateLowerGCFrame::PlaceGCFrameStore(State &S, unsigned R, unsigned MinColorRoot, const std::vector &Colors, Value *GCFrame, Instruction *InsertionPoint) { - Value *Val = GetPtrForNumber(S, R, InsertionPoint); - // Get the slot address. auto slotAddress = CallInst::Create( getOrDeclare(jl_intrinsics::getGCFrameSlot), @@ -1915,7 +1972,7 @@ void LateLowerGCFrame::PlaceGCFrameStore(State &S, unsigned R, unsigned MinColor slotAddress->insertBefore(InsertionPoint); - Val = MaybeExtractUnion(std::make_pair(Val, -1), InsertionPoint); + Value *Val = GetPtrForNumber(S, R, InsertionPoint); // Pointee types don't have semantics, so the optimizer is // free to rewrite them if convenient. We need to change // it back here for the store. diff --git a/test/llvmpasses/gcroots.ll b/test/llvmpasses/gcroots.ll index 7b484df611fc1..8a154c3adacad 100644 --- a/test/llvmpasses/gcroots.ll +++ b/test/llvmpasses/gcroots.ll @@ -112,6 +112,7 @@ define void @select_lift(i64 %a, i64 %b) { define void @phi_lift(i64 %a, i64 %b) { top: ; CHECK-LABEL: @phi_lift +; CHECK: %gclift = phi %jl_value_t addrspace(10)* [ %aboxed, %alabel ], [ %bboxed, %blabel ], [ %gclift, %common ] %ptls = call %jl_value_t*** @julia.ptls_states() %cmp = icmp eq i64 %a, %b br i1 %cmp, label %alabel, label %blabel @@ -124,11 +125,12 @@ blabel: %bdecayed = addrspacecast %jl_value_t addrspace(10)* %bboxed to i64 addrspace(12)* br label %common common: - %phi = phi i64 addrspace(12)* [ %adecayed, %alabel ], [ %bdecayed, %blabel ] + %phi = phi i64 addrspace(12)* [ %adecayed, %alabel ], [ %bdecayed, %blabel ], [ %phi, %common ] call void @one_arg_decayed(i64 addrspace(12)* %phi) - ret void + br label %common } + define void @phi_lift_union(i64 %a, i64 %b) { top: ; CHECK-LABEL: @phi_lift_union @@ -512,7 +514,7 @@ top: %ptls = call %jl_value_t*** @julia.ptls_states() %loaded = load <2 x %jl_value_t addrspace(10)*>, <2 x %jl_value_t addrspace(10)*> *%arg call void @jl_safepoint() - %select = select i1 %cond, <2 x %jl_value_t addrspace(10)*> zeroinitializer, <2 x %jl_value_t addrspace(10)*> %loaded + %select = select i1 %cond, <2 x %jl_value_t addrspace(10)*> zeroinitializer, <2 x %jl_value_t addrspace(10)*> %loaded call void @jl_safepoint() %el1 = extractelement <2 x %jl_value_t addrspace(10)*> %select, i32 0 %el2 = extractelement <2 x %jl_value_t addrspace(10)*> %select, i32 1 @@ -521,6 +523,76 @@ top: unreachable } +define void @vecselect_lift(i1 %cond, <2 x %jl_value_t addrspace(10)*> *%arg) { +; CHECK-LABEL: @vecselect_lift +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 + %ptls = call %jl_value_t*** @julia.ptls_states() + %loaded = load <2 x %jl_value_t addrspace(10)*>, <2 x %jl_value_t addrspace(10)*> *%arg + %decayed = addrspacecast <2 x %jl_value_t addrspace(10)*> %loaded to <2 x i64 addrspace(12)*> + call void @jl_safepoint() +; CHECK: %gclift = select i1 %cond, %jl_value_t addrspace(10)* null, %jl_value_t addrspace(10)* %{{[0-9]+}} + %select = select i1 %cond, <2 x i64 addrspace(12)*> zeroinitializer, <2 x i64 addrspace(12)*> %decayed + call void @jl_safepoint() + %el1 = extractelement <2 x i64 addrspace(12)*> %select, i32 0 + %el2 = extractelement <2 x i64 addrspace(12)*> %select, i32 1 + call void @one_arg_decayed(i64 addrspace(12)* %el1) + call void @one_arg_decayed(i64 addrspace(12)* %el2) + unreachable +} + +define void @vecvecselect_lift(<2 x i1> %cond, <2 x %jl_value_t addrspace(10)*> *%arg) { +; CHECK-LABEL: @vecvecselect_lift +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 + %ptls = call %jl_value_t*** @julia.ptls_states() + %loaded = load <2 x %jl_value_t addrspace(10)*>, <2 x %jl_value_t addrspace(10)*> *%arg + %decayed = addrspacecast <2 x %jl_value_t addrspace(10)*> %loaded to <2 x i64 addrspace(12)*> + call void @jl_safepoint() +; CHECK: %gclift = select i1 %{{[0-9]+}}, %jl_value_t addrspace(10)* null, %jl_value_t addrspace(10)* %{{[0-9]+}} + %select = select <2 x i1> %cond, <2 x i64 addrspace(12)*> zeroinitializer, <2 x i64 addrspace(12)*> %decayed + call void @jl_safepoint() + %el1 = extractelement <2 x i64 addrspace(12)*> %select, i32 0 + %el2 = extractelement <2 x i64 addrspace(12)*> %select, i32 1 + call void @one_arg_decayed(i64 addrspace(12)* %el1) + call void @one_arg_decayed(i64 addrspace(12)* %el2) + unreachable +} + +define void @vecscalarselect_lift(<2 x i1> %cond, i64 %a) { +; CHECK-LABEL: @vecscalarselect_lift +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 + %ptls = call %jl_value_t*** @julia.ptls_states() + %aboxed = call %jl_value_t addrspace(10)* @jl_box_int64(i64 signext %a) + %adecayed = addrspacecast %jl_value_t addrspace(10)* %aboxed to i64 addrspace(12)* + %avec = getelementptr i64, i64 addrspace(12)* %adecayed, <2 x i32> zeroinitializer + call void @jl_safepoint() +; CHECK: %gclift = select i1 %{{[0-9]+}}, %jl_value_t addrspace(10)* null, %jl_value_t addrspace(10)* %aboxed + %select = select <2 x i1> %cond, <2 x i64 addrspace(12)*> zeroinitializer, <2 x i64 addrspace(12)*> %avec + call void @jl_safepoint() + %el1 = extractelement <2 x i64 addrspace(12)*> %select, i32 0 + %el2 = extractelement <2 x i64 addrspace(12)*> %select, i32 1 + call void @one_arg_decayed(i64 addrspace(12)* %el1) + call void @one_arg_decayed(i64 addrspace(12)* %el2) + unreachable +} + +define void @scalarvecselect_lift(i1 %cond, i64 %a) { +; CHECK-LABEL: @scalarvecselect_lift +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 + %ptls = call %jl_value_t*** @julia.ptls_states() + %aboxed = call %jl_value_t addrspace(10)* @jl_box_int64(i64 signext %a) + %adecayed = addrspacecast %jl_value_t addrspace(10)* %aboxed to i64 addrspace(12)* + %avec = getelementptr i64, i64 addrspace(12)* %adecayed, <2 x i32> zeroinitializer + call void @jl_safepoint() +; CHECK: %gclift = select i1 %cond, %jl_value_t addrspace(10)* null, %jl_value_t addrspace(10)* %aboxed + %select = select i1 %cond, <2 x i64 addrspace(12)*> zeroinitializer, <2 x i64 addrspace(12)*> %avec + call void @jl_safepoint() + %el1 = extractelement <2 x i64 addrspace(12)*> %select, i32 0 + %el2 = extractelement <2 x i64 addrspace(12)*> %select, i32 1 + call void @one_arg_decayed(i64 addrspace(12)* %el1) + call void @one_arg_decayed(i64 addrspace(12)* %el2) + unreachable +} + define i8 @select_arrayptr(i1 %cond) { ; CHECK-LABEL: @select_arrayptr ; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4