diff --git a/src/llvm-late-gc-lowering.cpp b/src/llvm-late-gc-lowering.cpp index d83f5785f9cb1..902e110cf05cf 100644 --- a/src/llvm-late-gc-lowering.cpp +++ b/src/llvm-late-gc-lowering.cpp @@ -349,8 +349,8 @@ struct LateLowerGCFrame: public FunctionPass { NoteUse(S, BBS, V, BBS.UpExposedUses); } Value *MaybeExtractUnion(std::pair Val, Instruction *InsertBefore); - int LiftPhi(State &S, PHINode *Phi); - int LiftSelect(State &S, SelectInst *SI); + void LiftPhi(State &S, PHINode *Phi, SmallVector &PHINumbers); + bool LiftSelect(State &S, SelectInst *SI); int Number(State &S, Value *V); std::vector NumberVector(State &S, Value *Vec); int NumberBase(State &S, Value *V, Value *Base); @@ -383,7 +383,10 @@ struct LateLowerGCFrame: public FunctionPass { }; static unsigned getValueAddrSpace(Value *V) { - return cast(V->getType())->getAddressSpace(); + Type *Ty = V->getType(); + if (isa(Ty)) + Ty = cast(V->getType())->getElementType(); + return cast(Ty)->getAddressSpace(); } static bool isSpecialPtr(Type *Ty) { @@ -508,42 +511,108 @@ Value *LateLowerGCFrame::MaybeExtractUnion(std::pair Val, Instructio return Val.first; } -int LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) { - Value *TrueBase = MaybeExtractUnion(FindBaseValue(S, SI->getTrueValue(), false), SI); - Value *FalseBase = MaybeExtractUnion(FindBaseValue(S, SI->getFalseValue(), false), SI); - if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) - TrueBase = ConstantPointerNull::get(cast(FalseBase->getType())); - if (getValueAddrSpace(FalseBase) != AddressSpace::Tracked) - FalseBase = ConstantPointerNull::get(cast(TrueBase->getType())); - if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) - return -1; - Value *SelectBase = SelectInst::Create(SI->getCondition(), - TrueBase, FalseBase, "gclift", SI); - int Number = ++S.MaxPtrNumber; - S.PtrNumbering[SelectBase] = S.AllPtrNumbering[SelectBase] = - S.AllPtrNumbering[SI] = Number; - S.ReversePtrNumbering[Number] = SelectBase; - return Number; +static Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint) +{ + Value *Val = S.ReversePtrNumbering[Num]; + if (isSpecialPtrVec(Val->getType())) { + const std::vector &AllNums = S.AllVectorNumbering[Val]; + unsigned Idx = 0; + for (; Idx < AllNums.size(); ++Idx) { + if ((unsigned)AllNums[Idx] == Num) + break; + } + Val = ExtractElementInst::Create(Val, ConstantInt::get( + Type::getInt32Ty(Val->getContext()), Idx), "", InsertionPoint); + } + return Val; +} + +bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) { + if (isSpecialPtrVec(SI->getType())) { + VectorType *VT = cast(SI->getType()); + std::vector TrueNumbers = NumberVector(S, SI->getTrueValue()); + std::vector FalseNumbers = NumberVector(S, SI->getFalseValue()); + std::vector Numbers; + for (unsigned i = 0; i < VT->getNumElements(); ++i) { + SelectInst *LSI = SelectInst::Create(SI->getCondition(), + TrueNumbers[i] < 0 ? + ConstantPointerNull::get(cast(T_prjlvalue)) : + GetPtrForNumber(S, TrueNumbers[i], SI), + FalseNumbers[i] < 0 ? + ConstantPointerNull::get(cast(T_prjlvalue)) : + GetPtrForNumber(S, FalseNumbers[i], SI), + "gclift", SI); + int Number = ++S.MaxPtrNumber; + Numbers.push_back(Number); + S.PtrNumbering[LSI] = S.AllPtrNumbering[LSI] = Number; + S.ReversePtrNumbering[Number] = LSI; + } + S.AllVectorNumbering[SI] = Numbers; + } else { + Value *TrueBase = MaybeExtractUnion(FindBaseValue(S, SI->getTrueValue(), false), SI); + Value *FalseBase = MaybeExtractUnion(FindBaseValue(S, SI->getFalseValue(), false), SI); + if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) + TrueBase = ConstantPointerNull::get(cast(FalseBase->getType())); + if (getValueAddrSpace(FalseBase) != AddressSpace::Tracked) + FalseBase = ConstantPointerNull::get(cast(TrueBase->getType())); + if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked) + return false; + Value *SelectBase = SelectInst::Create(SI->getCondition(), + TrueBase, FalseBase, "gclift", SI); + int Number = ++S.MaxPtrNumber; + S.PtrNumbering[SelectBase] = S.AllPtrNumbering[SelectBase] = + S.AllPtrNumbering[SI] = Number; + S.ReversePtrNumbering[Number] = SelectBase; + } + return true; } -int LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi) +void LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi, SmallVector &PHINumbers) { - PHINode *lift = PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi); - for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) { - Value *Incoming = Phi->getIncomingValue(i); - Value *Base = MaybeExtractUnion(FindBaseValue(S, Incoming, false), - Phi->getIncomingBlock(i)->getTerminator()); - if (getValueAddrSpace(Base) != AddressSpace::Tracked) - Base = ConstantPointerNull::get(cast(T_prjlvalue)); - if (Base->getType() != T_prjlvalue) - Base = new BitCastInst(Base, T_prjlvalue, "", Phi->getIncomingBlock(i)->getTerminator()); - lift->addIncoming(Base, Phi->getIncomingBlock(i)); + if (isSpecialPtrVec(Phi->getType())) { + VectorType *VT = cast(Phi->getType()); + std::vector lifted; + for (unsigned i = 0; i < VT->getNumElements(); ++i) { + lifted.push_back(PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi)); + } + for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) { + std::vector Numbers = NumberVector(S, Phi->getIncomingValue(i)); + BasicBlock *IncomingBB = Phi->getIncomingBlock(i); + Instruction *Terminator = IncomingBB->getTerminator(); + for (unsigned i = 0; i < VT->getNumElements(); ++i) { + if (Numbers[i] < 0) + lifted[i]->addIncoming(ConstantPointerNull::get(cast(T_prjlvalue)), IncomingBB); + else + lifted[i]->addIncoming(GetPtrForNumber(S, Numbers[i], Terminator), IncomingBB); + } + } + std::vector Numbers; + for (unsigned i = 0; i < VT->getNumElements(); ++i) { + int Number = ++S.MaxPtrNumber; + PHINumbers.push_back(Number); + Numbers.push_back(Number); + S.PtrNumbering[lifted[i]] = S.AllPtrNumbering[lifted[i]] = Number; + S.ReversePtrNumbering[Number] = lifted[i]; + } + S.AllVectorNumbering[Phi] = Numbers; + } else { + PHINode *lift = PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi); + for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) { + Value *Incoming = Phi->getIncomingValue(i); + Value *Base = MaybeExtractUnion(FindBaseValue(S, Incoming, false), + Phi->getIncomingBlock(i)->getTerminator()); + if (getValueAddrSpace(Base) != AddressSpace::Tracked) + Base = ConstantPointerNull::get(cast(T_prjlvalue)); + if (Base->getType() != T_prjlvalue) + Base = new BitCastInst(Base, T_prjlvalue, "", Phi->getIncomingBlock(i)->getTerminator()); + lift->addIncoming(Base, Phi->getIncomingBlock(i)); + } + int Number = ++S.MaxPtrNumber; + PHINumbers.push_back(Number); + S.PtrNumbering[lift] = S.AllPtrNumbering[lift] = + S.AllPtrNumbering[Phi] = Number; + S.ReversePtrNumbering[Number] = lift; } - int Number = ++S.MaxPtrNumber; - S.PtrNumbering[lift] = S.AllPtrNumbering[lift] = - S.AllPtrNumbering[Phi] = Number; - S.ReversePtrNumbering[Number] = lift; - return Number; } int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV) @@ -566,12 +635,14 @@ int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV) // input IR) Number = -1; } else if (isa(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { - int Number = LiftSelect(S, cast(CurrentV)); - S.AllPtrNumbering[V] = Number; + Number = -1; + if (LiftSelect(S, cast(CurrentV))) + Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV); return Number; } else if (isa(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { - int Number = LiftPhi(S, cast(CurrentV)); - S.AllPtrNumbering[V] = Number; + SmallVector PHINumbers; + LiftPhi(S, cast(CurrentV), PHINumbers); + Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV); return Number; } else if (isa(CurrentV) && !isUnion) { assert(false && "TODO: Extract"); @@ -630,7 +701,15 @@ std::vector LateLowerGCFrame::NumberVectorBase(State &S, Value *CurrentV) { Numbers = NumberVectorBase(S, IEI->getOperand(0)); int ElNumber = Number(S, IEI->getOperand(1)); Numbers[idx] = ElNumber; - } else if (isa(CurrentV) || isa(CurrentV) || isa(CurrentV)) { + } else if (isa(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { + LiftSelect(S, cast(CurrentV)); + Numbers = S.AllVectorNumbering[CurrentV]; + } else if (isa(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) { + SmallVector PHINumbers; + LiftPhi(S, cast(CurrentV), PHINumbers); + Numbers = S.AllVectorNumbering[CurrentV]; + } else if (isa(CurrentV) || isa(CurrentV) || isa(CurrentV) || + isa(CurrentV)) { // This is simple, we can just number them sequentially for (unsigned i = 0; i < cast(CurrentV->getType())->getNumElements(); ++i) { int Num = ++S.MaxPtrNumber; @@ -638,7 +717,7 @@ std::vector LateLowerGCFrame::NumberVectorBase(State &S, Value *CurrentV) { S.ReversePtrNumbering[Num] = CurrentV; } } else { - assert(false && "Unexpected vector generating operating"); + assert(false && "Unexpected vector generating operation"); } S.AllVectorNumbering[CurrentV] = Numbers; return Numbers; @@ -1148,40 +1227,63 @@ State LateLowerGCFrame::LocalScan(Function &F) { NoteOperandUses(S, BBS, I, BBS.UpExposedUsesUnrooted); } else if (SelectInst *SI = dyn_cast(&I)) { // We need to insert an extra select for the GC root - if (!isSpecialPtr(SI->getType()) && !isUnionRep(SI->getType())) + if (!isSpecialPtr(SI->getType()) && !isSpecialPtrVec(SI->getType()) && + !isUnionRep(SI->getType())) continue; if (!isUnionRep(SI->getType()) && getValueAddrSpace(SI) != AddressSpace::Tracked) { - if (S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end()) + if (isSpecialPtrVec(SI->getType()) ? + S.AllVectorNumbering.find(SI) != S.AllVectorNumbering.end() : + S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end()) continue; - auto Num = LiftSelect(S, SI); - if (Num < 0) + if (!LiftSelect(S, SI)) continue; - auto SelectBase = cast(S.ReversePtrNumbering[Num]); - SmallVector RefinedPtr{Number(S, SelectBase->getTrueValue()), - Number(S, SelectBase->getFalseValue())}; - S.Refinements[Num] = std::move(RefinedPtr); + if (!isSpecialPtrVec(SI->getType())) { + // TODO: Refinements for vector select + int Num = S.AllPtrNumbering[SI]; + if (Num < 0) + continue; + auto SelectBase = cast(S.ReversePtrNumbering[Num]); + SmallVector RefinedPtr{Number(S, SelectBase->getTrueValue()), + Number(S, SelectBase->getFalseValue())}; + S.Refinements[Num] = std::move(RefinedPtr); + } } else { - SmallVector RefinedPtr{Number(S, SI->getTrueValue()), - Number(S, SI->getFalseValue())}; + SmallVector RefinedPtr; + if (!isSpecialPtrVec(SI->getType())) { + RefinedPtr = { + Number(S, SI->getTrueValue()), + Number(S, SI->getFalseValue()) + }; + } MaybeNoteDef(S, BBS, SI, BBS.Safepoints, std::move(RefinedPtr)); NoteOperandUses(S, BBS, I, BBS.UpExposedUsesUnrooted); } } else if (PHINode *Phi = dyn_cast(&I)) { - if (!isSpecialPtr(Phi->getType()) && !isUnionRep(Phi->getType())) { + if (!isSpecialPtr(Phi->getType()) && !isSpecialPtrVec(Phi->getType()) && + !isUnionRep(Phi->getType())) { continue; } auto nIncoming = Phi->getNumIncomingValues(); // We need to insert an extra phi for the GC root if (!isUnionRep(Phi->getType()) && getValueAddrSpace(Phi) != AddressSpace::Tracked) { - if (S.AllPtrNumbering.find(Phi) != S.AllPtrNumbering.end()) + if (isSpecialPtrVec(Phi->getType()) ? + S.AllVectorNumbering.find(Phi) != S.AllVectorNumbering.end() : + S.AllPtrNumbering.find(Phi) != S.AllPtrNumbering.end()) continue; - auto Num = LiftPhi(S, Phi); - auto lift = cast(S.ReversePtrNumbering[Num]); - S.Refinements[Num] = GetPHIRefinements(lift, S); - PHINumbers.push_back(Num); + LiftPhi(S, Phi, PHINumbers); } else { - MaybeNoteDef(S, BBS, Phi, BBS.Safepoints, GetPHIRefinements(Phi, S)); - PHINumbers.push_back(Number(S, Phi)); + SmallVector PHIRefinements; + if (!isSpecialPtrVec(Phi->getType())) + PHIRefinements = GetPHIRefinements(Phi, S); + MaybeNoteDef(S, BBS, Phi, BBS.Safepoints, std::move(PHIRefinements)); + if (isSpecialPtrVec(Phi->getType())) { + // TODO: Vector refinements + std::vector Nums = NumberVector(S, Phi); + for (int Num : Nums) + PHINumbers.push_back(Num); + } else { + PHINumbers.push_back(Number(S, Phi)); + } for (unsigned i = 0; i < nIncoming; ++i) { BBState &IncomingBBS = S.BBStates[Phi->getIncomingBlock(i)]; NoteUse(S, IncomingBBS, Phi->getIncomingValue(i), IncomingBBS.PhiOuts); @@ -1776,22 +1878,6 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S) { return ChangesMade; } -static Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint) -{ - Value *Val = S.ReversePtrNumbering[Num]; - if (isSpecialPtrVec(Val->getType())) { - const std::vector &AllNums = S.AllVectorNumbering[Val]; - unsigned Idx = 0; - for (; Idx < AllNums.size(); ++Idx) { - if ((unsigned)AllNums[Idx] == Num) - break; - } - Val = ExtractElementInst::Create(Val, ConstantInt::get( - Type::getInt32Ty(Val->getContext()), Idx), "", InsertionPoint); - } - return Val; -} - static void AddInPredLiveOuts(BasicBlock *BB, BitVector &LiveIn, State &S) { bool First = true; diff --git a/test/llvmpasses/gcroots.ll b/test/llvmpasses/gcroots.ll index 958a735b8efca..eb72e2e580d91 100644 --- a/test/llvmpasses/gcroots.ll +++ b/test/llvmpasses/gcroots.ll @@ -414,6 +414,111 @@ top: ret %jl_value_t addrspace(10)* %obj } +define void @vecphi(i1 %cond, <2 x %jl_value_t addrspace(10)*> *%arg) { +; CHECK-LABEL: @vecphi +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 +top: + %ptls = call %jl_value_t*** @julia.ptls_states() + br i1 %cond, label %A, label %B + +A: + br label %common + +B: + %loaded = load <2 x %jl_value_t addrspace(10)*>, <2 x %jl_value_t addrspace(10)*> *%arg + call void @jl_safepoint() + br label %common + +common: + %phi = phi <2 x %jl_value_t addrspace(10)*> [ zeroinitializer, %A ], [ %loaded, %B ] + call void @jl_safepoint() + %el1 = extractelement <2 x %jl_value_t addrspace(10)*> %phi, i32 0 + %el2 = extractelement <2 x %jl_value_t addrspace(10)*> %phi, i32 1 + call void @one_arg_boxed(%jl_value_t addrspace(10)* %el1) + call void @one_arg_boxed(%jl_value_t addrspace(10)* %el2) + unreachable +} + +define i8 @phi_arrayptr(i1 %cond) { +; CHECK-LABEL: @phi_arrayptr +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 +top: + %ptls = call %jl_value_t*** @julia.ptls_states() + br i1 %cond, label %A, label %B + +A: + %obj1 = call %jl_value_t addrspace(10) *@alloc() + %obj2 = call %jl_value_t addrspace(10) *@alloc() + %decayed1 = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11) * + %arrayptrptr1 = bitcast %jl_value_t addrspace(11) *%decayed1 to i8 addrspace(13)* addrspace(11)* + %arrayptr1 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %arrayptrptr1 + %decayed2 = addrspacecast %jl_value_t addrspace(10) *%obj2 to %jl_value_t addrspace(11) * + %arrayptrptr2 = bitcast %jl_value_t addrspace(11) *%decayed2 to i8 addrspace(13)* addrspace(11)* + %arrayptr2 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %arrayptrptr2 + %insert1 = insertelement <2 x i8 addrspace(13)*> undef, i8 addrspace(13)* %arrayptr1, i32 0 + %insert2 = insertelement <2 x i8 addrspace(13)*> %insert1, i8 addrspace(13)* %arrayptr2, i32 1 + call void @jl_safepoint() + br label %common + +B: + br label %common + +common: +; CHECK: %gclift +; CHECK: %gclift1 +; CHECK-NOT: %gclift2 + %phi = phi <2 x i8 addrspace(13)*> [ %insert2, %A ], [ zeroinitializer, %B ] + call void @jl_safepoint() + %el1 = extractelement <2 x i8 addrspace(13)*> %phi, i32 0 + %el2 = extractelement <2 x i8 addrspace(13)*> %phi, i32 1 + %l1 = load i8, i8 addrspace(13)* %el1 + %l2 = load i8, i8 addrspace(13)* %el2 + %add = add i8 %l1, %l2 + ret i8 %add +} + +define void @vecselect(i1 %cond, <2 x %jl_value_t addrspace(10)*> *%arg) { +; CHECK-LABEL: @vecselect +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 +top: + %ptls = call %jl_value_t*** @julia.ptls_states() + %loaded = load <2 x %jl_value_t addrspace(10)*>, <2 x %jl_value_t addrspace(10)*> *%arg + call void @jl_safepoint() + %select = select i1 %cond, <2 x %jl_value_t addrspace(10)*> zeroinitializer, <2 x %jl_value_t addrspace(10)*> %loaded + call void @jl_safepoint() + %el1 = extractelement <2 x %jl_value_t addrspace(10)*> %select, i32 0 + %el2 = extractelement <2 x %jl_value_t addrspace(10)*> %select, i32 1 + call void @one_arg_boxed(%jl_value_t addrspace(10)* %el1) + call void @one_arg_boxed(%jl_value_t addrspace(10)* %el2) + unreachable +} + +define i8 @select_arrayptr(i1 %cond) { +; CHECK-LABEL: @select_arrayptr +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 4 +top: + %ptls = call %jl_value_t*** @julia.ptls_states() + %obj1 = call %jl_value_t addrspace(10) *@alloc() + %obj2 = call %jl_value_t addrspace(10) *@alloc() + %decayed1 = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11) * + %arrayptrptr1 = bitcast %jl_value_t addrspace(11) *%decayed1 to i8 addrspace(13)* addrspace(11)* + %arrayptr1 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %arrayptrptr1 + %decayed2 = addrspacecast %jl_value_t addrspace(10) *%obj2 to %jl_value_t addrspace(11) * + %arrayptrptr2 = bitcast %jl_value_t addrspace(11) *%decayed2 to i8 addrspace(13)* addrspace(11)* + %arrayptr2 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %arrayptrptr2 + %insert1 = insertelement <2 x i8 addrspace(13)*> undef, i8 addrspace(13)* %arrayptr1, i32 0 + %insert2 = insertelement <2 x i8 addrspace(13)*> %insert1, i8 addrspace(13)* %arrayptr2, i32 1 + call void @jl_safepoint() + %select = select i1 %cond, <2 x i8 addrspace(13)*> %insert2, <2 x i8 addrspace(13)*> zeroinitializer + call void @jl_safepoint() + %el1 = extractelement <2 x i8 addrspace(13)*> %select, i32 0 + %el2 = extractelement <2 x i8 addrspace(13)*> %select, i32 1 + %l1 = load i8, i8 addrspace(13)* %el1 + %l2 = load i8, i8 addrspace(13)* %el2 + %add = add i8 %l1, %l2 + ret i8 %add +} + !0 = !{!"jtbaa"} !1 = !{!"jtbaa_const", !0, i64 0} !2 = !{!1, !1, i64 0, i64 1}