From 29d5d4175c443fbb75ad00dcf6e143e7d6bd5653 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Wed, 8 Apr 2020 14:55:11 -0400 Subject: [PATCH] [LateGCLowering] Fix skipped Select lifting (#35387) Fixes #35341 --- src/llvm-late-gc-lowering.cpp | 121 ++++++++++++++++------------------ test/llvmpasses/gcroots.ll | 16 +++++ 2 files changed, 72 insertions(+), 65 deletions(-) diff --git a/src/llvm-late-gc-lowering.cpp b/src/llvm-late-gc-lowering.cpp index 65d8c0c84f851..66cc169c8609e 100644 --- a/src/llvm-late-gc-lowering.cpp +++ b/src/llvm-late-gc-lowering.cpp @@ -331,7 +331,7 @@ struct LateLowerGCFrame: public FunctionPass, private JuliaPassContext { } void LiftPhi(State &S, PHINode *Phi); - bool LiftSelect(State &S, SelectInst *SI); + void LiftSelect(State &S, SelectInst *SI); Value *MaybeExtractScalar(State &S, std::pair ValExpr, Instruction *InsertBefore); std::vector MaybeExtractVector(State &S, Value *BaseVec, Instruction *InsertBefore); Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertBefore); @@ -600,12 +600,12 @@ Value *LateLowerGCFrame::GetPtrForNumber(State &S, unsigned Num, Instruction *In return MaybeExtractScalar(S, std::make_pair(Val, Idx), InsertBefore); } -bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) { +void LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) { if (isa(SI->getType()) ? S.AllPtrNumbering.count(SI) : S.AllCompositeNumbering.count(SI)) { // already visited here--nothing to do - return true; + return; } std::vector Numbers; unsigned NumRoots = 1; @@ -617,68 +617,60 @@ bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) { // find the base root for the arguments Value *TrueBase = MaybeExtractScalar(S, FindBaseValue(S, SI->getTrueValue(), false), SI); Value *FalseBase = MaybeExtractScalar(S, FindBaseValue(S, SI->getFalseValue(), false), SI); - Value *V_null = ConstantPointerNull::get(cast(T_prjlvalue)); - bool didsplit = false; - if (TrueBase != V_null && FalseBase != V_null) { - std::vector TrueBases; - std::vector FalseBases; - if (!isa(TrueBase->getType())) { - TrueBases = MaybeExtractVector(S, TrueBase, SI); - assert(TrueBases.size() == Numbers.size()); - NumRoots = TrueBases.size(); - } - if (!isa(FalseBase->getType())) { - FalseBases = MaybeExtractVector(S, FalseBase, SI); - assert(FalseBases.size() == Numbers.size()); - NumRoots = FalseBases.size(); - } - if (isa(SI->getType()) ? - S.AllPtrNumbering.count(SI) : - S.AllCompositeNumbering.count(SI)) { - // MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode) - return true; - } - // need to handle each element (may just be one scalar) - for (unsigned i = 0; i < NumRoots; ++i) { - Value *TrueElem; - if (isa(TrueBase->getType())) - TrueElem = TrueBase; - else - TrueElem = TrueBases[i]; - Value *FalseElem; - if (isa(FalseBase->getType())) - FalseElem = FalseBase; - else - FalseElem = FalseBases[i]; - if (TrueElem != V_null || FalseElem != V_null) { - Value *Cond = SI->getCondition(); - if (isa(Cond->getType())) { - Cond = ExtractElementInst::Create(Cond, - ConstantInt::get(Type::getInt32Ty(Cond->getContext()), i), - "", SI); - } - SelectInst *SelectBase = SelectInst::Create(Cond, TrueElem, FalseElem, "gclift", SI); - int Number = ++S.MaxPtrNumber; - S.AllPtrNumbering[SelectBase] = Number; - S.ReversePtrNumbering[Number] = SelectBase; - if (isa(SI->getType())) - S.AllPtrNumbering[SI] = Number; - else - Numbers[i] = Number; - didsplit = true; - } - } - if (isa(SI->getType()) && NumRoots != Numbers.size()) { - // broadcast the scalar root number to fill the vector - assert(NumRoots == 1); - int Number = Numbers[0]; - Numbers.resize(0); - Numbers.resize(SI->getType()->getVectorNumElements(), Number); - } + std::vector TrueBases; + std::vector FalseBases; + if (!isa(TrueBase->getType())) { + TrueBases = MaybeExtractVector(S, TrueBase, SI); + assert(TrueBases.size() == Numbers.size()); + NumRoots = TrueBases.size(); + } + if (!isa(FalseBase->getType())) { + FalseBases = MaybeExtractVector(S, FalseBase, SI); + assert(FalseBases.size() == Numbers.size()); + NumRoots = FalseBases.size(); + } + if (isa(SI->getType()) ? + S.AllPtrNumbering.count(SI) : + S.AllCompositeNumbering.count(SI)) { + // MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode) + return; + } + // need to handle each element (may just be one scalar) + for (unsigned i = 0; i < NumRoots; ++i) { + Value *TrueElem; + if (isa(TrueBase->getType())) + TrueElem = TrueBase; + else + TrueElem = TrueBases[i]; + Value *FalseElem; + if (isa(FalseBase->getType())) + FalseElem = FalseBase; + else + FalseElem = FalseBases[i]; + Value *Cond = SI->getCondition(); + if (isa(Cond->getType())) { + Cond = ExtractElementInst::Create(Cond, + ConstantInt::get(Type::getInt32Ty(Cond->getContext()), i), + "", SI); + } + SelectInst *SelectBase = SelectInst::Create(Cond, TrueElem, FalseElem, "gclift", SI); + int Number = ++S.MaxPtrNumber; + S.AllPtrNumbering[SelectBase] = Number; + S.ReversePtrNumbering[Number] = SelectBase; + if (isa(SI->getType())) + S.AllPtrNumbering[SI] = Number; + else + Numbers[i] = Number; + } + if (isa(SI->getType()) && NumRoots != Numbers.size()) { + // broadcast the scalar root number to fill the vector + assert(NumRoots == 1); + int Number = Numbers[0]; + Numbers.resize(0); + Numbers.resize(SI->getType()->getVectorNumElements(), Number); } if (!isa(SI->getType())) S.AllCompositeNumbering[SI] = Numbers; - return didsplit; } void LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi) { @@ -754,9 +746,8 @@ int LateLowerGCFrame::NumberBase(State &S, Value *CurrentV) // input IR) Number = -1; } else if (isa(CurrentV) && !isTrackedValue(CurrentV)) { - Number = -1; - if (LiftSelect(S, cast(CurrentV))) // lifting a scalar pointer (if necessary) - Number = S.AllPtrNumbering.at(CurrentV); + LiftSelect(S, cast(CurrentV)); + Number = S.AllPtrNumbering.at(CurrentV); return Number; } else if (isa(CurrentV) && !isTrackedValue(CurrentV)) { LiftPhi(S, cast(CurrentV)); diff --git a/test/llvmpasses/gcroots.ll b/test/llvmpasses/gcroots.ll index ede6f9ed88fc4..bcaa0c7a2b70e 100644 --- a/test/llvmpasses/gcroots.ll +++ b/test/llvmpasses/gcroots.ll @@ -703,6 +703,22 @@ top: ret i8 %val } +define i8 @lost_select_decayed(i1 %arg1) { +; CHECK-LABEL: @lost_select_decayed +; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 3 +; CHECK: [[GEP0:%.*]] = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2 +; CHECK: store %jl_value_t addrspace(10)* [[SOMETHING:%.*]], %jl_value_t addrspace(10)** [[GEP0]] +top: + %ptls = call %jl_value_t*** @julia.ptls_states() + %obj1 = call %jl_value_t addrspace(10) *@alloc() + %decayed = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11)* + %selected = select i1 %arg1, %jl_value_t addrspace(11)* null, %jl_value_t addrspace(11)* %decayed + %casted = bitcast %jl_value_t addrspace(11)* %selected to i8 addrspace(11)* + call void @jl_safepoint() + %val = load i8, i8 addrspace(11)* %casted + ret i8 %val +} + !0 = !{!"jtbaa"} !1 = !{!"jtbaa_const", !0, i64 0} !2 = !{!1, !1, i64 0, i64 1}